Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions include/tvm/script/ir_builder/tir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,8 @@ class AllocateFrameNode : public TIRFrameNode {
PrimExpr condition;
/*! \brief Additional annotation hints. */
Map<String, ObjectRef> annotations;
/*! \brief The buffer. */
tvm::tir::Buffer buffer;
/*! \brief The buffer var. */
tvm::tir::Var buffer_var;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
Expand All @@ -463,7 +463,7 @@ class AllocateFrameNode : public TIRFrameNode {
v->Visit("storage_scope", &storage_scope);
v->Visit("condition", &condition);
v->Visit("annotations", &annotations);
v->Visit("buffer", &buffer);
v->Visit("buffer_var", &buffer_var);
}

static constexpr const char* _type_key = "script.ir_builder.tir.AllocateFrame";
Expand Down Expand Up @@ -500,8 +500,8 @@ class AllocateConstFrameNode : public TIRFrameNode {
Array<PrimExpr> extents;
/*! \brief The data associated with the constant. */
tvm::runtime::NDArray data;
/*! \brief The buffer */
tvm::tir::Buffer buffer;
/*! \brief The buffer var */
tvm::tir::Var buffer_var;
/*! \brief Additional annotations about the allocation. */
Map<String, ObjectRef> annotations;

Expand All @@ -510,7 +510,7 @@ class AllocateConstFrameNode : public TIRFrameNode {
v->Visit("dtype", &dtype);
v->Visit("extents", &extents);
v->Visit("data", &data);
v->Visit("buffer", &buffer);
v->Visit("buffer_var", &buffer_var);
v->Visit("annotations", &annotations);
}

Expand Down Expand Up @@ -723,11 +723,15 @@ class ElseFrame : public TIRFrame {

class DeclBufferFrameNode : public TIRFrameNode {
public:
/*! \brief The declared buffer. */
tvm::tir::Buffer buffer;
/*! \brief The buffer allocated or not. */
bool allocated;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("buffer", &buffer);
v->Visit("allocated", &allocated);
}

static constexpr const char* _type_key = "script.ir_builder.tir.DeclBufferFrame";
Expand Down
46 changes: 28 additions & 18 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,8 @@ AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_s
* \param annotations Additional annotation hints.
* \return The created AllocateConstFrame.
*/
AllocateConstFrame AllocateConst(
NDArray data, DataType dtype, Array<PrimExpr> extents,
Map<String, ObjectRef> annotations = NullValue<Map<String, ObjectRef>>());
AllocateConstFrame AllocateConst(NDArray data, DataType dtype, Array<PrimExpr> extents,
Optional<Map<String, ObjectRef>> annotations = NullOpt);

/*!
* \brief Create an attribute.
Expand Down Expand Up @@ -449,21 +448,32 @@ PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global");
return expr.defined() ? tvm::cast(dtype, expr.value()) : tvm::tir::Var("", dtype); \
}

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int8, DataType::Int(8));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int16, DataType::Int(16));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32, DataType::Int(32));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int64, DataType::Int(64));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt8, DataType::UInt(8));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt16, DataType::UInt(16));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt32, DataType::UInt(32));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt64, DataType::UInt(64));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float8, DataType::Float(8));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float16, DataType::Float(16));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float32, DataType::Float(32));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float64, DataType::Float(64));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x4, DataType::Int(32, 4));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x8, DataType::Int(32, 8));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x16, DataType::Int(32, 16));
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType) \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##8, FDType(8)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##16, FDType(16)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##32, FDType(32)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##64, FDType(64));

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Float, DataType::Float);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(UInt, DataType::UInt);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int);

#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(FuncName, FDType, Size) \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x4, FDType(Size, 4)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x8, FDType(Size, 8)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x16, FDType(Size, 16)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x32, FDType(Size, 32)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x64, FDType(Size, 64));

#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(DType, FDType) \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##8, FDType, 8); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##16, FDType, 16); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##32, FDType, 32); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##64, FDType, 64);

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/ir_builder/tir/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ class RealizeFrame(TIRFrame):
class AllocateFrame(TIRFrame):
def __enter__(self) -> Buffer:
super().__enter__()
return self.buffer
return self.buffer_var


@_register_object("script.ir_builder.tir.AllocateConstFrame")
class AllocateConstFrame(TIRFrame):
def __enter__(self) -> Buffer:
super().__enter__()
return self.buffer
return self.buffer_var


@_register_object("script.ir_builder.tir.AttrFrame")
Expand Down
Loading