diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index a330ccbbdf65..c49fde1746bc 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -148,6 +148,8 @@ class DataType { bool is_fixed_length_vector() const { return static_cast(data_.lanes) > 1; } /*! \return Whether the type is a scalable vector. */ bool is_scalable_vector() const { return static_cast(data_.lanes) < -1; } + /*! \return whether type is a vector type. */ + bool is_vector() const { return lanes() > 1; } /*! \return whether type is a bool vector type. */ bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && bits() == 1; } /*! \return whether type is a Void type. */ diff --git a/include/tvm/topi/elemwise.h b/include/tvm/topi/elemwise.h index 132992c57dc7..806ddcb662f9 100644 --- a/include/tvm/topi/elemwise.h +++ b/include/tvm/topi/elemwise.h @@ -287,7 +287,7 @@ inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast", if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) { if (expr.dtype().lanes() == type.lanes()) { return expr; - } else if (expr.dtype().lanes() == 1 && type.lanes() > 1) { + } else if (expr.dtype().lanes() == 1 && type.is_vector()) { return tvm::tir::Broadcast(expr, type.lanes()); } } diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index e21436e556ee..3d6d3a9461d3 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1737,7 +1737,7 @@ void CodeGenLLVM::BufferAccessHelper( if (const RampNode* ramp = last_index.as()) { PrimExpr offset = ramp->base + (ramp->stride * i); last_index_value = MakeValue(offset); - } else if (last_index.dtype().lanes() > 1) { + } else if (last_index.dtype().is_vector()) { if (i == 0) { cached_vector_index = MakeValue(last_index); } diff --git a/src/target/llvm/intrin_rule_hexagon.cc b/src/target/llvm/intrin_rule_hexagon.cc index 7c4b38c1d702..2661f2fa6591 100644 --- a/src/target/llvm/intrin_rule_hexagon.cc +++ b/src/target/llvm/intrin_rule_hexagon.cc @@ -66,7 +66,7 @@ inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) { // Enable QHL library for FP16 data type const PrimExpr& x = call->args[0]; - if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) { + if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { return TVMExternCall(call, tvm_wrapper); } #endif @@ -116,7 +116,7 @@ TVM_REGISTER_OP("tir.tanh") } // Enable QHL library for FP16 data type - if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) { + if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { std::string tvm_wrapper("tvm_vect_qhmath_hvx_tanh_ahf"); return TVMExternCall(call, tvm_wrapper); } @@ -152,7 +152,7 @@ TVM_REGISTER_OP("tir.tan").set_attr( } // Enable QHL library for FP16 data type - if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) { + if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { std::string tvm_wrapper("tvm_vect_qhmath_hvx_tan_ahf"); return TVMExternCall(call, tvm_wrapper); } @@ -191,7 +191,7 @@ TVM_REGISTER_OP("tir.sigmoid") const tir::Call new_call = tir::Call(call->dtype, call->op, new_args); // Enable QHL library for FP16 data type - if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) { + if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { std::string tvm_wrapper("tvm_vect_qhmath_hvx_sigmoid_ahf"); return TVMExternCall(new_call.get(), tvm_wrapper); } diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index f012f8a1b35e..8eda537579e7 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -71,7 +71,7 @@ class GPUCodeVerifier : public StmtExprVisitor { size_t size = static_cast(op->ConstantAllocationSize()); shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); } - if (op->dtype.lanes() > 1) { + if (op->dtype.is_vector()) { if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { std::stringstream s; s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" @@ -202,7 +202,7 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitExpr_(const CastNode* op) { - if (op->dtype.lanes() > 1) { + if (op->dtype.is_vector()) { if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { std::stringstream s; s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" @@ -215,7 +215,7 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitExpr_(const BufferLoadNode* op) { - if (op->dtype.lanes() > 1) { + if (op->dtype.is_vector()) { if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { std::stringstream s; s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" @@ -229,7 +229,7 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitStmt_(const BufferStoreNode* op) { - if (op->value->dtype.lanes() > 1) { + if (op->value->dtype.is_vector()) { if (static_cast(op->value->dtype.lanes() * op->value->dtype.bytes()) > max_vector_bytes_) { std::stringstream s;