From 73098227ecbef54946d8b86a7deeba9e3ef750be Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 26 May 2026 15:16:28 +0000 Subject: [PATCH 1/9] [REFACTOR][IR] Phase out OpNode::arguments and AttrFieldInfo OpNode::arguments only stored metadata for self-documentation; no Python tooling, no test, and no C++ caller other than internal sanity checks read it. Removing it deletes AttrFieldInfo (which existed solely to type that array) and the ~335 add_argument() chain calls that populated dead metadata. The 12 internal consumers that read op->arguments.size() now read op->num_inputs (always set by the same TVM_REGISTER_OP() chain via set_num_inputs). Error messages that read op->arguments[i]->name now report the index ("input[i]") instead of the per-arg name. --- include/tvm/ir/attrs.h | 31 ------- include/tvm/ir/op.h | 22 ----- python/tvm/ir/op.py | 14 --- src/ir/attrs.cc | 5 +- src/ir/op.cc | 5 -- src/relax/op/ccl/ccl.cc | 5 -- src/relax/op/distributed/distributed.cc | 8 -- src/relax/op/image/resize.cc | 8 -- src/relax/op/nn/attention.cc | 14 --- src/relax/op/nn/convolution.cc | 12 --- src/relax/op/nn/nn.cc | 45 ++-------- src/relax/op/nn/pooling.cc | 9 -- src/relax/op/op.cc | 89 ------------------- src/relax/op/op_common.cc | 10 +-- src/relax/op/op_common.h | 9 +- src/relax/op/tensor/binary.cc | 2 +- src/relax/op/tensor/binary.h | 2 - src/relax/op/tensor/create.cc | 26 ------ src/relax/op/tensor/datatype.cc | 2 - src/relax/op/tensor/grad.cc | 14 --- src/relax/op/tensor/index.cc | 7 -- src/relax/op/tensor/inspect.cc | 10 --- src/relax/op/tensor/linear_algebra.cc | 5 -- src/relax/op/tensor/manipulate.cc | 44 --------- src/relax/op/tensor/qdq.cc | 6 -- src/relax/op/tensor/sampling.cc | 3 - src/relax/op/tensor/search.cc | 9 -- src/relax/op/tensor/set.cc | 18 ---- src/relax/op/tensor/sorting.cc | 3 - src/relax/op/tensor/statistical.cc | 3 - src/relax/op/tensor/statistical.h | 1 - src/relax/op/tensor/ternary.cc | 3 - src/relax/op/tensor/unary.cc | 3 - src/relax/op/vision/multibox_transform_loc.cc | 4 - src/relax/op/vision/nms.cc | 14 --- src/relax/op/vision/roi_align.cc | 3 - src/relax/op/vision/roi_pool.cc | 3 - src/target/cuda/intrin_rule_cuda.cc | 16 ---- src/target/metal/intrin_rule_metal.cc | 6 -- src/target/webgpu/intrin_rule_webgpu.cc | 6 -- 40 files changed, 18 insertions(+), 481 deletions(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index c549fcdbc138..e0651f742ca1 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -43,37 +43,6 @@ namespace tvm { -/*! - * \brief Information about attribute fields in string representations. - */ -class AttrFieldInfoNode : public ffi::Object { - public: - /*! \brief name of the field */ - ffi::String name; - /*! \brief type docstring information in str. */ - ffi::String type_info; - /*! \brief detailed description of the type */ - ffi::String description; - - static void RegisterReflection() { - namespace rfl = ffi::reflection; - rfl::ObjectDef() - .def_ro("name", &AttrFieldInfoNode::name) - .def_ro("type_info", &AttrFieldInfoNode::type_info) - .def_ro("description", &AttrFieldInfoNode::description); - } - - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.AttrFieldInfo", AttrFieldInfoNode, ffi::Object); -}; - -/*! \brief AttrFieldInfo */ -class AttrFieldInfo : public ffi::ObjectRef { - public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AttrFieldInfo, ffi::ObjectRef, AttrFieldInfoNode); -}; - /*! * \brief Base class of all attribute class * \note Do not subclass AttrBaseNode directly, diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index dc8f99cd4789..619ab08f79c4 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -67,8 +67,6 @@ class OpNode : public RelaxExprNode { * This can be used to generate docstring automatically for the operator. */ ffi::String description; - /* \brief Information of input arguments to the operator */ - ffi::Array arguments; /*! * \brief The type key of the attribute field * This can be empty, in which case it defaults to anything. @@ -97,7 +95,6 @@ class OpNode : public RelaxExprNode { .def_ro("name", &OpNode::name) .def_ro("op_type", &OpNode::op_type, refl::AttachFieldFlag::SEqHashIgnore()) .def_ro("description", &OpNode::description, refl::AttachFieldFlag::SEqHashIgnore()) - .def_ro("arguments", &OpNode::arguments, refl::AttachFieldFlag::SEqHashIgnore()) .def_ro("attrs_type_key", &OpNode::attrs_type_key, refl::AttachFieldFlag::SEqHashIgnore()) .def_ro("num_inputs", &OpNode::num_inputs, refl::AttachFieldFlag::SEqHashIgnore()) .def_ro("support_level", &OpNode::support_level, refl::AttachFieldFlag::SEqHashIgnore()); @@ -179,15 +176,6 @@ class OpRegEntry { * \return reference to self. */ inline OpRegEntry& describe(const std::string& descr); // NOLINT(*) - /*! - * \brief Add argument information to the function. - * \param name Name of the argument. - * \param type Type of the argument. - * \param description Description of the argument. - * \return reference to self. - */ - inline OpRegEntry& add_argument(const std::string& name, const std::string& type, - const std::string& description); /*! * \brief Set the attrs type key and index to be AttrsType. * \tparam AttrsType the attribute type to b set. @@ -328,16 +316,6 @@ inline OpRegEntry& OpRegEntry::describe(const std::string& descr) { // NOLINT(* return *this; } -inline OpRegEntry& OpRegEntry::add_argument(const std::string& name, const std::string& type, - const std::string& description) { - auto n = ffi::make_object(); - n->name = name; - n->type_info = type; - n->description = description; - get()->arguments.push_back(AttrFieldInfo(n)); - return *this; -} - inline OpRegEntry& OpRegEntry::set_num_inputs(int32_t n) { // NOLINT(*) get()->num_inputs = n; return *this; diff --git a/python/tvm/ir/op.py b/python/tvm/ir/op.py index 6c0912f86476..ca1979ffabc4 100644 --- a/python/tvm/ir/op.py +++ b/python/tvm/ir/op.py @@ -102,20 +102,6 @@ def reset_attr(self, attr_name): """ _ffi_api.OpResetAttr(self, attr_name) - def add_argument(self, name, type, description): # pylint: disable=redefined-builtin - """Add arguments information to the function. - - Parameters - ---------- - name : str - The argument name. - type : str - The argument type. - description : str - The argument description. - """ - _ffi_api.OpAddArgument(self, name, type, description) - def set_support_level(self, level): """Set the support level of op. diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index e7d9b9082809..11d8d4f77595 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -26,10 +26,7 @@ namespace tvm { -TVM_FFI_STATIC_INIT_BLOCK() { - AttrFieldInfoNode::RegisterReflection(); - DictAttrsNode::RegisterReflection(); -} +TVM_FFI_STATIC_INIT_BLOCK() { DictAttrsNode::RegisterReflection(); } DictAttrs WithAttrs(DictAttrs attrs, ffi::Map new_attrs) { if (new_attrs.empty()) { diff --git a/src/ir/op.cc b/src/ir/op.cc index f6078e30d964..9f0c20c92090 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -113,11 +113,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { auto& op = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); op.describe(descr); }) - .def("ir.OpAddArgument", - [](Op op, ffi::String name, ffi::String type, ffi::String description) { - auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); - reg.add_argument(name, type, description); - }) .def("ir.OpSetSupportLevel", [](Op op, int level) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc index 7f7eb3c8935d..d27353eef050 100644 --- a/src/relax/op/ccl/ccl.cc +++ b/src/relax/op/ccl/ccl.cc @@ -56,7 +56,6 @@ StructInfo InferStructInfoAllReduce(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.ccl.allreduce") .set_attrs_type() .set_num_inputs(1) - .add_argument("x", "Tensor", "Input to which allreduce will be applied.") .set_attr("FInferStructInfo", InferStructInfoAllReduce) .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) .set_attr("FPurity", true); @@ -95,7 +94,6 @@ StructInfo InferStructInfoAllGather(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.ccl.allgather") .set_num_inputs(1) - .add_argument("x", "Tensor", "Input to which allgather will be applied.") .set_attr("FInferStructInfo", InferStructInfoAllGather) .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) .set_attr("FPurity", true); @@ -118,7 +116,6 @@ StructInfo InferStructInfoBroadcastFromZero(const Call& call, const BlockBuilder TVM_REGISTER_OP("relax.ccl.broadcast_from_worker0") .set_num_inputs(1) - .add_argument("x", "Tensor", "Input to be broadcast.") .set_attr("FInferStructInfo", InferStructInfoBroadcastFromZero) .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) .set_attr("FPurity", true); @@ -166,8 +163,6 @@ StructInfo InferStructInfoScatter(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.ccl.scatter_from_worker0") .set_num_inputs(1) - .add_argument("x", "Tensor", - "The buffer to be divided into equal parts and sent to each worker accordingly.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoScatter) .set_attr("FPurity", true); diff --git a/src/relax/op/distributed/distributed.cc b/src/relax/op/distributed/distributed.cc index bee2751564d9..c737ef392984 100644 --- a/src/relax/op/distributed/distributed.cc +++ b/src/relax/op/distributed/distributed.cc @@ -62,7 +62,6 @@ StructInfo InferStructInfoAnnotateSharding(const Call& call, const BlockBuilder& TVM_REGISTER_OP("relax.dist.annotate_sharding") .set_num_inputs(1) - .add_argument("input", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoAnnotateSharding) .set_attr("dist.FInferStructInfo", InferStructInfoAnnotateSharding) .set_attr("FPurity", true); @@ -93,7 +92,6 @@ StructInfo InferDistStructInfoRedistribute(const Call& call, const BlockBuilder& TVM_REGISTER_OP("relax.dist.redistribute") .set_num_inputs(1) - .add_argument("input", "Tensor", "The input tensor.") .set_attr("dist.FInferStructInfo", InferDistStructInfoRedistribute) .set_attr("FPurity", true); @@ -111,11 +109,6 @@ StructInfo InferStructInfoCallTIRLocalView(const Call& call, const BlockBuilder& TVM_REGISTER_OP("relax.dist.call_tir_local_view") .set_num_inputs(3) - .add_argument("func", "Expr", "The destination-passing-style function.") - .add_argument("args", "Tuple", "The input arguments.") - .add_argument("packed_ints", "Expr", - "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " - "args if unused") .set_attr("FInferStructInfo", InferStructInfoCallTIRLocalView) .set_attr("FPurity", true); @@ -228,7 +221,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.dist.redistribute_replica_to_shard") .set_num_inputs(1) - .add_argument("input", "Tensor", "The buffer to be sliced.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoRtoS) .set_attr("dist.FInferStructInfo", InferDistStructInfoRtoS) diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc index db8a8c3c43ee..3279b16f0679 100644 --- a/src/relax/op/image/resize.cc +++ b/src/relax/op/image/resize.cc @@ -143,8 +143,6 @@ InferLayoutOutput InferLayoutResize2d( TVM_REGISTER_OP("relax.image.resize2d") .set_attrs_type() .set_num_inputs(2) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("size", "Shape", "The output image shape.") .set_attr("FInferStructInfo", InferStructInfoResize2D) .set_attr("FRelaxInferLayout", InferLayoutResize2d) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -256,8 +254,6 @@ InferLayoutOutput InferLayoutResize3d( TVM_REGISTER_OP("relax.image.resize3d") .set_attrs_type() .set_num_inputs(2) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("size", "Shape", "The output image shape.") .set_attr("FInferStructInfo", InferStructInfoResize3D) .set_attr("FRelaxInferLayout", InferLayoutResize3d) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -335,8 +331,6 @@ StructInfo InferStructInfoGridSample(const Call& call, const BlockBuilder& ctx) TVM_REGISTER_OP("relax.image.grid_sample") .set_attrs_type() .set_num_inputs(2) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("grid", "Tensor", "The grid tensor for sampling.") .set_attr("FInferStructInfo", InferStructInfoGridSample) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -427,8 +421,6 @@ StructInfo InferStructInfoAffineGrid(const Call& call, const BlockBuilder& ctx) TVM_REGISTER_OP("relax.image.affine_grid") .set_num_inputs(2) - .add_argument("data", "Tensor", "The input affine matrix tensor.") - .add_argument("size", "Shape", "The target output shape (H, W).") .set_attr("FInferStructInfo", InferStructInfoAffineGrid) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index f19c55b5d2ec..e7ceb2fd68dd 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -151,9 +151,6 @@ Call InferMixedPrecisionAttention(const Call& call, const DataType& out_dtype) { TVM_REGISTER_OP("relax.nn.attention") .set_attrs_type() .set_num_inputs(3) - .add_argument("query", "Tensor", "The input queries tensor.") - .add_argument("key", "Tensor", "The input keys tensor.") - .add_argument("value", "Tensor", "The input values tensor.") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionAttention) .set_attr("FInferStructInfo", InferStructInfoAttention) @@ -162,10 +159,6 @@ TVM_REGISTER_OP("relax.nn.attention") TVM_REGISTER_OP("relax.nn.attention_bias") .set_attrs_type() .set_num_inputs(4) - .add_argument("query", "Tensor", "The input queries tensor.") - .add_argument("key", "Tensor", "The input keys tensor.") - .add_argument("value", "Tensor", "The input values tensor.") - .add_argument("bias", "Tensor", "The input bias tensor.") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionAttention) .set_attr("FInferStructInfo", InferStructInfoAttention) @@ -174,13 +167,6 @@ TVM_REGISTER_OP("relax.nn.attention_bias") TVM_REGISTER_OP("relax.nn.attention_var_len") .set_attrs_type() .set_num_inputs(7) - .add_argument("query", "Tensor", "The input queries tensor.") - .add_argument("key", "Tensor", "The input keys tensor.") - .add_argument("value", "Tensor", "The input values tensor.") - .add_argument("seqstart_q", "Tensor", "The cumsum of query sequence lengths, prepended with 0.") - .add_argument("seqstart_k", "Tensor", "The cumsum of key sequence lengths, prepended with 0.") - .add_argument("max_seqlen_q", "Tensor", "The maximum query sequence length in the batch.") - .add_argument("max_seqlen_k", "Tensor", "The maximum key sequence length in the batch.") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionAttention) .set_attr("FInferStructInfo", InferStructInfoAttention) diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index 1b77b4225203..367fe762112d 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -195,8 +195,6 @@ Call InferMixedPrecisionConv1d(const Call& call, const DataType& out_dtype) { TVM_REGISTER_OP("relax.nn.conv1d") .set_num_inputs(2) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("weight", "Tensor", "The weight tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoConv1d) .set_attr("FRelaxInferLayout", InferLayoutConv1d) @@ -403,8 +401,6 @@ Call InferMixedPrecisionConv2d(const Call& call, const DataType& out_dtype) { TVM_REGISTER_OP("relax.nn.conv2d") .set_num_inputs(2) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("weight", "Tensor", "The weight tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoConv2d) .set_attr("FRelaxInferLayout", InferLayoutConv2d) @@ -585,8 +581,6 @@ Call InferMixedPrecisionConv3d(const Call& call, const DataType& out_dtype) { TVM_REGISTER_OP("relax.nn.conv3d") .set_num_inputs(2) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("weight", "Tensor", "The weight tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoConv3d) .set_attr("FRelaxInferLayout", InferLayoutConv3d) @@ -767,8 +761,6 @@ Call InferMixedPrecisionConv1dTranspose(const Call& call, const DataType& out_dt TVM_REGISTER_OP("relax.nn.conv1d_transpose") .set_num_inputs(2) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("weight", "Tensor", "The weight tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoConv1dTranspose) .set_attr("FRelaxInferLayout", InferLayoutConv1dTranspose) @@ -998,8 +990,6 @@ Call InferMixedPrecisionConv2dTranspose(const Call& call, const DataType& out_dt TVM_REGISTER_OP("relax.nn.conv2d_transpose") .set_num_inputs(2) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("weight", "Tensor", "The weight tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoConv2dTranspose) .set_attr("FRelaxInferLayout", InferLayoutConv2dTranspose) @@ -1240,8 +1230,6 @@ Call InferMixedPrecisionConv3dTranspose(const Call& call, const DataType& out_dt TVM_REGISTER_OP("relax.nn.conv3d_transpose") .set_num_inputs(2) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("weight", "Tensor", "The weight tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoConv3dTranspose) .set_attr("FRelaxInferLayout", InferLayoutConv3dTranspose) diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index b6e2051a68f7..cc2e7b406c90 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -75,7 +75,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.nn.leakyrelu") .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoUnaryArith) @@ -98,7 +97,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.nn.softplus") .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoUnaryArith) @@ -161,8 +159,6 @@ InferLayoutOutput InferLayoutPRelu( TVM_REGISTER_OP("relax.nn.prelu") .set_num_inputs(2) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("alpha", "Tensor", "The channel-wise learnable slope.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPRelu) .set_attr("FRelaxInferLayout", InferLayoutPRelu) @@ -224,7 +220,6 @@ InferLayoutOutput InferLayoutSoftmax( TVM_REGISTER_OP("relax.nn.softmax") .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoSoftmax) .set_attr("FRelaxInferLayout", InferLayoutSoftmax) @@ -245,7 +240,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.nn.log_softmax") .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoSoftmax) .set_attr("FPurity", true); @@ -292,7 +286,6 @@ StructInfo InferStructInfoPad(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.nn.pad") .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPad) .set_attr("FPurity", true); @@ -364,7 +357,6 @@ StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx TVM_REGISTER_OP("relax.nn.pixel_shuffle") .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPixelShuffle) .set_attr("FPurity", true); @@ -374,7 +366,7 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, const ffi::Array& input_sinfo, ffi::Array axes) { Op op = Downcast(call->op); - int n_input = op->arguments.size(); + int n_input = op->num_inputs; TensorStructInfo data_sinfo = input_sinfo[0]; @@ -394,13 +386,13 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, if (input_sinfo[i]->dtype != data_sinfo->dtype) { ctx->ReportFatal(Diagnostic::Error(call) << op - << " requires all the input tensors to have the same dtype. However, the " - << op->arguments[i]->name << " has dtype " << input_sinfo[i]->dtype + << " requires all the input tensors to have the same dtype. However, input[" + << i << "] has dtype " << input_sinfo[i]->dtype << " which is other than the input data's dtype " << data_sinfo->dtype); } else if (input_sinfo[i]->ndim != n_axis) { ctx->ReportFatal(Diagnostic::Error(call) - << op << " requires the input " << op->arguments[i]->name - << " to have as many dimensions as the length of input axes. However, the " + << op << " requires input[" << i + << "] to have as many dimensions as the length of input axes. However, the " "given one has ndim " << input_sinfo[i]->ndim << ", which is other than the length of axes " << n_axis); @@ -514,11 +506,6 @@ InferLayoutOutput InferLayoutBatchNorm( TVM_REGISTER_OP("relax.nn.batch_norm") .set_attrs_type() .set_num_inputs(5) - .add_argument("data", "Tensor", "Input to which batch_norm will be applied.") - .add_argument("gamma", "Tensor", "The gamma scale factor.") - .add_argument("beta", "Tensor", "The beta offset factor.") - .add_argument("moving_mean", "Tensor", "Running mean of input.") - .add_argument("moving_var", "Tensor", "Running variance of input.") .set_attr("FInferStructInfo", InferStructInfoBatchNorm) .set_attr("FRelaxInferLayout", InferLayoutBatchNorm) .set_attr("FPurity", true); @@ -583,9 +570,6 @@ InferLayoutOutput InferLayoutLayerNorm( TVM_REGISTER_OP("relax.nn.layer_norm") .set_attrs_type() .set_num_inputs(3) - .add_argument("data", "Tensor", "Input to which layer_norm will be applied.") - .add_argument("gamma", "Tensor", "The gamma scale factor.") - .add_argument("beta", "Tensor", "The beta offset factor.") .set_attr("FInferStructInfo", InferStructInfoLayerNorm) .set_attr("FRelaxInferLayout", InferLayoutLayerNorm) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -642,7 +626,7 @@ StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) { << op << " expects that the size of channel_axis must be divisible by " << attrs->num_groups << ", but got " << data_shape->values[channel_axis]); } - for (int i = 1; i < static_cast(op->arguments.size()); ++i) { + for (int i = 1; i < op->num_inputs; ++i) { if (input_sinfo[i]->dtype != data_sinfo->dtype) { ctx->ReportFatal(Diagnostic::Error(call) << op << " expects that all inputs must have the same dtype, but got " @@ -697,9 +681,6 @@ InferLayoutOutput InferLayoutGroupNorm( TVM_REGISTER_OP("relax.nn.group_norm") .set_attrs_type() .set_num_inputs(3) - .add_argument("data", "Tensor", "Input to which group_norm will be applied.") - .add_argument("gamma", "Tensor", "The gamma scale factor.") - .add_argument("beta", "Tensor", "The beta offset factor.") .set_attr("FInferStructInfo", InferStructInfoGroupNorm) .set_attr("FRelaxInferLayout", InferLayoutGroupNorm) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -746,7 +727,7 @@ StructInfo InferStructInfoInstanceNorm(const Call& call, const BlockBuilder& ctx } const auto* data_shape = data_sinfo->shape.as(); arith::Analyzer* analyzer = ctx->GetAnalyzer(); - for (int i = 1; i < static_cast(op->arguments.size()); ++i) { + for (int i = 1; i < op->num_inputs; ++i) { if (input_sinfo[i]->dtype != data_sinfo->dtype) { ctx->ReportFatal(Diagnostic::Error(call) << op << " expects that all inputs must have the same dtype, but got " @@ -800,9 +781,6 @@ InferLayoutOutput InferLayoutInstanceNorm( TVM_REGISTER_OP("relax.nn.instance_norm") .set_attrs_type() .set_num_inputs(3) - .add_argument("data", "Tensor", "Input to which instance_norm will be applied.") - .add_argument("gamma", "Tensor", "The gamma scale factor.") - .add_argument("beta", "Tensor", "The beta offset factor.") .set_attr("FInferStructInfo", InferStructInfoInstanceNorm) .set_attr("FRelaxInferLayout", InferLayoutInstanceNorm) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -861,8 +839,6 @@ InferLayoutOutput InferLayoutRMSNorm( TVM_REGISTER_OP("relax.nn.rms_norm") .set_attrs_type() .set_num_inputs(2) - .add_argument("data", "Tensor", "Input to which rms_norm will be applied.") - .add_argument("weight", "Tensor", "The scale factor.") .set_attr("FInferStructInfo", InferStructInfoRMSNorm) .set_attr("FRelaxInferLayout", InferLayoutRMSNorm) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -891,7 +867,6 @@ StructInfo InferStructInfoDropout(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.nn.dropout") .set_attrs_type() .set_num_inputs(1) - .add_argument("data", "Tensor", "Input to which dropout will be applied.") .set_attr("FInferStructInfo", InferStructInfoDropout) .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -956,8 +931,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.nn.cross_entropy_with_logits") .set_num_inputs(2) - .add_argument("predictions", "Tensor", "The predictions.") - .add_argument("labels", "Tensor", "The labels.") .set_attr("FInferStructInfo", InferStructInfoCrossEntropy) .set_attr("FPurity", true); @@ -1187,9 +1160,6 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.nn.nll_loss") .set_attrs_type() .set_num_inputs(3) - .add_argument("predictions", "Tensor", "The prediction tensor.") - .add_argument("targets", "Tensor", "The target tensor.") - .add_argument("weights", "ffi::Optional", "The weight of each target values.") .set_attr("FInferStructInfo", InferStructInfoNLLLoss) .set_attr("FPurity", true); @@ -1238,7 +1208,6 @@ StructInfo InferStructInfoBatchFlatten(const Call& call, const BlockBuilder& ctx TVM_REGISTER_OP("relax.nn.batch_flatten") .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoBatchFlatten) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index 60430519111d..16dba700b0da 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -144,7 +144,6 @@ InferLayoutOutput InferLayoutPool1d( TVM_REGISTER_OP("relax.nn.max_pool1d") .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPool1D) .set_attr("FRelaxInferLayout", InferLayoutPool1d) @@ -295,7 +294,6 @@ InferLayoutOutput InferLayoutPool2d( TVM_REGISTER_OP("relax.nn.max_pool2d") .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPool2D) .set_attr("FRelaxInferLayout", InferLayoutPool2d) @@ -441,7 +439,6 @@ InferLayoutOutput InferLayoutPool3d( TVM_REGISTER_OP("relax.nn.max_pool3d") .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPool3D) .set_attr("FRelaxInferLayout", InferLayoutPool3d) @@ -463,7 +460,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.nn.avg_pool1d") .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPool1D) .set_attr("FRelaxInferLayout", InferLayoutPool1d) @@ -485,7 +481,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.nn.avg_pool2d") .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPool2D) .set_attr("FRelaxInferLayout", InferLayoutPool2d) @@ -507,7 +502,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.nn.avg_pool3d") .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPool3D) .set_attr("FRelaxInferLayout", InferLayoutPool3d) @@ -590,7 +584,6 @@ InferLayoutOutput InferLayoutAdaptiveAvgPool1D( TVM_REGISTER_OP("relax.nn.adaptive_avg_pool1d") .set_attrs_type() .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor") .set_attr("FInferStructInfo", InferStructInfoAdaptiveAvgPool1D) .set_attr("FRelaxInferLayout", InferLayoutAdaptiveAvgPool1D) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -692,7 +685,6 @@ InferLayoutOutput InferLayoutAdaptiveAvgPool2D( TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") .set_attrs_type() .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor") .set_attr("FInferStructInfo", InferStructInfoAdaptiveAvgPool2D) .set_attr("FRelaxInferLayout", InferLayoutAdaptiveAvgPool2D) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -779,7 +771,6 @@ InferLayoutOutput InferLayoutAdaptiveAvgPool3D( TVM_REGISTER_OP("relax.nn.adaptive_avg_pool3d") .set_attrs_type() .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor") .set_attr("FInferStructInfo", InferStructInfoAdaptiveAvgPool3D) .set_attr("FRelaxInferLayout", InferLayoutAdaptiveAvgPool3D) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 8a28ab361af2..21b336db1edc 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -115,9 +115,6 @@ StructInfo InferStructInfoCallPurePacked(const Call& call, const BlockBuilder& c TVM_REGISTER_OP("relax.call_pure_packed") .set_num_inputs(-1) - .add_argument("args", "ffi::Array", - "The first argument is the function being called. The rest are the " - "arguments to that function.") .set_attr("FInferStructInfo", InferStructInfoCallPurePacked) .set_attr("FPurity", true); @@ -231,9 +228,6 @@ StructInfo InferStructInfoCallInplacePacked(const Call& call, const BlockBuilder TVM_REGISTER_OP("relax.call_inplace_packed") .set_num_inputs(-1) .set_attrs_type() - .add_argument("args", "ffi::Array", - "The first argument is the function being called. The rest are the " - "arguments to that function.") .set_attr("FInferStructInfo", InferStructInfoCallInplacePacked) // Warning: considered pure, but it has the potential to create visible effects! // This should only be used if it has been *checked* that it is safe (no aliases, in-place @@ -576,11 +570,6 @@ void ValidateCallTIR(Call call) { TVM_REGISTER_OP("relax.call_tir") .set_num_inputs(3) - .add_argument("func", "Expr", "The destination-passing-style function.") - .add_argument("args", "Tuple", "The input arguments.") - .add_argument("packed_ints", "Expr", - "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " - "args if unused") .set_attr("FInferStructInfo", InferStructInfoCallTIR) .set_attr("FNormalize", NormalizeCallTIR) .set_attr("FValidate", ValidateCallTIR) @@ -624,11 +613,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.call_tir_with_grad") .set_num_inputs(3) .set_attrs_type() - .add_argument("func", "Expr", "The destination-passing-style function.") - .add_argument("args", "Tuple", "The input arguments.") - .add_argument("packed_ints", "Expr", - "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " - "args if unused") .set_attr("FInferStructInfo", InferStructInfoCallTIR) .set_attr("FNormalize", NormalizeCallTIR) .set_attr("FValidate", ValidateCallTIR) @@ -766,11 +750,6 @@ Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) { TVM_REGISTER_OP("relax.call_tir_inplace") .set_num_inputs(3) .set_attrs_type() - .add_argument("func", "Expr", "The destination-passing-style function.") - .add_argument("args", "Tuple", "The input arguments.") - .add_argument("packed_ints", "Expr", - "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " - "args if unused") .set_attr("FInferStructInfo", InferStructInfoCallTIR) .set_attr("FNormalize", NormalizeCallTIRInPlace) .set_attr("FValidate", ValidateCallTIR) @@ -828,8 +807,6 @@ StructInfo InferStructInfoCallDPSPacked(const Call& call, const BlockBuilder& ct TVM_REGISTER_OP("relax.call_dps_packed") .set_num_inputs(2) - .add_argument("func", "Expr", "The destination-passing-style function.") - .add_argument("args", "Tuple", "The input arguments.") .set_attr("FInferStructInfo", InferStructInfoCallDPSPacked) // technically, an impure op could be used with this, but there is // little reason to use DPS with an impure op @@ -894,8 +871,6 @@ void ValidateCallPyFunc(Call call) { TVM_REGISTER_OP("relax.call_py_func") .set_num_inputs(2) - .add_argument("func_name", "StringImm", "The name of the Python function to call.") - .add_argument("args", "Tuple", "The input arguments.") .set_attr("FInferStructInfo", InferStructInfoCallPyFunc) .set_attr("FValidate", ValidateCallPyFunc) .set_attr("FPurity", true); @@ -938,8 +913,6 @@ StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilde TVM_REGISTER_OP("relax.call_builtin_with_ctx") .set_num_inputs(4) - .add_argument("func", "Expr", "The builtin packed func.") - .add_argument("args", "Tuple", "The input arguments.") .set_attr("FInferStructInfo", InferStructInfoCallBuiltinWithCtx) // Most builtins are pure, but some are not, like `vm.builtin.attention_kv_cache_append` .set_attr("FPurity", false); @@ -973,9 +946,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.print") .set_num_inputs(-1) - .add_argument("vals", "ffi::Array", - "The first value is Python-style format string to use to print. The others " - "are values to print") .set_attr("FInferStructInfo", ReturnVoidStructInfo) .set_attr("FCallPacked", "relax.run.print") .set_attr("FPurity", false); @@ -1018,10 +988,6 @@ StructInfo InferAssertStructInfo(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.assert_op") .set_num_inputs(-1) - .add_argument("vals", "ffi::Array", - "The first value is used as the assertion condition. The second value is " - "Python-style format string to use for displaying an error message, if the " - "assert fails. The others are used as format arguments if there is an error.") .set_attr("FInferStructInfo", InferAssertStructInfo) .set_attr("FCallPacked", "relax.run.assert_op") .set_attr("FPurity", false); @@ -1045,8 +1011,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.make_closure") .set_num_inputs(2) - .add_argument("func", "Expr", "The closure.") - .add_argument("args", "Tuple", "The captured variables.") .set_attr("FInferStructInfo", ReturnObjectStructInfo) .set_attr("FPurity", true); @@ -1074,8 +1038,6 @@ StructInfo InferStructInfoInvokeClosure(const Call& call, const BlockBuilder& ct TVM_REGISTER_OP("relax.invoke_closure") .set_num_inputs(2) - .add_argument("closure", "Expr", "The VMClosure.") - .add_argument("args", "Tuple", "The captured variables.") .set_attr("FInferStructInfo", InferStructInfoInvokeClosure) // Not all closures are pure. Use invoke_pure_closure for specifying purity .set_attr("FPurity", false); @@ -1094,8 +1056,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.invoke_pure_closure") .set_num_inputs(2) - .add_argument("closure", "Expr", "The VMClosure.") - .add_argument("args", "Tuple", "The captured variables.") .set_attr("FInferStructInfo", InferStructInfoInvokeClosure) .set_attr("FPurity", true); @@ -1113,7 +1073,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.shape_of") .set_num_inputs(1) - .add_argument("input", "Expr", "The input expression") .set_attr("FInferStructInfo", InferStructInfoShapeOf) .set_attr("FPurity", true); @@ -1139,7 +1098,6 @@ StructInfo InferStructInfoSize(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.size") .set_num_inputs(1) - .add_argument("input", "Expr", "The input tensor") .set_attr("FInferStructInfo", InferStructInfoSize) .set_attr("FPurity", true); @@ -1176,7 +1134,6 @@ StructInfo ReturnTensorToShapeStructInfo(const Call& call, const BlockBuilder& c TVM_REGISTER_OP("relax.tensor_to_shape") .set_num_inputs(1) - .add_argument("input", "Expr", "The input expression") .set_attr("FInferStructInfo", ReturnTensorToShapeStructInfo) .set_attr("FPurity", true); @@ -1202,7 +1159,6 @@ StructInfo ReturnShapeToTensorStructInfo(const Call& call, const BlockBuilder& c TVM_REGISTER_OP("relax.shape_to_tensor") .set_num_inputs(1) - .add_argument("input", "Expr", "The input expression") .set_attr("FInferStructInfo", ReturnShapeToTensorStructInfo) .set_attr("FCallPacked", "relax.run.shape_to_tensor") .set_attr("FPurity", true); @@ -1243,13 +1199,6 @@ StructInfo InferStructInfoAllocateTensor(const Call& call, const BlockBuilder& c TVM_REGISTER_OP("relax.builtin.alloc_tensor") .set_num_inputs(4) - .add_argument("shape", "Expr", "The shape of the tensor to allocate.") - .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.") - .add_argument("runtime_device_index", "PrimValue", - "The device index indicating on which device the tensor is to be " - "allocated at runtime. Index -1 is reserved for the host device.") - .add_argument("storage_scope", "StringImm", - "The storage scope of the storage to allocate. Default is global.") .set_attr("FInferStructInfo", InferStructInfoAllocateTensor) // memory allocation isn't considered a "visible effect" as far as purity is concerned .set_attr("FPurity", true) @@ -1270,14 +1219,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.memory.alloc_storage") .set_num_inputs(4) - .add_argument("total_space", "Expr", "The total space of the storage to allocate.") - .add_argument( - "virtual_device_index", "PrimValue", - "The virtual device index indicating on which device the storage is to be allocated, " - "Index -1 is reserved for the host device.") - .add_argument("storage_scope", "StringImm", - "The storage scope of the storage to allocate. Default is global.") - .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.") .set_attr("FInferStructInfo", ReturnObjectStructInfo) // memory allocation isn't considered a "visible effect" as far as purity is concerned .set_attr("FPurity", true) @@ -1321,13 +1262,6 @@ StructInfo InferStructInfoMemAllocTensor(const Call& call, const BlockBuilder& c TVM_REGISTER_OP("relax.memory.alloc_tensor") .set_num_inputs(5) - .add_argument("storage", "Expr", "The storage to allocate the tensor to.") - .add_argument("offset", "PrimValue", "Storage offset to allocate the tensor.") - .add_argument("shape", "Expr", "The shape of the tensor to allocate.") - .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.") - .add_argument("runtime_device_index", "PrimValue", - "The device index indicating on which device the tensor is to be " - "allocated at runtime. Index -1 is reserved for the host device.") .set_attr("FInferStructInfo", InferStructInfoMemAllocTensor) // memory allocation isn't considered a "visible effect" as far as purity is concerned .set_attr("FPurity", true) @@ -1359,7 +1293,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.memory.kill_storage") .set_num_inputs(1) - .add_argument("storage", "Expr", "The storage to be killed.") .set_attr("FInferStructInfo", ReturnVoidStructInfo) // We mark this as impure so it wouldn't be removed by "remove_all_unused" .set_attr("FPurity", false); @@ -1378,7 +1311,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.memory.kill_tensor") .set_num_inputs(1) - .add_argument("tensor", "Expr", "The tensor to be killed.") .set_attr("FInferStructInfo", ReturnVoidStructInfo) // We mark this as impure so it wouldn't be removed by "remove_all_unused" .set_attr("FPurity", false); @@ -1397,13 +1329,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.vm.alloc_storage") .set_num_inputs(4) - .add_argument("size", "Expr", "The size of the storage to allocate.") - .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.") - .add_argument("runtime_device_index", "PrimValue", - "The device index indicating on which device the tensor is " - "to be allocated at runtime.") - .add_argument("storage_scope", "StringImm", - "The storage scope of the storage to allocate. Default is global.") .set_attr("FInferStructInfo", ReturnObjectStructInfo) // memory allocation isn't considered a "visible effect" as far as purity is concerned .set_attr("FPurity", true) @@ -1448,13 +1373,6 @@ StructInfo InferStructInfoVMAllocTensor(const Call& call, const BlockBuilder& ct TVM_REGISTER_OP("relax.vm.alloc_tensor") .set_num_inputs(5) - .add_argument("storage", "Expr", "The storage to allocate the tensor to.") - .add_argument("offset", "PrimValue", "Storage offset to allocate the tensor.") - .add_argument("shape", "Expr", "The shape of the tensor to allocate.") - .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.") - .add_argument("runtime_device_index", "PrimValue", - "The device index indicating on which device the tensor is " - "to be allocated at runtime.") .set_attr("FInferStructInfo", InferStructInfoVMAllocTensor) // memory allocation isn't considered a "visible effect" as far as purity is concerned .set_attr("FPurity", true) @@ -1484,7 +1402,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { // vm kill_object TVM_REGISTER_OP("relax.vm.kill_object") .set_num_inputs(1) - .add_argument("obj", "Expr", "The object to be killed.") .set_attr("FInferStructInfo", ReturnVoidStructInfo) // We mark this as impure so it wouldn't be removed by "remove_all_unused" .set_attr("FPurity", false); @@ -1503,9 +1420,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.vm.call_tir_dyn") .set_num_inputs(2) - .add_argument("func", "Expr", "The destination-passing-style function.") - .add_argument("args", "Tuple", - "The input arguments (list of tensors and last argument is ShapeExpr)") .set_attr("FInferStructInfo", ReturnVoidStructInfo) // "relax.vm.call_tir_dyn" works in an in-place way, which is impure. .set_attr("FPurity", false); @@ -1527,7 +1441,6 @@ StructInfo InferStructInfoStopLiftParams(const Call& call, const BlockBuilder& c TVM_REGISTER_OP("relax.builtin.stop_lift_params") .set_num_inputs(1) - .add_argument("x", "Expr", "The input data") .set_attr("FInferStructInfo", InferStructInfoStopLiftParams) .set_attr("FPurity", true); @@ -1558,7 +1471,6 @@ StructInfo InferToVDeviceStructInfo(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.to_vdevice") .set_num_inputs(1) .set_attrs_type() - .add_argument("data", "Expr", "The input expression to be copied") .set_attr("FInferStructInfo", InferToVDeviceStructInfo) .set_attr("FPurity", true); @@ -1586,7 +1498,6 @@ StructInfo InferHintOnDeviceStructInfo(const Call& call, const BlockBuilder& ctx TVM_REGISTER_OP("relax.hint_on_device") .set_num_inputs(1) .set_attrs_type() - .add_argument("data", "Expr", "The input expression") .set_attr("FInferStructInfo", InferHintOnDeviceStructInfo) .set_attr("FPurity", true); diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc index 61485b09112b..a700b40bd4ea 100644 --- a/src/relax/op/op_common.cc +++ b/src/relax/op/op_common.cc @@ -39,7 +39,7 @@ ffi::Array GetCallArgs(const Call& call) { void CheckNumArguments(const Call& call, const BlockBuilder& ctx) { Op op = Downcast(call->op); - int expected_input = op->arguments.size(); + int expected_input = op->num_inputs; if (static_cast(call->args.size()) != expected_input) { ctx->ReportFatal(Diagnostic::Error(call) << "Operator " << op << " expects " << expected_input << " arguments" @@ -50,10 +50,10 @@ void CheckNumArguments(const Call& call, const BlockBuilder& ctx) { TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const BlockBuilder& ctx) { Op op = Downcast(call->op); - TVM_FFI_ICHECK_EQ(op->arguments.size(), call->args.size()) + TVM_FFI_ICHECK_EQ(static_cast(op->num_inputs), call->args.size()) << "Failure caught by this check " << "should have previously been caught by `CheckNumArguments`"; - TVM_FFI_ICHECK_LT(i_arg, op->arguments.size()); + TVM_FFI_ICHECK_LT(i_arg, static_cast(op->num_inputs)); auto arg = call->args[i_arg]; auto sinfo = GetStructInfo(arg); @@ -62,8 +62,8 @@ TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const return tensor_sinfo.value(); } else { ctx->ReportFatal(Diagnostic::Error(call) - << "Operator " << op << " requires argument " << i_arg << " (" - << op->arguments[i_arg]->name << ") to be a tensor. " + << "Operator " << op << " requires argument input[" << i_arg + << "] to be a tensor. " << "However, the argument " << arg << " is instead of type " << sinfo); // Unreachable, but [[noreturn]] attribute on virtual function // `ReportFatal` is insufficient to silence -Wreturn-type, as diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 774eccfd58dd..88630a9043d6 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -142,10 +142,10 @@ std::tuple GetArgStructInfoHelper(const Call& call, const Op& op, template std::tuple GetArgStructInfo(const Call& call, const BlockBuilder& ctx) { Op op = Downcast(call->op); - size_t n_input = op->arguments.size(); + size_t n_input = op->num_inputs; - // Unfortunately, because the `.add_argument()` calls in - // TVM_REGISTER_OP occur during initialization of globals and are + // Unfortunately, because the `.set_num_inputs()` call in + // TVM_REGISTER_OP occurs during initialization of globals and is // not available at compile-time, this cannot be a static_assert. TVM_FFI_ICHECK_EQ(n_input, sizeof...(ArgTypes)) << "Internal error: " << op << " op defines " << n_input @@ -166,7 +166,6 @@ std::tuple GetArgStructInfo(const Call& call, const BlockBuilder& c #define RELAX_REGISTER_UNARY_OP(OpRegName) \ TVM_REGISTER_OP("relax." OpRegName) \ .set_num_inputs(1) \ - .add_argument("x", "Tensor", "The input tensor.") \ .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) \ .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) \ .set_attr("FPurity", true) @@ -235,7 +234,7 @@ inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx template StructInfo ReturnStructInfoFromArg(const Call& call, const BlockBuilder& ctx) { Op op = Downcast(call->op); - int n_input = op->arguments.size(); + int n_input = op->num_inputs; if (static_cast(call->args.size()) != n_input) { ctx->ReportFatal(Diagnostic::Error(call) << op << " op should have " << n_input << " arguments"); diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index 07c3364a9f35..4575fba43385 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -35,7 +35,7 @@ template StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, FType f_compute_out_dtype) { Op op = Downcast(call->op); - size_t n_input = op->arguments.size(); + size_t n_input = op->num_inputs; if (call->args.size() != n_input) { ctx->ReportFatal(Diagnostic::Error(call) << call->op << " op should have " << n_input << " arguments"); diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index a234a30bc221..6f4068b3ccb8 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -47,8 +47,6 @@ namespace relax { } \ TVM_REGISTER_OP("relax." #OpName) \ .set_num_inputs(2) \ - .add_argument("x1", "Tensor", "The first input tensor.") \ - .add_argument("x2", "Tensor", "The second input tensor.") \ .set_attr("FRelaxInferLayout", InferLayoutBinaryEwise) \ .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) \ .set_attr("FPurity", true) diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index 885f7c87257e..72de90b55024 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -94,8 +94,6 @@ StructInfo InferStructInfoFull(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.full") .set_attrs_type() .set_num_inputs(2) - .add_argument("shape", "Shape", "The shape of the created tensor.") - .add_argument("fill_value", "Tensor", "The scalar tensor, denoting the value to fill.") .set_attr("FInferStructInfo", InferStructInfoFull) .set_attr("RequiresArgumentShapes", false) .set_attr("FDataDependent", true) @@ -138,8 +136,6 @@ StructInfo InferStructInfoFullLike(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.full_like") .set_attrs_type() .set_num_inputs(2) - .add_argument("x", "Tensor", "The input tensor.") - .add_argument("fill_value", "Tensor", "The scalar value to fill.") .set_attr("FInferStructInfo", InferStructInfoFullLike) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -199,7 +195,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.ones") .set_attrs_type() .set_num_inputs(1) - .add_argument("shape", "Shape", "The shape of the created tensor.") .set_attr("FInferStructInfo", InferStructInfoOnesZeros) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -207,7 +202,6 @@ TVM_REGISTER_OP("relax.ones") TVM_REGISTER_OP("relax.ones_like") .set_attrs_type() .set_num_inputs(1) - .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoOnesLikeZerosLike) .set_attr("FPurity", true); @@ -236,7 +230,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.zeros") .set_attrs_type() .set_num_inputs(1) - .add_argument("shape", "Shape", "The shape of the created tensor.") .set_attr("FInferStructInfo", InferStructInfoOnesZeros) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -244,7 +237,6 @@ TVM_REGISTER_OP("relax.zeros") TVM_REGISTER_OP("relax.zeros_like") .set_attrs_type() .set_num_inputs(1) - .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoOnesLikeZerosLike) .set_attr("FPurity", true); @@ -319,9 +311,6 @@ StructInfo InferStructInfoEyeLike(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.eye") .set_attrs_type() .set_num_inputs(3) - .add_argument("n", "PrimValue", "Number of rows in the output.") - .add_argument("m", "PrimValue", "Number of columns in the output.") - .add_argument("k", "PrimValue", "Index of the diagonal.") .set_attr("FInferStructInfo", InferStructInfoEye) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -329,8 +318,6 @@ TVM_REGISTER_OP("relax.eye") TVM_REGISTER_OP("relax.eye_like") .set_attrs_type() .set_num_inputs(2) - .add_argument("x", "Tensor", "The input tensor.") - .add_argument("k", "PrimValue", "Index of the diagonal.") .set_attr("FInferStructInfo", InferStructInfoEyeLike) .set_attr("FPurity", true); @@ -382,9 +369,6 @@ StructInfo InferStructInfoArange(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.arange") .set_attrs_type() .set_num_inputs(3) - .add_argument("start", "PrimValue", "The starting value for the set of points.") - .add_argument("end", "PrimValue", "The ending value for the set of points.") - .add_argument("step", "PrimValue", "The gap between each pair of adjacent points.") .set_attr("FInferStructInfo", InferStructInfoArange) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -433,12 +417,6 @@ StructInfo InferStructInfoHammingWindow(const Call& call, const BlockBuilder& ct TVM_REGISTER_OP("relax.hamming_window") .set_attrs_type() .set_num_inputs(4) - .add_argument("window_size", "PrimValue", "The size of the window") - .add_argument("periodic", "PrimValue", - "If True, returns a window to be used as periodic function. If False, return a " - "symmetric window") - .add_argument("alpha", "PrimValue", "The coefficient alpha") - .add_argument("beta", "PrimValue", "The coefficient beta") .set_attr("FInferStructInfo", InferStructInfoHammingWindow) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -480,15 +458,11 @@ StructInfo InferStructInfoTrilTriu(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.tril") .set_num_inputs(2) - .add_argument("x", "Tensor", "The input tensor.") - .add_argument("k", "PrimValue", "The offset of the diagonal.") .set_attr("FInferStructInfo", InferStructInfoTrilTriu) .set_attr("FPurity", true); TVM_REGISTER_OP("relax.triu") .set_num_inputs(2) - .add_argument("x", "Tensor", "The input tensor.") - .add_argument("k", "PrimValue", "The offset of the diagonal.") .set_attr("FInferStructInfo", InferStructInfoTrilTriu) .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/datatype.cc b/src/relax/op/tensor/datatype.cc index 50624355c8fe..8f51de26b60b 100644 --- a/src/relax/op/tensor/datatype.cc +++ b/src/relax/op/tensor/datatype.cc @@ -63,7 +63,6 @@ StructInfo InferStructInfoAstype(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.astype") .set_attrs_type() .set_num_inputs(1) - .add_argument("x", "Tensor", "The input tensor") .set_attr("FInferStructInfo", InferStructInfoAstype) .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -96,7 +95,6 @@ StructInfo InferStructInfoWrapParam(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.wrap_param") .set_attrs_type() .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor") .set_attr("FInferStructInfo", InferStructInfoWrapParam) .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/grad.cc b/src/relax/op/tensor/grad.cc index b05757c7de5e..74d5660fa09b 100644 --- a/src/relax/op/tensor/grad.cc +++ b/src/relax/op/tensor/grad.cc @@ -48,7 +48,6 @@ StructInfo InferStructInfoNoGrad(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.grad.no_grad") .set_num_inputs(1) - .add_argument("x", "Expr", "The corresponding input tensor.") .set_attr("FInferStructInfo", InferStructInfoNoGrad) .set_attr("FPurity", true); @@ -73,7 +72,6 @@ StructInfo InferStructInfoStartCheckpoint(const Call& call, const BlockBuilder& TVM_REGISTER_OP("relax.grad.start_checkpoint") .set_num_inputs(1) - .add_argument("x", "Expr", "The tensor marking the input of the checkpoint stage.") .set_attr("FInferStructInfo", InferStructInfoStartCheckpoint) .set_attr("FPurity", true); @@ -98,7 +96,6 @@ StructInfo InferStructInfoEndCheckpoint(const Call& call, const BlockBuilder& ct TVM_REGISTER_OP("relax.grad.end_checkpoint") .set_num_inputs(1) - .add_argument("x", "Expr", "The output of the checkpoint stage.") .set_attr("FInferStructInfo", InferStructInfoEndCheckpoint) .set_attr("FPurity", true); @@ -133,10 +130,6 @@ StructInfo InferStructInfoNLLLossBackward(const Call& call, const BlockBuilder& TVM_REGISTER_OP("relax.grad.nll_loss_backward") .set_attrs_type() .set_num_inputs(4) - .add_argument("output_grad", "Tensor", "The output gradient.") - .add_argument("predictions", "Tensor", "The prediction tensor.") - .add_argument("targets", "Tensor", "The target tensor.") - .add_argument("weights", "ffi::Optional", "The weight of each target values.") .set_attr("FInferStructInfo", InferStructInfoNLLLossBackward) .set_attr("FPurity", true); @@ -169,8 +162,6 @@ StructInfo InferStructInfoMaxPool2DBackward(const Call& call, const BlockBuilder TVM_REGISTER_OP("relax.grad.max_pool2d_backward") .set_num_inputs(2) - .add_argument("output_grad", "Tensor", "The output gradient.") - .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoMaxPool2DBackward) .set_attr("FPurity", true); @@ -204,8 +195,6 @@ StructInfo InferStructInfoAvgPool2DBackward(const Call& call, const BlockBuilder TVM_REGISTER_OP("relax.grad.avg_pool2d_backward") .set_num_inputs(2) - .add_argument("output_grad", "Tensor", "The output gradient.") - .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoAvgPool2DBackward) .set_attr("FPurity", true); @@ -232,9 +221,6 @@ StructInfo InferStructInfoTakeBackward(const Call& call, const BlockBuilder& ctx TVM_REGISTER_OP("relax.grad.take_backward") .set_attrs_type() .set_num_inputs(3) - .add_argument("output_grad", "Tensor", "The output gradient.") - .add_argument("x", "Tensor", "The source tensor.") - .add_argument("indices", "Tensor", "The indices of the values to extract.") .set_attr("FInferStructInfo", InferStructInfoTakeBackward) .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index 6b02ca050bea..f3d948212538 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -130,8 +130,6 @@ StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.take") .set_attrs_type() .set_num_inputs(2) - .add_argument("x", "Tensor", "The source tensor.") - .add_argument("indices", "Tensor", "The indices of the values to extract.") .set_attr("FInferStructInfo", InferStructInfoTake) .set_attr("FPurity", true); @@ -480,7 +478,6 @@ InferLayoutOutput InferLayoutStridedSlice( TVM_REGISTER_OP("relax.strided_slice") .set_attrs_type() .set_num_inputs(1) - .add_argument("x", "Tensor", "The source tensor to be sliced.") .set_attr("FInferStructInfo", InferStructInfoStridedSlice) .set_attr("FRelaxInferLayout", InferLayoutStridedSlice) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -575,10 +572,6 @@ InferLayoutOutput InferLayoutDynStridedSlice( TVM_REGISTER_OP("relax.dynamic_strided_slice") .set_num_inputs(4) - .add_argument("x", "Tensor", "The source tensor to be sliced.") - .add_argument("begin", "Tensor", "The indices to begin with in the slicing.") - .add_argument("end", "Tensor", "Indices indicating end of the slice.") - .add_argument("strides", "Tensor", "The stride values.") .set_attr("FInferStructInfo", InferStructInfoDynStridedSlice) .set_attr("FRelaxInferLayout", InferLayoutDynStridedSlice) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc index 3988e0ba2359..f88342c93c11 100644 --- a/src/relax/op/tensor/inspect.cc +++ b/src/relax/op/tensor/inspect.cc @@ -150,7 +150,6 @@ Expr LegalizeTensorDtypeCode(const BlockBuilder& bb, const Call& call) { TVM_REGISTER_OP("relax.inspect.tensor_dtype_code") .set_num_inputs(1) - .add_argument("tensor", "Tensor", "The tensor to be inspected") .set_attr("FInferStructInfo", InferStructInfoTensorDtypeCode) .set_attr("FLegalize", LegalizeTensorDtypeCode) .set_attr("RequiresArgumentShapes", false) @@ -188,7 +187,6 @@ Expr LegalizeTensorDtypeBits(const BlockBuilder& bb, const Call& call) { TVM_REGISTER_OP("relax.inspect.tensor_dtype_bits") .set_num_inputs(1) - .add_argument("tensor", "Tensor", "The tensor to be inspected") .set_attr("FInferStructInfo", InferStructInfoTensorDtypeBits) .set_attr("FLegalize", LegalizeTensorDtypeBits) .set_attr("RequiresArgumentShapes", false) @@ -226,7 +224,6 @@ Expr LegalizeTensorDtypeLanes(const BlockBuilder& bb, const Call& call) { TVM_REGISTER_OP("relax.inspect.tensor_dtype_lanes") .set_num_inputs(1) - .add_argument("tensor", "Tensor", "The tensor to be inspected") .set_attr("FInferStructInfo", InferStructInfoTensorDtypeLanes) .set_attr("FLegalize", LegalizeTensorDtypeLanes) .set_attr("RequiresArgumentShapes", false) @@ -264,7 +261,6 @@ Expr LegalizeTensorNDim(const BlockBuilder& bb, const Call& call) { TVM_REGISTER_OP("relax.inspect.tensor_ndim") .set_num_inputs(1) - .add_argument("tensor", "Tensor", "The tensor to be inspected") .set_attr("FInferStructInfo", InferStructInfoTensorNDim) .set_attr("FLegalize", LegalizeTensorNDim) .set_attr("RequiresArgumentShapes", false) @@ -342,8 +338,6 @@ Expr LegalizeTensorShape(const BlockBuilder& bb, const Call& call) { TVM_REGISTER_OP("relax.inspect.tensor_shape_i") .set_num_inputs(2) - .add_argument("tensor", "Tensor", "The tensor to be inspected") - .add_argument("axis", "Prim(int64)", "The axis whose extent should be returned") .set_attr("FInferStructInfo", InferStructInfoTensorShape) .set_attr("FLegalize", LegalizeTensorShape) .set_attr("RequiresArgumentShapes", false) @@ -391,8 +385,6 @@ StructInfo InferStructInfoTensorStride(const Call& call, const BlockBuilder&) { TVM_REGISTER_OP("relax.inspect.tensor_stride_i") .set_num_inputs(2) - .add_argument("tensor", "Tensor", "The tensor to be inspected") - .add_argument("axis", "Prim(int64)", "The axis whose extent should be returned") .set_attr("FInferStructInfo", InferStructInfoTensorStride) .set_attr("RequiresArgumentShapes", false) .set_attr("FNormalize", NormalizeToKnownPrimValue) @@ -423,7 +415,6 @@ StructInfo InferStructInfoTensorByteOffset(const Call& call, const BlockBuilder& TVM_REGISTER_OP("relax.inspect.tensor_byte_offset") .set_num_inputs(1) - .add_argument("tensor", "Tensor", "The tensor to be inspected") .set_attr("FInferStructInfo", InferStructInfoTensorByteOffset) .set_attr("RequiresArgumentShapes", false) .set_attr("FNormalize", NormalizeToKnownPrimValue) @@ -454,7 +445,6 @@ StructInfo InferStructInfoTensorElemOffset(const Call& call, const BlockBuilder& TVM_REGISTER_OP("relax.inspect.tensor_elem_offset") .set_num_inputs(1) - .add_argument("tensor", "Tensor", "The tensor to be inspected") .set_attr("FInferStructInfo", InferStructInfoTensorElemOffset) .set_attr("RequiresArgumentShapes", false) .set_attr("FNormalize", NormalizeToKnownPrimValue) diff --git a/src/relax/op/tensor/linear_algebra.cc b/src/relax/op/tensor/linear_algebra.cc index 6936fa04348b..108248e10f74 100644 --- a/src/relax/op/tensor/linear_algebra.cc +++ b/src/relax/op/tensor/linear_algebra.cc @@ -166,8 +166,6 @@ Call InferMixedPrecisionMatmul(const Call& call, const DataType& out_dtype) { TVM_REGISTER_OP("relax.matmul") .set_num_inputs(2) - .add_argument("x1", "Tensor", "The first input tensor.") - .add_argument("x2", "Tensor", "The second input tensor.") .set_attr("FInferStructInfo", InferStructInfoMatmul) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionMatmul) @@ -257,7 +255,6 @@ StructInfo InferStructInfoEinsum(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.einsum") .set_attrs_type() .set_num_inputs(1) - .add_argument("operands", "Tensor", "The input tensors.") .set_attr("FInferStructInfo", InferStructInfoEinsum) .set_attr("FPurity", true); @@ -296,8 +293,6 @@ StructInfo InferStructInfoOuter(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.outer") .set_num_inputs(2) - .add_argument("x1", "Tensor", "The first input tensor.") - .add_argument("x2", "Tensor", "The second input tensor.") .set_attr("FInferStructInfo", InferStructInfoOuter) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 763e37ae6815..ca60c2476f06 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -135,8 +135,6 @@ StructInfo InferStructInfoBroadcastTo(const Call& call, const BlockBuilder& ctx) TVM_REGISTER_OP("relax.broadcast_to") .set_num_inputs(2) - .add_argument("x", "Tensor", "The input tensor.") - .add_argument("shape", "Shape", "The target shape.") .set_attr("FInferStructInfo", InferStructInfoBroadcastTo) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -398,7 +396,6 @@ InferLayoutOutput InferLayoutConcat( TVM_REGISTER_OP("relax.concat") .set_attrs_type() .set_num_inputs(1) - .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.") .set_attr("FInferStructInfo", InferStructInfoConcat) .set_attr("FRelaxInferLayout", InferLayoutConcat) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -502,7 +499,6 @@ InferLayoutOutput InferLayoutExpandDims( TVM_REGISTER_OP("relax.expand_dims") .set_num_inputs(1) .set_attrs_type() - .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoExpandDims) .set_attr("FRelaxInferLayout", InferLayoutExpandDims) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -549,7 +545,6 @@ StructInfo InferStructInfoFlatten(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.flatten") .set_num_inputs(1) - .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoFlatten) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -698,8 +693,6 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) TVM_REGISTER_OP("relax.index_tensor") .set_num_inputs(2) - .add_argument("data", "Tensor", "The input data.") - .add_argument("indices", "List of Tensors", "The indices used to index.") .set_attr("FInferStructInfo", InferStructInfoIndexTensor) .set_attr("FPurity", true); @@ -772,7 +765,6 @@ StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& TVM_REGISTER_OP("relax.layout_transform") .set_num_inputs(1) .set_attrs_type() - .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoLayoutTransform) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -892,7 +884,6 @@ InferLayoutOutput InferLayoutPermuteDims( TVM_REGISTER_OP("relax.permute_dims") .set_attrs_type() .set_num_inputs(1) - .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoPermuteDims) .set_attr("FRelaxInferLayout", InferLayoutPermuteDims) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -1055,8 +1046,6 @@ StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.reshape") .set_num_inputs(2) - .add_argument("x", "Tensor", "The input tensor.") - .add_argument("shape", "Shape", "The input new shape.") .set_attr("FInferStructInfo", InferStructInfoReshape) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -1234,7 +1223,6 @@ InferLayoutOutput InferLayoutSplit( TVM_REGISTER_OP("relax.split") .set_attrs_type() .set_num_inputs(1) - .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoSplit) .set_attr("FRelaxInferLayout", InferLayoutSplit) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -1394,7 +1382,6 @@ InferLayoutOutput InferLayoutSqueeze( TVM_REGISTER_OP("relax.squeeze") .set_num_inputs(1) .set_attrs_type() - .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoSqueeze) .set_attr("FRelaxInferLayout", InferLayoutSqueeze) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -1647,7 +1634,6 @@ InferLayoutOutput InferLayoutStack( TVM_REGISTER_OP("relax.stack") .set_attrs_type() .set_num_inputs(1) - .add_argument("tensors", "Tuple of Tensors", "The input list of tensors to stack") .set_attr("FInferStructInfo", InferStructInfoStack) .set_attr("FRelaxInferLayout", InferLayoutStack) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -1696,9 +1682,6 @@ StructInfo InferStructInfoCollapseSumLike(const Call& call, const BlockBuilder& TVM_REGISTER_OP("relax.collapse_sum_like") .set_num_inputs(2) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("collapse_target", "Tensor", - "The tensor whose shape is the shape to collapse to.") .set_attr("FInferStructInfo", InferStructInfoCollapseSumLike) .set_attr("FPurity", true); @@ -1749,8 +1732,6 @@ StructInfo InferStructInfoCollapseSumTo(const Call& call, const BlockBuilder& ct TVM_REGISTER_OP("relax.collapse_sum_to") .set_num_inputs(2) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("shape", "Shape", "The shape to collapse to.") .set_attr("FInferStructInfo", InferStructInfoCollapseSumTo) .set_attr("FPurity", true); @@ -1875,7 +1856,6 @@ InferLayoutOutput InferLayoutRepeat( TVM_REGISTER_OP("relax.repeat") .set_attrs_type() .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoRepeat) .set_attr("FRelaxInferLayout", InferLayoutRepeat) .set_attr("FPurity", true); @@ -2020,7 +2000,6 @@ InferLayoutOutput InferLayoutTile( TVM_REGISTER_OP("relax.tile") .set_attrs_type() .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoTile) .set_attr("FRelaxInferLayout", InferLayoutTile) .set_attr("FPurity", true); @@ -2092,7 +2071,6 @@ InferLayoutOutput InferLayoutFlip( TVM_REGISTER_OP("relax.flip") .set_attrs_type() .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoFlip) .set_attr("FRelaxInferLayout", InferLayoutFlip) .set_attr("FPurity", true); @@ -2196,8 +2174,6 @@ InferLayoutOutput InferLayoutGatherElements( TVM_REGISTER_OP("relax.gather_elements") .set_attrs_type() .set_num_inputs(2) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("indices", "Tensor", "The indices tensor.") .set_attr("FInferStructInfo", InferStructInfoGatherElements) .set_attr("FRelaxInferLayout", InferLayoutGatherElements) .set_attr("FPurity", true); @@ -2293,8 +2269,6 @@ StructInfo InferStructInfoGatherND(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.gather_nd") .set_attrs_type() .set_num_inputs(2) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("indices", "Tensor", "The indices tensor.") .set_attr("FInferStructInfo", InferStructInfoGatherND) .set_attr("FPurity", true); @@ -2441,9 +2415,6 @@ StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.index_put") .set_attrs_type() .set_num_inputs(3) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("indices", "Tensor", "The indices tensor(s).") - .add_argument("values", "Tensor", "The values to put.") .set_attr("FInferStructInfo", InferStructInfoIndexPut) .set_attr("FPurity", true); @@ -2547,7 +2518,6 @@ StructInfo InferStructInfoMeshgrid(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.meshgrid") .set_attrs_type() .set_num_inputs(1) - .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.") .set_attr("FInferStructInfo", InferStructInfoMeshgrid) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -2689,9 +2659,6 @@ InferLayoutOutput InferLayoutScatterElements( TVM_REGISTER_OP("relax.scatter_elements") .set_attrs_type() .set_num_inputs(3) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("indices", "Tensor", "The indices tensor.") - .add_argument("updates", "Tensor", "The input tensor of updates.") .set_attr("FInferStructInfo", InferStructInfoScatterElements) .set_attr("FRelaxInferLayout", InferLayoutScatterElements) .set_attr("FPurity", true); @@ -2866,9 +2833,6 @@ InferLayoutOutput InferLayoutScatterND( TVM_REGISTER_OP("relax.scatter_nd") .set_attrs_type() .set_num_inputs(3) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("indices", "Tensor", "The indices tensor.") - .add_argument("updates", "Tensor", "The input tensor of updates.") .set_attr("FInferStructInfo", InferStructInfoScatterND) .set_attr("FRelaxInferLayout", InferLayoutScatterND) .set_attr("FPurity", true); @@ -3022,11 +2986,6 @@ StructInfo InferStructInfoSliceScatter(const Call& call, const BlockBuilder& ctx TVM_REGISTER_OP("relax.slice_scatter") .set_attrs_type() .set_num_inputs(5) - .add_argument("input", "Tensor", "The input tensor.") - .add_argument("src", "Tensor", "The source tensor to scatter.") - .add_argument("start", "PrimValue", "The starting index of the slice (inclusive).") - .add_argument("end", "PrimValue", "The ending index of the slice (exclusive).") - .add_argument("step", "PrimValue", "The step of the slice.") .set_attr("FInferStructInfo", InferStructInfoSliceScatter) .set_attr("FPurity", true); @@ -3101,9 +3060,6 @@ StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.one_hot") .set_attrs_type() .set_num_inputs(3) - .add_argument("indices", "Tensor", "The indices tensor.") - .add_argument("on_value", "PrimValue", "The value to fill at specified indices.") - .add_argument("off_value", "PrimValue", "The value to fill at other indices.") .set_attr("FInferStructInfo", InferStructInfoOneHot) .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/qdq.cc b/src/relax/op/tensor/qdq.cc index 99cb5810e1ab..9386f9cc9ff8 100644 --- a/src/relax/op/tensor/qdq.cc +++ b/src/relax/op/tensor/qdq.cc @@ -135,9 +135,6 @@ StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.quantize") .set_attrs_type() .set_num_inputs(3) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("scale", "Tensor", "The quantization scale of the output tensor.") - .add_argument("zero_point", "Tensor", "The quantization zero_point of the output tensor.") .set_attr("FInferStructInfo", InferStructInfoQuantize) .set_attr("FPurity", true); @@ -242,9 +239,6 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx) TVM_REGISTER_OP("relax.dequantize") .set_attrs_type() .set_num_inputs(3) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("scale", "Tensor", "The quantization scale of the input tensor.") - .add_argument("zero_point", "Tensor", "The quantization zero_point of the input tensor.") .set_attr("FInferStructInfo", InferStructInfoDequantize) .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/sampling.cc b/src/relax/op/tensor/sampling.cc index febe4d521d3d..331aabd6703c 100644 --- a/src/relax/op/tensor/sampling.cc +++ b/src/relax/op/tensor/sampling.cc @@ -139,9 +139,6 @@ StructInfo InferStructInfoMultinomialFromUniform(const Call& call, const BlockBu TVM_REGISTER_OP("relax.multinomial_from_uniform") .set_attrs_type() .set_num_inputs(3) - .add_argument("prob", "Tensor", "The probability tensor.") - .add_argument("uniform_sample", "Tensor", "The uniform sample tensor.") - .add_argument("sample_indices", "Tensor", "The sample indices tensor.") .set_attr("FInferStructInfo", InferStructInfoMultinomialFromUniform) .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc index 5aa1e49557be..653e4bda6a0e 100644 --- a/src/relax/op/tensor/search.cc +++ b/src/relax/op/tensor/search.cc @@ -79,11 +79,6 @@ StructInfo InferStructInfoBucketize(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.bucketize") .set_num_inputs(2) - .add_argument("input_tensor", "Tensor", - " N-D tensor or a Scalar containing the search value(s).") - .add_argument("boundaries", "Tensor", - "1-D tensor, must contain a strictly increasing sequence, or the return value is " - "undefined.") .set_attr("FInferStructInfo", InferStructInfoBucketize) .set_attr("FPurity", true); @@ -180,9 +175,6 @@ StructInfo InferStructInfoWhere(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.where") .set_num_inputs(3) - .add_argument("condition", "Tensor", "When True, yield `x1`; otherwise, yield `x2`.") - .add_argument("x1", "Tensor", "The first input tensor.") - .add_argument("x2", "Tensor", "The second input tensor.") .set_attr("FInferStructInfo", InferStructInfoWhere) .set_attr("FPurity", true); @@ -260,7 +252,6 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx } \ TVM_REGISTER_OP("relax." #OpName) \ .set_num_inputs(1) \ - .add_argument("x", "Tensor", "The input data tensor") \ .set_attr("FInferStructInfo", InferStructInfoArgmaxArgmin) \ .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index a2743ab574c6..cb5e25184a69 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -148,23 +148,6 @@ StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.unique") .set_num_inputs(6) - .add_argument("x", "Tensor", "The input tensor") - .add_argument( - "sorted", "Tensor", - "Whether to sort the unique elements in ascending order before returning as output.") - .add_argument( - "return_index", "Tensor", - "Whether to return an additional tensor with indices for where elements in the unique " - "tensor come from the original input.") - .add_argument("return_inverse", "Tensor", - "Whether to return an additional tensor with indices for where elements in the " - "original input ended up in the returned unique list.") - .add_argument("return_counts", "Tensor", - "Whether to return an additional tensor with counts of each unique elements") - .add_argument("axis", "Tensor", - "The dimension to apply unique. If it is std::nullopt, the unique values of the " - "flattened input " - "are returned.") .set_attr("FInferStructInfo", InferStructInfoUnique) .set_attr("FCallPacked", "relax.run.unique") .set_attr("FPurity", true); @@ -187,7 +170,6 @@ StructInfo InferStructInfoNonzero(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.nonzero") .set_num_inputs(1) - .add_argument("x", "Tensor", "The input tensor") .set_attr("FInferStructInfo", InferStructInfoNonzero) .set_attr("FCallPacked", "relax.run.nonzero") .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/sorting.cc b/src/relax/op/tensor/sorting.cc index 7b8a310c65d9..d0e82613c1b1 100644 --- a/src/relax/op/tensor/sorting.cc +++ b/src/relax/op/tensor/sorting.cc @@ -60,7 +60,6 @@ StructInfo InferStructInfoSort(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.sort") .set_attrs_type() .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoSort) .set_attr("FPurity", true); @@ -94,7 +93,6 @@ StructInfo InferStructInfoArgsort(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.argsort") .set_attrs_type() .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoArgsort) .set_attr("FPurity", true); @@ -162,7 +160,6 @@ StructInfo InferStructInfoTopK(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.topk") .set_attrs_type() .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoTopK) .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc index d6f3a15005f3..c046aa04fa4f 100644 --- a/src/relax/op/tensor/statistical.cc +++ b/src/relax/op/tensor/statistical.cc @@ -262,7 +262,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.cumprod") .set_attrs_type() .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoScan) .set_attr("FPurity", true); @@ -285,7 +284,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.cumsum") .set_attrs_type() .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoScan) .set_attr("FPurity", true); @@ -305,7 +303,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.median") .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoStatisticalExtension) .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/statistical.h b/src/relax/op/tensor/statistical.h index ee4138f133b1..a99d5bb6062c 100644 --- a/src/relax/op/tensor/statistical.h +++ b/src/relax/op/tensor/statistical.h @@ -55,7 +55,6 @@ namespace relax { } \ TVM_REGISTER_OP("relax." #OpName) \ .set_num_inputs(1) \ - .add_argument("x", "Tensor", "The input data tensor") \ .set_attr("FInferStructInfo", InferStructInfoStatistical) \ .set_attr("FRelaxInferLayout", InferLayoutStatistical) \ .set_attr("FPurity", true) diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc index 523c694ff5e8..a9ae3e867f2b 100644 --- a/src/relax/op/tensor/ternary.cc +++ b/src/relax/op/tensor/ternary.cc @@ -132,9 +132,6 @@ InferLayoutOutput InferLayoutEwiseFMA( TVM_REGISTER_OP("relax.ewise_fma") .set_num_inputs(3) - .add_argument("x1", "Tensor", "The left hand operand of the multiplication") - .add_argument("x2", "Tensor", "The right hand operand of the multiplication") - .add_argument("x3", "Tensor", "The operand of the addition") .set_attr("FInferStructInfo", InferStructInfoEwiseFMA) .set_attr("FRelaxInferLayout", InferLayoutEwiseFMA) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc index 16a0bc305f17..bbaceedae61c 100644 --- a/src/relax/op/tensor/unary.cc +++ b/src/relax/op/tensor/unary.cc @@ -70,9 +70,6 @@ RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(erf, /*require_float_dtype=*/true); // relax.clip TVM_REGISTER_OP("relax.clip") .set_num_inputs(3) - .add_argument("x", "Tensor", "The input tensor.") - .add_argument("min", "PrimValue", "The lower-bound of the range to be clipped to") - .add_argument("max", "PrimValue", "The upper-bound of the range to be clipped to") .set_attr("FInferStructInfo", ReturnStructInfoFromArg<0>) .set_attr("FPurity", true); diff --git a/src/relax/op/vision/multibox_transform_loc.cc b/src/relax/op/vision/multibox_transform_loc.cc index 070c81bbe97d..0a4cb437db8c 100644 --- a/src/relax/op/vision/multibox_transform_loc.cc +++ b/src/relax/op/vision/multibox_transform_loc.cc @@ -194,10 +194,6 @@ TVM_REGISTER_OP("relax.vision.multibox_transform_loc") "inference. Very large variances (w,h) can overflow exp in half box sizes.") .set_attrs_type() .set_num_inputs(3) - .add_argument("cls_pred", "Tensor", "[B,C,N] class logits (pre-softmax).") - .add_argument("loc_pred", "Tensor", - "[B,4*N] box encodings (x,y,w,h); TFLite yxhw order remapped to xywh.") - .add_argument("anchor", "Tensor", "[1,N,4] priors as ltrb (left,top,right,bottom).") .set_attr("FInferStructInfo", InferStructInfoMultiboxTransformLoc) .set_attr("FPurity", true); diff --git a/src/relax/op/vision/nms.cc b/src/relax/op/vision/nms.cc index dbfe0d63aff5..675c7a05721d 100644 --- a/src/relax/op/vision/nms.cc +++ b/src/relax/op/vision/nms.cc @@ -104,14 +104,6 @@ StructInfo InferStructInfoAllClassNMS(const Call& call, const BlockBuilder& ctx) TVM_REGISTER_OP("relax.vision.all_class_non_max_suppression") .set_attrs_type() .set_num_inputs(5) - .add_argument("boxes", "Tensor", "The input boxes in the format [batch, num_boxes, 4].") - .add_argument("scores", "Tensor", - "Scores for each box and class in the format [batch, num_classes, num_boxes].") - .add_argument("max_output_boxes_per_class", "Tensor", - "The maximum number of output boxes per class.") - .add_argument("iou_threshold", "Tensor", "The IoU threshold for box the overlap test.") - .add_argument("score_threshold", "Tensor", - "The score threshold to filter out low score boxes early.") .set_attr("FInferStructInfo", InferStructInfoAllClassNMS) .set_attr("FPurity", true); @@ -186,8 +178,6 @@ StructInfo InferStructInfoGetValidCounts(const Call& call, const BlockBuilder& c TVM_REGISTER_OP("relax.vision.get_valid_counts") .set_attrs_type() .set_num_inputs(1) - .add_argument("data", "Tensor", - "Input data, 3-D tensor [batch_size, num_anchors, elem_length].") .set_attr("FInferStructInfo", InferStructInfoGetValidCounts) .set_attr("FPurity", true); @@ -366,10 +356,6 @@ StructInfo InferStructInfoNMS(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.vision.non_max_suppression") .set_attrs_type() .set_num_inputs(3) - .add_argument("data", "Tensor", - "Input data, 3-D tensor [batch_size, num_anchors, elem_length].") - .add_argument("valid_count", "Tensor", "1-D tensor for valid number of boxes.") - .add_argument("indices", "Tensor", "2-D tensor with shape [batch_size, num_anchors].") .set_attr("FInferStructInfo", InferStructInfoNMS) .set_attr("FPurity", true); diff --git a/src/relax/op/vision/roi_align.cc b/src/relax/op/vision/roi_align.cc index e1be949fce52..4ad5e999acee 100644 --- a/src/relax/op/vision/roi_align.cc +++ b/src/relax/op/vision/roi_align.cc @@ -130,9 +130,6 @@ StructInfo InferStructInfoROIAlign(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.vision.roi_align") .set_attrs_type() .set_num_inputs(2) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("rois", "Tensor", - "The input rois with shape (num_roi, 5) in [batch_idx, x1, y1, x2, y2] format.") .set_attr("FInferStructInfo", InferStructInfoROIAlign) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); diff --git a/src/relax/op/vision/roi_pool.cc b/src/relax/op/vision/roi_pool.cc index ffba294c5a77..21ef2b09469b 100644 --- a/src/relax/op/vision/roi_pool.cc +++ b/src/relax/op/vision/roi_pool.cc @@ -117,9 +117,6 @@ StructInfo InferStructInfoROIPool(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.vision.roi_pool") .set_attrs_type() .set_num_inputs(2) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("rois", "Tensor", - "The input rois with shape (num_roi, 5) in [batch_idx, x1, y1, x2, y2] format.") .set_attr("FInferStructInfo", InferStructInfoROIPool) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); diff --git a/src/target/cuda/intrin_rule_cuda.cc b/src/target/cuda/intrin_rule_cuda.cc index c56da3046bc0..2835c0cfc802 100644 --- a/src/target/cuda/intrin_rule_cuda.cc +++ b/src/target/cuda/intrin_rule_cuda.cc @@ -262,40 +262,24 @@ TVM_REGISTER_OP("tirx.fmod") // TODO(tvm-team): consider make CUDA its own subfolder and create a file for low-level builtins. TVM_REGISTER_OP("tirx.cuda.__shfl_sync") .set_num_inputs(4) - .add_argument("mask", "Expr", "The thread mask.") - .add_argument("var", "Expr", "The variable to sync.") - .add_argument("lane", "Expr", "The source thread id.") - .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") .set_attr("TGlobalSymbol", "__shfl_sync") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); TVM_REGISTER_OP("tirx.cuda.__shfl_up_sync") .set_num_inputs(4) - .add_argument("mask", "Expr", "The thread mask.") - .add_argument("var", "Expr", "The variable to sync.") - .add_argument("delta", "Expr", "The source lane id offset to be added.") - .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") .set_attr("TGlobalSymbol", "__shfl_up_sync") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); TVM_REGISTER_OP("tirx.cuda.__shfl_down_sync") .set_num_inputs(4) - .add_argument("mask", "Expr", "The thread mask.") - .add_argument("var", "Expr", "The variable to sync.") - .add_argument("delta", "Expr", "The source lane id offset to be subtracted.") - .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") .set_attr("TGlobalSymbol", "__shfl_down_sync") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); TVM_REGISTER_OP("tirx.cuda.__shfl_xor_sync") .set_num_inputs(4) - .add_argument("mask", "Expr", "The thread mask.") - .add_argument("var", "Expr", "The variable to sync.") - .add_argument("lane_mask", "Expr", "The lane mask.") - .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") .set_attr("TGlobalSymbol", "__shfl_xor_sync") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); diff --git a/src/target/metal/intrin_rule_metal.cc b/src/target/metal/intrin_rule_metal.cc index 217ad164b8e1..e1e0552f212c 100644 --- a/src/target/metal/intrin_rule_metal.cc +++ b/src/target/metal/intrin_rule_metal.cc @@ -141,22 +141,16 @@ TVM_REGISTER_OP("tirx.tvm_warp_shuffle_down") // Register low-level builtin ops. TVM_REGISTER_OP("tirx.metal.simd_shuffle") .set_num_inputs(2) - .add_argument("var", "Expr", "The variable to sync.") - .add_argument("lane", "Expr", "The source thread id.") .set_attr("TGlobalSymbol", "simd_shuffle") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); TVM_REGISTER_OP("tirx.metal.simd_shuffle_up") .set_num_inputs(2) - .add_argument("var", "Expr", "The variable to sync.") - .add_argument("delta", "Expr", "The source lane id offset to be added.") .set_attr("TGlobalSymbol", "simd_shuffle_up") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); TVM_REGISTER_OP("tirx.metal.simd_shuffle_down") .set_num_inputs(2) - .add_argument("var", "Expr", "The variable to sync.") - .add_argument("delta", "Expr", "The source lane id offset to be subtracted.") .set_attr("TGlobalSymbol", "simd_shuffle_down") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); diff --git a/src/target/webgpu/intrin_rule_webgpu.cc b/src/target/webgpu/intrin_rule_webgpu.cc index 889b85e56aad..4f27a81e1c15 100644 --- a/src/target/webgpu/intrin_rule_webgpu.cc +++ b/src/target/webgpu/intrin_rule_webgpu.cc @@ -161,22 +161,16 @@ TVM_REGISTER_OP("tirx.tvm_warp_shuffle_down") // Register low-level builtin ops. TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle") .set_num_inputs(2) - .add_argument("var", "Expr", "The variable to sync.") - .add_argument("lane", "Expr", "The source thread id.") .set_attr("TGlobalSymbol", "subgroupShuffle") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle_up") .set_num_inputs(2) - .add_argument("var", "Expr", "The variable to sync.") - .add_argument("delta", "Expr", "The source lane id offset to be added.") .set_attr("TGlobalSymbol", "subgroupShuffleUp") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle_down") .set_num_inputs(2) - .add_argument("var", "Expr", "The variable to sync.") - .add_argument("delta", "Expr", "The source lane id offset to be subtracted.") .set_attr("TGlobalSymbol", "subgroupShuffleDown") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); From e04f910c472978f245d6896851481a3d846384ce Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 26 May 2026 15:20:30 +0000 Subject: [PATCH 2/9] [REFACTOR][IR] Drop virtual dtor on BaseAttrsNode and inline DictAttrs ctor BaseAttrsNode no longer overrides any virtual method after #19607; ffi::Object destroys objects through a type-erased deleter captured at make_object time (see tvm-ffi/include/tvm/ffi/memory.h Deleter_::tptr->T::~T()) so the explicit virtual dtor adds nothing. The DictAttrs(Map) constructor is three lines and warrants header placement now that the file is otherwise unchanged. --- include/tvm/ir/attrs.h | 15 +++++++-------- src/ir/attrs.cc | 6 ------ 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index e0651f742ca1..86b0f947776b 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -45,15 +45,10 @@ namespace tvm { /*! * \brief Base class of all attribute class - * \note Do not subclass AttrBaseNode directly, - * subclass AttrsNode instead. - * \sa AttrsNode + * \sa Attrs */ class BaseAttrsNode : public ffi::Object { public: - /*! \brief virtual destructor */ - virtual ~BaseAttrsNode() {} - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; TVM_FFI_DECLARE_OBJECT_INFO("ir.Attrs", BaseAttrsNode, ffi::Object); }; @@ -98,10 +93,14 @@ class DictAttrs : public Attrs { */ explicit DictAttrs(ffi::UnsafeInit tag) : Attrs(tag) {} /*! - * \brief Consruct a Attrs backed by DictAttrsNode. + * \brief Construct a Attrs backed by DictAttrsNode. * \param dict The attributes. */ - TVM_DLL explicit DictAttrs(ffi::Map dict = {}); + explicit DictAttrs(ffi::Map dict = {}) { + ffi::ObjectPtr n = ffi::make_object(); + n->dict = std::move(dict); + data_ = std::move(n); + } // Utils for accessing attributes // This needs to be on DictAttrs, not DictAttrsNode because we return the default diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index 11d8d4f77595..e2ea70a9f663 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -50,12 +50,6 @@ DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key) { return attrs; } -DictAttrs::DictAttrs(ffi::Map dict) { - ffi::ObjectPtr n = ffi::make_object(); - n->dict = std::move(dict); - data_ = std::move(n); -} - TVM_FFI_STATIC_INIT_BLOCK() { tvm::ffi::reflection::ObjectDef(); } TVM_FFI_STATIC_INIT_BLOCK() { From 62fa8fe0a99c1e23bdd27cef14518bb3b1fe435b Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 26 May 2026 15:25:31 +0000 Subject: [PATCH 3/9] [REFACTOR][IR] Rename BaseAttrsNode to AttrsNode AttrsNodeReflAdapter was the historical 'in between' layer that gave BaseAttrsNode its Base prefix. With the adapter removed in canonical name now used in every comment. The FFI registry key "ir.Attrs" is unchanged, so Python sees no difference. --- include/tvm/ir/attrs.h | 16 ++-- include/tvm/relax/attrs/ccl.h | 12 +-- include/tvm/relax/attrs/create.h | 8 +- include/tvm/relax/attrs/datatype.h | 8 +- include/tvm/relax/attrs/distributed.h | 5 +- include/tvm/relax/attrs/image.h | 12 +-- include/tvm/relax/attrs/index.h | 9 +- include/tvm/relax/attrs/linear_algebra.h | 8 +- include/tvm/relax/attrs/manipulate.h | 74 ++++++++-------- include/tvm/relax/attrs/nn.h | 106 +++++++++++------------ include/tvm/relax/attrs/op.h | 21 +++-- include/tvm/relax/attrs/qdq.h | 4 +- include/tvm/relax/attrs/sampling.h | 4 +- include/tvm/relax/attrs/search.h | 9 +- include/tvm/relax/attrs/sorting.h | 12 +-- include/tvm/relax/attrs/statistical.h | 9 +- include/tvm/relax/attrs/vision.h | 24 ++--- include/tvm/target/virtual_device.h | 4 +- src/ir/attrs.cc | 2 +- 19 files changed, 169 insertions(+), 178 deletions(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 86b0f947776b..bb62910f2f16 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -23,7 +23,7 @@ * This module enables declaration of named attributes * which support default value setup and bound checking. * - * \sa BaseAttrsNode, AttrsWithDefaultValues + * \sa AttrsNode, AttrsWithDefaultValues */ #ifndef TVM_IR_ATTRS_H_ #define TVM_IR_ATTRS_H_ @@ -47,19 +47,19 @@ namespace tvm { * \brief Base class of all attribute class * \sa Attrs */ -class BaseAttrsNode : public ffi::Object { +class AttrsNode : public ffi::Object { public: static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO("ir.Attrs", BaseAttrsNode, ffi::Object); + TVM_FFI_DECLARE_OBJECT_INFO("ir.Attrs", AttrsNode, ffi::Object); }; /*! - * \brief Managed reference to BaseAttrsNode. - * \sa AttrsNode, BaseAttrsNode + * \brief Managed reference to AttrsNode. + * \sa AttrsNode */ class Attrs : public ffi::ObjectRef { public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Attrs, ffi::ObjectRef, BaseAttrsNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Attrs, ffi::ObjectRef, AttrsNode); }; /*! @@ -68,7 +68,7 @@ class Attrs : public ffi::ObjectRef { * its fields are directly accessible via object.field_name * like other normal nodes. */ -class DictAttrsNode : public BaseAttrsNode { +class DictAttrsNode : public AttrsNode { public: /*! \brief internal attrs map */ ffi::Map dict; @@ -79,7 +79,7 @@ class DictAttrsNode : public BaseAttrsNode { } // type info - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.DictAttrs", DictAttrsNode, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.DictAttrs", DictAttrsNode, AttrsNode); }; /*! diff --git a/include/tvm/relax/attrs/ccl.h b/include/tvm/relax/attrs/ccl.h index 7e0624706b0c..031a1de49311 100644 --- a/include/tvm/relax/attrs/ccl.h +++ b/include/tvm/relax/attrs/ccl.h @@ -31,7 +31,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in allreduce operators */ -struct AllReduceAttrs : public tvm::BaseAttrsNode { +struct AllReduceAttrs : public tvm::AttrsNode { ffi::String op_type; bool in_group; @@ -45,11 +45,11 @@ struct AllReduceAttrs : public tvm::BaseAttrsNode { "Whether the reduction operation performs in group or globally or in group as " "default."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AllReduceAttrs", AllReduceAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AllReduceAttrs", AllReduceAttrs, AttrsNode); }; // struct AllReduceAttrs /*! \brief Attributes used in allgather operators */ -struct AllGatherAttrs : public tvm::BaseAttrsNode { +struct AllGatherAttrs : public tvm::AttrsNode { int num_workers; bool in_group; @@ -63,11 +63,11 @@ struct AllGatherAttrs : public tvm::BaseAttrsNode { "Whether the allgather operation performs in group or globally or in group as " "default."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AllGatherAttrs", AllGatherAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AllGatherAttrs", AllGatherAttrs, AttrsNode); }; // struct AllGatherAttrs /*! \brief Attributes used in scatter operators */ -struct ScatterCollectiveAttrs : public tvm::BaseAttrsNode { +struct ScatterCollectiveAttrs : public tvm::AttrsNode { int num_workers; int axis; @@ -82,7 +82,7 @@ struct ScatterCollectiveAttrs : public tvm::BaseAttrsNode { "this axis."); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ScatterCollectiveAttrs", ScatterCollectiveAttrs, - BaseAttrsNode); + AttrsNode); }; // struct ScatterCollectiveAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/create.h b/include/tvm/relax/attrs/create.h index 9a9e453263a0..14a3402f2503 100644 --- a/include/tvm/relax/attrs/create.h +++ b/include/tvm/relax/attrs/create.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in full/full_like, ones/ones_like, and zeros/zeros_like operators */ -struct InitAttrs : public BaseAttrsNode { +struct InitAttrs : public AttrsNode { DataType dtype; static void RegisterReflection() { @@ -38,11 +38,11 @@ struct InitAttrs : public BaseAttrsNode { refl::ObjectDef().def_ro("dtype", &InitAttrs::dtype, "The data type of the created tensor."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.InitAttrs", InitAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.InitAttrs", InitAttrs, AttrsNode); }; // struct InitAttrs /*! \brief Attributes used in tril and triu operator */ -struct TriluAttrs : public BaseAttrsNode { +struct TriluAttrs : public AttrsNode { int k; static void RegisterReflection() { @@ -51,7 +51,7 @@ struct TriluAttrs : public BaseAttrsNode { "k", &TriluAttrs::k, "The number of diagonals above or below the main diagonal to exclude or include."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TriluAttrs", TriluAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TriluAttrs", TriluAttrs, AttrsNode); }; // struct TriluAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/datatype.h b/include/tvm/relax/attrs/datatype.h index a1870597033e..f67223edb546 100644 --- a/include/tvm/relax/attrs/datatype.h +++ b/include/tvm/relax/attrs/datatype.h @@ -30,25 +30,25 @@ namespace tvm { namespace relax { /*! \brief Attributes used in astype operator */ -struct AstypeAttrs : public BaseAttrsNode { +struct AstypeAttrs : public AttrsNode { DataType dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("dtype", &AstypeAttrs::dtype, "Target data type"); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AstypeAttrs", AstypeAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AstypeAttrs", AstypeAttrs, AttrsNode); }; // struct AstypeAttrs. /*! \brief Attributes used in wrap_param operator */ -struct WrapParamAttrs : public BaseAttrsNode { +struct WrapParamAttrs : public AttrsNode { DataType dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("dtype", &WrapParamAttrs::dtype, "Target data type"); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.WrapParamAttrs", WrapParamAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.WrapParamAttrs", WrapParamAttrs, AttrsNode); }; // struct WrapParamAttrs. } // namespace relax diff --git a/include/tvm/relax/attrs/distributed.h b/include/tvm/relax/attrs/distributed.h index cce508ef1d50..23b698eb3604 100644 --- a/include/tvm/relax/attrs/distributed.h +++ b/include/tvm/relax/attrs/distributed.h @@ -32,7 +32,7 @@ namespace tvm { namespace relax { /*! \brief Attributes for redistribute and annotate_sharding operator */ -struct DistributionAttrs : public BaseAttrsNode { +struct DistributionAttrs : public AttrsNode { distributed::DeviceMesh device_mesh; distributed::Placement placement; @@ -44,8 +44,7 @@ struct DistributionAttrs : public BaseAttrsNode { .def_ro("placement", &DistributionAttrs::placement, "The placement of a tensor's distribution plan"); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.DistributionAttrs", DistributionAttrs, - BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.DistributionAttrs", DistributionAttrs, AttrsNode); }; // struct DistributionAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/image.h b/include/tvm/relax/attrs/image.h index 8cc5e36734b6..eacbea7180bb 100644 --- a/include/tvm/relax/attrs/image.h +++ b/include/tvm/relax/attrs/image.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in image resize2d operator */ -struct Resize2DAttrs : public BaseAttrsNode { +struct Resize2DAttrs : public AttrsNode { ffi::Array roi; ffi::String layout; ffi::String method; @@ -75,11 +75,11 @@ struct Resize2DAttrs : public BaseAttrsNode { "The dtype of the output tensor. It it is not specified, the output will have the same " "dtype as input if not specified."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Resize2DAttrs", Resize2DAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Resize2DAttrs", Resize2DAttrs, AttrsNode); }; // struct Resize2dAttrs /*! \brief Attributes used in image resize3d operator */ -struct Resize3DAttrs : public BaseAttrsNode { +struct Resize3DAttrs : public AttrsNode { ffi::Array roi; ffi::String layout; ffi::String method; @@ -124,11 +124,11 @@ struct Resize3DAttrs : public BaseAttrsNode { "The dtype of the output tensor. It it is not specified, the output will have the same " "dtype as input if not specified."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Resize3DAttrs", Resize3DAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Resize3DAttrs", Resize3DAttrs, AttrsNode); }; // struct Resize3DAttrs /*! \brief Attributes used in image grid_sample operator */ -struct GridSampleAttrs : public BaseAttrsNode { +struct GridSampleAttrs : public AttrsNode { ffi::String method; ffi::String layout; ffi::String padding_mode; @@ -146,7 +146,7 @@ struct GridSampleAttrs : public BaseAttrsNode { .def_ro("align_corners", &GridSampleAttrs::align_corners, "If True, the corner pixels of the input and output tensors are aligned."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GridSampleAttrs", GridSampleAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GridSampleAttrs", GridSampleAttrs, AttrsNode); }; // struct GridSampleAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/index.h b/include/tvm/relax/attrs/index.h index 7b4c446bb80c..6133a6f580e4 100644 --- a/include/tvm/relax/attrs/index.h +++ b/include/tvm/relax/attrs/index.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in take operator */ -struct TakeAttrs : public BaseAttrsNode { +struct TakeAttrs : public AttrsNode { ffi::Optional axis; ffi::String mode; @@ -41,11 +41,11 @@ struct TakeAttrs : public BaseAttrsNode { .def_ro("mode", &TakeAttrs::mode, "The mode for handling out-of-bounds indices.", refl::DefaultValue("fast")); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TakeAttrs", TakeAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TakeAttrs", TakeAttrs, AttrsNode); }; // struct TakeAttrs /*! \brief Attributes used in strided_slice operator */ -struct StridedSliceAttrs : public BaseAttrsNode { +struct StridedSliceAttrs : public AttrsNode { bool assume_inbound; static void RegisterReflection() { @@ -56,8 +56,7 @@ struct StridedSliceAttrs : public BaseAttrsNode { "out of bound indices will be clipped to the bound.", refl::DefaultValue(true)); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.StridedSliceAttrs", StridedSliceAttrs, - BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.StridedSliceAttrs", StridedSliceAttrs, AttrsNode); }; // struct StridedSliceAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/linear_algebra.h b/include/tvm/relax/attrs/linear_algebra.h index 2627dafcf6b3..817885edb871 100644 --- a/include/tvm/relax/attrs/linear_algebra.h +++ b/include/tvm/relax/attrs/linear_algebra.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes for matmul operator */ -struct MatmulAttrs : public BaseAttrsNode { +struct MatmulAttrs : public AttrsNode { DataType out_dtype; static void RegisterReflection() { @@ -38,11 +38,11 @@ struct MatmulAttrs : public BaseAttrsNode { refl::ObjectDef().def_ro("out_dtype", &MatmulAttrs::out_dtype, "The data type of the output tensor"); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.MatmulAttrs", MatmulAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.MatmulAttrs", MatmulAttrs, AttrsNode); }; // struct MatmulAttrs /*! \brief Attributes used in einsum operator */ -struct EinsumAttrs : public BaseAttrsNode { +struct EinsumAttrs : public AttrsNode { ffi::String subscripts; static void RegisterReflection() { @@ -50,7 +50,7 @@ struct EinsumAttrs : public BaseAttrsNode { refl::ObjectDef().def_ro("subscripts", &EinsumAttrs::subscripts, "The einsum expression string"); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.EinsumAttrs", EinsumAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.EinsumAttrs", EinsumAttrs, AttrsNode); }; // struct EinsumAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index cc651207fa3d..7897b860e1f7 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -31,7 +31,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in concat operators */ -struct ConcatAttrs : public BaseAttrsNode { +struct ConcatAttrs : public AttrsNode { ffi::Optional axis; static void RegisterReflection() { @@ -40,11 +40,11 @@ struct ConcatAttrs : public BaseAttrsNode { "The axis at which the input arrays are concatenated." "Should lie in range `[-ndim, ndim)`."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ConcatAttrs", ConcatAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ConcatAttrs", ConcatAttrs, AttrsNode); }; // struct ConcatAttrs /*! \brief Attributes used in expand_dims operators */ -struct ExpandDimsAttrs : public BaseAttrsNode { +struct ExpandDimsAttrs : public AttrsNode { ffi::Array axis; static void RegisterReflection() { @@ -55,11 +55,11 @@ struct ExpandDimsAttrs : public BaseAttrsNode { "All values are required to lie in range `[-data.ndim - 1, data.ndim]`, " "with the convention of negative indexing."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ExpandDimsAttrs", ExpandDimsAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ExpandDimsAttrs", ExpandDimsAttrs, AttrsNode); }; // struct ExpandDimsAttrs /*! \brief Attributes used in layout_transform operator */ -struct LayoutTransformAttrs : public BaseAttrsNode { +struct LayoutTransformAttrs : public AttrsNode { tirx::IndexMap index_map; // pad_value is chosen to be of PrimValue type, as it represents constant TIR POD expression. This // needs to be revisited in case PrimValue is evolved to represent symbolic expression in future. @@ -93,11 +93,11 @@ struct LayoutTransformAttrs : public BaseAttrsNode { "The separators between axes to regenerate output"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.LayoutTransformAttrs", LayoutTransformAttrs, - BaseAttrsNode); + AttrsNode); }; // struct LayoutTransformAttrs /*! \brief Attributes used in permute_dims operator */ -struct PermuteDimsAttrs : public BaseAttrsNode { +struct PermuteDimsAttrs : public AttrsNode { ffi::Optional> axes; static void RegisterReflection() { @@ -105,12 +105,11 @@ struct PermuteDimsAttrs : public BaseAttrsNode { refl::ObjectDef().def_ro( "axes", &PermuteDimsAttrs::axes, "The target axes order, reverse order if not specified."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PermuteDimsAttrs", PermuteDimsAttrs, - BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PermuteDimsAttrs", PermuteDimsAttrs, AttrsNode); }; // struct PermuteDimsAttrs /*! \brief Attributes used in split operator */ -struct SplitAttrs : public BaseAttrsNode { +struct SplitAttrs : public AttrsNode { ffi::ObjectRef indices_or_sections; int axis; @@ -121,11 +120,11 @@ struct SplitAttrs : public BaseAttrsNode { "The input array of indices or the number of split sections.") .def_ro("axis", &SplitAttrs::axis, "The axis to be splitted"); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SplitAttrs", SplitAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SplitAttrs", SplitAttrs, AttrsNode); }; // struct SplitAttrs /*! \brief Attributes used in squeeze operators */ -struct SqueezeAttrs : public BaseAttrsNode { +struct SqueezeAttrs : public AttrsNode { ffi::Optional> axis; static void RegisterReflection() { @@ -136,11 +135,11 @@ struct SqueezeAttrs : public BaseAttrsNode { "Else, the dimension in axes get squeezed." "It is an error if an axis does not has dimension 1."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SqueezeAttrs", SqueezeAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SqueezeAttrs", SqueezeAttrs, AttrsNode); }; // struct SqueezeAttrs /*! \brief Attributes used in stack operators */ -struct StackAttrs : public BaseAttrsNode { +struct StackAttrs : public AttrsNode { ffi::Optional axis; static void RegisterReflection() { @@ -152,11 +151,11 @@ struct StackAttrs : public BaseAttrsNode { "so it must be in range [-ndim-1, ndim] where ndim is the " "number of dimensions of the input tensors."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.StackAttrs", StackAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.StackAttrs", StackAttrs, AttrsNode); }; // struct StackAttrs /*! \brief Attributes used in repeat operators */ -struct RepeatAttrs : public BaseAttrsNode { +struct RepeatAttrs : public AttrsNode { int repeats; ffi::Optional axis; @@ -169,11 +168,11 @@ struct RepeatAttrs : public BaseAttrsNode { "counting from the backward. By default, use the flattened input array, and " "return a flat output array."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.RepeatAttrs", RepeatAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.RepeatAttrs", RepeatAttrs, AttrsNode); }; // struct RepeatAttrs /*! \brief Attributes used in tile operators */ -struct TileAttrs : public BaseAttrsNode { +struct TileAttrs : public AttrsNode { ffi::Array repeats; static void RegisterReflection() { @@ -181,11 +180,11 @@ struct TileAttrs : public BaseAttrsNode { refl::ObjectDef().def_ro("repeats", &TileAttrs::repeats, "The number of repetitions of data along each axis."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TileAttrs", TileAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TileAttrs", TileAttrs, AttrsNode); }; // struct TileAttrs /*! \brief Attributes used in flip operators */ -struct FlipAttrs : public BaseAttrsNode { +struct FlipAttrs : public AttrsNode { int64_t axis; static void RegisterReflection() { @@ -193,11 +192,11 @@ struct FlipAttrs : public BaseAttrsNode { refl::ObjectDef().def_ro("axis", &FlipAttrs::axis, "The axis along which to flip over."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.FlipAttrs", FlipAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.FlipAttrs", FlipAttrs, AttrsNode); }; // struct FlipAttrs /*! \brief Attributes used in gather_elements operators */ -struct GatherElementsAttrs : public BaseAttrsNode { +struct GatherElementsAttrs : public AttrsNode { int64_t axis; static void RegisterReflection() { @@ -207,11 +206,11 @@ struct GatherElementsAttrs : public BaseAttrsNode { refl::DefaultValue(0)); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GatherElementsAttrs", GatherElementsAttrs, - BaseAttrsNode); + AttrsNode); }; // struct GatherElementsAttrs /*! \brief Attributes used in gather_nd operators */ -struct GatherNDAttrs : public BaseAttrsNode { +struct GatherNDAttrs : public AttrsNode { int64_t batch_dims; static void RegisterReflection() { @@ -219,11 +218,11 @@ struct GatherNDAttrs : public BaseAttrsNode { refl::ObjectDef().def_ro("batch_dims", &GatherNDAttrs::batch_dims, "The number of batch dims.", refl::DefaultValue(0)); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GatherNDAttrs", GatherNDAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GatherNDAttrs", GatherNDAttrs, AttrsNode); }; // struct GatherNDAttrs /*! \brief Attributes used in index_put operator */ -struct IndexPutAttrs : public BaseAttrsNode { +struct IndexPutAttrs : public AttrsNode { bool accumulate; static void RegisterReflection() { @@ -235,11 +234,11 @@ struct IndexPutAttrs : public BaseAttrsNode { "otherwise performs tensor[indices] = values.", refl::DefaultValue(false)); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.IndexPutAttrs", IndexPutAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.IndexPutAttrs", IndexPutAttrs, AttrsNode); }; // struct IndexPutAttrs /*! \brief Attribute used in meshgrid operator */ -struct MeshgridAttrs : public BaseAttrsNode { +struct MeshgridAttrs : public AttrsNode { ffi::Optional indexing; static void RegisterReflection() { @@ -247,11 +246,11 @@ struct MeshgridAttrs : public BaseAttrsNode { refl::ObjectDef().def_ro("indexing", &MeshgridAttrs::indexing, "Specifies how the grid dimensions are ordered."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.MeshgridAttrs", MeshgridAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.MeshgridAttrs", MeshgridAttrs, AttrsNode); }; /*! \brief Attributes used in scatter_elements operators */ -struct ScatterElementsAttrs : public BaseAttrsNode { +struct ScatterElementsAttrs : public AttrsNode { int64_t axis; ffi::String reduction; @@ -266,11 +265,11 @@ struct ScatterElementsAttrs : public BaseAttrsNode { refl::DefaultValue("update")); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ScatterElementsAttrs", ScatterElementsAttrs, - BaseAttrsNode); + AttrsNode); }; // struct ScatterElementsAttrs /*! \brief Attributes used in scatter_nd operators */ -struct ScatterNDAttrs : public BaseAttrsNode { +struct ScatterNDAttrs : public AttrsNode { ffi::String reduction; static void RegisterReflection() { @@ -281,11 +280,11 @@ struct ScatterNDAttrs : public BaseAttrsNode { "either \"update\", \"add\", \"mul\", \"min\" or \"max\".", refl::DefaultValue("update")); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ScatterNDAttrs", ScatterNDAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ScatterNDAttrs", ScatterNDAttrs, AttrsNode); }; // struct ScatterNDAttrs /*! \brief Attributes used in slice_scatter operator */ -struct SliceScatterAttrs : public BaseAttrsNode { +struct SliceScatterAttrs : public AttrsNode { int axis; static void RegisterReflection() { @@ -294,12 +293,11 @@ struct SliceScatterAttrs : public BaseAttrsNode { "the dimension to insert the slice into ", refl::DefaultValue(0)); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SliceScatterAttrs", SliceScatterAttrs, - BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SliceScatterAttrs", SliceScatterAttrs, AttrsNode); }; // struct SliceScatterAttrs /*! \brief Attributes used in one_hot operator */ -struct OneHotAttrs : public BaseAttrsNode { +struct OneHotAttrs : public AttrsNode { int depth; int axis; @@ -309,7 +307,7 @@ struct OneHotAttrs : public BaseAttrsNode { .def_ro("depth", &OneHotAttrs::depth, "Depth of the one hot dimension.") .def_ro("axis", &OneHotAttrs::axis, "Axis to fill.", refl::DefaultValue(-1)); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.OneHotAttrs", OneHotAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.OneHotAttrs", OneHotAttrs, AttrsNode); }; // struct OneHotAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index b483d3e2339d..52d9c40d742d 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in Conv1d operator */ -struct Conv1DAttrs : public BaseAttrsNode { +struct Conv1DAttrs : public AttrsNode { ffi::Array strides; ffi::Array padding; ffi::Array dilation; @@ -70,11 +70,11 @@ struct Conv1DAttrs : public BaseAttrsNode { .def_ro("out_dtype", &Conv1DAttrs::out_dtype, "Output data type, set to explicit type under mixed precision setting"); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv1DAttrs", Conv1DAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv1DAttrs", Conv1DAttrs, AttrsNode); }; // struct Conv1dAttrs /*! \brief Attributes used in Conv2d operator */ -struct Conv2DAttrs : public BaseAttrsNode { +struct Conv2DAttrs : public AttrsNode { ffi::Array strides; ffi::Array padding; ffi::Array dilation; @@ -116,11 +116,11 @@ struct Conv2DAttrs : public BaseAttrsNode { .def_ro("out_dtype", &Conv2DAttrs::out_dtype, "Output data type, set to explicit type under mixed precision setting"); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv2DAttrs", Conv2DAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv2DAttrs", Conv2DAttrs, AttrsNode); }; // struct Conv2dAttrs /*! \brief Attributes used in Conv3d operator */ -struct Conv3DAttrs : public BaseAttrsNode { +struct Conv3DAttrs : public AttrsNode { ffi::Array strides; ffi::Array padding; ffi::Array dilation; @@ -164,11 +164,11 @@ struct Conv3DAttrs : public BaseAttrsNode { .def_ro("out_dtype", &Conv3DAttrs::out_dtype, "Output data type, set to explicit type under mixed precision setting"); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv3DAttrs", Conv3DAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv3DAttrs", Conv3DAttrs, AttrsNode); }; // struct Conv3dAttrs /*! \brief Attributes used in Conv1DTranspose operator */ -struct Conv1DTransposeAttrs : public BaseAttrsNode { +struct Conv1DTransposeAttrs : public AttrsNode { ffi::Array strides; ffi::Array padding; ffi::Array output_padding; @@ -213,11 +213,11 @@ struct Conv1DTransposeAttrs : public BaseAttrsNode { "Output data type, set to explicit type under mixed precision setting"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv1DTransposeAttrs", Conv1DTransposeAttrs, - BaseAttrsNode); + AttrsNode); }; // struct Conv1DTransposeAttrs /*! \brief Attributes used in Conv2d operator */ -struct Conv2DTransposeAttrs : public BaseAttrsNode { +struct Conv2DTransposeAttrs : public AttrsNode { ffi::Array strides; ffi::Array padding; ffi::Array output_padding; @@ -264,11 +264,11 @@ struct Conv2DTransposeAttrs : public BaseAttrsNode { "Output data type, set to explicit type under mixed precision setting"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv2DTransposeAttrs", Conv2DTransposeAttrs, - BaseAttrsNode); + AttrsNode); }; // struct Conv2DTransposeAttrs /*! \brief Attributes used in Conv3dTranspose operator */ -struct Conv3DTransposeAttrs : public BaseAttrsNode { +struct Conv3DTransposeAttrs : public AttrsNode { ffi::Array strides; ffi::Array padding; ffi::Array output_padding; @@ -317,11 +317,11 @@ struct Conv3DTransposeAttrs : public BaseAttrsNode { "Output data type, set to explicit type under mixed precision setting"); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv3DTransposeAttrs", Conv3DTransposeAttrs, - BaseAttrsNode); + AttrsNode); }; // struct Conv3DTransposeAttrs /*! \brief Attributes used in max_pool1d and avg_pool1d operator */ -struct Pool1DAttrs : public BaseAttrsNode { +struct Pool1DAttrs : public AttrsNode { ffi::Array pool_size; ffi::Array strides; ffi::Array padding; @@ -358,11 +358,11 @@ struct Pool1DAttrs : public BaseAttrsNode { "'N', 'C', 'W' stands for batch, channel, and width" "dimensions respectively. Pooling is applied on the 'W' dimensions."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Pool1DAttrs", Pool1DAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Pool1DAttrs", Pool1DAttrs, AttrsNode); }; // struct Pool1dAttrs /*! \brief Attributes used in max_pool2d and avg_pool2d operator */ -struct Pool2DAttrs : public BaseAttrsNode { +struct Pool2DAttrs : public AttrsNode { ffi::Array pool_size; ffi::Array strides; ffi::Array padding; @@ -401,11 +401,11 @@ struct Pool2DAttrs : public BaseAttrsNode { "dimensions respectively. Pooling is applied on the 'H' and" "'W' dimensions."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Pool2DAttrs", Pool2DAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Pool2DAttrs", Pool2DAttrs, AttrsNode); }; // struct Pool2dAttrs /*! \brief Attributes used in max_pool3d and avg_pool3d operator */ -struct Pool3DAttrs : public BaseAttrsNode { +struct Pool3DAttrs : public AttrsNode { ffi::Array pool_size; ffi::Array strides; ffi::Array padding; @@ -444,11 +444,11 @@ struct Pool3DAttrs : public BaseAttrsNode { "dimensions respectively. Pooling is applied on the 'D', 'H' and" "'W' dimensions."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Pool3DAttrs", Pool3DAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Pool3DAttrs", Pool3DAttrs, AttrsNode); }; // struct Pool3dAttrs /*! \brief Attributes for 1d adaptive pool operator */ -struct AdaptivePool1DAttrs : public BaseAttrsNode { +struct AdaptivePool1DAttrs : public AttrsNode { ffi::Optional> output_size; ffi::String layout; ffi::String out_layout; @@ -469,11 +469,11 @@ struct AdaptivePool1DAttrs : public BaseAttrsNode { "'W' dimensions."); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AdaptivePool1DAttrs", AdaptivePool1DAttrs, - BaseAttrsNode); + AttrsNode); }; // struct AdaptivePool1DAttrs /*! \brief Attributes for 2d adaptive pool operator */ -struct AdaptivePool2DAttrs : public BaseAttrsNode { +struct AdaptivePool2DAttrs : public AttrsNode { ffi::Optional> output_size; ffi::String layout; ffi::String out_layout; @@ -494,11 +494,11 @@ struct AdaptivePool2DAttrs : public BaseAttrsNode { "'W' dimensions."); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AdaptivePool2DAttrs", AdaptivePool2DAttrs, - BaseAttrsNode); + AttrsNode); }; // struct AdaptivePool2DAttrs /*! \brief Attributes for 3d adaptive pool operator */ -struct AdaptivePool3DAttrs : public BaseAttrsNode { +struct AdaptivePool3DAttrs : public AttrsNode { ffi::Optional> output_size; ffi::String layout; ffi::String out_layout; @@ -519,11 +519,11 @@ struct AdaptivePool3DAttrs : public BaseAttrsNode { "'W' dimensions."); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AdaptivePool3DAttrs", AdaptivePool3DAttrs, - BaseAttrsNode); + AttrsNode); }; // struct AdaptivePool3DAttrs /*! \brief Attributes used in softmax operators */ -struct SoftmaxAttrs : public BaseAttrsNode { +struct SoftmaxAttrs : public AttrsNode { int axis; static void RegisterReflection() { @@ -531,11 +531,11 @@ struct SoftmaxAttrs : public BaseAttrsNode { refl::ObjectDef().def_ro("axis", &SoftmaxAttrs::axis, "The axis to sum over when computing softmax."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SoftmaxAttrs", SoftmaxAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SoftmaxAttrs", SoftmaxAttrs, AttrsNode); }; /*! \brief Attributes used in softmax operators */ -struct LeakyReluAttrs : public BaseAttrsNode { +struct LeakyReluAttrs : public AttrsNode { double alpha; static void RegisterReflection() { @@ -543,11 +543,11 @@ struct LeakyReluAttrs : public BaseAttrsNode { refl::ObjectDef().def_ro("alpha", &LeakyReluAttrs::alpha, "The slope of the negative part."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.LeakyReluAttrs", LeakyReluAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.LeakyReluAttrs", LeakyReluAttrs, AttrsNode); }; /*! \brief Attributes used in softplus operators */ -struct SoftplusAttrs : public BaseAttrsNode { +struct SoftplusAttrs : public AttrsNode { double beta; double threshold; @@ -559,11 +559,11 @@ struct SoftplusAttrs : public BaseAttrsNode { .def_ro("threshold", &SoftplusAttrs::threshold, "Value determining when to use linear approximation for numerical stability."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SoftplusAttrs", SoftplusAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SoftplusAttrs", SoftplusAttrs, AttrsNode); }; /*! \brief Attributes used in PReLU operator */ -struct PReluAttrs : public BaseAttrsNode { +struct PReluAttrs : public AttrsNode { int axis; static void RegisterReflection() { @@ -571,11 +571,11 @@ struct PReluAttrs : public BaseAttrsNode { refl::ObjectDef().def_ro("axis", &PReluAttrs::axis, "The axis along which the alpha values are applied."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PReluAttrs", PReluAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PReluAttrs", PReluAttrs, AttrsNode); }; /*! \brief Attributes used in batch_norm operator */ -struct BatchNormAttrs : public BaseAttrsNode { +struct BatchNormAttrs : public AttrsNode { int axis; double epsilon; bool center; @@ -598,11 +598,11 @@ struct BatchNormAttrs : public BaseAttrsNode { .def_ro("training", &BatchNormAttrs::training, "Whether we are training (i.e., not in eval mode)."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.BatchNormAttrs", BatchNormAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.BatchNormAttrs", BatchNormAttrs, AttrsNode); }; // struct BatchNormAttrs /*! \brief Attributes used in layer_norm operator */ -struct LayerNormAttrs : public BaseAttrsNode { +struct LayerNormAttrs : public AttrsNode { ffi::Array axes; double epsilon; bool center; @@ -620,11 +620,11 @@ struct LayerNormAttrs : public BaseAttrsNode { .def_ro("scale", &LayerNormAttrs::scale, "Indicating if the gamma scale will be multiplied."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.LayerNormAttrs", LayerNormAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.LayerNormAttrs", LayerNormAttrs, AttrsNode); }; // struct LayerNormAttrs /*! \brief Attributes used in group_norm operator */ -struct GroupNormAttrs : public BaseAttrsNode { +struct GroupNormAttrs : public AttrsNode { int num_groups; int channel_axis; ffi::Array axes; @@ -649,11 +649,11 @@ struct GroupNormAttrs : public BaseAttrsNode { .def_ro("scale", &GroupNormAttrs::scale, "Indicating if the gamma scale will be multiplied."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GroupNormAttrs", GroupNormAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GroupNormAttrs", GroupNormAttrs, AttrsNode); }; // struct GroupNormAttrs /*! \brief Attributes used in instance_norm operator */ -struct InstanceNormAttrs : public BaseAttrsNode { +struct InstanceNormAttrs : public AttrsNode { int channel_axis; ffi::Array axes; double epsilon; @@ -674,12 +674,11 @@ struct InstanceNormAttrs : public BaseAttrsNode { .def_ro("scale", &InstanceNormAttrs::scale, "Indicating if the gamma scale will be multiplied."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.InstanceNormAttrs", InstanceNormAttrs, - BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.InstanceNormAttrs", InstanceNormAttrs, AttrsNode); }; // struct InstanceNormAttrs /*! \brief Attributes used in rms_norm operator */ -struct RMSNormAttrs : public BaseAttrsNode { +struct RMSNormAttrs : public AttrsNode { ffi::Array axes; double epsilon; @@ -691,11 +690,11 @@ struct RMSNormAttrs : public BaseAttrsNode { .def_ro("epsilon", &RMSNormAttrs::epsilon, "Small float added to variance to avoid dividing by zero"); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.RMSNormAttrs", RMSNormAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.RMSNormAttrs", RMSNormAttrs, AttrsNode); }; // struct RMSNormAttrs /*! \brief Attributes used in nll_loss operator */ -struct NLLLossAttrs : public BaseAttrsNode { +struct NLLLossAttrs : public AttrsNode { ffi::String reduction; int ignore_index; @@ -708,11 +707,11 @@ struct NLLLossAttrs : public BaseAttrsNode { refl::DefaultValue("mean")) .def_ro("ignore_index", &NLLLossAttrs::ignore_index, "The target value to ignore."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.NLLLossAttrs", NLLLossAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.NLLLossAttrs", NLLLossAttrs, AttrsNode); }; // struct NLLLossAttrs /*! \brief Attributes used in dropout operator */ -struct DropoutAttrs : public BaseAttrsNode { +struct DropoutAttrs : public AttrsNode { double rate; static void RegisterReflection() { @@ -721,11 +720,11 @@ struct DropoutAttrs : public BaseAttrsNode { "rate", &DropoutAttrs::rate, "Fraction of the input that gets dropped out during training time"); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.DropoutAttrs", DropoutAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.DropoutAttrs", DropoutAttrs, AttrsNode); }; // struct DropoutAttrs /*! \brief Attributes used in Attention operator */ -struct AttentionAttrs : public BaseAttrsNode { +struct AttentionAttrs : public AttrsNode { ffi::Optional scale; ffi::Optional causal_mask; ffi::Optional window_size; @@ -741,11 +740,11 @@ struct AttentionAttrs : public BaseAttrsNode { .def_ro("window_size", &AttentionAttrs::window_size, "The size of the window for sliding-window attention."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AttentionAttrs", AttentionAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AttentionAttrs", AttentionAttrs, AttrsNode); }; // struct AttentionAttrs /*! \brief Attributes used for the padding operator */ -struct PadAttrs : public BaseAttrsNode { +struct PadAttrs : public AttrsNode { ffi::Array pad_width; double pad_value = 0.0; tvm::ffi::String pad_mode; @@ -764,11 +763,11 @@ struct PadAttrs : public BaseAttrsNode { "\"reflect\" pads by reflecting values with respect to the edges.", refl::DefaultValue("constant")); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PadAttrs", PadAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PadAttrs", PadAttrs, AttrsNode); }; /*! \brief Attributes used for the pixel shuffle operator */ -struct PixelShuffleAttrs : public BaseAttrsNode { +struct PixelShuffleAttrs : public AttrsNode { int upscale_factor; static void RegisterReflection() { @@ -777,8 +776,7 @@ struct PixelShuffleAttrs : public BaseAttrsNode { &PixelShuffleAttrs::upscale_factor, "Scale factor for spatial upsampling."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PixelShuffleAttrs", PixelShuffleAttrs, - BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PixelShuffleAttrs", PixelShuffleAttrs, AttrsNode); }; } // namespace relax diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h index 54970e0eab18..4c1451c3dc29 100644 --- a/include/tvm/relax/attrs/op.h +++ b/include/tvm/relax/attrs/op.h @@ -31,7 +31,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in call_tir_with_grad */ -struct CallTIRWithGradAttrs : public BaseAttrsNode { +struct CallTIRWithGradAttrs : public AttrsNode { ffi::String te_grad_name; ffi::Map te_grad_kwargs; @@ -45,11 +45,11 @@ struct CallTIRWithGradAttrs : public BaseAttrsNode { "The keyword arguments passed to the te gradient function."); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.CallTIRWithGradAttrs", CallTIRWithGradAttrs, - BaseAttrsNode); + AttrsNode); }; // struct CallTIRAttrs /*! \brief Attributes used in call_tir_inplace */ -struct CallTIRInplaceAttrs : public BaseAttrsNode { +struct CallTIRInplaceAttrs : public AttrsNode { /*! * \brief Indices that describe which input corresponds to which output. * @@ -65,11 +65,11 @@ struct CallTIRInplaceAttrs : public BaseAttrsNode { &CallTIRInplaceAttrs::inplace_indices); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.CallTIRInplaceAttrs", CallTIRInplaceAttrs, - BaseAttrsNode); + AttrsNode); }; // struct CallTIRInplaceAttrs /*! \brief Attributes used in call_inplace_packed */ -struct CallInplacePackedAttrs : public BaseAttrsNode { +struct CallInplacePackedAttrs : public AttrsNode { /*! * \brief Indices that describe which input corresponds to which output. * @@ -85,11 +85,11 @@ struct CallInplacePackedAttrs : public BaseAttrsNode { &CallInplacePackedAttrs::inplace_indices); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.CallInplacePackedAttrs", CallInplacePackedAttrs, - BaseAttrsNode); + AttrsNode); }; // struct CallInplacePackedAttrs /*! \brief Attributes used in to_vdevice */ -struct ToVDeviceAttrs : public BaseAttrsNode { +struct ToVDeviceAttrs : public AttrsNode { VDevice dst_vdevice; static void RegisterReflection() { @@ -97,11 +97,11 @@ struct ToVDeviceAttrs : public BaseAttrsNode { refl::ObjectDef().def_ro("dst_vdevice", &ToVDeviceAttrs::dst_vdevice, "The destination device where the data is copied to."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ToVDeviceAttrs", ToVDeviceAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ToVDeviceAttrs", ToVDeviceAttrs, AttrsNode); }; // struct ToVDeviceAttrs /*! \brief Attributes used in hint_on_device */ -struct HintOnDeviceAttrs : public BaseAttrsNode { +struct HintOnDeviceAttrs : public AttrsNode { int32_t device_type; int32_t index; MemoryScope memory_scope; @@ -114,8 +114,7 @@ struct HintOnDeviceAttrs : public BaseAttrsNode { .def_ro("index", &HintOnDeviceAttrs::index, "The device id.") .def_ro("memory_scope", &HintOnDeviceAttrs::memory_scope, "The device memory scope."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.HintOnDeviceAttrs", HintOnDeviceAttrs, - BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.HintOnDeviceAttrs", HintOnDeviceAttrs, AttrsNode); }; // struct HintOnDeviceAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/qdq.h b/include/tvm/relax/attrs/qdq.h index 08bc054dc54f..83ec2223c3c7 100644 --- a/include/tvm/relax/attrs/qdq.h +++ b/include/tvm/relax/attrs/qdq.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes for relax.quantize/relax.dequantize operator */ -struct QuantizeAttrs : public BaseAttrsNode { +struct QuantizeAttrs : public AttrsNode { DataType out_dtype; int axis; @@ -43,7 +43,7 @@ struct QuantizeAttrs : public BaseAttrsNode { "Default value is -1, which corresponds to the last axis.", refl::DefaultValue(-1)); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.QuantizeAttrs", QuantizeAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.QuantizeAttrs", QuantizeAttrs, AttrsNode); }; // QuantizeAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/sampling.h b/include/tvm/relax/attrs/sampling.h index 2d7421cc20e8..11bbfb6eba31 100644 --- a/include/tvm/relax/attrs/sampling.h +++ b/include/tvm/relax/attrs/sampling.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in multinomial_from_uniform operator */ -struct MultinomialFromUniformAttrs : public BaseAttrsNode { +struct MultinomialFromUniformAttrs : public AttrsNode { DataType dtype; static void RegisterReflection() { @@ -40,7 +40,7 @@ struct MultinomialFromUniformAttrs : public BaseAttrsNode { refl::DefaultValue(DataType::Int(64))); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.MultinomialFromUniformAttrs", - MultinomialFromUniformAttrs, BaseAttrsNode); + MultinomialFromUniformAttrs, AttrsNode); }; // struct MultinomialFromUniformAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/search.h b/include/tvm/relax/attrs/search.h index 015e5d8edc1c..6b3ee4860a3f 100644 --- a/include/tvm/relax/attrs/search.h +++ b/include/tvm/relax/attrs/search.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes for search operators */ -struct ArgmaxArgminAttrs : public BaseAttrsNode { +struct ArgmaxArgminAttrs : public AttrsNode { ffi::Optional axis; bool keepdims; @@ -44,12 +44,11 @@ struct ArgmaxArgminAttrs : public BaseAttrsNode { "with size " "one."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ArgmaxArgminAttrs", ArgmaxArgminAttrs, - BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ArgmaxArgminAttrs", ArgmaxArgminAttrs, AttrsNode); }; // struct ArgmaxArgminAttrs /*! \brief Attributes for bucketize operator */ -struct BucketizeAttrs : public tvm::BaseAttrsNode { +struct BucketizeAttrs : public tvm::AttrsNode { bool out_int32; bool right; @@ -61,7 +60,7 @@ struct BucketizeAttrs : public tvm::BaseAttrsNode { .def_ro("right", &BucketizeAttrs::right, "Determines the behavior for values in boundaries"); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.BucketizeAttrs", BucketizeAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.BucketizeAttrs", BucketizeAttrs, AttrsNode); }; // struct BucketizeAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/sorting.h b/include/tvm/relax/attrs/sorting.h index e32d47239f35..e8bf65d55a43 100644 --- a/include/tvm/relax/attrs/sorting.h +++ b/include/tvm/relax/attrs/sorting.h @@ -31,7 +31,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in sort operator */ -struct SortAttrs : public BaseAttrsNode { +struct SortAttrs : public AttrsNode { int axis; bool descending; @@ -47,11 +47,11 @@ struct SortAttrs : public BaseAttrsNode { "If it is not specified, it defaults to the ascending order.", refl::DefaultValue(false)); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SortAttrs", SortAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SortAttrs", SortAttrs, AttrsNode); }; // struct SortAttrs /*! \brief Attributes used in argsort operator */ -struct ArgsortAttrs : public BaseAttrsNode { +struct ArgsortAttrs : public AttrsNode { int axis; bool descending; DataType dtype; @@ -70,11 +70,11 @@ struct ArgsortAttrs : public BaseAttrsNode { .def_ro("dtype", &ArgsortAttrs::dtype, "DType of the output indices.", refl::DefaultValue(DataType::Void())); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ArgsortAttrs", ArgsortAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ArgsortAttrs", ArgsortAttrs, AttrsNode); }; // struct ArgsortAttrs /*! \brief Attributes used in topk operator */ -struct TopKAttrs : public BaseAttrsNode { +struct TopKAttrs : public AttrsNode { int k; int axis; bool largest; @@ -100,7 +100,7 @@ struct TopKAttrs : public BaseAttrsNode { .def_ro("dtype", &TopKAttrs::dtype, "Data type of the output indices.", refl::DefaultValue(DataType::Void())); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TopKAttrs", TopKAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TopKAttrs", TopKAttrs, AttrsNode); }; // struct TopKAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/statistical.h b/include/tvm/relax/attrs/statistical.h index 884946402a9e..66996c802cc3 100644 --- a/include/tvm/relax/attrs/statistical.h +++ b/include/tvm/relax/attrs/statistical.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes for statistical operators */ -struct StatisticalAttrs : public BaseAttrsNode { +struct StatisticalAttrs : public AttrsNode { ffi::Optional> axis; bool keepdims; @@ -44,12 +44,11 @@ struct StatisticalAttrs : public BaseAttrsNode { "with size " "one."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.StatisticalAttrs", StatisticalAttrs, - BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.StatisticalAttrs", StatisticalAttrs, AttrsNode); }; // struct StatisticalAttrs /*! \brief Attributes used in scan operators like cumsum, cumprod */ -struct ScanopAttrs : public BaseAttrsNode { +struct ScanopAttrs : public AttrsNode { ffi::Optional axis; DataType dtype; bool exclusive = false; @@ -66,7 +65,7 @@ struct ScanopAttrs : public BaseAttrsNode { .def_ro("exclusive", &ScanopAttrs::exclusive, "The first element is not included", refl::DefaultValue(false)); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ScanopAttrs", ScanopAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ScanopAttrs", ScanopAttrs, AttrsNode); }; // struct ScanopAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h index 37ec77cbbff6..f4b1830669c7 100644 --- a/include/tvm/relax/attrs/vision.h +++ b/include/tvm/relax/attrs/vision.h @@ -32,7 +32,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in AllClassNonMaximumSuppression operator */ -struct AllClassNonMaximumSuppressionAttrs : public BaseAttrsNode { +struct AllClassNonMaximumSuppressionAttrs : public AttrsNode { ffi::String output_format; static void RegisterReflection() { @@ -43,11 +43,11 @@ struct AllClassNonMaximumSuppressionAttrs : public BaseAttrsNode { "consumed by each frontend."); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AllClassNonMaximumSuppressionAttrs", - AllClassNonMaximumSuppressionAttrs, BaseAttrsNode); + AllClassNonMaximumSuppressionAttrs, AttrsNode); }; // struct AllClassNonMaximumSuppressionAttrs /*! \brief Attributes used in ROIAlign operator */ -struct ROIAlignAttrs : public BaseAttrsNode { +struct ROIAlignAttrs : public AttrsNode { ffi::Array pooled_size; double spatial_scale; int sample_ratio; @@ -68,11 +68,11 @@ struct ROIAlignAttrs : public BaseAttrsNode { .def_ro("layout", &ROIAlignAttrs::layout, "Dimension ordering of the input data.") .def_ro("mode", &ROIAlignAttrs::mode, "Mode for ROI Align. Can be 'avg' or 'max'."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIAlignAttrs", ROIAlignAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIAlignAttrs", ROIAlignAttrs, AttrsNode); }; // struct ROIAlignAttrs /*! \brief Attributes used in ROIPool operator */ -struct ROIPoolAttrs : public BaseAttrsNode { +struct ROIPoolAttrs : public AttrsNode { ffi::Array pooled_size; double spatial_scale; ffi::String layout; @@ -85,11 +85,11 @@ struct ROIPoolAttrs : public BaseAttrsNode { "Ratio of input feature map height (or width) to raw image height (or width).") .def_ro("layout", &ROIPoolAttrs::layout, "Dimension ordering of the input data."); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIPoolAttrs", ROIPoolAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIPoolAttrs", ROIPoolAttrs, AttrsNode); }; // struct ROIPoolAttrs /*! \brief Attributes used in GetValidCounts operator */ -struct GetValidCountsAttrs : public BaseAttrsNode { +struct GetValidCountsAttrs : public AttrsNode { double score_threshold; int id_index; int score_index; @@ -105,11 +105,11 @@ struct GetValidCountsAttrs : public BaseAttrsNode { "Index of the scores/confidence of boxes."); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GetValidCountsAttrs", GetValidCountsAttrs, - BaseAttrsNode); + AttrsNode); }; // struct GetValidCountsAttrs /*! \brief Attributes used in NonMaximumSuppression operator */ -struct NonMaximumSuppressionAttrs : public BaseAttrsNode { +struct NonMaximumSuppressionAttrs : public AttrsNode { int max_output_size; double iou_threshold; bool force_suppress; @@ -149,11 +149,11 @@ struct NonMaximumSuppressionAttrs : public BaseAttrsNode { "Score threshold for soft-NMS validity check; 0.0 when unused."); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.NonMaximumSuppressionAttrs", - NonMaximumSuppressionAttrs, BaseAttrsNode); + NonMaximumSuppressionAttrs, AttrsNode); }; // struct NonMaximumSuppressionAttrs /*! \brief Attributes for multibox_transform_loc (SSD / TFLite-style box decode). */ -struct MultiboxTransformLocAttrs : public BaseAttrsNode { +struct MultiboxTransformLocAttrs : public AttrsNode { bool clip; double threshold; ffi::Array variances; @@ -173,7 +173,7 @@ struct MultiboxTransformLocAttrs : public BaseAttrsNode { "If false, force output scores[:,0,:] to 0 (background class)."); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.MultiboxTransformLocAttrs", - MultiboxTransformLocAttrs, BaseAttrsNode); + MultiboxTransformLocAttrs, AttrsNode); }; // struct MultiboxTransformLocAttrs } // namespace relax diff --git a/include/tvm/target/virtual_device.h b/include/tvm/target/virtual_device.h index b791387306da..83c7f5655a73 100644 --- a/include/tvm/target/virtual_device.h +++ b/include/tvm/target/virtual_device.h @@ -169,7 +169,7 @@ constexpr int kInvalidDeviceType = -1; * These operations are needed during device planning. */ -class VirtualDeviceNode : public BaseAttrsNode { +class VirtualDeviceNode : public AttrsNode { private: /*! * \brief The \p DLDeviceType (represented as an int) of the virtual device. If \p target is @@ -257,7 +257,7 @@ class VirtualDeviceNode : public BaseAttrsNode { "The area of memory w.r.t. the virtual device where data is stored.", refl::DefaultValue("")); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("target.VirtualDevice", VirtualDeviceNode, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("target.VirtualDevice", VirtualDeviceNode, AttrsNode); friend class VirtualDevice; }; diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index e2ea70a9f663..b4c36881c999 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -50,7 +50,7 @@ DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key) { return attrs; } -TVM_FFI_STATIC_INIT_BLOCK() { tvm::ffi::reflection::ObjectDef(); } +TVM_FFI_STATIC_INIT_BLOCK() { tvm::ffi::reflection::ObjectDef(); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; From d9dd2bcaca06faeb4e0159461c6a923b1c6e00a2 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 26 May 2026 15:41:03 +0000 Subject: [PATCH 4/9] [REFACTOR][IR] Promote DictAttrs to NOTNULLABLE The DictAttrs no-arg constructor already creates an always-defined empty backing (post-#19607 inlined ctor), so every existing call site that constructed a DictAttrs already produced a defined object. Switching to TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE makes that property part of the type contract. This removes the wall of explicit copy/move and operator-> declarations on DictAttrs and lets us drop ~15 defensive attrs.defined() checks that could never fire. Python wrappers that previously passed attrs=None for an absent attrs (Function and Function.create_empty) now construct an empty DictAttrs explicitly. --- include/tvm/ir/attrs.h | 22 +++++--------------- python/tvm/relax/expr.py | 4 ++++ src/relax/ir/dataflow_matcher.cc | 2 +- src/relax/ir/expr.cc | 4 ---- src/relax/script/printer/function.cc | 5 ++--- src/script/printer/ir/ir.cc | 2 +- src/target/cuda/codegen_cuda.cc | 2 +- src/tirx/analysis/verify_tirx_well_formed.cc | 3 +-- src/tirx/ir/function.cc | 4 ---- src/tirx/script/printer/buffer.cc | 2 +- src/tirx/script/printer/function.cc | 10 ++++----- src/tirx/transform/ir_utils.cc | 4 ---- src/tirx/transform/split_host_device.cc | 2 +- 13 files changed, 22 insertions(+), 44 deletions(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index bb62910f2f16..d1fd220bdf57 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -89,11 +89,9 @@ class DictAttrsNode : public AttrsNode { class DictAttrs : public Attrs { public: /*! - * \brief constructor with UnsafeInit - */ - explicit DictAttrs(ffi::UnsafeInit tag) : Attrs(tag) {} - /*! - * \brief Construct a Attrs backed by DictAttrsNode. + * \brief Construct a DictAttrs backed by DictAttrsNode. + * + * The no-argument form constructs an empty (but always defined) DictAttrs. * \param dict The attributes. */ explicit DictAttrs(ffi::Map dict = {}) { @@ -103,8 +101,6 @@ class DictAttrs : public Attrs { } // Utils for accessing attributes - // This needs to be on DictAttrs, not DictAttrsNode because we return the default - // value if DictAttrsNode is not defined. /*! * \brief Get a function attribute. * @@ -128,8 +124,7 @@ class DictAttrs : public Attrs { ffi::Optional GetAttr( const std::string& attr_key, ffi::Optional default_value = ffi::Optional(std::nullopt)) const { - if (!defined()) return default_value; - const DictAttrsNode* node = this->as(); + const DictAttrsNode* node = get(); auto it = node->dict.find(attr_key); if (it != node->dict.end()) { return (*it).second.cast(); @@ -165,14 +160,7 @@ class DictAttrs : public Attrs { return GetAttr(attr_key, 0).value_or(0) != 0; } - explicit DictAttrs(::tvm::ffi::ObjectPtr n) : Attrs(n) {} - DictAttrs(const DictAttrs&) = default; - DictAttrs(DictAttrs&&) = default; - DictAttrs& operator=(const DictAttrs&) = default; - DictAttrs& operator=(DictAttrs&&) = default; - const DictAttrsNode* operator->() const { return static_cast(data_.get()); } - const DictAttrsNode* get() const { return operator->(); } - using ContainerType = DictAttrsNode; + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(DictAttrs, Attrs, DictAttrsNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode); }; diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 5a75e43b1284..6dffaab8f4a3 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -1010,6 +1010,8 @@ def __init__( attrs: tvm.ir.DictAttrs | None = None, span: Span | None = None, ) -> None: + if attrs is None: + attrs = tvm.ir.DictAttrs({}) self.__init_handle_by_constructor__( _ffi_api.Function, params, @@ -1029,6 +1031,8 @@ def create_empty( span: Span | None = None, ): """Construct a relax.Function but without body""" + if attrs is None: + attrs = tvm.ir.DictAttrs({}) return _ffi_api.FunctionCreateEmpty(params, ret_struct_info, is_pure, attrs, span) # type: ignore def __call__(self, *args): diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 22e3a7bbc31a..e8eafde31747 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -209,7 +209,7 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons } else if (auto* op = expr.as()) { matches = true; for (auto kv : attributes) { - if (matches && op->attrs.defined() && op->attrs->dict.count(kv.first)) { + if (matches && op->attrs->dict.count(kv.first)) { matches &= ffi::StructuralEqual()(kv.second, op->attrs->dict[kv.first]); } else { matches = false; diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index c2f404d41f63..5c2419209b42 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -542,10 +542,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { Function::Function(ffi::Array params, Expr body, ffi::Optional ret_struct_info, bool is_pure, DictAttrs attrs, Span span) { - if (!attrs.defined()) { - attrs = DictAttrs(); - } - // Set the function type. // For function, we take a conservative approach and require the function type // to be known at construction time. diff --git a/src/relax/script/printer/function.cc b/src/relax/script/printer/function.cc index e30a2b0bf432..4c0d84f9f6af 100644 --- a/src/relax/script/printer/function.cc +++ b/src/relax/script/printer/function.cc @@ -84,7 +84,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // Step 3. Clean up func variables (*f)->func_vars = nullptr; // Step 4. Print attributes - if (n->attrs.defined() && !n->attrs->dict.empty()) { + if (!n->attrs->dict.empty()) { // If the function is a global function and has a global symbol, // then don't print the global symbol (it will be implicit from not being private). // For a function without an IR module whose global symbol @@ -119,8 +119,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } // if the function is global or is not in a module and does not have a global symbol, // indicate that it's private - if (AtTopLevelFunction(d) && - (!n->attrs.defined() || !n->attrs->dict.count(tvm::attr::kGlobalSymbol))) { + if (AtTopLevelFunction(d) && !n->attrs->dict.count(tvm::attr::kGlobalSymbol)) { dec_keys.push_back("private"); dec_values.push_back(LiteralDoc::Boolean(true, ffi::Optional())); } diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index d49f3123d908..640bc6c57e85 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -75,7 +75,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) (*f)->AddDispatchToken(d, "ir"); IdDoc module_doc = d->Define(mod, f(), GetBindingName(d).value_or("Module")); (*f)->global_infos = &mod->global_infos; - if (mod->attrs.defined() && !mod->attrs->dict.empty()) { + if (!mod->attrs->dict.empty()) { (*f)->stmts.push_back( ExprStmtDoc(IR(d, "module_attrs") // ->Call({d->AsDoc(mod->attrs, p->Attr("attrs"))}))); diff --git a/src/target/cuda/codegen_cuda.cc b/src/target/cuda/codegen_cuda.cc index 863cc4eb2061..27cad36735b6 100644 --- a/src/target/cuda/codegen_cuda.cc +++ b/src/target/cuda/codegen_cuda.cc @@ -211,7 +211,7 @@ void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) { extractor(f->body); // Also check PrimFunc attrs for persistent kernel (decorator-level) bool is_persistent = extractor.is_persistent_kernel; - if (!is_persistent && f->attrs.defined() && f->attrs->dict.count(tirx::attr::kPersistentKernel)) { + if (!is_persistent && f->attrs->dict.count(tirx::attr::kPersistentKernel)) { is_persistent = true; } arith::Analyzer analyzer; diff --git a/src/tirx/analysis/verify_tirx_well_formed.cc b/src/tirx/analysis/verify_tirx_well_formed.cc index 64ede04f2075..f9063bd2d2e3 100644 --- a/src/tirx/analysis/verify_tirx_well_formed.cc +++ b/src/tirx/analysis/verify_tirx_well_formed.cc @@ -251,8 +251,7 @@ bool VerifyTIRxWellFormed(const IRModule& mod, bool assert_mode, bool device_fun for (const auto& [gvar, base_func] : mod->functions) { if (auto prim_func = base_func.as()) { // s_tir=True PrimFuncs use s_tir semantics — defer to VerifyWellFormed. - if (prim_func.value()->attrs.defined() && - prim_func.value()->attrs->dict.count(tvm::attr::kSTir)) { + if (prim_func.value()->attrs->dict.count(tvm::attr::kSTir)) { if (!VerifyWellFormed(prim_func.value(), assert_mode)) return false; continue; } diff --git a/src/tirx/ir/function.cc b/src/tirx/ir/function.cc index a92767c85aaa..273ed1ae3c99 100644 --- a/src/tirx/ir/function.cc +++ b/src/tirx/ir/function.cc @@ -77,10 +77,6 @@ relax::StructInfo InferStructInfo(const PrimFunc& prim_func) { // Get the function type of a PrimFunc PrimFunc::PrimFunc(ffi::Array params, Stmt body, Type ret_type, ffi::Map buffer_map, DictAttrs attrs, Span span) { - if (!attrs.defined()) { - attrs = DictAttrs(); - } - if (!ret_type.defined()) { ret_type = VoidType(); } diff --git a/src/tirx/script/printer/buffer.cc b/src/tirx/script/printer/buffer.cc index 72f3f9f9df41..32d50a8f8d6d 100644 --- a/src/tirx/script/printer/buffer.cc +++ b/src/tirx/script/printer/buffer.cc @@ -193,7 +193,7 @@ ffi::Map BufferAttrs(tirx::Buffer buffer, const AccessPath for (const auto& f : d->frames) { if (const auto* tir_f = f.as()) { if (auto func = tir_f->tirx.as()) { - if (func->attrs.defined() && func->attrs->dict.count(tvm::attr::kSTir)) { + if (func->attrs->dict.count(tvm::attr::kSTir)) { enclosing_s_tir = true; } break; diff --git a/src/tirx/script/printer/function.cc b/src/tirx/script/printer/function.cc index 41b561e739eb..30912034da7d 100644 --- a/src/tirx/script/printer/function.cc +++ b/src/tirx/script/printer/function.cc @@ -106,7 +106,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (d->cfg->syntax_sugar && CountVarOccurrence(func, var) == 2 && func->buffer_map.count(var)) { tirx::Buffer buffer = func->buffer_map[var]; - bool s_tir = func->attrs.defined() && func->attrs->dict.count(tvm::attr::kSTir); + bool s_tir = func->attrs->dict.count(tvm::attr::kSTir); if (IsSimpleBuffer(buffer, s_tir) && buffer_data_counter.at(buffer->data.get()) == 1) { AccessPath buffer_p = p->Attr("buffer_map")->MapItem(var); IdDoc lhs = DefineBuffer(buffer, *f, d); @@ -120,7 +120,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) args.push_back(AssignDoc(DefineVar(var, *f, d), std::nullopt, a)); } // Step 2. Handle `func->attrs` - if (func->attrs.defined() && !func->attrs->dict.empty()) { + if (!func->attrs->dict.empty()) { // for global symbol, don't display it if it matches the func name std::unordered_set keys_to_remove; if (func->attrs->dict.count(tvm::attr::kGlobalSymbol) && @@ -214,15 +214,15 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ffi::Array kwargs_keys; ffi::Array kwargs_values; // mark private if there is no global symbol - if (!func->attrs.defined() || !func->attrs->dict.count(tvm::attr::kGlobalSymbol)) { + if (!func->attrs->dict.count(tvm::attr::kGlobalSymbol)) { kwargs_keys.push_back("private"); kwargs_values.push_back(LiteralDoc::Boolean(true, ffi::Optional())); } - if (func->attrs.defined() && func->attrs->dict.count(tvm::attr::kSTir)) { + if (func->attrs->dict.count(tvm::attr::kSTir)) { kwargs_keys.push_back("s_tir"); kwargs_values.push_back(LiteralDoc::Boolean(true, ffi::Optional())); } - if (func->attrs.defined() && func->attrs->dict.count(tirx::attr::kPersistentKernel)) { + if (func->attrs->dict.count(tirx::attr::kPersistentKernel)) { kwargs_keys.push_back("persistent"); kwargs_values.push_back(LiteralDoc::Boolean(true, ffi::Optional())); } diff --git a/src/tirx/transform/ir_utils.cc b/src/tirx/transform/ir_utils.cc index 8582968f0e58..281e53d76c4c 100644 --- a/src/tirx/transform/ir_utils.cc +++ b/src/tirx/transform/ir_utils.cc @@ -158,10 +158,6 @@ class IRConvertSSA final : public StmtExprMutator { }(); auto attrs = [&]() -> DictAttrs { - if (!func->attrs.defined()) { - return DictAttrs(); - } - ffi::Map dict; bool made_change = false; diff --git a/src/tirx/transform/split_host_device.cc b/src/tirx/transform/split_host_device.cc index 6a07306b38ae..70c44ba66c98 100644 --- a/src/tirx/transform/split_host_device.cc +++ b/src/tirx/transform/split_host_device.cc @@ -112,7 +112,7 @@ class HostDeviceSplitter : public StmtMutator { device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target}, {tirx::attr::kNoAlias, true}, {tirx::attr::kIsGlobalFunc, true}}); - if (cur_func_->attrs.defined() && cur_func_->attrs->dict.count(tvm::attr::kSTir)) { + if (cur_func_->attrs->dict.count(tvm::attr::kSTir)) { device_func = WithAttr(std::move(device_func), tvm::attr::kSTir, true); } auto num_inputs = cur_func_->GetAttr(tvm::attr::kNumInputs); From 9651feb83950c7c871fa1d075d876923ea834810 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 26 May 2026 15:45:03 +0000 Subject: [PATCH 5/9] [REFACTOR][IR] Inline DictAttrs-form With{,out}Attr overloads The DictAttrs-form WithAttr/WithAttrs/WithoutAttr free functions were not TVM_DLL-exported, not bound to Python, and had no external C++ callers - they exist only as one-hop delegations from the TFunc-template wrappers in the same header. Inlining the dict mutation into the templates removes a layer of indirection. --- include/tvm/ir/attrs.h | 54 ++++++------------------------------------ src/ir/attrs.cc | 22 ----------------- 2 files changed, 7 insertions(+), 69 deletions(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index d1fd220bdf57..ecc97a05c576 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -164,46 +164,6 @@ class DictAttrs : public Attrs { TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode); }; -/*! - * \brief Copy the DictAttrs, but overrides attributes with the - * entries from \p attrs. - * - * \param attrs The DictAttrs to update - * - * \param new_attrs Key/values attributes to add to \p attrs. - * - * \returns The new DictAttrs with updated attributes. - */ -DictAttrs WithAttrs(DictAttrs attrs, ffi::Map new_attrs); - -/*! - * \brief Copy the DictAttrs, but overrides a single attribute. - * - * \param attrs The DictAttrs to update - * - * \param key The update to insert or update. - * - * \param value The new value of the attribute - * - * \returns The new DictAttrs with updated attributes. - */ -DictAttrs WithAttr(DictAttrs attrs, ffi::String key, Any value); - -inline DictAttrs WithAttr(DictAttrs attrs, const std::string& key, Any value) { - return WithAttr(std::move(attrs), ffi::String(key), std::move(value)); -} - -/*! - * \brief Copy the DictAttrs, but without a specific attribute. - * - * \param attrs The DictAttrs to update - * - * \param key The key to remove - * - * \returns The new DictAttrs with updated attributes. - */ -DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key); - /*! * \brief Copy the function or module, but overrides * the attribute value key with the value. @@ -236,7 +196,7 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, Any attr_value) using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); TNode* node = input.CopyOnWrite(); - node->attrs = WithAttr(std::move(node->attrs), attr_key, attr_value); + node->attrs.CopyOnWrite()->dict.Set(attr_key, std::move(attr_value)); return input; } @@ -254,10 +214,12 @@ template inline TFunc WithAttrs(TFunc input, ffi::Map attrs) { using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); + if (attrs.empty()) return input; TNode* node = input.CopyOnWrite(); - - node->attrs = WithAttrs(std::move(node->attrs), attrs); - + auto* dict_node = node->attrs.CopyOnWrite(); + for (const auto& [k, v] : attrs) { + dict_node->dict.Set(k, v); + } return input; } @@ -291,10 +253,8 @@ template inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) { using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); - TNode* node = input.CopyOnWrite(); - node->attrs = WithoutAttr(std::move(node->attrs), attr_key); - + node->attrs.CopyOnWrite()->dict.erase(attr_key); return input; } diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index b4c36881c999..b58c183c7aec 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -28,28 +28,6 @@ namespace tvm { TVM_FFI_STATIC_INIT_BLOCK() { DictAttrsNode::RegisterReflection(); } -DictAttrs WithAttrs(DictAttrs attrs, ffi::Map new_attrs) { - if (new_attrs.empty()) { - return attrs; - } - - auto* write_ptr = attrs.CopyOnWrite(); - for (const auto& [key, value] : new_attrs) { - write_ptr->dict.Set(key, value); - } - return attrs; -} - -DictAttrs WithAttr(DictAttrs attrs, ffi::String key, ffi::Any value) { - attrs.CopyOnWrite()->dict.Set(key, value); - return attrs; -} - -DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key) { - attrs.CopyOnWrite()->dict.erase(key); - return attrs; -} - TVM_FFI_STATIC_INIT_BLOCK() { tvm::ffi::reflection::ObjectDef(); } TVM_FFI_STATIC_INIT_BLOCK() { From 768e4018cc42039d02a7b1fcdc9cfb8c47ea1105 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 26 May 2026 15:54:56 +0000 Subject: [PATCH 6/9] [REFACTOR][IR] Rename AttrsWithDefaultValues to PassConfigWithDefaults and move to transform.h After #19607 every consumer of AttrsWithDefaultValues is a pass-config class registered via TVM_REGISTER_PASS_CONFIG_OPTION; none are Attrs. Renaming to PassConfigWithDefaults and relocating next to PassContext makes the helper's domain explicit and shrinks attrs.h further. --- include/tvm/ir/attrs.h | 19 +---------------- include/tvm/ir/transform.h | 21 +++++++++++++++++++ src/relax/backend/contrib/clml/codegen.cc | 2 +- src/relax/backend/contrib/tensorrt/codegen.cc | 2 +- src/s_tir/transform/hoist_expression.cc | 4 ++-- src/s_tir/transform/inject_double_buffer.cc | 2 +- src/s_tir/transform/loop_partition.cc | 2 +- src/tirx/transform/remove_no_op.cc | 5 +++-- src/tirx/transform/stmt_simplify.cc | 2 +- src/tirx/transform/unroll_loop.cc | 2 +- 10 files changed, 33 insertions(+), 28 deletions(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index ecc97a05c576..06d913731f6f 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -23,7 +23,7 @@ * This module enables declaration of named attributes * which support default value setup and bound checking. * - * \sa AttrsNode, AttrsWithDefaultValues + * \sa AttrsNode */ #ifndef TVM_IR_ATTRS_H_ #define TVM_IR_ATTRS_H_ @@ -258,22 +258,5 @@ inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) { return input; } -/*! - * \brief Create an object with all default values, using the reflection defaults. - * \tparam TObj the ObjectRef type to be created. - * \return An instance with all reflection-defined default values applied. - */ -template -inline TObj AttrsWithDefaultValues() { - static_assert(std::is_base_of_v, "Can only create ObjectRef-derived types"); - using ContainerType = typename TObj::ContainerType; - static auto finit_object = ffi::Function::GetGlobalRequired("ffi.MakeObjectFromPackedArgs"); - AnyView packed_args[1]; - packed_args[0] = ContainerType::RuntimeTypeIndex(); - ffi::Any rv; - finit_object.CallPacked(ffi::PackedArgs(packed_args, 1), &rv); - return rv.cast(); -} - } // namespace tvm #endif // TVM_IR_ATTRS_H_ diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 436987ae784d..f929f1654b81 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -57,6 +57,7 @@ #define TVM_IR_TRANSFORM_H_ #include +#include #include #include #include @@ -66,6 +67,7 @@ #include #include +#include #include namespace tvm { @@ -300,6 +302,25 @@ class PassContext : public ffi::ObjectRef { friend class With; }; +/*! + * \brief Create a pass-config object with all default values, using the + * reflection defaults. + * \tparam TConfig the ObjectRef type to be created. + * \return An instance with all reflection-defined default values applied. + */ +template +inline TConfig PassConfigWithDefaults() { + static_assert(std::is_base_of_v, + "Can only create ObjectRef-derived types"); + using ContainerType = typename TConfig::ContainerType; + static auto finit_object = ffi::Function::GetGlobalRequired("ffi.MakeObjectFromPackedArgs"); + ffi::AnyView packed_args[1]; + packed_args[0] = ContainerType::RuntimeTypeIndex(); + ffi::Any rv; + finit_object.CallPacked(ffi::PackedArgs(packed_args, 1), &rv); + return rv.cast(); +} + #define TVM_PASS_CTX_CONFIG_VAR_DEF [[maybe_unused]] static uint32_t __make_PassContext_tid /*! diff --git a/src/relax/backend/contrib/clml/codegen.cc b/src/relax/backend/contrib/clml/codegen.cc index 5fd04c05bfdc..c58c2ee9aa92 100644 --- a/src/relax/backend/contrib/clml/codegen.cc +++ b/src/relax/backend/contrib/clml/codegen.cc @@ -267,7 +267,7 @@ class OpenCLMLJSONSerializer : public JSONSerializer { auto ctx = transform::PassContext::Current(); auto cfg = ctx->GetConfig("relax.ext.clml.options"); if (!cfg.defined()) { - cfg = AttrsWithDefaultValues(); + cfg = transform::PassConfigWithDefaults(); } node->SetAttr("clml_version", static_cast(cfg.value()->clml_version.IntValue())); } diff --git a/src/relax/backend/contrib/tensorrt/codegen.cc b/src/relax/backend/contrib/tensorrt/codegen.cc index 8720c77b4388..7fa6d48bdc24 100644 --- a/src/relax/backend/contrib/tensorrt/codegen.cc +++ b/src/relax/backend/contrib/tensorrt/codegen.cc @@ -180,7 +180,7 @@ class TensorRTJSONSerializer : public JSONSerializer { auto ctx = transform::PassContext::Current(); auto cfg = ctx->GetConfig("relax.ext.tensorrt.options"); if (!cfg.defined()) { - cfg = AttrsWithDefaultValues(); + cfg = transform::PassConfigWithDefaults(); } TVM_FFI_ICHECK_EQ(cfg.value()->tensorrt_version.size(), 3); ffi::Array tensorrt_version = {cfg.value()->tensorrt_version[0], diff --git a/src/s_tir/transform/hoist_expression.cc b/src/s_tir/transform/hoist_expression.cc index 8fe18450290f..5cb851ca2a52 100644 --- a/src/s_tir/transform/hoist_expression.cc +++ b/src/s_tir/transform/hoist_expression.cc @@ -568,7 +568,7 @@ Pass HoistExpression() { auto cfg = ctx->GetConfig("s_tir.HoistExpression"); if (!cfg.defined()) { - cfg = AttrsWithDefaultValues(); + cfg = tvm::transform::PassConfigWithDefaults(); } n->body = ExpressionHoister::Hoist(std::move(n->body), cfg.value()); return f; @@ -602,7 +602,7 @@ static Pass HoistIfThenElseImpl() { return f; } if (!cfg.defined()) { - cfg = AttrsWithDefaultValues(); + cfg = tvm::transform::PassConfigWithDefaults(); } int block_var = static_cast(cfg.value()->support_block_scope_hoisting ? HoistedConditionals::kUsingBlockVar diff --git a/src/s_tir/transform/inject_double_buffer.cc b/src/s_tir/transform/inject_double_buffer.cc index 0c934ddbcdb6..ac2f25a62972 100644 --- a/src/s_tir/transform/inject_double_buffer.cc +++ b/src/s_tir/transform/inject_double_buffer.cc @@ -332,7 +332,7 @@ Pass InjectDoubleBuffer() { auto* n = f.CopyOnWrite(); auto cfg = ctx->GetConfig("s_tir.InjectDoubleBuffer"); if (!cfg.defined()) { - cfg = AttrsWithDefaultValues(); + cfg = tvm::transform::PassConfigWithDefaults(); } n->body = DoubleBufferInjector(cfg.value()->split_loop).Inject(std::move(n->body)); return f; diff --git a/src/s_tir/transform/loop_partition.cc b/src/s_tir/transform/loop_partition.cc index bf2dca776cfe..8eb444dcfd53 100644 --- a/src/s_tir/transform/loop_partition.cc +++ b/src/s_tir/transform/loop_partition.cc @@ -817,7 +817,7 @@ Pass LoopPartition() { auto* n = f.CopyOnWrite(); auto cfg = ctx->GetConfig("s_tir.LoopPartition"); if (!cfg.defined()) { - cfg = AttrsWithDefaultValues(); + cfg = tvm::transform::PassConfigWithDefaults(); } n->body = s_tir::LoopPartition(std::move(n->body), cfg.value()->partition_const_loop, cfg.value()->no_unroll_loop_with_extent_one, diff --git a/src/tirx/transform/remove_no_op.cc b/src/tirx/transform/remove_no_op.cc index aa2280215471..133cfa9d9a56 100644 --- a/src/tirx/transform/remove_no_op.cc +++ b/src/tirx/transform/remove_no_op.cc @@ -271,8 +271,9 @@ namespace transform { Pass RemoveNoOp() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - RemoveNoOpConfig config = ctx->GetConfig("tirx.RemoveNoOp") - .value_or(AttrsWithDefaultValues()); + RemoveNoOpConfig config = + ctx->GetConfig("tirx.RemoveNoOp") + .value_or(tvm::transform::PassConfigWithDefaults()); arith::Analyzer analyzer; analyzer.rewrite_simplify.SetMaximumRewriteSteps(config->max_simplification_steps); diff --git a/src/tirx/transform/stmt_simplify.cc b/src/tirx/transform/stmt_simplify.cc index 2238625255cd..9ebbcab9e133 100644 --- a/src/tirx/transform/stmt_simplify.cc +++ b/src/tirx/transform/stmt_simplify.cc @@ -89,7 +89,7 @@ class StmtSimplifyConfig : public ffi::ObjectRef { }; static StmtSimplifyConfig MakeDefaultStmtSimplifyConfig() { - return AttrsWithDefaultValues(); + return tvm::transform::PassConfigWithDefaults(); } TVM_FFI_STATIC_INIT_BLOCK() { StmtSimplifyConfigNode::RegisterReflection(); } diff --git a/src/tirx/transform/unroll_loop.cc b/src/tirx/transform/unroll_loop.cc index faf1ec2d677d..ae99410ceea0 100644 --- a/src/tirx/transform/unroll_loop.cc +++ b/src/tirx/transform/unroll_loop.cc @@ -285,7 +285,7 @@ Pass UnrollLoop() { auto* n = f.CopyOnWrite(); auto cfg = ctx->GetConfig("tirx.UnrollLoop"); if (!cfg.defined()) { - cfg = AttrsWithDefaultValues(); + cfg = tvm::transform::PassConfigWithDefaults(); } n->body = UnrollLoop(std::move(f->body), cfg.value()); return f; From c676224bfbc106864d88c38c3cb8f89c4a5c298d Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 26 May 2026 18:34:17 +0000 Subject: [PATCH 7/9] [REVERT][IR] Restore field-info machinery as FieldInfoNode in op.h PR #19615 Commit A deleted AttrFieldInfoNode and OpNode::arguments on the rationale that the only consumer (GetArgStructInfo) could be satisfied with "input[i]" indices instead of stored names. On second thought we want the named-argument metadata back -- it's worth the ~340 lines of .add_argument(...) registrations for the clearer error messages and downstream tooling that wants to introspect Op argument schemas. Restore with two adjustments: relocate from include/tvm/ir/attrs.h to include/tvm/ir/op.h (closer to OpNode), and rename to FieldInfoNode (drop the "Attr" prefix that referenced the old Attrs hierarchy this class never actually used). The FFI type-key becomes "ir.FieldInfo" to match the new C++ name. --- include/tvm/ir/op.h | 57 ++++++++++++ python/tvm/ir/op.py | 14 +++ src/ir/op.cc | 10 ++- src/relax/op/ccl/ccl.cc | 5 ++ src/relax/op/distributed/distributed.cc | 8 ++ src/relax/op/image/resize.cc | 8 ++ src/relax/op/nn/attention.cc | 14 +++ src/relax/op/nn/convolution.cc | 12 +++ src/relax/op/nn/nn.cc | 45 ++++++++-- src/relax/op/nn/pooling.cc | 9 ++ src/relax/op/op.cc | 89 +++++++++++++++++++ src/relax/op/op_common.cc | 10 +-- src/relax/op/op_common.h | 9 +- src/relax/op/tensor/binary.cc | 2 +- src/relax/op/tensor/binary.h | 2 + src/relax/op/tensor/create.cc | 26 ++++++ src/relax/op/tensor/datatype.cc | 2 + src/relax/op/tensor/grad.cc | 14 +++ src/relax/op/tensor/index.cc | 7 ++ src/relax/op/tensor/inspect.cc | 10 +++ src/relax/op/tensor/linear_algebra.cc | 5 ++ src/relax/op/tensor/manipulate.cc | 44 +++++++++ src/relax/op/tensor/qdq.cc | 6 ++ src/relax/op/tensor/sampling.cc | 3 + src/relax/op/tensor/search.cc | 9 ++ src/relax/op/tensor/set.cc | 18 ++++ src/relax/op/tensor/sorting.cc | 3 + src/relax/op/tensor/statistical.cc | 3 + src/relax/op/tensor/statistical.h | 1 + src/relax/op/tensor/ternary.cc | 3 + src/relax/op/tensor/unary.cc | 3 + src/relax/op/vision/multibox_transform_loc.cc | 4 + src/relax/op/vision/nms.cc | 14 +++ src/relax/op/vision/roi_align.cc | 3 + src/relax/op/vision/roi_pool.cc | 3 + src/target/cuda/intrin_rule_cuda.cc | 16 ++++ src/target/metal/intrin_rule_metal.cc | 6 ++ src/target/webgpu/intrin_rule_webgpu.cc | 6 ++ 38 files changed, 485 insertions(+), 18 deletions(-) diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 619ab08f79c4..6a8591443b9a 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -44,6 +44,41 @@ namespace tvm { template class OpAttrMap; +/*! + * \brief Information about an input field of an Op (name, type, description). + * + * Populated via OpRegEntry::add_argument and consumed both by + * internal sanity checks / error messages and by external tooling + * that wants to introspect an Op's argument schema. + */ +class FieldInfoNode : public ffi::Object { + public: + /*! \brief name of the field */ + ffi::String name; + /*! \brief type docstring information in str. */ + ffi::String type_info; + /*! \brief detailed description of the type */ + ffi::String description; + + static void RegisterReflection() { + namespace rfl = ffi::reflection; + rfl::ObjectDef() + .def_ro("name", &FieldInfoNode::name) + .def_ro("type_info", &FieldInfoNode::type_info) + .def_ro("description", &FieldInfoNode::description); + } + + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.FieldInfo", FieldInfoNode, ffi::Object); +}; + +/*! \brief Managed reference to FieldInfoNode. */ +class FieldInfo : public ffi::ObjectRef { + public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FieldInfo, ffi::ObjectRef, FieldInfoNode); +}; + // TODO(tvm-team): migrate low-level intrinsics to use Op /*! * \brief Primitive Op(builtin intrinsics) @@ -67,6 +102,8 @@ class OpNode : public RelaxExprNode { * This can be used to generate docstring automatically for the operator. */ ffi::String description; + /* \brief Information of input arguments to the operator */ + ffi::Array arguments; /*! * \brief The type key of the attribute field * This can be empty, in which case it defaults to anything. @@ -95,6 +132,7 @@ class OpNode : public RelaxExprNode { .def_ro("name", &OpNode::name) .def_ro("op_type", &OpNode::op_type, refl::AttachFieldFlag::SEqHashIgnore()) .def_ro("description", &OpNode::description, refl::AttachFieldFlag::SEqHashIgnore()) + .def_ro("arguments", &OpNode::arguments, refl::AttachFieldFlag::SEqHashIgnore()) .def_ro("attrs_type_key", &OpNode::attrs_type_key, refl::AttachFieldFlag::SEqHashIgnore()) .def_ro("num_inputs", &OpNode::num_inputs, refl::AttachFieldFlag::SEqHashIgnore()) .def_ro("support_level", &OpNode::support_level, refl::AttachFieldFlag::SEqHashIgnore()); @@ -176,6 +214,15 @@ class OpRegEntry { * \return reference to self. */ inline OpRegEntry& describe(const std::string& descr); // NOLINT(*) + /*! + * \brief Add argument information to the function. + * \param name Name of the argument. + * \param type Type of the argument. + * \param description Description of the argument. + * \return reference to self. + */ + inline OpRegEntry& add_argument(const std::string& name, const std::string& type, + const std::string& description); /*! * \brief Set the attrs type key and index to be AttrsType. * \tparam AttrsType the attribute type to b set. @@ -316,6 +363,16 @@ inline OpRegEntry& OpRegEntry::describe(const std::string& descr) { // NOLINT(* return *this; } +inline OpRegEntry& OpRegEntry::add_argument(const std::string& name, const std::string& type, + const std::string& description) { + auto n = ffi::make_object(); + n->name = name; + n->type_info = type; + n->description = description; + get()->arguments.push_back(FieldInfo(n)); + return *this; +} + inline OpRegEntry& OpRegEntry::set_num_inputs(int32_t n) { // NOLINT(*) get()->num_inputs = n; return *this; diff --git a/python/tvm/ir/op.py b/python/tvm/ir/op.py index ca1979ffabc4..6c0912f86476 100644 --- a/python/tvm/ir/op.py +++ b/python/tvm/ir/op.py @@ -102,6 +102,20 @@ def reset_attr(self, attr_name): """ _ffi_api.OpResetAttr(self, attr_name) + def add_argument(self, name, type, description): # pylint: disable=redefined-builtin + """Add arguments information to the function. + + Parameters + ---------- + name : str + The argument name. + type : str + The argument type. + description : str + The argument description. + """ + _ffi_api.OpAddArgument(self, name, type, description) + def set_support_level(self, level): """Set the support level of op. diff --git a/src/ir/op.cc b/src/ir/op.cc index 9f0c20c92090..7722b0d30730 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -33,7 +33,10 @@ namespace tvm { -TVM_FFI_STATIC_INIT_BLOCK() { OpNode::RegisterReflection(); } +TVM_FFI_STATIC_INIT_BLOCK() { + FieldInfoNode::RegisterReflection(); + OpNode::RegisterReflection(); +} using ffi::Any; using ffi::Function; @@ -113,6 +116,11 @@ TVM_FFI_STATIC_INIT_BLOCK() { auto& op = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); op.describe(descr); }) + .def("ir.OpAddArgument", + [](Op op, ffi::String name, ffi::String type, ffi::String description) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); + reg.add_argument(name, type, description); + }) .def("ir.OpSetSupportLevel", [](Op op, int level) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc index d27353eef050..7f7eb3c8935d 100644 --- a/src/relax/op/ccl/ccl.cc +++ b/src/relax/op/ccl/ccl.cc @@ -56,6 +56,7 @@ StructInfo InferStructInfoAllReduce(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.ccl.allreduce") .set_attrs_type() .set_num_inputs(1) + .add_argument("x", "Tensor", "Input to which allreduce will be applied.") .set_attr("FInferStructInfo", InferStructInfoAllReduce) .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) .set_attr("FPurity", true); @@ -94,6 +95,7 @@ StructInfo InferStructInfoAllGather(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.ccl.allgather") .set_num_inputs(1) + .add_argument("x", "Tensor", "Input to which allgather will be applied.") .set_attr("FInferStructInfo", InferStructInfoAllGather) .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) .set_attr("FPurity", true); @@ -116,6 +118,7 @@ StructInfo InferStructInfoBroadcastFromZero(const Call& call, const BlockBuilder TVM_REGISTER_OP("relax.ccl.broadcast_from_worker0") .set_num_inputs(1) + .add_argument("x", "Tensor", "Input to be broadcast.") .set_attr("FInferStructInfo", InferStructInfoBroadcastFromZero) .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) .set_attr("FPurity", true); @@ -163,6 +166,8 @@ StructInfo InferStructInfoScatter(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.ccl.scatter_from_worker0") .set_num_inputs(1) + .add_argument("x", "Tensor", + "The buffer to be divided into equal parts and sent to each worker accordingly.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoScatter) .set_attr("FPurity", true); diff --git a/src/relax/op/distributed/distributed.cc b/src/relax/op/distributed/distributed.cc index c737ef392984..bee2751564d9 100644 --- a/src/relax/op/distributed/distributed.cc +++ b/src/relax/op/distributed/distributed.cc @@ -62,6 +62,7 @@ StructInfo InferStructInfoAnnotateSharding(const Call& call, const BlockBuilder& TVM_REGISTER_OP("relax.dist.annotate_sharding") .set_num_inputs(1) + .add_argument("input", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoAnnotateSharding) .set_attr("dist.FInferStructInfo", InferStructInfoAnnotateSharding) .set_attr("FPurity", true); @@ -92,6 +93,7 @@ StructInfo InferDistStructInfoRedistribute(const Call& call, const BlockBuilder& TVM_REGISTER_OP("relax.dist.redistribute") .set_num_inputs(1) + .add_argument("input", "Tensor", "The input tensor.") .set_attr("dist.FInferStructInfo", InferDistStructInfoRedistribute) .set_attr("FPurity", true); @@ -109,6 +111,11 @@ StructInfo InferStructInfoCallTIRLocalView(const Call& call, const BlockBuilder& TVM_REGISTER_OP("relax.dist.call_tir_local_view") .set_num_inputs(3) + .add_argument("func", "Expr", "The destination-passing-style function.") + .add_argument("args", "Tuple", "The input arguments.") + .add_argument("packed_ints", "Expr", + "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " + "args if unused") .set_attr("FInferStructInfo", InferStructInfoCallTIRLocalView) .set_attr("FPurity", true); @@ -221,6 +228,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.dist.redistribute_replica_to_shard") .set_num_inputs(1) + .add_argument("input", "Tensor", "The buffer to be sliced.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoRtoS) .set_attr("dist.FInferStructInfo", InferDistStructInfoRtoS) diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc index 3279b16f0679..db8a8c3c43ee 100644 --- a/src/relax/op/image/resize.cc +++ b/src/relax/op/image/resize.cc @@ -143,6 +143,8 @@ InferLayoutOutput InferLayoutResize2d( TVM_REGISTER_OP("relax.image.resize2d") .set_attrs_type() .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("size", "Shape", "The output image shape.") .set_attr("FInferStructInfo", InferStructInfoResize2D) .set_attr("FRelaxInferLayout", InferLayoutResize2d) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -254,6 +256,8 @@ InferLayoutOutput InferLayoutResize3d( TVM_REGISTER_OP("relax.image.resize3d") .set_attrs_type() .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("size", "Shape", "The output image shape.") .set_attr("FInferStructInfo", InferStructInfoResize3D) .set_attr("FRelaxInferLayout", InferLayoutResize3d) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -331,6 +335,8 @@ StructInfo InferStructInfoGridSample(const Call& call, const BlockBuilder& ctx) TVM_REGISTER_OP("relax.image.grid_sample") .set_attrs_type() .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("grid", "Tensor", "The grid tensor for sampling.") .set_attr("FInferStructInfo", InferStructInfoGridSample) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -421,6 +427,8 @@ StructInfo InferStructInfoAffineGrid(const Call& call, const BlockBuilder& ctx) TVM_REGISTER_OP("relax.image.affine_grid") .set_num_inputs(2) + .add_argument("data", "Tensor", "The input affine matrix tensor.") + .add_argument("size", "Shape", "The target output shape (H, W).") .set_attr("FInferStructInfo", InferStructInfoAffineGrid) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index e7ceb2fd68dd..f19c55b5d2ec 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -151,6 +151,9 @@ Call InferMixedPrecisionAttention(const Call& call, const DataType& out_dtype) { TVM_REGISTER_OP("relax.nn.attention") .set_attrs_type() .set_num_inputs(3) + .add_argument("query", "Tensor", "The input queries tensor.") + .add_argument("key", "Tensor", "The input keys tensor.") + .add_argument("value", "Tensor", "The input values tensor.") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionAttention) .set_attr("FInferStructInfo", InferStructInfoAttention) @@ -159,6 +162,10 @@ TVM_REGISTER_OP("relax.nn.attention") TVM_REGISTER_OP("relax.nn.attention_bias") .set_attrs_type() .set_num_inputs(4) + .add_argument("query", "Tensor", "The input queries tensor.") + .add_argument("key", "Tensor", "The input keys tensor.") + .add_argument("value", "Tensor", "The input values tensor.") + .add_argument("bias", "Tensor", "The input bias tensor.") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionAttention) .set_attr("FInferStructInfo", InferStructInfoAttention) @@ -167,6 +174,13 @@ TVM_REGISTER_OP("relax.nn.attention_bias") TVM_REGISTER_OP("relax.nn.attention_var_len") .set_attrs_type() .set_num_inputs(7) + .add_argument("query", "Tensor", "The input queries tensor.") + .add_argument("key", "Tensor", "The input keys tensor.") + .add_argument("value", "Tensor", "The input values tensor.") + .add_argument("seqstart_q", "Tensor", "The cumsum of query sequence lengths, prepended with 0.") + .add_argument("seqstart_k", "Tensor", "The cumsum of key sequence lengths, prepended with 0.") + .add_argument("max_seqlen_q", "Tensor", "The maximum query sequence length in the batch.") + .add_argument("max_seqlen_k", "Tensor", "The maximum key sequence length in the batch.") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionAttention) .set_attr("FInferStructInfo", InferStructInfoAttention) diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index 367fe762112d..1b77b4225203 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -195,6 +195,8 @@ Call InferMixedPrecisionConv1d(const Call& call, const DataType& out_dtype) { TVM_REGISTER_OP("relax.nn.conv1d") .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoConv1d) .set_attr("FRelaxInferLayout", InferLayoutConv1d) @@ -401,6 +403,8 @@ Call InferMixedPrecisionConv2d(const Call& call, const DataType& out_dtype) { TVM_REGISTER_OP("relax.nn.conv2d") .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoConv2d) .set_attr("FRelaxInferLayout", InferLayoutConv2d) @@ -581,6 +585,8 @@ Call InferMixedPrecisionConv3d(const Call& call, const DataType& out_dtype) { TVM_REGISTER_OP("relax.nn.conv3d") .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoConv3d) .set_attr("FRelaxInferLayout", InferLayoutConv3d) @@ -761,6 +767,8 @@ Call InferMixedPrecisionConv1dTranspose(const Call& call, const DataType& out_dt TVM_REGISTER_OP("relax.nn.conv1d_transpose") .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoConv1dTranspose) .set_attr("FRelaxInferLayout", InferLayoutConv1dTranspose) @@ -990,6 +998,8 @@ Call InferMixedPrecisionConv2dTranspose(const Call& call, const DataType& out_dt TVM_REGISTER_OP("relax.nn.conv2d_transpose") .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoConv2dTranspose) .set_attr("FRelaxInferLayout", InferLayoutConv2dTranspose) @@ -1230,6 +1240,8 @@ Call InferMixedPrecisionConv3dTranspose(const Call& call, const DataType& out_dt TVM_REGISTER_OP("relax.nn.conv3d_transpose") .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoConv3dTranspose) .set_attr("FRelaxInferLayout", InferLayoutConv3dTranspose) diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index cc2e7b406c90..b6e2051a68f7 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -75,6 +75,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.nn.leakyrelu") .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoUnaryArith) @@ -97,6 +98,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.nn.softplus") .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoUnaryArith) @@ -159,6 +161,8 @@ InferLayoutOutput InferLayoutPRelu( TVM_REGISTER_OP("relax.nn.prelu") .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("alpha", "Tensor", "The channel-wise learnable slope.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPRelu) .set_attr("FRelaxInferLayout", InferLayoutPRelu) @@ -220,6 +224,7 @@ InferLayoutOutput InferLayoutSoftmax( TVM_REGISTER_OP("relax.nn.softmax") .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoSoftmax) .set_attr("FRelaxInferLayout", InferLayoutSoftmax) @@ -240,6 +245,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.nn.log_softmax") .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoSoftmax) .set_attr("FPurity", true); @@ -286,6 +292,7 @@ StructInfo InferStructInfoPad(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.nn.pad") .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPad) .set_attr("FPurity", true); @@ -357,6 +364,7 @@ StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx TVM_REGISTER_OP("relax.nn.pixel_shuffle") .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPixelShuffle) .set_attr("FPurity", true); @@ -366,7 +374,7 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, const ffi::Array& input_sinfo, ffi::Array axes) { Op op = Downcast(call->op); - int n_input = op->num_inputs; + int n_input = op->arguments.size(); TensorStructInfo data_sinfo = input_sinfo[0]; @@ -386,13 +394,13 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, if (input_sinfo[i]->dtype != data_sinfo->dtype) { ctx->ReportFatal(Diagnostic::Error(call) << op - << " requires all the input tensors to have the same dtype. However, input[" - << i << "] has dtype " << input_sinfo[i]->dtype + << " requires all the input tensors to have the same dtype. However, the " + << op->arguments[i]->name << " has dtype " << input_sinfo[i]->dtype << " which is other than the input data's dtype " << data_sinfo->dtype); } else if (input_sinfo[i]->ndim != n_axis) { ctx->ReportFatal(Diagnostic::Error(call) - << op << " requires input[" << i - << "] to have as many dimensions as the length of input axes. However, the " + << op << " requires the input " << op->arguments[i]->name + << " to have as many dimensions as the length of input axes. However, the " "given one has ndim " << input_sinfo[i]->ndim << ", which is other than the length of axes " << n_axis); @@ -506,6 +514,11 @@ InferLayoutOutput InferLayoutBatchNorm( TVM_REGISTER_OP("relax.nn.batch_norm") .set_attrs_type() .set_num_inputs(5) + .add_argument("data", "Tensor", "Input to which batch_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .add_argument("moving_mean", "Tensor", "Running mean of input.") + .add_argument("moving_var", "Tensor", "Running variance of input.") .set_attr("FInferStructInfo", InferStructInfoBatchNorm) .set_attr("FRelaxInferLayout", InferLayoutBatchNorm) .set_attr("FPurity", true); @@ -570,6 +583,9 @@ InferLayoutOutput InferLayoutLayerNorm( TVM_REGISTER_OP("relax.nn.layer_norm") .set_attrs_type() .set_num_inputs(3) + .add_argument("data", "Tensor", "Input to which layer_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") .set_attr("FInferStructInfo", InferStructInfoLayerNorm) .set_attr("FRelaxInferLayout", InferLayoutLayerNorm) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -626,7 +642,7 @@ StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) { << op << " expects that the size of channel_axis must be divisible by " << attrs->num_groups << ", but got " << data_shape->values[channel_axis]); } - for (int i = 1; i < op->num_inputs; ++i) { + for (int i = 1; i < static_cast(op->arguments.size()); ++i) { if (input_sinfo[i]->dtype != data_sinfo->dtype) { ctx->ReportFatal(Diagnostic::Error(call) << op << " expects that all inputs must have the same dtype, but got " @@ -681,6 +697,9 @@ InferLayoutOutput InferLayoutGroupNorm( TVM_REGISTER_OP("relax.nn.group_norm") .set_attrs_type() .set_num_inputs(3) + .add_argument("data", "Tensor", "Input to which group_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") .set_attr("FInferStructInfo", InferStructInfoGroupNorm) .set_attr("FRelaxInferLayout", InferLayoutGroupNorm) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -727,7 +746,7 @@ StructInfo InferStructInfoInstanceNorm(const Call& call, const BlockBuilder& ctx } const auto* data_shape = data_sinfo->shape.as(); arith::Analyzer* analyzer = ctx->GetAnalyzer(); - for (int i = 1; i < op->num_inputs; ++i) { + for (int i = 1; i < static_cast(op->arguments.size()); ++i) { if (input_sinfo[i]->dtype != data_sinfo->dtype) { ctx->ReportFatal(Diagnostic::Error(call) << op << " expects that all inputs must have the same dtype, but got " @@ -781,6 +800,9 @@ InferLayoutOutput InferLayoutInstanceNorm( TVM_REGISTER_OP("relax.nn.instance_norm") .set_attrs_type() .set_num_inputs(3) + .add_argument("data", "Tensor", "Input to which instance_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") .set_attr("FInferStructInfo", InferStructInfoInstanceNorm) .set_attr("FRelaxInferLayout", InferLayoutInstanceNorm) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -839,6 +861,8 @@ InferLayoutOutput InferLayoutRMSNorm( TVM_REGISTER_OP("relax.nn.rms_norm") .set_attrs_type() .set_num_inputs(2) + .add_argument("data", "Tensor", "Input to which rms_norm will be applied.") + .add_argument("weight", "Tensor", "The scale factor.") .set_attr("FInferStructInfo", InferStructInfoRMSNorm) .set_attr("FRelaxInferLayout", InferLayoutRMSNorm) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -867,6 +891,7 @@ StructInfo InferStructInfoDropout(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.nn.dropout") .set_attrs_type() .set_num_inputs(1) + .add_argument("data", "Tensor", "Input to which dropout will be applied.") .set_attr("FInferStructInfo", InferStructInfoDropout) .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -931,6 +956,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.nn.cross_entropy_with_logits") .set_num_inputs(2) + .add_argument("predictions", "Tensor", "The predictions.") + .add_argument("labels", "Tensor", "The labels.") .set_attr("FInferStructInfo", InferStructInfoCrossEntropy) .set_attr("FPurity", true); @@ -1160,6 +1187,9 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.nn.nll_loss") .set_attrs_type() .set_num_inputs(3) + .add_argument("predictions", "Tensor", "The prediction tensor.") + .add_argument("targets", "Tensor", "The target tensor.") + .add_argument("weights", "ffi::Optional", "The weight of each target values.") .set_attr("FInferStructInfo", InferStructInfoNLLLoss) .set_attr("FPurity", true); @@ -1208,6 +1238,7 @@ StructInfo InferStructInfoBatchFlatten(const Call& call, const BlockBuilder& ctx TVM_REGISTER_OP("relax.nn.batch_flatten") .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoBatchFlatten) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index 16dba700b0da..60430519111d 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -144,6 +144,7 @@ InferLayoutOutput InferLayoutPool1d( TVM_REGISTER_OP("relax.nn.max_pool1d") .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPool1D) .set_attr("FRelaxInferLayout", InferLayoutPool1d) @@ -294,6 +295,7 @@ InferLayoutOutput InferLayoutPool2d( TVM_REGISTER_OP("relax.nn.max_pool2d") .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPool2D) .set_attr("FRelaxInferLayout", InferLayoutPool2d) @@ -439,6 +441,7 @@ InferLayoutOutput InferLayoutPool3d( TVM_REGISTER_OP("relax.nn.max_pool3d") .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPool3D) .set_attr("FRelaxInferLayout", InferLayoutPool3d) @@ -460,6 +463,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.nn.avg_pool1d") .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPool1D) .set_attr("FRelaxInferLayout", InferLayoutPool1d) @@ -481,6 +485,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.nn.avg_pool2d") .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPool2D) .set_attr("FRelaxInferLayout", InferLayoutPool2d) @@ -502,6 +507,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.nn.avg_pool3d") .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPool3D) .set_attr("FRelaxInferLayout", InferLayoutPool3d) @@ -584,6 +590,7 @@ InferLayoutOutput InferLayoutAdaptiveAvgPool1D( TVM_REGISTER_OP("relax.nn.adaptive_avg_pool1d") .set_attrs_type() .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") .set_attr("FInferStructInfo", InferStructInfoAdaptiveAvgPool1D) .set_attr("FRelaxInferLayout", InferLayoutAdaptiveAvgPool1D) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -685,6 +692,7 @@ InferLayoutOutput InferLayoutAdaptiveAvgPool2D( TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") .set_attrs_type() .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") .set_attr("FInferStructInfo", InferStructInfoAdaptiveAvgPool2D) .set_attr("FRelaxInferLayout", InferLayoutAdaptiveAvgPool2D) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -771,6 +779,7 @@ InferLayoutOutput InferLayoutAdaptiveAvgPool3D( TVM_REGISTER_OP("relax.nn.adaptive_avg_pool3d") .set_attrs_type() .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") .set_attr("FInferStructInfo", InferStructInfoAdaptiveAvgPool3D) .set_attr("FRelaxInferLayout", InferLayoutAdaptiveAvgPool3D) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 21b336db1edc..8a28ab361af2 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -115,6 +115,9 @@ StructInfo InferStructInfoCallPurePacked(const Call& call, const BlockBuilder& c TVM_REGISTER_OP("relax.call_pure_packed") .set_num_inputs(-1) + .add_argument("args", "ffi::Array", + "The first argument is the function being called. The rest are the " + "arguments to that function.") .set_attr("FInferStructInfo", InferStructInfoCallPurePacked) .set_attr("FPurity", true); @@ -228,6 +231,9 @@ StructInfo InferStructInfoCallInplacePacked(const Call& call, const BlockBuilder TVM_REGISTER_OP("relax.call_inplace_packed") .set_num_inputs(-1) .set_attrs_type() + .add_argument("args", "ffi::Array", + "The first argument is the function being called. The rest are the " + "arguments to that function.") .set_attr("FInferStructInfo", InferStructInfoCallInplacePacked) // Warning: considered pure, but it has the potential to create visible effects! // This should only be used if it has been *checked* that it is safe (no aliases, in-place @@ -570,6 +576,11 @@ void ValidateCallTIR(Call call) { TVM_REGISTER_OP("relax.call_tir") .set_num_inputs(3) + .add_argument("func", "Expr", "The destination-passing-style function.") + .add_argument("args", "Tuple", "The input arguments.") + .add_argument("packed_ints", "Expr", + "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " + "args if unused") .set_attr("FInferStructInfo", InferStructInfoCallTIR) .set_attr("FNormalize", NormalizeCallTIR) .set_attr("FValidate", ValidateCallTIR) @@ -613,6 +624,11 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.call_tir_with_grad") .set_num_inputs(3) .set_attrs_type() + .add_argument("func", "Expr", "The destination-passing-style function.") + .add_argument("args", "Tuple", "The input arguments.") + .add_argument("packed_ints", "Expr", + "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " + "args if unused") .set_attr("FInferStructInfo", InferStructInfoCallTIR) .set_attr("FNormalize", NormalizeCallTIR) .set_attr("FValidate", ValidateCallTIR) @@ -750,6 +766,11 @@ Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) { TVM_REGISTER_OP("relax.call_tir_inplace") .set_num_inputs(3) .set_attrs_type() + .add_argument("func", "Expr", "The destination-passing-style function.") + .add_argument("args", "Tuple", "The input arguments.") + .add_argument("packed_ints", "Expr", + "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " + "args if unused") .set_attr("FInferStructInfo", InferStructInfoCallTIR) .set_attr("FNormalize", NormalizeCallTIRInPlace) .set_attr("FValidate", ValidateCallTIR) @@ -807,6 +828,8 @@ StructInfo InferStructInfoCallDPSPacked(const Call& call, const BlockBuilder& ct TVM_REGISTER_OP("relax.call_dps_packed") .set_num_inputs(2) + .add_argument("func", "Expr", "The destination-passing-style function.") + .add_argument("args", "Tuple", "The input arguments.") .set_attr("FInferStructInfo", InferStructInfoCallDPSPacked) // technically, an impure op could be used with this, but there is // little reason to use DPS with an impure op @@ -871,6 +894,8 @@ void ValidateCallPyFunc(Call call) { TVM_REGISTER_OP("relax.call_py_func") .set_num_inputs(2) + .add_argument("func_name", "StringImm", "The name of the Python function to call.") + .add_argument("args", "Tuple", "The input arguments.") .set_attr("FInferStructInfo", InferStructInfoCallPyFunc) .set_attr("FValidate", ValidateCallPyFunc) .set_attr("FPurity", true); @@ -913,6 +938,8 @@ StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilde TVM_REGISTER_OP("relax.call_builtin_with_ctx") .set_num_inputs(4) + .add_argument("func", "Expr", "The builtin packed func.") + .add_argument("args", "Tuple", "The input arguments.") .set_attr("FInferStructInfo", InferStructInfoCallBuiltinWithCtx) // Most builtins are pure, but some are not, like `vm.builtin.attention_kv_cache_append` .set_attr("FPurity", false); @@ -946,6 +973,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.print") .set_num_inputs(-1) + .add_argument("vals", "ffi::Array", + "The first value is Python-style format string to use to print. The others " + "are values to print") .set_attr("FInferStructInfo", ReturnVoidStructInfo) .set_attr("FCallPacked", "relax.run.print") .set_attr("FPurity", false); @@ -988,6 +1018,10 @@ StructInfo InferAssertStructInfo(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.assert_op") .set_num_inputs(-1) + .add_argument("vals", "ffi::Array", + "The first value is used as the assertion condition. The second value is " + "Python-style format string to use for displaying an error message, if the " + "assert fails. The others are used as format arguments if there is an error.") .set_attr("FInferStructInfo", InferAssertStructInfo) .set_attr("FCallPacked", "relax.run.assert_op") .set_attr("FPurity", false); @@ -1011,6 +1045,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.make_closure") .set_num_inputs(2) + .add_argument("func", "Expr", "The closure.") + .add_argument("args", "Tuple", "The captured variables.") .set_attr("FInferStructInfo", ReturnObjectStructInfo) .set_attr("FPurity", true); @@ -1038,6 +1074,8 @@ StructInfo InferStructInfoInvokeClosure(const Call& call, const BlockBuilder& ct TVM_REGISTER_OP("relax.invoke_closure") .set_num_inputs(2) + .add_argument("closure", "Expr", "The VMClosure.") + .add_argument("args", "Tuple", "The captured variables.") .set_attr("FInferStructInfo", InferStructInfoInvokeClosure) // Not all closures are pure. Use invoke_pure_closure for specifying purity .set_attr("FPurity", false); @@ -1056,6 +1094,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.invoke_pure_closure") .set_num_inputs(2) + .add_argument("closure", "Expr", "The VMClosure.") + .add_argument("args", "Tuple", "The captured variables.") .set_attr("FInferStructInfo", InferStructInfoInvokeClosure) .set_attr("FPurity", true); @@ -1073,6 +1113,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.shape_of") .set_num_inputs(1) + .add_argument("input", "Expr", "The input expression") .set_attr("FInferStructInfo", InferStructInfoShapeOf) .set_attr("FPurity", true); @@ -1098,6 +1139,7 @@ StructInfo InferStructInfoSize(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.size") .set_num_inputs(1) + .add_argument("input", "Expr", "The input tensor") .set_attr("FInferStructInfo", InferStructInfoSize) .set_attr("FPurity", true); @@ -1134,6 +1176,7 @@ StructInfo ReturnTensorToShapeStructInfo(const Call& call, const BlockBuilder& c TVM_REGISTER_OP("relax.tensor_to_shape") .set_num_inputs(1) + .add_argument("input", "Expr", "The input expression") .set_attr("FInferStructInfo", ReturnTensorToShapeStructInfo) .set_attr("FPurity", true); @@ -1159,6 +1202,7 @@ StructInfo ReturnShapeToTensorStructInfo(const Call& call, const BlockBuilder& c TVM_REGISTER_OP("relax.shape_to_tensor") .set_num_inputs(1) + .add_argument("input", "Expr", "The input expression") .set_attr("FInferStructInfo", ReturnShapeToTensorStructInfo) .set_attr("FCallPacked", "relax.run.shape_to_tensor") .set_attr("FPurity", true); @@ -1199,6 +1243,13 @@ StructInfo InferStructInfoAllocateTensor(const Call& call, const BlockBuilder& c TVM_REGISTER_OP("relax.builtin.alloc_tensor") .set_num_inputs(4) + .add_argument("shape", "Expr", "The shape of the tensor to allocate.") + .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.") + .add_argument("runtime_device_index", "PrimValue", + "The device index indicating on which device the tensor is to be " + "allocated at runtime. Index -1 is reserved for the host device.") + .add_argument("storage_scope", "StringImm", + "The storage scope of the storage to allocate. Default is global.") .set_attr("FInferStructInfo", InferStructInfoAllocateTensor) // memory allocation isn't considered a "visible effect" as far as purity is concerned .set_attr("FPurity", true) @@ -1219,6 +1270,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.memory.alloc_storage") .set_num_inputs(4) + .add_argument("total_space", "Expr", "The total space of the storage to allocate.") + .add_argument( + "virtual_device_index", "PrimValue", + "The virtual device index indicating on which device the storage is to be allocated, " + "Index -1 is reserved for the host device.") + .add_argument("storage_scope", "StringImm", + "The storage scope of the storage to allocate. Default is global.") + .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.") .set_attr("FInferStructInfo", ReturnObjectStructInfo) // memory allocation isn't considered a "visible effect" as far as purity is concerned .set_attr("FPurity", true) @@ -1262,6 +1321,13 @@ StructInfo InferStructInfoMemAllocTensor(const Call& call, const BlockBuilder& c TVM_REGISTER_OP("relax.memory.alloc_tensor") .set_num_inputs(5) + .add_argument("storage", "Expr", "The storage to allocate the tensor to.") + .add_argument("offset", "PrimValue", "Storage offset to allocate the tensor.") + .add_argument("shape", "Expr", "The shape of the tensor to allocate.") + .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.") + .add_argument("runtime_device_index", "PrimValue", + "The device index indicating on which device the tensor is to be " + "allocated at runtime. Index -1 is reserved for the host device.") .set_attr("FInferStructInfo", InferStructInfoMemAllocTensor) // memory allocation isn't considered a "visible effect" as far as purity is concerned .set_attr("FPurity", true) @@ -1293,6 +1359,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.memory.kill_storage") .set_num_inputs(1) + .add_argument("storage", "Expr", "The storage to be killed.") .set_attr("FInferStructInfo", ReturnVoidStructInfo) // We mark this as impure so it wouldn't be removed by "remove_all_unused" .set_attr("FPurity", false); @@ -1311,6 +1378,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.memory.kill_tensor") .set_num_inputs(1) + .add_argument("tensor", "Expr", "The tensor to be killed.") .set_attr("FInferStructInfo", ReturnVoidStructInfo) // We mark this as impure so it wouldn't be removed by "remove_all_unused" .set_attr("FPurity", false); @@ -1329,6 +1397,13 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.vm.alloc_storage") .set_num_inputs(4) + .add_argument("size", "Expr", "The size of the storage to allocate.") + .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.") + .add_argument("runtime_device_index", "PrimValue", + "The device index indicating on which device the tensor is " + "to be allocated at runtime.") + .add_argument("storage_scope", "StringImm", + "The storage scope of the storage to allocate. Default is global.") .set_attr("FInferStructInfo", ReturnObjectStructInfo) // memory allocation isn't considered a "visible effect" as far as purity is concerned .set_attr("FPurity", true) @@ -1373,6 +1448,13 @@ StructInfo InferStructInfoVMAllocTensor(const Call& call, const BlockBuilder& ct TVM_REGISTER_OP("relax.vm.alloc_tensor") .set_num_inputs(5) + .add_argument("storage", "Expr", "The storage to allocate the tensor to.") + .add_argument("offset", "PrimValue", "Storage offset to allocate the tensor.") + .add_argument("shape", "Expr", "The shape of the tensor to allocate.") + .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.") + .add_argument("runtime_device_index", "PrimValue", + "The device index indicating on which device the tensor is " + "to be allocated at runtime.") .set_attr("FInferStructInfo", InferStructInfoVMAllocTensor) // memory allocation isn't considered a "visible effect" as far as purity is concerned .set_attr("FPurity", true) @@ -1402,6 +1484,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { // vm kill_object TVM_REGISTER_OP("relax.vm.kill_object") .set_num_inputs(1) + .add_argument("obj", "Expr", "The object to be killed.") .set_attr("FInferStructInfo", ReturnVoidStructInfo) // We mark this as impure so it wouldn't be removed by "remove_all_unused" .set_attr("FPurity", false); @@ -1420,6 +1503,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.vm.call_tir_dyn") .set_num_inputs(2) + .add_argument("func", "Expr", "The destination-passing-style function.") + .add_argument("args", "Tuple", + "The input arguments (list of tensors and last argument is ShapeExpr)") .set_attr("FInferStructInfo", ReturnVoidStructInfo) // "relax.vm.call_tir_dyn" works in an in-place way, which is impure. .set_attr("FPurity", false); @@ -1441,6 +1527,7 @@ StructInfo InferStructInfoStopLiftParams(const Call& call, const BlockBuilder& c TVM_REGISTER_OP("relax.builtin.stop_lift_params") .set_num_inputs(1) + .add_argument("x", "Expr", "The input data") .set_attr("FInferStructInfo", InferStructInfoStopLiftParams) .set_attr("FPurity", true); @@ -1471,6 +1558,7 @@ StructInfo InferToVDeviceStructInfo(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.to_vdevice") .set_num_inputs(1) .set_attrs_type() + .add_argument("data", "Expr", "The input expression to be copied") .set_attr("FInferStructInfo", InferToVDeviceStructInfo) .set_attr("FPurity", true); @@ -1498,6 +1586,7 @@ StructInfo InferHintOnDeviceStructInfo(const Call& call, const BlockBuilder& ctx TVM_REGISTER_OP("relax.hint_on_device") .set_num_inputs(1) .set_attrs_type() + .add_argument("data", "Expr", "The input expression") .set_attr("FInferStructInfo", InferHintOnDeviceStructInfo) .set_attr("FPurity", true); diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc index a700b40bd4ea..61485b09112b 100644 --- a/src/relax/op/op_common.cc +++ b/src/relax/op/op_common.cc @@ -39,7 +39,7 @@ ffi::Array GetCallArgs(const Call& call) { void CheckNumArguments(const Call& call, const BlockBuilder& ctx) { Op op = Downcast(call->op); - int expected_input = op->num_inputs; + int expected_input = op->arguments.size(); if (static_cast(call->args.size()) != expected_input) { ctx->ReportFatal(Diagnostic::Error(call) << "Operator " << op << " expects " << expected_input << " arguments" @@ -50,10 +50,10 @@ void CheckNumArguments(const Call& call, const BlockBuilder& ctx) { TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const BlockBuilder& ctx) { Op op = Downcast(call->op); - TVM_FFI_ICHECK_EQ(static_cast(op->num_inputs), call->args.size()) + TVM_FFI_ICHECK_EQ(op->arguments.size(), call->args.size()) << "Failure caught by this check " << "should have previously been caught by `CheckNumArguments`"; - TVM_FFI_ICHECK_LT(i_arg, static_cast(op->num_inputs)); + TVM_FFI_ICHECK_LT(i_arg, op->arguments.size()); auto arg = call->args[i_arg]; auto sinfo = GetStructInfo(arg); @@ -62,8 +62,8 @@ TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const return tensor_sinfo.value(); } else { ctx->ReportFatal(Diagnostic::Error(call) - << "Operator " << op << " requires argument input[" << i_arg - << "] to be a tensor. " + << "Operator " << op << " requires argument " << i_arg << " (" + << op->arguments[i_arg]->name << ") to be a tensor. " << "However, the argument " << arg << " is instead of type " << sinfo); // Unreachable, but [[noreturn]] attribute on virtual function // `ReportFatal` is insufficient to silence -Wreturn-type, as diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 88630a9043d6..774eccfd58dd 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -142,10 +142,10 @@ std::tuple GetArgStructInfoHelper(const Call& call, const Op& op, template std::tuple GetArgStructInfo(const Call& call, const BlockBuilder& ctx) { Op op = Downcast(call->op); - size_t n_input = op->num_inputs; + size_t n_input = op->arguments.size(); - // Unfortunately, because the `.set_num_inputs()` call in - // TVM_REGISTER_OP occurs during initialization of globals and is + // Unfortunately, because the `.add_argument()` calls in + // TVM_REGISTER_OP occur during initialization of globals and are // not available at compile-time, this cannot be a static_assert. TVM_FFI_ICHECK_EQ(n_input, sizeof...(ArgTypes)) << "Internal error: " << op << " op defines " << n_input @@ -166,6 +166,7 @@ std::tuple GetArgStructInfo(const Call& call, const BlockBuilder& c #define RELAX_REGISTER_UNARY_OP(OpRegName) \ TVM_REGISTER_OP("relax." OpRegName) \ .set_num_inputs(1) \ + .add_argument("x", "Tensor", "The input tensor.") \ .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) \ .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) \ .set_attr("FPurity", true) @@ -234,7 +235,7 @@ inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx template StructInfo ReturnStructInfoFromArg(const Call& call, const BlockBuilder& ctx) { Op op = Downcast(call->op); - int n_input = op->num_inputs; + int n_input = op->arguments.size(); if (static_cast(call->args.size()) != n_input) { ctx->ReportFatal(Diagnostic::Error(call) << op << " op should have " << n_input << " arguments"); diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index 4575fba43385..07c3364a9f35 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -35,7 +35,7 @@ template StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, FType f_compute_out_dtype) { Op op = Downcast(call->op); - size_t n_input = op->num_inputs; + size_t n_input = op->arguments.size(); if (call->args.size() != n_input) { ctx->ReportFatal(Diagnostic::Error(call) << call->op << " op should have " << n_input << " arguments"); diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index 6f4068b3ccb8..a234a30bc221 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -47,6 +47,8 @@ namespace relax { } \ TVM_REGISTER_OP("relax." #OpName) \ .set_num_inputs(2) \ + .add_argument("x1", "Tensor", "The first input tensor.") \ + .add_argument("x2", "Tensor", "The second input tensor.") \ .set_attr("FRelaxInferLayout", InferLayoutBinaryEwise) \ .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) \ .set_attr("FPurity", true) diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index 72de90b55024..885f7c87257e 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -94,6 +94,8 @@ StructInfo InferStructInfoFull(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.full") .set_attrs_type() .set_num_inputs(2) + .add_argument("shape", "Shape", "The shape of the created tensor.") + .add_argument("fill_value", "Tensor", "The scalar tensor, denoting the value to fill.") .set_attr("FInferStructInfo", InferStructInfoFull) .set_attr("RequiresArgumentShapes", false) .set_attr("FDataDependent", true) @@ -136,6 +138,8 @@ StructInfo InferStructInfoFullLike(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.full_like") .set_attrs_type() .set_num_inputs(2) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("fill_value", "Tensor", "The scalar value to fill.") .set_attr("FInferStructInfo", InferStructInfoFullLike) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -195,6 +199,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.ones") .set_attrs_type() .set_num_inputs(1) + .add_argument("shape", "Shape", "The shape of the created tensor.") .set_attr("FInferStructInfo", InferStructInfoOnesZeros) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -202,6 +207,7 @@ TVM_REGISTER_OP("relax.ones") TVM_REGISTER_OP("relax.ones_like") .set_attrs_type() .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoOnesLikeZerosLike) .set_attr("FPurity", true); @@ -230,6 +236,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.zeros") .set_attrs_type() .set_num_inputs(1) + .add_argument("shape", "Shape", "The shape of the created tensor.") .set_attr("FInferStructInfo", InferStructInfoOnesZeros) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -237,6 +244,7 @@ TVM_REGISTER_OP("relax.zeros") TVM_REGISTER_OP("relax.zeros_like") .set_attrs_type() .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoOnesLikeZerosLike) .set_attr("FPurity", true); @@ -311,6 +319,9 @@ StructInfo InferStructInfoEyeLike(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.eye") .set_attrs_type() .set_num_inputs(3) + .add_argument("n", "PrimValue", "Number of rows in the output.") + .add_argument("m", "PrimValue", "Number of columns in the output.") + .add_argument("k", "PrimValue", "Index of the diagonal.") .set_attr("FInferStructInfo", InferStructInfoEye) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -318,6 +329,8 @@ TVM_REGISTER_OP("relax.eye") TVM_REGISTER_OP("relax.eye_like") .set_attrs_type() .set_num_inputs(2) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("k", "PrimValue", "Index of the diagonal.") .set_attr("FInferStructInfo", InferStructInfoEyeLike) .set_attr("FPurity", true); @@ -369,6 +382,9 @@ StructInfo InferStructInfoArange(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.arange") .set_attrs_type() .set_num_inputs(3) + .add_argument("start", "PrimValue", "The starting value for the set of points.") + .add_argument("end", "PrimValue", "The ending value for the set of points.") + .add_argument("step", "PrimValue", "The gap between each pair of adjacent points.") .set_attr("FInferStructInfo", InferStructInfoArange) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -417,6 +433,12 @@ StructInfo InferStructInfoHammingWindow(const Call& call, const BlockBuilder& ct TVM_REGISTER_OP("relax.hamming_window") .set_attrs_type() .set_num_inputs(4) + .add_argument("window_size", "PrimValue", "The size of the window") + .add_argument("periodic", "PrimValue", + "If True, returns a window to be used as periodic function. If False, return a " + "symmetric window") + .add_argument("alpha", "PrimValue", "The coefficient alpha") + .add_argument("beta", "PrimValue", "The coefficient beta") .set_attr("FInferStructInfo", InferStructInfoHammingWindow) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -458,11 +480,15 @@ StructInfo InferStructInfoTrilTriu(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.tril") .set_num_inputs(2) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("k", "PrimValue", "The offset of the diagonal.") .set_attr("FInferStructInfo", InferStructInfoTrilTriu) .set_attr("FPurity", true); TVM_REGISTER_OP("relax.triu") .set_num_inputs(2) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("k", "PrimValue", "The offset of the diagonal.") .set_attr("FInferStructInfo", InferStructInfoTrilTriu) .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/datatype.cc b/src/relax/op/tensor/datatype.cc index 8f51de26b60b..50624355c8fe 100644 --- a/src/relax/op/tensor/datatype.cc +++ b/src/relax/op/tensor/datatype.cc @@ -63,6 +63,7 @@ StructInfo InferStructInfoAstype(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.astype") .set_attrs_type() .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor") .set_attr("FInferStructInfo", InferStructInfoAstype) .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -95,6 +96,7 @@ StructInfo InferStructInfoWrapParam(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.wrap_param") .set_attrs_type() .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") .set_attr("FInferStructInfo", InferStructInfoWrapParam) .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/grad.cc b/src/relax/op/tensor/grad.cc index 74d5660fa09b..b05757c7de5e 100644 --- a/src/relax/op/tensor/grad.cc +++ b/src/relax/op/tensor/grad.cc @@ -48,6 +48,7 @@ StructInfo InferStructInfoNoGrad(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.grad.no_grad") .set_num_inputs(1) + .add_argument("x", "Expr", "The corresponding input tensor.") .set_attr("FInferStructInfo", InferStructInfoNoGrad) .set_attr("FPurity", true); @@ -72,6 +73,7 @@ StructInfo InferStructInfoStartCheckpoint(const Call& call, const BlockBuilder& TVM_REGISTER_OP("relax.grad.start_checkpoint") .set_num_inputs(1) + .add_argument("x", "Expr", "The tensor marking the input of the checkpoint stage.") .set_attr("FInferStructInfo", InferStructInfoStartCheckpoint) .set_attr("FPurity", true); @@ -96,6 +98,7 @@ StructInfo InferStructInfoEndCheckpoint(const Call& call, const BlockBuilder& ct TVM_REGISTER_OP("relax.grad.end_checkpoint") .set_num_inputs(1) + .add_argument("x", "Expr", "The output of the checkpoint stage.") .set_attr("FInferStructInfo", InferStructInfoEndCheckpoint) .set_attr("FPurity", true); @@ -130,6 +133,10 @@ StructInfo InferStructInfoNLLLossBackward(const Call& call, const BlockBuilder& TVM_REGISTER_OP("relax.grad.nll_loss_backward") .set_attrs_type() .set_num_inputs(4) + .add_argument("output_grad", "Tensor", "The output gradient.") + .add_argument("predictions", "Tensor", "The prediction tensor.") + .add_argument("targets", "Tensor", "The target tensor.") + .add_argument("weights", "ffi::Optional", "The weight of each target values.") .set_attr("FInferStructInfo", InferStructInfoNLLLossBackward) .set_attr("FPurity", true); @@ -162,6 +169,8 @@ StructInfo InferStructInfoMaxPool2DBackward(const Call& call, const BlockBuilder TVM_REGISTER_OP("relax.grad.max_pool2d_backward") .set_num_inputs(2) + .add_argument("output_grad", "Tensor", "The output gradient.") + .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoMaxPool2DBackward) .set_attr("FPurity", true); @@ -195,6 +204,8 @@ StructInfo InferStructInfoAvgPool2DBackward(const Call& call, const BlockBuilder TVM_REGISTER_OP("relax.grad.avg_pool2d_backward") .set_num_inputs(2) + .add_argument("output_grad", "Tensor", "The output gradient.") + .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoAvgPool2DBackward) .set_attr("FPurity", true); @@ -221,6 +232,9 @@ StructInfo InferStructInfoTakeBackward(const Call& call, const BlockBuilder& ctx TVM_REGISTER_OP("relax.grad.take_backward") .set_attrs_type() .set_num_inputs(3) + .add_argument("output_grad", "Tensor", "The output gradient.") + .add_argument("x", "Tensor", "The source tensor.") + .add_argument("indices", "Tensor", "The indices of the values to extract.") .set_attr("FInferStructInfo", InferStructInfoTakeBackward) .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index f3d948212538..6b02ca050bea 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -130,6 +130,8 @@ StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.take") .set_attrs_type() .set_num_inputs(2) + .add_argument("x", "Tensor", "The source tensor.") + .add_argument("indices", "Tensor", "The indices of the values to extract.") .set_attr("FInferStructInfo", InferStructInfoTake) .set_attr("FPurity", true); @@ -478,6 +480,7 @@ InferLayoutOutput InferLayoutStridedSlice( TVM_REGISTER_OP("relax.strided_slice") .set_attrs_type() .set_num_inputs(1) + .add_argument("x", "Tensor", "The source tensor to be sliced.") .set_attr("FInferStructInfo", InferStructInfoStridedSlice) .set_attr("FRelaxInferLayout", InferLayoutStridedSlice) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -572,6 +575,10 @@ InferLayoutOutput InferLayoutDynStridedSlice( TVM_REGISTER_OP("relax.dynamic_strided_slice") .set_num_inputs(4) + .add_argument("x", "Tensor", "The source tensor to be sliced.") + .add_argument("begin", "Tensor", "The indices to begin with in the slicing.") + .add_argument("end", "Tensor", "Indices indicating end of the slice.") + .add_argument("strides", "Tensor", "The stride values.") .set_attr("FInferStructInfo", InferStructInfoDynStridedSlice) .set_attr("FRelaxInferLayout", InferLayoutDynStridedSlice) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc index f88342c93c11..3988e0ba2359 100644 --- a/src/relax/op/tensor/inspect.cc +++ b/src/relax/op/tensor/inspect.cc @@ -150,6 +150,7 @@ Expr LegalizeTensorDtypeCode(const BlockBuilder& bb, const Call& call) { TVM_REGISTER_OP("relax.inspect.tensor_dtype_code") .set_num_inputs(1) + .add_argument("tensor", "Tensor", "The tensor to be inspected") .set_attr("FInferStructInfo", InferStructInfoTensorDtypeCode) .set_attr("FLegalize", LegalizeTensorDtypeCode) .set_attr("RequiresArgumentShapes", false) @@ -187,6 +188,7 @@ Expr LegalizeTensorDtypeBits(const BlockBuilder& bb, const Call& call) { TVM_REGISTER_OP("relax.inspect.tensor_dtype_bits") .set_num_inputs(1) + .add_argument("tensor", "Tensor", "The tensor to be inspected") .set_attr("FInferStructInfo", InferStructInfoTensorDtypeBits) .set_attr("FLegalize", LegalizeTensorDtypeBits) .set_attr("RequiresArgumentShapes", false) @@ -224,6 +226,7 @@ Expr LegalizeTensorDtypeLanes(const BlockBuilder& bb, const Call& call) { TVM_REGISTER_OP("relax.inspect.tensor_dtype_lanes") .set_num_inputs(1) + .add_argument("tensor", "Tensor", "The tensor to be inspected") .set_attr("FInferStructInfo", InferStructInfoTensorDtypeLanes) .set_attr("FLegalize", LegalizeTensorDtypeLanes) .set_attr("RequiresArgumentShapes", false) @@ -261,6 +264,7 @@ Expr LegalizeTensorNDim(const BlockBuilder& bb, const Call& call) { TVM_REGISTER_OP("relax.inspect.tensor_ndim") .set_num_inputs(1) + .add_argument("tensor", "Tensor", "The tensor to be inspected") .set_attr("FInferStructInfo", InferStructInfoTensorNDim) .set_attr("FLegalize", LegalizeTensorNDim) .set_attr("RequiresArgumentShapes", false) @@ -338,6 +342,8 @@ Expr LegalizeTensorShape(const BlockBuilder& bb, const Call& call) { TVM_REGISTER_OP("relax.inspect.tensor_shape_i") .set_num_inputs(2) + .add_argument("tensor", "Tensor", "The tensor to be inspected") + .add_argument("axis", "Prim(int64)", "The axis whose extent should be returned") .set_attr("FInferStructInfo", InferStructInfoTensorShape) .set_attr("FLegalize", LegalizeTensorShape) .set_attr("RequiresArgumentShapes", false) @@ -385,6 +391,8 @@ StructInfo InferStructInfoTensorStride(const Call& call, const BlockBuilder&) { TVM_REGISTER_OP("relax.inspect.tensor_stride_i") .set_num_inputs(2) + .add_argument("tensor", "Tensor", "The tensor to be inspected") + .add_argument("axis", "Prim(int64)", "The axis whose extent should be returned") .set_attr("FInferStructInfo", InferStructInfoTensorStride) .set_attr("RequiresArgumentShapes", false) .set_attr("FNormalize", NormalizeToKnownPrimValue) @@ -415,6 +423,7 @@ StructInfo InferStructInfoTensorByteOffset(const Call& call, const BlockBuilder& TVM_REGISTER_OP("relax.inspect.tensor_byte_offset") .set_num_inputs(1) + .add_argument("tensor", "Tensor", "The tensor to be inspected") .set_attr("FInferStructInfo", InferStructInfoTensorByteOffset) .set_attr("RequiresArgumentShapes", false) .set_attr("FNormalize", NormalizeToKnownPrimValue) @@ -445,6 +454,7 @@ StructInfo InferStructInfoTensorElemOffset(const Call& call, const BlockBuilder& TVM_REGISTER_OP("relax.inspect.tensor_elem_offset") .set_num_inputs(1) + .add_argument("tensor", "Tensor", "The tensor to be inspected") .set_attr("FInferStructInfo", InferStructInfoTensorElemOffset) .set_attr("RequiresArgumentShapes", false) .set_attr("FNormalize", NormalizeToKnownPrimValue) diff --git a/src/relax/op/tensor/linear_algebra.cc b/src/relax/op/tensor/linear_algebra.cc index 108248e10f74..6936fa04348b 100644 --- a/src/relax/op/tensor/linear_algebra.cc +++ b/src/relax/op/tensor/linear_algebra.cc @@ -166,6 +166,8 @@ Call InferMixedPrecisionMatmul(const Call& call, const DataType& out_dtype) { TVM_REGISTER_OP("relax.matmul") .set_num_inputs(2) + .add_argument("x1", "Tensor", "The first input tensor.") + .add_argument("x2", "Tensor", "The second input tensor.") .set_attr("FInferStructInfo", InferStructInfoMatmul) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionMatmul) @@ -255,6 +257,7 @@ StructInfo InferStructInfoEinsum(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.einsum") .set_attrs_type() .set_num_inputs(1) + .add_argument("operands", "Tensor", "The input tensors.") .set_attr("FInferStructInfo", InferStructInfoEinsum) .set_attr("FPurity", true); @@ -293,6 +296,8 @@ StructInfo InferStructInfoOuter(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.outer") .set_num_inputs(2) + .add_argument("x1", "Tensor", "The first input tensor.") + .add_argument("x2", "Tensor", "The second input tensor.") .set_attr("FInferStructInfo", InferStructInfoOuter) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index ca60c2476f06..763e37ae6815 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -135,6 +135,8 @@ StructInfo InferStructInfoBroadcastTo(const Call& call, const BlockBuilder& ctx) TVM_REGISTER_OP("relax.broadcast_to") .set_num_inputs(2) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("shape", "Shape", "The target shape.") .set_attr("FInferStructInfo", InferStructInfoBroadcastTo) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -396,6 +398,7 @@ InferLayoutOutput InferLayoutConcat( TVM_REGISTER_OP("relax.concat") .set_attrs_type() .set_num_inputs(1) + .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.") .set_attr("FInferStructInfo", InferStructInfoConcat) .set_attr("FRelaxInferLayout", InferLayoutConcat) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -499,6 +502,7 @@ InferLayoutOutput InferLayoutExpandDims( TVM_REGISTER_OP("relax.expand_dims") .set_num_inputs(1) .set_attrs_type() + .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoExpandDims) .set_attr("FRelaxInferLayout", InferLayoutExpandDims) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -545,6 +549,7 @@ StructInfo InferStructInfoFlatten(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.flatten") .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoFlatten) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -693,6 +698,8 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) TVM_REGISTER_OP("relax.index_tensor") .set_num_inputs(2) + .add_argument("data", "Tensor", "The input data.") + .add_argument("indices", "List of Tensors", "The indices used to index.") .set_attr("FInferStructInfo", InferStructInfoIndexTensor) .set_attr("FPurity", true); @@ -765,6 +772,7 @@ StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& TVM_REGISTER_OP("relax.layout_transform") .set_num_inputs(1) .set_attrs_type() + .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoLayoutTransform) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -884,6 +892,7 @@ InferLayoutOutput InferLayoutPermuteDims( TVM_REGISTER_OP("relax.permute_dims") .set_attrs_type() .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoPermuteDims) .set_attr("FRelaxInferLayout", InferLayoutPermuteDims) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -1046,6 +1055,8 @@ StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.reshape") .set_num_inputs(2) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("shape", "Shape", "The input new shape.") .set_attr("FInferStructInfo", InferStructInfoReshape) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -1223,6 +1234,7 @@ InferLayoutOutput InferLayoutSplit( TVM_REGISTER_OP("relax.split") .set_attrs_type() .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoSplit) .set_attr("FRelaxInferLayout", InferLayoutSplit) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -1382,6 +1394,7 @@ InferLayoutOutput InferLayoutSqueeze( TVM_REGISTER_OP("relax.squeeze") .set_num_inputs(1) .set_attrs_type() + .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoSqueeze) .set_attr("FRelaxInferLayout", InferLayoutSqueeze) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -1634,6 +1647,7 @@ InferLayoutOutput InferLayoutStack( TVM_REGISTER_OP("relax.stack") .set_attrs_type() .set_num_inputs(1) + .add_argument("tensors", "Tuple of Tensors", "The input list of tensors to stack") .set_attr("FInferStructInfo", InferStructInfoStack) .set_attr("FRelaxInferLayout", InferLayoutStack) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -1682,6 +1696,9 @@ StructInfo InferStructInfoCollapseSumLike(const Call& call, const BlockBuilder& TVM_REGISTER_OP("relax.collapse_sum_like") .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("collapse_target", "Tensor", + "The tensor whose shape is the shape to collapse to.") .set_attr("FInferStructInfo", InferStructInfoCollapseSumLike) .set_attr("FPurity", true); @@ -1732,6 +1749,8 @@ StructInfo InferStructInfoCollapseSumTo(const Call& call, const BlockBuilder& ct TVM_REGISTER_OP("relax.collapse_sum_to") .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("shape", "Shape", "The shape to collapse to.") .set_attr("FInferStructInfo", InferStructInfoCollapseSumTo) .set_attr("FPurity", true); @@ -1856,6 +1875,7 @@ InferLayoutOutput InferLayoutRepeat( TVM_REGISTER_OP("relax.repeat") .set_attrs_type() .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoRepeat) .set_attr("FRelaxInferLayout", InferLayoutRepeat) .set_attr("FPurity", true); @@ -2000,6 +2020,7 @@ InferLayoutOutput InferLayoutTile( TVM_REGISTER_OP("relax.tile") .set_attrs_type() .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoTile) .set_attr("FRelaxInferLayout", InferLayoutTile) .set_attr("FPurity", true); @@ -2071,6 +2092,7 @@ InferLayoutOutput InferLayoutFlip( TVM_REGISTER_OP("relax.flip") .set_attrs_type() .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoFlip) .set_attr("FRelaxInferLayout", InferLayoutFlip) .set_attr("FPurity", true); @@ -2174,6 +2196,8 @@ InferLayoutOutput InferLayoutGatherElements( TVM_REGISTER_OP("relax.gather_elements") .set_attrs_type() .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("indices", "Tensor", "The indices tensor.") .set_attr("FInferStructInfo", InferStructInfoGatherElements) .set_attr("FRelaxInferLayout", InferLayoutGatherElements) .set_attr("FPurity", true); @@ -2269,6 +2293,8 @@ StructInfo InferStructInfoGatherND(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.gather_nd") .set_attrs_type() .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("indices", "Tensor", "The indices tensor.") .set_attr("FInferStructInfo", InferStructInfoGatherND) .set_attr("FPurity", true); @@ -2415,6 +2441,9 @@ StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.index_put") .set_attrs_type() .set_num_inputs(3) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("indices", "Tensor", "The indices tensor(s).") + .add_argument("values", "Tensor", "The values to put.") .set_attr("FInferStructInfo", InferStructInfoIndexPut) .set_attr("FPurity", true); @@ -2518,6 +2547,7 @@ StructInfo InferStructInfoMeshgrid(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.meshgrid") .set_attrs_type() .set_num_inputs(1) + .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.") .set_attr("FInferStructInfo", InferStructInfoMeshgrid) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -2659,6 +2689,9 @@ InferLayoutOutput InferLayoutScatterElements( TVM_REGISTER_OP("relax.scatter_elements") .set_attrs_type() .set_num_inputs(3) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("indices", "Tensor", "The indices tensor.") + .add_argument("updates", "Tensor", "The input tensor of updates.") .set_attr("FInferStructInfo", InferStructInfoScatterElements) .set_attr("FRelaxInferLayout", InferLayoutScatterElements) .set_attr("FPurity", true); @@ -2833,6 +2866,9 @@ InferLayoutOutput InferLayoutScatterND( TVM_REGISTER_OP("relax.scatter_nd") .set_attrs_type() .set_num_inputs(3) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("indices", "Tensor", "The indices tensor.") + .add_argument("updates", "Tensor", "The input tensor of updates.") .set_attr("FInferStructInfo", InferStructInfoScatterND) .set_attr("FRelaxInferLayout", InferLayoutScatterND) .set_attr("FPurity", true); @@ -2986,6 +3022,11 @@ StructInfo InferStructInfoSliceScatter(const Call& call, const BlockBuilder& ctx TVM_REGISTER_OP("relax.slice_scatter") .set_attrs_type() .set_num_inputs(5) + .add_argument("input", "Tensor", "The input tensor.") + .add_argument("src", "Tensor", "The source tensor to scatter.") + .add_argument("start", "PrimValue", "The starting index of the slice (inclusive).") + .add_argument("end", "PrimValue", "The ending index of the slice (exclusive).") + .add_argument("step", "PrimValue", "The step of the slice.") .set_attr("FInferStructInfo", InferStructInfoSliceScatter) .set_attr("FPurity", true); @@ -3060,6 +3101,9 @@ StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.one_hot") .set_attrs_type() .set_num_inputs(3) + .add_argument("indices", "Tensor", "The indices tensor.") + .add_argument("on_value", "PrimValue", "The value to fill at specified indices.") + .add_argument("off_value", "PrimValue", "The value to fill at other indices.") .set_attr("FInferStructInfo", InferStructInfoOneHot) .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/qdq.cc b/src/relax/op/tensor/qdq.cc index 9386f9cc9ff8..99cb5810e1ab 100644 --- a/src/relax/op/tensor/qdq.cc +++ b/src/relax/op/tensor/qdq.cc @@ -135,6 +135,9 @@ StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.quantize") .set_attrs_type() .set_num_inputs(3) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("scale", "Tensor", "The quantization scale of the output tensor.") + .add_argument("zero_point", "Tensor", "The quantization zero_point of the output tensor.") .set_attr("FInferStructInfo", InferStructInfoQuantize) .set_attr("FPurity", true); @@ -239,6 +242,9 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx) TVM_REGISTER_OP("relax.dequantize") .set_attrs_type() .set_num_inputs(3) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("scale", "Tensor", "The quantization scale of the input tensor.") + .add_argument("zero_point", "Tensor", "The quantization zero_point of the input tensor.") .set_attr("FInferStructInfo", InferStructInfoDequantize) .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/sampling.cc b/src/relax/op/tensor/sampling.cc index 331aabd6703c..febe4d521d3d 100644 --- a/src/relax/op/tensor/sampling.cc +++ b/src/relax/op/tensor/sampling.cc @@ -139,6 +139,9 @@ StructInfo InferStructInfoMultinomialFromUniform(const Call& call, const BlockBu TVM_REGISTER_OP("relax.multinomial_from_uniform") .set_attrs_type() .set_num_inputs(3) + .add_argument("prob", "Tensor", "The probability tensor.") + .add_argument("uniform_sample", "Tensor", "The uniform sample tensor.") + .add_argument("sample_indices", "Tensor", "The sample indices tensor.") .set_attr("FInferStructInfo", InferStructInfoMultinomialFromUniform) .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc index 653e4bda6a0e..5aa1e49557be 100644 --- a/src/relax/op/tensor/search.cc +++ b/src/relax/op/tensor/search.cc @@ -79,6 +79,11 @@ StructInfo InferStructInfoBucketize(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.bucketize") .set_num_inputs(2) + .add_argument("input_tensor", "Tensor", + " N-D tensor or a Scalar containing the search value(s).") + .add_argument("boundaries", "Tensor", + "1-D tensor, must contain a strictly increasing sequence, or the return value is " + "undefined.") .set_attr("FInferStructInfo", InferStructInfoBucketize) .set_attr("FPurity", true); @@ -175,6 +180,9 @@ StructInfo InferStructInfoWhere(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.where") .set_num_inputs(3) + .add_argument("condition", "Tensor", "When True, yield `x1`; otherwise, yield `x2`.") + .add_argument("x1", "Tensor", "The first input tensor.") + .add_argument("x2", "Tensor", "The second input tensor.") .set_attr("FInferStructInfo", InferStructInfoWhere) .set_attr("FPurity", true); @@ -252,6 +260,7 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx } \ TVM_REGISTER_OP("relax." #OpName) \ .set_num_inputs(1) \ + .add_argument("x", "Tensor", "The input data tensor") \ .set_attr("FInferStructInfo", InferStructInfoArgmaxArgmin) \ .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index cb5e25184a69..a2743ab574c6 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -148,6 +148,23 @@ StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.unique") .set_num_inputs(6) + .add_argument("x", "Tensor", "The input tensor") + .add_argument( + "sorted", "Tensor", + "Whether to sort the unique elements in ascending order before returning as output.") + .add_argument( + "return_index", "Tensor", + "Whether to return an additional tensor with indices for where elements in the unique " + "tensor come from the original input.") + .add_argument("return_inverse", "Tensor", + "Whether to return an additional tensor with indices for where elements in the " + "original input ended up in the returned unique list.") + .add_argument("return_counts", "Tensor", + "Whether to return an additional tensor with counts of each unique elements") + .add_argument("axis", "Tensor", + "The dimension to apply unique. If it is std::nullopt, the unique values of the " + "flattened input " + "are returned.") .set_attr("FInferStructInfo", InferStructInfoUnique) .set_attr("FCallPacked", "relax.run.unique") .set_attr("FPurity", true); @@ -170,6 +187,7 @@ StructInfo InferStructInfoNonzero(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.nonzero") .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor") .set_attr("FInferStructInfo", InferStructInfoNonzero) .set_attr("FCallPacked", "relax.run.nonzero") .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/sorting.cc b/src/relax/op/tensor/sorting.cc index d0e82613c1b1..7b8a310c65d9 100644 --- a/src/relax/op/tensor/sorting.cc +++ b/src/relax/op/tensor/sorting.cc @@ -60,6 +60,7 @@ StructInfo InferStructInfoSort(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.sort") .set_attrs_type() .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoSort) .set_attr("FPurity", true); @@ -93,6 +94,7 @@ StructInfo InferStructInfoArgsort(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.argsort") .set_attrs_type() .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoArgsort) .set_attr("FPurity", true); @@ -160,6 +162,7 @@ StructInfo InferStructInfoTopK(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.topk") .set_attrs_type() .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoTopK) .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc index c046aa04fa4f..d6f3a15005f3 100644 --- a/src/relax/op/tensor/statistical.cc +++ b/src/relax/op/tensor/statistical.cc @@ -262,6 +262,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.cumprod") .set_attrs_type() .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoScan) .set_attr("FPurity", true); @@ -284,6 +285,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.cumsum") .set_attrs_type() .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoScan) .set_attr("FPurity", true); @@ -303,6 +305,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.median") .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoStatisticalExtension) .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/statistical.h b/src/relax/op/tensor/statistical.h index a99d5bb6062c..ee4138f133b1 100644 --- a/src/relax/op/tensor/statistical.h +++ b/src/relax/op/tensor/statistical.h @@ -55,6 +55,7 @@ namespace relax { } \ TVM_REGISTER_OP("relax." #OpName) \ .set_num_inputs(1) \ + .add_argument("x", "Tensor", "The input data tensor") \ .set_attr("FInferStructInfo", InferStructInfoStatistical) \ .set_attr("FRelaxInferLayout", InferLayoutStatistical) \ .set_attr("FPurity", true) diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc index a9ae3e867f2b..523c694ff5e8 100644 --- a/src/relax/op/tensor/ternary.cc +++ b/src/relax/op/tensor/ternary.cc @@ -132,6 +132,9 @@ InferLayoutOutput InferLayoutEwiseFMA( TVM_REGISTER_OP("relax.ewise_fma") .set_num_inputs(3) + .add_argument("x1", "Tensor", "The left hand operand of the multiplication") + .add_argument("x2", "Tensor", "The right hand operand of the multiplication") + .add_argument("x3", "Tensor", "The operand of the addition") .set_attr("FInferStructInfo", InferStructInfoEwiseFMA) .set_attr("FRelaxInferLayout", InferLayoutEwiseFMA) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc index bbaceedae61c..16a0bc305f17 100644 --- a/src/relax/op/tensor/unary.cc +++ b/src/relax/op/tensor/unary.cc @@ -70,6 +70,9 @@ RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(erf, /*require_float_dtype=*/true); // relax.clip TVM_REGISTER_OP("relax.clip") .set_num_inputs(3) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("min", "PrimValue", "The lower-bound of the range to be clipped to") + .add_argument("max", "PrimValue", "The upper-bound of the range to be clipped to") .set_attr("FInferStructInfo", ReturnStructInfoFromArg<0>) .set_attr("FPurity", true); diff --git a/src/relax/op/vision/multibox_transform_loc.cc b/src/relax/op/vision/multibox_transform_loc.cc index 0a4cb437db8c..070c81bbe97d 100644 --- a/src/relax/op/vision/multibox_transform_loc.cc +++ b/src/relax/op/vision/multibox_transform_loc.cc @@ -194,6 +194,10 @@ TVM_REGISTER_OP("relax.vision.multibox_transform_loc") "inference. Very large variances (w,h) can overflow exp in half box sizes.") .set_attrs_type() .set_num_inputs(3) + .add_argument("cls_pred", "Tensor", "[B,C,N] class logits (pre-softmax).") + .add_argument("loc_pred", "Tensor", + "[B,4*N] box encodings (x,y,w,h); TFLite yxhw order remapped to xywh.") + .add_argument("anchor", "Tensor", "[1,N,4] priors as ltrb (left,top,right,bottom).") .set_attr("FInferStructInfo", InferStructInfoMultiboxTransformLoc) .set_attr("FPurity", true); diff --git a/src/relax/op/vision/nms.cc b/src/relax/op/vision/nms.cc index 675c7a05721d..dbfe0d63aff5 100644 --- a/src/relax/op/vision/nms.cc +++ b/src/relax/op/vision/nms.cc @@ -104,6 +104,14 @@ StructInfo InferStructInfoAllClassNMS(const Call& call, const BlockBuilder& ctx) TVM_REGISTER_OP("relax.vision.all_class_non_max_suppression") .set_attrs_type() .set_num_inputs(5) + .add_argument("boxes", "Tensor", "The input boxes in the format [batch, num_boxes, 4].") + .add_argument("scores", "Tensor", + "Scores for each box and class in the format [batch, num_classes, num_boxes].") + .add_argument("max_output_boxes_per_class", "Tensor", + "The maximum number of output boxes per class.") + .add_argument("iou_threshold", "Tensor", "The IoU threshold for box the overlap test.") + .add_argument("score_threshold", "Tensor", + "The score threshold to filter out low score boxes early.") .set_attr("FInferStructInfo", InferStructInfoAllClassNMS) .set_attr("FPurity", true); @@ -178,6 +186,8 @@ StructInfo InferStructInfoGetValidCounts(const Call& call, const BlockBuilder& c TVM_REGISTER_OP("relax.vision.get_valid_counts") .set_attrs_type() .set_num_inputs(1) + .add_argument("data", "Tensor", + "Input data, 3-D tensor [batch_size, num_anchors, elem_length].") .set_attr("FInferStructInfo", InferStructInfoGetValidCounts) .set_attr("FPurity", true); @@ -356,6 +366,10 @@ StructInfo InferStructInfoNMS(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.vision.non_max_suppression") .set_attrs_type() .set_num_inputs(3) + .add_argument("data", "Tensor", + "Input data, 3-D tensor [batch_size, num_anchors, elem_length].") + .add_argument("valid_count", "Tensor", "1-D tensor for valid number of boxes.") + .add_argument("indices", "Tensor", "2-D tensor with shape [batch_size, num_anchors].") .set_attr("FInferStructInfo", InferStructInfoNMS) .set_attr("FPurity", true); diff --git a/src/relax/op/vision/roi_align.cc b/src/relax/op/vision/roi_align.cc index 4ad5e999acee..e1be949fce52 100644 --- a/src/relax/op/vision/roi_align.cc +++ b/src/relax/op/vision/roi_align.cc @@ -130,6 +130,9 @@ StructInfo InferStructInfoROIAlign(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.vision.roi_align") .set_attrs_type() .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("rois", "Tensor", + "The input rois with shape (num_roi, 5) in [batch_idx, x1, y1, x2, y2] format.") .set_attr("FInferStructInfo", InferStructInfoROIAlign) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); diff --git a/src/relax/op/vision/roi_pool.cc b/src/relax/op/vision/roi_pool.cc index 21ef2b09469b..ffba294c5a77 100644 --- a/src/relax/op/vision/roi_pool.cc +++ b/src/relax/op/vision/roi_pool.cc @@ -117,6 +117,9 @@ StructInfo InferStructInfoROIPool(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.vision.roi_pool") .set_attrs_type() .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("rois", "Tensor", + "The input rois with shape (num_roi, 5) in [batch_idx, x1, y1, x2, y2] format.") .set_attr("FInferStructInfo", InferStructInfoROIPool) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); diff --git a/src/target/cuda/intrin_rule_cuda.cc b/src/target/cuda/intrin_rule_cuda.cc index 2835c0cfc802..c56da3046bc0 100644 --- a/src/target/cuda/intrin_rule_cuda.cc +++ b/src/target/cuda/intrin_rule_cuda.cc @@ -262,24 +262,40 @@ TVM_REGISTER_OP("tirx.fmod") // TODO(tvm-team): consider make CUDA its own subfolder and create a file for low-level builtins. TVM_REGISTER_OP("tirx.cuda.__shfl_sync") .set_num_inputs(4) + .add_argument("mask", "Expr", "The thread mask.") + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("lane", "Expr", "The source thread id.") + .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") .set_attr("TGlobalSymbol", "__shfl_sync") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); TVM_REGISTER_OP("tirx.cuda.__shfl_up_sync") .set_num_inputs(4) + .add_argument("mask", "Expr", "The thread mask.") + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("delta", "Expr", "The source lane id offset to be added.") + .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") .set_attr("TGlobalSymbol", "__shfl_up_sync") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); TVM_REGISTER_OP("tirx.cuda.__shfl_down_sync") .set_num_inputs(4) + .add_argument("mask", "Expr", "The thread mask.") + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("delta", "Expr", "The source lane id offset to be subtracted.") + .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") .set_attr("TGlobalSymbol", "__shfl_down_sync") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); TVM_REGISTER_OP("tirx.cuda.__shfl_xor_sync") .set_num_inputs(4) + .add_argument("mask", "Expr", "The thread mask.") + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("lane_mask", "Expr", "The lane mask.") + .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") .set_attr("TGlobalSymbol", "__shfl_xor_sync") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); diff --git a/src/target/metal/intrin_rule_metal.cc b/src/target/metal/intrin_rule_metal.cc index e1e0552f212c..217ad164b8e1 100644 --- a/src/target/metal/intrin_rule_metal.cc +++ b/src/target/metal/intrin_rule_metal.cc @@ -141,16 +141,22 @@ TVM_REGISTER_OP("tirx.tvm_warp_shuffle_down") // Register low-level builtin ops. TVM_REGISTER_OP("tirx.metal.simd_shuffle") .set_num_inputs(2) + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("lane", "Expr", "The source thread id.") .set_attr("TGlobalSymbol", "simd_shuffle") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); TVM_REGISTER_OP("tirx.metal.simd_shuffle_up") .set_num_inputs(2) + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("delta", "Expr", "The source lane id offset to be added.") .set_attr("TGlobalSymbol", "simd_shuffle_up") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); TVM_REGISTER_OP("tirx.metal.simd_shuffle_down") .set_num_inputs(2) + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("delta", "Expr", "The source lane id offset to be subtracted.") .set_attr("TGlobalSymbol", "simd_shuffle_down") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); diff --git a/src/target/webgpu/intrin_rule_webgpu.cc b/src/target/webgpu/intrin_rule_webgpu.cc index 4f27a81e1c15..889b85e56aad 100644 --- a/src/target/webgpu/intrin_rule_webgpu.cc +++ b/src/target/webgpu/intrin_rule_webgpu.cc @@ -161,16 +161,22 @@ TVM_REGISTER_OP("tirx.tvm_warp_shuffle_down") // Register low-level builtin ops. TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle") .set_num_inputs(2) + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("lane", "Expr", "The source thread id.") .set_attr("TGlobalSymbol", "subgroupShuffle") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle_up") .set_num_inputs(2) + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("delta", "Expr", "The source lane id offset to be added.") .set_attr("TGlobalSymbol", "subgroupShuffleUp") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle_down") .set_num_inputs(2) + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("delta", "Expr", "The source lane id offset to be subtracted.") .set_attr("TGlobalSymbol", "subgroupShuffleDown") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); From 9468cc364b6b0b0cdeedab411d79b283385b5578 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 26 May 2026 18:40:39 +0000 Subject: [PATCH 8/9] =?UTF-8?q?[REFACTOR][IR]=20Rename=20FieldInfo=20?= =?UTF-8?q?=E2=86=92=20ArgumentInfo?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Matches the host field name OpNode::arguments. The prior restoration commit used FieldInfo on first pass; this aligns the naming. --- include/tvm/ir/op.h | 24 ++++++++++++------------ src/ir/op.cc | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 6a8591443b9a..3fd39c1060ce 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -51,7 +51,7 @@ class OpAttrMap; * internal sanity checks / error messages and by external tooling * that wants to introspect an Op's argument schema. */ -class FieldInfoNode : public ffi::Object { +class ArgumentInfoNode : public ffi::Object { public: /*! \brief name of the field */ ffi::String name; @@ -62,21 +62,21 @@ class FieldInfoNode : public ffi::Object { static void RegisterReflection() { namespace rfl = ffi::reflection; - rfl::ObjectDef() - .def_ro("name", &FieldInfoNode::name) - .def_ro("type_info", &FieldInfoNode::type_info) - .def_ro("description", &FieldInfoNode::description); + rfl::ObjectDef() + .def_ro("name", &ArgumentInfoNode::name) + .def_ro("type_info", &ArgumentInfoNode::type_info) + .def_ro("description", &ArgumentInfoNode::description); } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.FieldInfo", FieldInfoNode, ffi::Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.ArgumentInfo", ArgumentInfoNode, ffi::Object); }; -/*! \brief Managed reference to FieldInfoNode. */ -class FieldInfo : public ffi::ObjectRef { +/*! \brief Managed reference to ArgumentInfoNode. */ +class ArgumentInfo : public ffi::ObjectRef { public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FieldInfo, ffi::ObjectRef, FieldInfoNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ArgumentInfo, ffi::ObjectRef, ArgumentInfoNode); }; // TODO(tvm-team): migrate low-level intrinsics to use Op @@ -103,7 +103,7 @@ class OpNode : public RelaxExprNode { */ ffi::String description; /* \brief Information of input arguments to the operator */ - ffi::Array arguments; + ffi::Array arguments; /*! * \brief The type key of the attribute field * This can be empty, in which case it defaults to anything. @@ -365,11 +365,11 @@ inline OpRegEntry& OpRegEntry::describe(const std::string& descr) { // NOLINT(* inline OpRegEntry& OpRegEntry::add_argument(const std::string& name, const std::string& type, const std::string& description) { - auto n = ffi::make_object(); + auto n = ffi::make_object(); n->name = name; n->type_info = type; n->description = description; - get()->arguments.push_back(FieldInfo(n)); + get()->arguments.push_back(ArgumentInfo(n)); return *this; } diff --git a/src/ir/op.cc b/src/ir/op.cc index 7722b0d30730..3684298e4a76 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -34,7 +34,7 @@ namespace tvm { TVM_FFI_STATIC_INIT_BLOCK() { - FieldInfoNode::RegisterReflection(); + ArgumentInfoNode::RegisterReflection(); OpNode::RegisterReflection(); } From 3369d300994ee6033ef63222d8ff2b7ee7c89498 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 26 May 2026 22:01:20 +0000 Subject: [PATCH 9/9] [FIX][IR] Close source paths for strict NOTNULLABLE DictAttrs Followup to #19615 commit D, addressing Gemini review concerns that the removed defensive `attrs.defined()` checks at 15 call sites could segfault if a DictAttrs ever became null. Rather than re-add per-site checks (which defeats NOTNULLABLE), this change closes the upstream source paths so the type invariant holds end-to-end: - Move ctor and move assignment now leave the moved-from DictAttrs in a defined-but-empty state (allocating a fresh empty backing) rather than the null state default move would yield. This removes the one realistic C++ path that could leak a null DictAttrs. - WithAttr / WithAttrs / WithoutAttr templates gain a cheap belt+suspenders guard that re-initializes node->attrs to DictAttrs() if a caller somehow produced a null one. These three templates are central to the codebase and called from third-party code, so the extra check is worth the cost. - The class-level doxygen now documents the invariant and how it is enforced (default ctor allocates, move members preserve definedness, FFI type traits reject None at deserialization, the ir.IRModule FFI lambda normalizes None to DictAttrs() explicitly). Notes: - The default constructor (`DictAttrs dict;`) already produced an empty-backed instance via `explicit DictAttrs(Map = {})`, so no change was needed there. - The FFI type traits already reject None for non-nullable types (`_type_is_nullable == false` makes CheckAnyStrict return false), so reflection-driven deserializers cannot inject a null DictAttrs. - The IRModule FFI lambda explicitly normalizes a missing/None attrs parameter before forwarding to the C++ constructor; the Function and PrimFunc Python wrappers do the same on the Python side. - The 15 Gemini-flagged access sites are safe under the closed invariant without per-site `defined()` checks. --- include/tvm/ir/attrs.h | 66 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 06d913731f6f..96eec4616b4d 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -85,6 +85,20 @@ class DictAttrsNode : public AttrsNode { /*! * \brief Managed reference to DictAttrsNode * \sa DictAttrsNode. + * + * \note DictAttrs is NOTNULLABLE: every instance must hold a backing + * DictAttrsNode. The class enforces this end-to-end by: + * - the default constructor (no args) allocating an empty backing, + * - the copy/move ctors and assignments leaving the moved-from + * instance in a defined-but-empty state rather than null, + * - the FFI type traits rejecting None at deserialization boundaries + * (since `_type_is_nullable == false`), and + * - the FFI lambda for ``ir.IRModule`` explicitly normalizing a + * missing/None attrs argument to ``DictAttrs()`` before forwarding + * to the C++ constructor. + * Callers (including third-party code via templates like ``WithAttr``) + * can therefore rely on ``attrs->dict`` being safe to dereference + * without a ``.defined()`` guard. */ class DictAttrs : public Attrs { public: @@ -100,6 +114,34 @@ class DictAttrs : public Attrs { data_ = std::move(n); } + /*! + * \brief Move constructor that leaves the source in a defined-but-empty + * state rather than null, preserving the NOTNULLABLE invariant + * even after `std::move`. + */ + DictAttrs(DictAttrs&& other) noexcept : Attrs(ffi::UnsafeInit{}) { + data_ = std::move(other.data_); + other.data_ = ffi::make_object(); + } + + /*! + * \brief Move assignment that leaves the source in a defined-but-empty + * state rather than null, preserving the NOTNULLABLE invariant + * even after `std::move`. + */ + DictAttrs& operator=(DictAttrs&& other) noexcept { + if (this != &other) { + data_ = std::move(other.data_); + other.data_ = ffi::make_object(); + } + return *this; + } + + // Explicit copy ctor/assign defaults. Declaring the move members above + // would otherwise suppress the implicit copy members. + DictAttrs(const DictAttrs& other) = default; + DictAttrs& operator=(const DictAttrs& other) = default; + // Utils for accessing attributes /*! * \brief Get a function attribute. @@ -160,7 +202,16 @@ class DictAttrs : public Attrs { return GetAttr(attr_key, 0).value_or(0) != 0; } - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(DictAttrs, Attrs, DictAttrsNode); + // Inline-expand TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE here, minus + // the default copy/move it normally injects (we define our own move members + // above so the moved-from instance stays defined-but-empty). + explicit DictAttrs(::tvm::ffi::UnsafeInit tag) : Attrs(tag) {} + using __PtrType = + std::conditional_t; + __PtrType operator->() const { return static_cast<__PtrType>(data_.get()); } + __PtrType get() const { return static_cast<__PtrType>(data_.get()); } + static constexpr bool _type_is_nullable = false; + using ContainerType = DictAttrsNode; TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode); }; @@ -196,6 +247,9 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, Any attr_value) using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); TNode* node = input.CopyOnWrite(); + // node->attrs is NOTNULLABLE by contract, but defend against a caller + // that left a moved-from DictAttrs in place by re-initializing here. + if (!node->attrs.defined()) node->attrs = DictAttrs(); node->attrs.CopyOnWrite()->dict.Set(attr_key, std::move(attr_value)); return input; } @@ -216,6 +270,9 @@ inline TFunc WithAttrs(TFunc input, ffi::Map attrs) { static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); if (attrs.empty()) return input; TNode* node = input.CopyOnWrite(); + // node->attrs is NOTNULLABLE by contract, but defend against a caller + // that left a moved-from DictAttrs in place by re-initializing here. + if (!node->attrs.defined()) node->attrs = DictAttrs(); auto* dict_node = node->attrs.CopyOnWrite(); for (const auto& [k, v] : attrs) { dict_node->dict.Set(k, v); @@ -254,6 +311,13 @@ inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) { using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); TNode* node = input.CopyOnWrite(); + // node->attrs is NOTNULLABLE by contract, but defend against a caller + // that left a moved-from DictAttrs in place; nothing to erase from an + // empty dict. + if (!node->attrs.defined()) { + node->attrs = DictAttrs(); + return input; + } node->attrs.CopyOnWrite()->dict.erase(attr_key); return input; }