diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 13a54a16b378..2a2ac5fe07d9 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -31,9 +31,9 @@ namespace relax { /*! \brief Attributes used in Conv1d operator */ struct Conv1DAttrs : public AttrsNodeReflAdapter { - ffi::Array strides; - ffi::Array padding; - ffi::Array dilation; + ffi::Array strides; + ffi::Array padding; + ffi::Array dilation; int groups; ffi::String data_layout; ffi::String kernel_layout; @@ -75,9 +75,9 @@ struct Conv1DAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in Conv2d operator */ struct Conv2DAttrs : public AttrsNodeReflAdapter { - ffi::Array strides; - ffi::Array padding; - ffi::Array dilation; + ffi::Array strides; + ffi::Array padding; + ffi::Array dilation; int groups; ffi::String data_layout; ffi::String kernel_layout; @@ -121,9 +121,9 @@ struct Conv2DAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in Conv3d operator */ struct Conv3DAttrs : public AttrsNodeReflAdapter { - ffi::Array strides; - ffi::Array padding; - ffi::Array dilation; + ffi::Array strides; + ffi::Array padding; + ffi::Array dilation; int groups; ffi::String data_layout; ffi::String kernel_layout; @@ -169,10 +169,10 @@ struct Conv3DAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in Conv1DTranspose operator */ struct Conv1DTransposeAttrs : public AttrsNodeReflAdapter { - ffi::Array strides; - ffi::Array padding; - ffi::Array output_padding; - ffi::Array dilation; + ffi::Array strides; + ffi::Array padding; + ffi::Array output_padding; + ffi::Array dilation; int groups; ffi::String data_layout; ffi::String kernel_layout; @@ -218,10 +218,10 @@ struct Conv1DTransposeAttrs : public AttrsNodeReflAdapter /*! \brief Attributes used in Conv2d operator */ struct Conv2DTransposeAttrs : public AttrsNodeReflAdapter { - ffi::Array strides; - ffi::Array padding; - ffi::Array output_padding; - ffi::Array dilation; + ffi::Array strides; + ffi::Array padding; + ffi::Array output_padding; + ffi::Array dilation; int groups; ffi::String data_layout; ffi::String kernel_layout; @@ -269,10 +269,10 @@ struct Conv2DTransposeAttrs : public AttrsNodeReflAdapter /*! \brief Attributes used in max_pool1d and avg_pool1d operator */ struct Pool1DAttrs : public AttrsNodeReflAdapter { - ffi::Array pool_size; - ffi::Array strides; - ffi::Array padding; - ffi::Array dilation; + ffi::Array pool_size; + ffi::Array strides; + ffi::Array padding; + ffi::Array dilation; bool ceil_mode; bool count_include_pad; ffi::String layout; @@ -310,10 +310,10 @@ struct Pool1DAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in max_pool2d and avg_pool2d operator */ struct Pool2DAttrs : public AttrsNodeReflAdapter { - ffi::Array pool_size; - ffi::Array strides; - ffi::Array padding; - ffi::Array dilation; + ffi::Array pool_size; + ffi::Array strides; + ffi::Array padding; + ffi::Array dilation; bool ceil_mode; bool count_include_pad; ffi::String layout; @@ -353,10 +353,10 @@ struct Pool2DAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in max_pool3d and avg_pool3d operator */ struct Pool3DAttrs : public AttrsNodeReflAdapter { - ffi::Array pool_size; - ffi::Array strides; - ffi::Array padding; - ffi::Array dilation; + ffi::Array pool_size; + ffi::Array strides; + ffi::Array padding; + ffi::Array dilation; bool ceil_mode; bool count_include_pad; ffi::String layout; @@ -396,7 +396,7 @@ struct Pool3DAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes for 1d adaptive pool operator */ struct AdaptivePool1DAttrs : public AttrsNodeReflAdapter { - ffi::Optional> output_size; + ffi::Optional> output_size; ffi::String layout; ffi::String out_layout; @@ -421,7 +421,7 @@ struct AdaptivePool1DAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes for 2d adaptive pool operator */ struct AdaptivePool2DAttrs : public AttrsNodeReflAdapter { - ffi::Optional> output_size; + ffi::Optional> output_size; ffi::String layout; ffi::String out_layout; @@ -446,7 +446,7 @@ struct AdaptivePool2DAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes for 3d adaptive pool operator */ struct AdaptivePool3DAttrs : public AttrsNodeReflAdapter { - ffi::Optional> output_size; + ffi::Optional> output_size; ffi::String layout; ffi::String out_layout; diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index ef7516f31f46..1a08b66f2a33 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -656,8 +656,8 @@ def __len__(self): @register_df_node class AttrPattern(DFPattern): - """Get match an expression with a certain attributes. - Currently only supports Op Attributes, not call Attributes. + """Match an expression with certain attributes. + Supports Op attributes, Call attributes, and Function attributes. Parameters ---------- diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 61ab45d308c8..5f7c2ab75290 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -2567,7 +2567,7 @@ def _impl_v1(cls, bb, inputs, attr, params): pads = [] if cls.name == "avg_pool": for axis in range(len(input_shape) - 2): - axis_shape = input_shape[2 + axis] + axis_shape = int(input_shape[2 + axis]) stride = strides[axis] kernel = kernel_shape[axis] pad = cls.get_pad_pair(axis_shape, kernel, stride, auto_pad) diff --git a/python/tvm/relax/op/_op_gradient.py b/python/tvm/relax/op/_op_gradient.py index fd80f1e31333..bf267473f8ad 100644 --- a/python/tvm/relax/op/_op_gradient.py +++ b/python/tvm/relax/op/_op_gradient.py @@ -1202,7 +1202,7 @@ def conv2d_grad( out_h = (grad_h - 1) * stride_h - pad_top - pad_bottom + filter_h out_w = (grad_w - 1) * stride_w - pad_left - pad_right + filter_w - output_padding = (in_h - out_h, in_w - out_w) + output_padding = (int(in_h - out_h), int(in_w - out_w)) data_grad = conv2d_transpose( # type: ignore output_grad, diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index bc70c809af7c..b917bc47a24e 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -275,11 +275,13 @@ const ffi::String StringUtils::ToString(const ffi::Any& obj) { obj_string = *opt_str; } else if (const auto* n = obj.as()) { obj_string = std::to_string(n->value); + } else if (obj.type_index() == kTVMFFIInt) { + obj_string = std::to_string(obj.cast()); } else if (const auto* n = obj.as()) { obj_string = std::to_string(n->value); } else if (const auto* n = obj.as()) { for (size_t i = 0; i < n->size(); i++) { - obj_string = obj_string + ToString((*n)[i].cast()); + obj_string = obj_string + ToString((*n)[i]); if (n->size() == 1 || i < n->size() - 1) { obj_string = obj_string + ","; } diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc index e3579ec7ef77..8b20d5957474 100644 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -429,10 +429,9 @@ Expr RewriteConv1d(BlockBuilder builder, const Var& var, const Call& src_call, // change to conv2d static const Op& conv2d_op = Op::Get("relax.nn.conv2d"); auto conv_attrs = ffi::make_object(); - conv_attrs->strides = ffi::Array{src_attrs->strides[0], Integer(1)}; - conv_attrs->padding = - ffi::Array{Integer(0), src_attrs->padding[0], Integer(0), src_attrs->padding[1]}; - conv_attrs->dilation = ffi::Array{src_attrs->dilation[0], Integer(1)}; + conv_attrs->strides = ffi::Array{src_attrs->strides[0], 1}; + conv_attrs->padding = ffi::Array{0, src_attrs->padding[0], 0, src_attrs->padding[1]}; + conv_attrs->dilation = ffi::Array{src_attrs->dilation[0], 1}; conv_attrs->groups = src_attrs->groups; conv_attrs->data_layout = "NCHW"; conv_attrs->kernel_layout = "OIHW"; @@ -706,9 +705,9 @@ Expr RewriteMatmul(BlockBuilder builder, const Var& var, const Call& src_call, // to conv2d static const Op& conv2d_op = Op::Get("relax.nn.conv2d"); auto conv_attrs = ffi::make_object(); - conv_attrs->strides = ffi::Array{Integer(1), Integer(1)}; - conv_attrs->padding = ffi::Array{Integer(0), Integer(0), Integer(0), Integer(0)}; - conv_attrs->dilation = ffi::Array{Integer(1), Integer(1)}; + conv_attrs->strides = ffi::Array{1, 1}; + conv_attrs->padding = ffi::Array{0, 0, 0, 0}; + conv_attrs->dilation = ffi::Array{1, 1}; conv_attrs->groups = 1; conv_attrs->data_layout = "NCHW"; conv_attrs->kernel_layout = "OIHW"; diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h index 505696254209..6076de75bc5b 100644 --- a/src/relax/backend/contrib/codegen_json/codegen_json.h +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -115,8 +115,12 @@ class OpAttrExtractor { if (const auto* an = (*value).as()) { std::vector attr; for (size_t i = 0; i < an->size(); ++i) { - if (const auto* im = (*an)[i].as()) { + if (auto opt_int = (*an)[i].try_cast()) { + attr.push_back(std::to_string(opt_int.value())); + } else if (const auto* im = (*an)[i].as()) { attr.push_back(std::to_string(im->value)); + } else if (auto opt_float = (*an)[i].try_cast()) { + attr.push_back(Fp2String(opt_float.value())); } else if (const auto* fm = (*an)[i].as()) { attr.push_back(Fp2String(fm->value)); } else if (auto opt_str = (*an)[i].as()) { diff --git a/src/relax/backend/contrib/nnapi/codegen.cc b/src/relax/backend/contrib/nnapi/codegen.cc index 92933ba070b9..0ea05b986324 100644 --- a/src/relax/backend/contrib/nnapi/codegen.cc +++ b/src/relax/backend/contrib/nnapi/codegen.cc @@ -107,10 +107,7 @@ class CollectFromCompositeFunctionBody : public ExprVisitor { std::vector strides; if (!conv2d_attr->strides.empty()) { for (auto stride : conv2d_attr->strides) { - const auto* stride_val = stride.as(); - ICHECK(stride_val) << "convertion failed"; - - strides.push_back(std::to_string(stride_val->value)); + strides.push_back(std::to_string(stride)); } } else { strides = {"1", "1"}; @@ -118,9 +115,7 @@ class CollectFromCompositeFunctionBody : public ExprVisitor { std::vector padding; for (auto pad : conv2d_attr->padding) { - const auto* padding_val = pad.as(); - - padding.push_back(std::to_string(padding_val->value)); + padding.push_back(std::to_string(pad)); } std::vector groups; @@ -147,10 +142,7 @@ class CollectFromCompositeFunctionBody : public ExprVisitor { std::vector strides; if (!max_pool_2d_attr->strides.empty()) { for (auto stride : max_pool_2d_attr->strides) { - const auto* stride_val = stride.as(); - ICHECK(stride_val) << "convertion failed"; - - strides.push_back(std::to_string(stride_val->value)); + strides.push_back(std::to_string(stride)); } } else { strides.push_back("1"); @@ -159,16 +151,12 @@ class CollectFromCompositeFunctionBody : public ExprVisitor { std::vector padding; for (auto pad : max_pool_2d_attr->padding) { - const auto* padding_val = pad.as(); - - padding.push_back(std::to_string(padding_val->value)); + padding.push_back(std::to_string(pad)); } std::vector pool_size; for (auto size : max_pool_2d_attr->pool_size) { - const auto* pooling_val = size.as(); - - pool_size.push_back(std::to_string(pooling_val->value)); + pool_size.push_back(std::to_string(size)); } std::vector strides_attr; diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index 3fba58ede2be..5368db79d262 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -41,8 +41,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { /* relax.nn.conv1d */ -Expr conv1d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, - ffi::Array dilation, int groups, ffi::String data_layout, +Expr conv1d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, + ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, ffi::Optional out_dtype) { padding = GetCompletePadding1D(std::move(padding)); @@ -125,15 +125,15 @@ StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) { PrimExpr input_w = data_NCW_shape[2]; PrimExpr kernel_w = weight_OIW_shape[2]; - PrimExpr padding_w = attrs->padding[0] + attrs->padding[1]; + PrimExpr padding_w = Integer(attrs->padding[0]) + Integer(attrs->padding[1]); std::vector out_NCW_shape; out_NCW_shape.resize(3); out_NCW_shape[0] = data_NCW_shape[0]; out_NCW_shape[1] = weight_OIW_shape[0]; - PrimExpr numerator_w = input_w + padding_w - attrs->dilation[0] * (kernel_w - 1) - 1; - out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[0]) + 1); + PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[0]) * (kernel_w - 1) - 1; + out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, Integer(attrs->strides[0])) + 1); ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); @@ -202,8 +202,8 @@ TVM_REGISTER_OP("relax.nn.conv1d") /* relax.nn.conv2d */ -Expr conv2d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, - ffi::Array dilation, int groups, ffi::String data_layout, +Expr conv2d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, + ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, ffi::Optional out_dtype) { padding = GetCompletePadding2D(std::move(padding)); @@ -294,18 +294,18 @@ StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { PrimExpr input_w = data_NCHW_shape[3]; PrimExpr kernel_h = weight_OIHW_shape[2]; PrimExpr kernel_w = weight_OIHW_shape[3]; - PrimExpr padding_h = attrs->padding[0] + attrs->padding[2]; - PrimExpr padding_w = attrs->padding[1] + attrs->padding[3]; + PrimExpr padding_h = Integer(attrs->padding[0]) + Integer(attrs->padding[2]); + PrimExpr padding_w = Integer(attrs->padding[1]) + Integer(attrs->padding[3]); std::vector out_NCHW_shape; out_NCHW_shape.resize(4); out_NCHW_shape[0] = data_NCHW_shape[0]; out_NCHW_shape[1] = weight_OIHW_shape[0]; - PrimExpr numerator_h = input_h + padding_h - attrs->dilation[0] * (kernel_h - 1) - 1; - PrimExpr numerator_w = input_w + padding_w - attrs->dilation[1] * (kernel_w - 1) - 1; - out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h, attrs->strides[0]) + 1); - out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[1]) + 1); + PrimExpr numerator_h = input_h + padding_h - Integer(attrs->dilation[0]) * (kernel_h - 1) - 1; + PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[1]) * (kernel_w - 1) - 1; + out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h, Integer(attrs->strides[0])) + 1); + out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, Integer(attrs->strides[1])) + 1); ffi::Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); @@ -409,8 +409,8 @@ TVM_REGISTER_OP("relax.nn.conv2d") /* relax.nn.conv3d */ -Expr conv3d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, - ffi::Array dilation, int groups, ffi::String data_layout, +Expr conv3d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, + ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, ffi::Optional out_dtype) { padding = GetCompletePadding3D(std::move(padding)); @@ -506,21 +506,21 @@ StructInfo InferStructInfoConv3d(const Call& call, const BlockBuilder& ctx) { PrimExpr kernel_d = weight_OIDHW_shape[2]; PrimExpr kernel_h = weight_OIDHW_shape[3]; PrimExpr kernel_w = weight_OIDHW_shape[4]; - PrimExpr padding_d = attrs->padding[0] + attrs->padding[3]; - PrimExpr padding_h = attrs->padding[1] + attrs->padding[4]; - PrimExpr padding_w = attrs->padding[2] + attrs->padding[5]; + PrimExpr padding_d = Integer(attrs->padding[0]) + Integer(attrs->padding[3]); + PrimExpr padding_h = Integer(attrs->padding[1]) + Integer(attrs->padding[4]); + PrimExpr padding_w = Integer(attrs->padding[2]) + Integer(attrs->padding[5]); std::vector out_NCDHW_shape; out_NCDHW_shape.resize(5); out_NCDHW_shape[0] = data_NCDHW_shape[0]; out_NCDHW_shape[1] = weight_OIDHW_shape[0]; - PrimExpr numerator_d = input_d + padding_d - attrs->dilation[0] * (kernel_d - 1) - 1; - PrimExpr numerator_h = input_h + padding_h - attrs->dilation[1] * (kernel_h - 1) - 1; - PrimExpr numerator_w = input_w + padding_w - attrs->dilation[2] * (kernel_w - 1) - 1; - out_NCDHW_shape[2] = analyzer->Simplify(floordiv(numerator_d, attrs->strides[0]) + 1); - out_NCDHW_shape[3] = analyzer->Simplify(floordiv(numerator_h, attrs->strides[1]) + 1); - out_NCDHW_shape[4] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[2]) + 1); + PrimExpr numerator_d = input_d + padding_d - Integer(attrs->dilation[0]) * (kernel_d - 1) - 1; + PrimExpr numerator_h = input_h + padding_h - Integer(attrs->dilation[1]) * (kernel_h - 1) - 1; + PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[2]) * (kernel_w - 1) - 1; + out_NCDHW_shape[2] = analyzer->Simplify(floordiv(numerator_d, Integer(attrs->strides[0])) + 1); + out_NCDHW_shape[3] = analyzer->Simplify(floordiv(numerator_h, Integer(attrs->strides[1])) + 1); + out_NCDHW_shape[4] = analyzer->Simplify(floordiv(numerator_w, Integer(attrs->strides[2])) + 1); ffi::Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); @@ -587,9 +587,9 @@ TVM_REGISTER_OP("relax.nn.conv3d") .set_attr("FInferMixedPrecision", InferMixedPrecisionConv3d) .set_attr("FPurity", Bool(true)); -Expr conv1d_transpose(Expr data, Expr weight, ffi::Array strides, - ffi::Array padding, ffi::Array output_padding, - ffi::Array dilation, int groups, ffi::String data_layout, +Expr conv1d_transpose(Expr data, Expr weight, ffi::Array strides, + ffi::Array padding, ffi::Array output_padding, + ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, ffi::Optional out_dtype) { padding = GetCompletePadding1D(std::move(padding)); @@ -607,10 +607,10 @@ Expr conv1d_transpose(Expr data, Expr weight, ffi::Array strides, << dilation; auto attrs = ffi::make_object(); - attrs->strides = ConvertIntImmToInt64(strides); - attrs->padding = ConvertIntImmToInt64(padding); - attrs->output_padding = ConvertIntImmToInt64(output_padding); - attrs->dilation = ConvertIntImmToInt64(dilation); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->output_padding = std::move(output_padding); + attrs->dilation = std::move(dilation); attrs->groups = groups; attrs->data_layout = data_layout; attrs->kernel_layout = std::move(kernel_layout); @@ -680,27 +680,28 @@ StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder& // Todo(relax-team): Trust the input shape at this moment, and revisit // this condition with runtime shape check } - if (analyzer->CanProve(attrs->output_padding[0]->value >= attrs->strides[0]->value)) { + if (attrs->output_padding[0] >= attrs->strides[0]) { ctx->ReportFatal(Diagnostic::Error(call) << "Conv1dTranspose expects the output padding less than the strides, but the " "output padding is" << attrs->output_padding << " while the strides are" << attrs->strides); - } else if (!analyzer->CanProve(attrs->output_padding[0]->value < attrs->strides[0]->value)) { + } else if (!(attrs->output_padding[0] < attrs->strides[0])) { // Todo(relax-team): Trust the input padding at this moment, and revisit // this condition with runtime shape check } PrimExpr input_w = data_NCW_shape[2]; PrimExpr kernel_w = weight_IOW_shape[2]; - PrimExpr padding_w = attrs->padding[0] + attrs->padding[1]; + PrimExpr padding_w = Integer(attrs->padding[0]) + Integer(attrs->padding[1]); std::vector out_NCW_shape; out_NCW_shape.resize(3); out_NCW_shape[0] = data_NCW_shape[0]; out_NCW_shape[1] = weight_IOW_shape[1] * attrs->groups; - PrimExpr out_w = (input_w - 1) * attrs->strides[0] - padding_w + - attrs->dilation[0] * (kernel_w - 1) + attrs->output_padding[0] + 1; + PrimExpr out_w = (input_w - 1) * Integer(attrs->strides[0]) - padding_w + + Integer(attrs->dilation[0]) * (kernel_w - 1) + + Integer(attrs->output_padding[0]) + 1; out_NCW_shape[2] = analyzer->Simplify(out_w); ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); @@ -767,9 +768,9 @@ TVM_REGISTER_OP("relax.nn.conv1d_transpose") /* relax.nn.conv2d_transpose */ -Expr conv2d_transpose(Expr data, Expr weight, ffi::Array strides, - ffi::Array padding, ffi::Array output_padding, - ffi::Array dilation, int groups, ffi::String data_layout, +Expr conv2d_transpose(Expr data, Expr weight, ffi::Array strides, + ffi::Array padding, ffi::Array output_padding, + ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, ffi::Optional out_dtype) { padding = GetCompletePadding2D(std::move(padding)); @@ -796,10 +797,10 @@ Expr conv2d_transpose(Expr data, Expr weight, ffi::Array strides, << dilation; auto attrs = ffi::make_object(); - attrs->strides = ConvertIntImmToInt64(strides); - attrs->padding = ConvertIntImmToInt64(padding); - attrs->output_padding = ConvertIntImmToInt64(output_padding); - attrs->dilation = ConvertIntImmToInt64(dilation); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->output_padding = std::move(output_padding); + attrs->dilation = std::move(dilation); attrs->groups = groups; attrs->data_layout = data_layout; attrs->kernel_layout = std::move(kernel_layout); @@ -870,14 +871,14 @@ StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder& // Todo(relax-team): Trust the input shape at this moment, and revisit // this condition with runtime shape check } - if (analyzer->CanProve(attrs->output_padding[0]->value >= attrs->strides[0]->value || - attrs->output_padding[1]->value >= attrs->strides[1]->value)) { + if (attrs->output_padding[0] >= attrs->strides[0] || + attrs->output_padding[1] >= attrs->strides[1]) { ctx->ReportFatal(Diagnostic::Error(call) << "Conv2dTranspose expects the output padding less than the strides, but the " "output padding is" << attrs->output_padding << " while the strides are" << attrs->strides); - } else if (!analyzer->CanProve(attrs->output_padding[0]->value < attrs->strides[0]->value && - attrs->output_padding[1]->value < attrs->strides[1]->value)) { + } else if (!(attrs->output_padding[0] < attrs->strides[0] && + attrs->output_padding[1] < attrs->strides[1])) { // Todo(relax-team): Trust the input padding at this moment, and revisit // this condition with runtime shape check } @@ -886,18 +887,20 @@ StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder& PrimExpr input_w = data_NCHW_shape[3]; PrimExpr kernel_h = weight_IOHW_shape[2]; PrimExpr kernel_w = weight_IOHW_shape[3]; - PrimExpr padding_h = attrs->padding[0] + attrs->padding[2]; - PrimExpr padding_w = attrs->padding[1] + attrs->padding[3]; + PrimExpr padding_h = Integer(attrs->padding[0]) + Integer(attrs->padding[2]); + PrimExpr padding_w = Integer(attrs->padding[1]) + Integer(attrs->padding[3]); std::vector out_NCHW_shape; out_NCHW_shape.resize(4); out_NCHW_shape[0] = data_NCHW_shape[0]; out_NCHW_shape[1] = weight_IOHW_shape[1] * attrs->groups; - PrimExpr out_h = (input_h - 1) * attrs->strides[0] - padding_h + - attrs->dilation[0] * (kernel_h - 1) + attrs->output_padding[0] + 1; - PrimExpr out_w = (input_w - 1) * attrs->strides[1] - padding_w + - attrs->dilation[1] * (kernel_w - 1) + attrs->output_padding[1] + 1; + PrimExpr out_h = (input_h - 1) * Integer(attrs->strides[0]) - padding_h + + Integer(attrs->dilation[0]) * (kernel_h - 1) + + Integer(attrs->output_padding[0]) + 1; + PrimExpr out_w = (input_w - 1) * Integer(attrs->strides[1]) - padding_w + + Integer(attrs->dilation[1]) * (kernel_w - 1) + + Integer(attrs->output_padding[1]) + 1; out_NCHW_shape[2] = analyzer->Simplify(out_h); out_NCHW_shape[3] = analyzer->Simplify(out_w); diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h index 4fc175b5aa07..a5704d3f701a 100644 --- a/src/relax/op/nn/convolution.h +++ b/src/relax/op/nn/convolution.h @@ -36,14 +36,14 @@ namespace tvm { namespace relax { template -inline Expr MakeConv(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, - ffi::Array dilation, int groups, ffi::String data_layout, - ffi::String kernel_layout, ffi::String out_layout, DataType out_dtype, - std::string op_name) { +inline Expr MakeConv(Expr data, Expr weight, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, int groups, + ffi::String data_layout, ffi::String kernel_layout, ffi::String out_layout, + DataType out_dtype, std::string op_name) { auto attrs = ffi::make_object(); - attrs->strides = ConvertIntImmToInt64(strides); - attrs->padding = ConvertIntImmToInt64(padding); - attrs->dilation = ConvertIntImmToInt64(dilation); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); attrs->groups = groups; attrs->data_layout = std::move(data_layout); attrs->kernel_layout = std::move(kernel_layout); @@ -54,20 +54,20 @@ inline Expr MakeConv(Expr data, Expr weight, ffi::Array strides, ffi::Ar } /*! \brief 1D convolution */ -Expr conv1d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, - ffi::Array dilation, int groups, ffi::String data_layout, +Expr conv1d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, + ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, ffi::Optional out_dtype); /*! \brief 2D convolution */ -Expr conv2d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, - ffi::Array dilation, int groups, ffi::String data_layout, +Expr conv2d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, + ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, ffi::Optional out_dtype); /*! \brief 3D convolution */ -Expr conv3d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, - ffi::Array dilation, int groups, ffi::String data_layout, +Expr conv3d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, + ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, ffi::Optional out_dtype); @@ -77,9 +77,9 @@ Expr conv3d(Expr data, Expr weight, ffi::Array strides, ffi::Array strides, - ffi::Array padding, ffi::Array output_padding, - ffi::Array dilation, int groups, ffi::String data_layout, +Expr conv1d_transpose(Expr data, Expr weight, ffi::Array strides, + ffi::Array padding, ffi::Array output_padding, + ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, ffi::Optional out_dtype); @@ -89,9 +89,9 @@ Expr conv1d_transpose(Expr data, Expr weight, ffi::Array strides, * This operator is intended to be the backward operator of conv2d. It can be used to calculate the * gradient of the result of conv2d w.r.t. the input of conv2d. */ -Expr conv2d_transpose(Expr data, Expr weight, ffi::Array strides, - ffi::Array padding, ffi::Array output_padding, - ffi::Array dilation, int groups, ffi::String data_layout, +Expr conv2d_transpose(Expr data, Expr weight, ffi::Array strides, + ffi::Array padding, ffi::Array output_padding, + ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, ffi::Optional out_dtype); diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index 1a19872c2794..2397bf009866 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -38,10 +38,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { /* relax.nn.max_pool1d */ -Expr MakePool1d(ffi::String op_name, Expr data, ffi::Array pool_size, - ffi::Array strides, ffi::Array padding, ffi::Array dilation, - bool ceil_mode, bool count_include_pad, ffi::String layout, - ffi::Optional out_layout) { +Expr MakePool1d(ffi::String op_name, Expr data, ffi::Array pool_size, + ffi::Array strides, ffi::Array padding, + ffi::Array dilation, bool ceil_mode, bool count_include_pad, + ffi::String layout, ffi::Optional out_layout) { padding = GetCompletePadding1D(std::move(padding)); CHECK_EQ(pool_size.size(), 1) @@ -54,10 +54,10 @@ Expr MakePool1d(ffi::String op_name, Expr data, ffi::Array pool_size, << dilation; auto attrs = ffi::make_object(); - attrs->pool_size = ConvertIntImmToInt64(pool_size); - attrs->strides = ConvertIntImmToInt64(strides); - attrs->padding = ConvertIntImmToInt64(padding); - attrs->dilation = ConvertIntImmToInt64(dilation); + attrs->pool_size = std::move(pool_size); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); attrs->ceil_mode = ceil_mode; attrs->count_include_pad = count_include_pad; attrs->layout = layout; @@ -66,8 +66,8 @@ Expr MakePool1d(ffi::String op_name, Expr data, ffi::Array pool_size, return Call(op, {std::move(data)}, Attrs(attrs), {}); } -Expr max_pool1d(Expr data, ffi::Array pool_size, ffi::Array strides, - ffi::Array padding, ffi::Array dilation, bool ceil_mode, +Expr max_pool1d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, bool count_include_pad, ffi::String layout, ffi::Optional out_layout) { return MakePool1d("relax.nn.max_pool1d", data, pool_size, strides, padding, dilation, ceil_mode, count_include_pad, layout, out_layout); @@ -98,8 +98,8 @@ StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { ffi::Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); PrimExpr input_w = data_NCW_shape[2]; - PrimExpr kernel_w = attrs->pool_size[0]; - PrimExpr padding_w = attrs->padding[0] + attrs->padding[1]; + PrimExpr kernel_w = Integer(attrs->pool_size[0]); + PrimExpr padding_w = Integer(attrs->padding[0]) + Integer(attrs->padding[1]); arith::Analyzer* analyzer = ctx->GetAnalyzer(); std::vector out_NCW_shape; @@ -107,13 +107,14 @@ StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { out_NCW_shape[0] = data_NCW_shape[0]; out_NCW_shape[1] = data_NCW_shape[1]; - PrimExpr numerator_w = input_w + padding_w - attrs->dilation[0] * (kernel_w - 1) - 1; + PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[0]) * (kernel_w - 1) - 1; if (attrs->ceil_mode) { - numerator_w += attrs->strides[0] - 1; + numerator_w += Integer(attrs->strides[0]) - 1; } - PrimExpr raw_out_w = floordiv(numerator_w, attrs->strides[0]) + 1; + PrimExpr raw_out_w = floordiv(numerator_w, Integer(attrs->strides[0])) + 1; if (attrs->ceil_mode) { - PrimExpr invalid_last_w = (raw_out_w - 1) * attrs->strides[0] >= input_w + attrs->padding[0]; + PrimExpr invalid_last_w = + (raw_out_w - 1) * Integer(attrs->strides[0]) >= input_w + Integer(attrs->padding[0]); out_NCW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_w, raw_out_w - 1, raw_out_w)); } else { out_NCW_shape[2] = analyzer->Simplify(raw_out_w); @@ -151,10 +152,10 @@ TVM_REGISTER_OP("relax.nn.max_pool1d") /* relax.nn.max_pool2d */ -Expr MakePool2d(ffi::String op_name, Expr data, ffi::Array pool_size, - ffi::Array strides, ffi::Array padding, ffi::Array dilation, - bool ceil_mode, bool count_include_pad, ffi::String layout, - ffi::Optional out_layout) { +Expr MakePool2d(ffi::String op_name, Expr data, ffi::Array pool_size, + ffi::Array strides, ffi::Array padding, + ffi::Array dilation, bool ceil_mode, bool count_include_pad, + ffi::String layout, ffi::Optional out_layout) { padding = GetCompletePadding2D(std::move(padding)); if (pool_size.size() == 1) { pool_size.push_back(pool_size[0]); @@ -176,10 +177,10 @@ Expr MakePool2d(ffi::String op_name, Expr data, ffi::Array pool_size, << dilation; auto attrs = ffi::make_object(); - attrs->pool_size = ConvertIntImmToInt64(pool_size); - attrs->strides = ConvertIntImmToInt64(strides); - attrs->padding = ConvertIntImmToInt64(padding); - attrs->dilation = ConvertIntImmToInt64(dilation); + attrs->pool_size = std::move(pool_size); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); attrs->ceil_mode = ceil_mode; attrs->count_include_pad = count_include_pad; attrs->layout = layout; @@ -188,8 +189,8 @@ Expr MakePool2d(ffi::String op_name, Expr data, ffi::Array pool_size, return Call(op, {std::move(data)}, Attrs(attrs), {}); } -Expr max_pool2d(Expr data, ffi::Array pool_size, ffi::Array strides, - ffi::Array padding, ffi::Array dilation, bool ceil_mode, +Expr max_pool2d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, bool count_include_pad, ffi::String layout, ffi::Optional out_layout) { return MakePool2d("relax.nn.max_pool2d", data, pool_size, strides, padding, dilation, ceil_mode, count_include_pad, layout, out_layout); @@ -221,10 +222,10 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { PrimExpr input_h = data_NCHW_shape[2]; PrimExpr input_w = data_NCHW_shape[3]; - PrimExpr kernel_h = attrs->pool_size[0]; - PrimExpr kernel_w = attrs->pool_size[1]; - PrimExpr padding_h = attrs->padding[0] + attrs->padding[2]; - PrimExpr padding_w = attrs->padding[1] + attrs->padding[3]; + PrimExpr kernel_h = Integer(attrs->pool_size[0]); + PrimExpr kernel_w = Integer(attrs->pool_size[1]); + PrimExpr padding_h = Integer(attrs->padding[0]) + Integer(attrs->padding[2]); + PrimExpr padding_w = Integer(attrs->padding[1]) + Integer(attrs->padding[3]); arith::Analyzer* analyzer = ctx->GetAnalyzer(); std::vector out_NCHW_shape; @@ -232,17 +233,19 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { out_NCHW_shape[0] = data_NCHW_shape[0]; out_NCHW_shape[1] = data_NCHW_shape[1]; - PrimExpr numerator_h = input_h + padding_h - attrs->dilation[0] * (kernel_h - 1) - 1; - PrimExpr numerator_w = input_w + padding_w - attrs->dilation[1] * (kernel_w - 1) - 1; + PrimExpr numerator_h = input_h + padding_h - Integer(attrs->dilation[0]) * (kernel_h - 1) - 1; + PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[1]) * (kernel_w - 1) - 1; if (attrs->ceil_mode) { - numerator_h += attrs->strides[0] - 1; - numerator_w += attrs->strides[1] - 1; + numerator_h += Integer(attrs->strides[0]) - 1; + numerator_w += Integer(attrs->strides[1]) - 1; } - PrimExpr raw_out_h = floordiv(numerator_h, attrs->strides[0]) + 1; - PrimExpr raw_out_w = floordiv(numerator_w, attrs->strides[1]) + 1; + PrimExpr raw_out_h = floordiv(numerator_h, Integer(attrs->strides[0])) + 1; + PrimExpr raw_out_w = floordiv(numerator_w, Integer(attrs->strides[1])) + 1; if (attrs->ceil_mode) { - PrimExpr invalid_last_h = (raw_out_h - 1) * attrs->strides[0] >= input_h + attrs->padding[0]; - PrimExpr invalid_last_w = (raw_out_w - 1) * attrs->strides[1] >= input_w + attrs->padding[1]; + PrimExpr invalid_last_h = + (raw_out_h - 1) * Integer(attrs->strides[0]) >= input_h + Integer(attrs->padding[0]); + PrimExpr invalid_last_w = + (raw_out_w - 1) * Integer(attrs->strides[1]) >= input_w + Integer(attrs->padding[1]); out_NCHW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_h, raw_out_h - 1, raw_out_h)); out_NCHW_shape[3] = analyzer->Simplify(if_then_else(invalid_last_w, raw_out_w - 1, raw_out_w)); } else { @@ -300,10 +303,10 @@ TVM_REGISTER_OP("relax.nn.max_pool2d") /* relax.nn.max_pool3d */ -Expr MakePool3d(ffi::String op_name, Expr data, ffi::Array pool_size, - ffi::Array strides, ffi::Array padding, ffi::Array dilation, - bool ceil_mode, bool count_include_pad, ffi::String layout, - ffi::Optional out_layout) { +Expr MakePool3d(ffi::String op_name, Expr data, ffi::Array pool_size, + ffi::Array strides, ffi::Array padding, + ffi::Array dilation, bool ceil_mode, bool count_include_pad, + ffi::String layout, ffi::Optional out_layout) { padding = GetCompletePadding3D(std::move(padding)); if (pool_size.size() == 1) { pool_size.push_back(pool_size[0]); @@ -328,10 +331,10 @@ Expr MakePool3d(ffi::String op_name, Expr data, ffi::Array pool_size, << dilation; auto attrs = ffi::make_object(); - attrs->pool_size = ConvertIntImmToInt64(pool_size); - attrs->strides = ConvertIntImmToInt64(strides); - attrs->padding = ConvertIntImmToInt64(padding); - attrs->dilation = ConvertIntImmToInt64(dilation); + attrs->pool_size = std::move(pool_size); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); attrs->ceil_mode = ceil_mode; attrs->count_include_pad = count_include_pad; attrs->layout = layout; @@ -340,8 +343,8 @@ Expr MakePool3d(ffi::String op_name, Expr data, ffi::Array pool_size, return Call(op, {std::move(data)}, Attrs(attrs), {}); } -Expr max_pool3d(Expr data, ffi::Array pool_size, ffi::Array strides, - ffi::Array padding, ffi::Array dilation, bool ceil_mode, +Expr max_pool3d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, bool count_include_pad, ffi::String layout, ffi::Optional out_layout) { return MakePool3d("relax.nn.max_pool3d", data, pool_size, strides, padding, dilation, ceil_mode, count_include_pad, layout, out_layout); @@ -374,12 +377,12 @@ StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { PrimExpr input_d = data_NCDHW_shape[2]; PrimExpr input_h = data_NCDHW_shape[3]; PrimExpr input_w = data_NCDHW_shape[4]; - PrimExpr kernel_d = attrs->pool_size[0]; - PrimExpr kernel_h = attrs->pool_size[1]; - PrimExpr kernel_w = attrs->pool_size[2]; - PrimExpr padding_d = attrs->padding[0] + attrs->padding[3]; - PrimExpr padding_h = attrs->padding[1] + attrs->padding[4]; - PrimExpr padding_w = attrs->padding[2] + attrs->padding[5]; + PrimExpr kernel_d = Integer(attrs->pool_size[0]); + PrimExpr kernel_h = Integer(attrs->pool_size[1]); + PrimExpr kernel_w = Integer(attrs->pool_size[2]); + PrimExpr padding_d = Integer(attrs->padding[0]) + Integer(attrs->padding[3]); + PrimExpr padding_h = Integer(attrs->padding[1]) + Integer(attrs->padding[4]); + PrimExpr padding_w = Integer(attrs->padding[2]) + Integer(attrs->padding[5]); arith::Analyzer* analyzer = ctx->GetAnalyzer(); std::vector out_NCDHW_shape; @@ -387,21 +390,24 @@ StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { out_NCDHW_shape[0] = data_NCDHW_shape[0]; out_NCDHW_shape[1] = data_NCDHW_shape[1]; - PrimExpr numerator_d = input_d + padding_d - attrs->dilation[0] * (kernel_d - 1) - 1; - PrimExpr numerator_h = input_h + padding_h - attrs->dilation[1] * (kernel_h - 1) - 1; - PrimExpr numerator_w = input_w + padding_w - attrs->dilation[2] * (kernel_w - 1) - 1; + PrimExpr numerator_d = input_d + padding_d - Integer(attrs->dilation[0]) * (kernel_d - 1) - 1; + PrimExpr numerator_h = input_h + padding_h - Integer(attrs->dilation[1]) * (kernel_h - 1) - 1; + PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[2]) * (kernel_w - 1) - 1; if (attrs->ceil_mode) { - numerator_d += attrs->strides[0] - 1; - numerator_h += attrs->strides[1] - 1; - numerator_w += attrs->strides[2] - 1; + numerator_d += Integer(attrs->strides[0]) - 1; + numerator_h += Integer(attrs->strides[1]) - 1; + numerator_w += Integer(attrs->strides[2]) - 1; } - PrimExpr raw_out_d = floordiv(numerator_d, attrs->strides[0]) + 1; - PrimExpr raw_out_h = floordiv(numerator_h, attrs->strides[1]) + 1; - PrimExpr raw_out_w = floordiv(numerator_w, attrs->strides[2]) + 1; + PrimExpr raw_out_d = floordiv(numerator_d, Integer(attrs->strides[0])) + 1; + PrimExpr raw_out_h = floordiv(numerator_h, Integer(attrs->strides[1])) + 1; + PrimExpr raw_out_w = floordiv(numerator_w, Integer(attrs->strides[2])) + 1; if (attrs->ceil_mode) { - PrimExpr invalid_last_d = (raw_out_d - 1) * attrs->strides[0] >= input_d + attrs->padding[0]; - PrimExpr invalid_last_h = (raw_out_h - 1) * attrs->strides[1] >= input_h + attrs->padding[1]; - PrimExpr invalid_last_w = (raw_out_w - 1) * attrs->strides[2] >= input_w + attrs->padding[2]; + PrimExpr invalid_last_d = + (raw_out_d - 1) * Integer(attrs->strides[0]) >= input_d + Integer(attrs->padding[0]); + PrimExpr invalid_last_h = + (raw_out_h - 1) * Integer(attrs->strides[1]) >= input_h + Integer(attrs->padding[1]); + PrimExpr invalid_last_w = + (raw_out_w - 1) * Integer(attrs->strides[2]) >= input_w + Integer(attrs->padding[2]); out_NCDHW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_d, raw_out_d - 1, raw_out_d)); out_NCDHW_shape[3] = analyzer->Simplify(if_then_else(invalid_last_h, raw_out_h - 1, raw_out_h)); out_NCDHW_shape[4] = analyzer->Simplify(if_then_else(invalid_last_w, raw_out_w - 1, raw_out_w)); @@ -442,8 +448,8 @@ TVM_REGISTER_OP("relax.nn.max_pool3d") .set_attr("FPurity", Bool(true)); /* relax.nn.avg_pool1d */ -Expr avg_pool1d(Expr data, ffi::Array pool_size, ffi::Array strides, - ffi::Array padding, ffi::Array dilation, bool ceil_mode, +Expr avg_pool1d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, bool count_include_pad, ffi::String layout, ffi::Optional out_layout) { return MakePool1d("relax.nn.avg_pool1d", data, pool_size, strides, padding, dilation, ceil_mode, count_include_pad, layout, out_layout); @@ -464,8 +470,8 @@ TVM_REGISTER_OP("relax.nn.avg_pool1d") .set_attr("FPurity", Bool(true)); /* relax.nn.avg_pool2d */ -Expr avg_pool2d(Expr data, ffi::Array pool_size, ffi::Array strides, - ffi::Array padding, ffi::Array dilation, bool ceil_mode, +Expr avg_pool2d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, bool count_include_pad, ffi::String layout, ffi::Optional out_layout) { return MakePool2d("relax.nn.avg_pool2d", data, pool_size, strides, padding, dilation, ceil_mode, count_include_pad, layout, out_layout); @@ -486,8 +492,8 @@ TVM_REGISTER_OP("relax.nn.avg_pool2d") .set_attr("FPurity", Bool(true)); /* relax.nn.avg_pool3d */ -Expr avg_pool3d(Expr data, ffi::Array pool_size, ffi::Array strides, - ffi::Array padding, ffi::Array dilation, bool ceil_mode, +Expr avg_pool3d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, bool count_include_pad, ffi::String layout, ffi::Optional out_layout) { return MakePool3d("relax.nn.avg_pool3d", data, pool_size, strides, padding, dilation, ceil_mode, count_include_pad, layout, out_layout); @@ -509,13 +515,13 @@ TVM_REGISTER_OP("relax.nn.avg_pool3d") /* relax.nn.adaptive_avg_pool1d */ -Expr adaptive_avg_pool1d(Expr data, ffi::Optional> output_size, +Expr adaptive_avg_pool1d(Expr data, ffi::Optional> output_size, ffi::String layout, ffi::Optional out_layout) { ObjectPtr attrs = ffi::make_object(); attrs->layout = layout; attrs->out_layout = out_layout.value_or(layout); if (output_size.defined()) { - ffi::Array _output_size = output_size.value(); + ffi::Array _output_size = output_size.value(); CHECK_EQ(_output_size.size(), 1) << "The output_size length is expected to be 1. However, the given output_size is " << _output_size; @@ -556,7 +562,7 @@ StructInfo InferStructInfoAdaptiveAvgPool1D(const Call& call, const BlockBuilder ffi::Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); ffi::Array out_NCW_shape(data_NCW_shape); if (attrs->output_size.defined()) { - out_NCW_shape.Set(2, attrs->output_size.value()[0]); + out_NCW_shape.Set(2, Integer(attrs->output_size.value()[0])); } ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); @@ -591,13 +597,13 @@ TVM_REGISTER_OP("relax.nn.adaptive_avg_pool1d") /* relax.nn.adaptive_avg_pool2d */ -Expr adaptive_avg_pool2d(Expr data, ffi::Optional> output_size, +Expr adaptive_avg_pool2d(Expr data, ffi::Optional> output_size, ffi::String layout, ffi::Optional out_layout) { ObjectPtr attrs = ffi::make_object(); attrs->layout = layout; attrs->out_layout = out_layout.value_or(layout); if (output_size.defined()) { - ffi::Array _output_size = output_size.value(); + ffi::Array _output_size = output_size.value(); if (_output_size.size() == 1) { _output_size.push_back(_output_size[0]); } @@ -641,8 +647,8 @@ StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder ffi::Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); ffi::Array out_NCHW_shape(data_NCHW_shape); if (attrs->output_size.defined()) { - out_NCHW_shape.Set(2, attrs->output_size.value()[0]); - out_NCHW_shape.Set(3, attrs->output_size.value()[1]); + out_NCHW_shape.Set(2, Integer(attrs->output_size.value()[0])); + out_NCHW_shape.Set(3, Integer(attrs->output_size.value()[1])); } ffi::Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); @@ -693,13 +699,13 @@ TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") /* relax.nn.adaptive_avg_pool3d */ -Expr adaptive_avg_pool3d(Expr data, ffi::Optional> output_size, +Expr adaptive_avg_pool3d(Expr data, ffi::Optional> output_size, ffi::String layout, ffi::Optional out_layout) { ObjectPtr attrs = ffi::make_object(); attrs->layout = layout; attrs->out_layout = out_layout.value_or(layout); if (output_size.defined()) { - ffi::Array _output_size = output_size.value(); + ffi::Array _output_size = output_size.value(); if (_output_size.size() == 1) { _output_size.push_back(_output_size[0]); } @@ -743,9 +749,9 @@ StructInfo InferStructInfoAdaptiveAvgPool3D(const Call& call, const BlockBuilder ffi::Array data_NCDHW_shape = data2NCDHW.ForwardShape(data_shape.value()->values); ffi::Array out_NCDHW_shape(data_NCDHW_shape); if (attrs->output_size.defined()) { - out_NCDHW_shape.Set(2, attrs->output_size.value()[0]); - out_NCDHW_shape.Set(3, attrs->output_size.value()[1]); - out_NCDHW_shape.Set(4, attrs->output_size.value()[2]); + out_NCDHW_shape.Set(2, Integer(attrs->output_size.value()[0])); + out_NCDHW_shape.Set(3, Integer(attrs->output_size.value()[1])); + out_NCDHW_shape.Set(4, Integer(attrs->output_size.value()[2])); } ffi::Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); diff --git a/src/relax/op/nn/pooling.h b/src/relax/op/nn/pooling.h index c5435303e82b..d1fbc834ee10 100644 --- a/src/relax/op/nn/pooling.h +++ b/src/relax/op/nn/pooling.h @@ -33,17 +33,17 @@ namespace tvm { namespace relax { /*! \brief 2D maximum pooling operator. */ -Expr max_pool2d(Expr data, ffi::Array pool_size, ffi::Array strides, - ffi::Array padding, ffi::Array dilation, bool ceil_mode, +Expr max_pool2d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, bool count_include_pad, ffi::String layout, ffi::Optional out_layout); /*! \brief 2D average pooling operator. */ -Expr avg_pool2d(Expr data, ffi::Array pool_size, ffi::Array strides, - ffi::Array padding, ffi::Array dilation, bool ceil_mode, +Expr avg_pool2d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, bool count_include_pad, ffi::String layout, ffi::Optional out_layout); /*! \brief 2D adaptive average pooling operator. */ -Expr adaptive_avg_pool2d(Expr data, ffi::Optional> output_size, +Expr adaptive_avg_pool2d(Expr data, ffi::Optional> output_size, ffi::String layout, ffi::Optional out_layout); } // namespace relax diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 5c4f563bebee..69e573faf17d 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -465,7 +465,7 @@ inline ffi::Array ConvertIntImmToInt64(const ffi::Array& int_imm * \return The completed padding. * \throws Throws error if the input padding length is neither 1 or 2. */ -inline ffi::Array GetCompletePadding1D(ffi::Array padding) { +inline ffi::Array GetCompletePadding1D(ffi::Array padding) { if (padding.size() == 1) { return {padding[0], padding[0]}; } else if (padding.size() == 2) { @@ -486,7 +486,7 @@ inline ffi::Array GetCompletePadding1D(ffi::Array padding) { * \return The completed padding. * \throws Throws error if the input padding length is neither 1, 2 or 4. */ -inline ffi::Array GetCompletePadding2D(ffi::Array padding) { +inline ffi::Array GetCompletePadding2D(ffi::Array padding) { if (padding.size() == 1) { return {padding[0], padding[0], padding[0], padding[0]}; } else if (padding.size() == 2) { @@ -511,7 +511,7 @@ inline ffi::Array GetCompletePadding2D(ffi::Array padding) { * \return The completed padding. * \throws Throws error if the input padding length is neither 1, 3 or 6. */ -inline ffi::Array GetCompletePadding3D(ffi::Array padding) { +inline ffi::Array GetCompletePadding3D(ffi::Array padding) { if (padding.size() == 1) { return {padding[0], padding[0], padding[0], padding[0], padding[0], padding[0]}; } else if (padding.size() == 3) { diff --git a/src/relax/op/tensor/grad.cc b/src/relax/op/tensor/grad.cc index 52a218b730d0..0594ef75bd40 100644 --- a/src/relax/op/tensor/grad.cc +++ b/src/relax/op/tensor/grad.cc @@ -141,15 +141,15 @@ TVM_REGISTER_OP("relax.grad.nll_loss_backward") .set_attr("FPurity", Bool(true)); /* relax.grad.max_pool2d_backward */ -Expr max_pool2d_backward(Expr output_grad, Expr data, ffi::Array pool_size, - ffi::Array strides, ffi::Array padding, - ffi::Array dilation, bool ceil_mode, bool count_include_pad, +Expr max_pool2d_backward(Expr output_grad, Expr data, ffi::Array pool_size, + ffi::Array strides, ffi::Array padding, + ffi::Array dilation, bool ceil_mode, bool count_include_pad, ffi::String layout, ffi::Optional out_layout) { auto attrs = ffi::make_object(); attrs->pool_size = std::move(pool_size); - attrs->strides = ConvertIntImmToInt64(strides); - attrs->padding = ConvertIntImmToInt64(padding); - attrs->dilation = ConvertIntImmToInt64(dilation); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); attrs->ceil_mode = ceil_mode; attrs->count_include_pad = count_include_pad; attrs->layout = layout; @@ -176,15 +176,15 @@ TVM_REGISTER_OP("relax.grad.max_pool2d_backward") .set_attr("FPurity", Bool(true)); /* relax.grad.avg_pool2d_backward */ -Expr avg_pool2d_backward(Expr output_grad, Expr data, ffi::Array pool_size, - ffi::Array strides, ffi::Array padding, - ffi::Array dilation, bool ceil_mode, bool count_include_pad, +Expr avg_pool2d_backward(Expr output_grad, Expr data, ffi::Array pool_size, + ffi::Array strides, ffi::Array padding, + ffi::Array dilation, bool ceil_mode, bool count_include_pad, ffi::String layout, ffi::Optional out_layout) { auto attrs = ffi::make_object(); attrs->pool_size = std::move(pool_size); - attrs->strides = ConvertIntImmToInt64(strides); - attrs->padding = ConvertIntImmToInt64(padding); - attrs->dilation = ConvertIntImmToInt64(dilation); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); attrs->ceil_mode = ceil_mode; attrs->count_include_pad = count_include_pad; attrs->layout = layout; diff --git a/src/relax/op/tensor/grad.h b/src/relax/op/tensor/grad.h index 406d7a2f779e..911049475d2c 100644 --- a/src/relax/op/tensor/grad.h +++ b/src/relax/op/tensor/grad.h @@ -46,16 +46,16 @@ Expr nll_loss_backward(Expr output_grad, Expr predictions, Expr targets, /*! \brief Backward operator of relax.max_pool2d. All parameters except output_grad is the same as * relax.max_pool2d. Returns the gradient w.r.t. data. */ -Expr max_pool2d_backward(Expr output_grad, Expr data, ffi::Array pool_size, - ffi::Array strides, ffi::Array padding, - ffi::Array dilation, bool ceil_mode, bool count_include_pad, +Expr max_pool2d_backward(Expr output_grad, Expr data, ffi::Array pool_size, + ffi::Array strides, ffi::Array padding, + ffi::Array dilation, bool ceil_mode, bool count_include_pad, ffi::String layout, ffi::Optional out_layout); /*! \brief Backward operator of relax.avg_pool2d. All parameters except output_grad is the same as * relax.avg_pool2d. Returns the gradient w.r.t. data. */ -Expr avg_pool2d_backward(Expr output_grad, Expr data, ffi::Array pool_size, - ffi::Array strides, ffi::Array padding, - ffi::Array dilation, bool ceil_mode, bool count_include_pad, +Expr avg_pool2d_backward(Expr output_grad, Expr data, ffi::Array pool_size, + ffi::Array strides, ffi::Array padding, + ffi::Array dilation, bool ceil_mode, bool count_include_pad, ffi::String layout, ffi::Optional out_layout); /*! \brief Backward operator of relax.take. All parameters except output_grad is the same as diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index ee21c14c6f44..1d2a6e467615 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -277,10 +277,8 @@ def test_op_attr(): conv2d = rx.op.nn.conv2d(x, y, strides=(3, 3)) xp = is_var("x") yp = is_var("y") - # TODO(@yuchen): reenable the assert after figuring out why it fails - # assert is_op("nn.conv2d")(xp, yp).has_attr({"strides": [3, 3]}).match(conv2d) + assert is_op("relax.nn.conv2d")(xp, yp).has_attr({"strides": [3, 3]}).match(conv2d) assert not is_op("relax.nn.conv2d")(xp, yp).has_attr({"strides": [4, 3]}).match(conv2d) - assert not is_op("relax.nn.conv2d")(xp, yp).has_attr({"strides": [3, 3]}).match(conv2d) def test_match_call_attr(): diff --git a/tests/python/relax/test_op_nn_convolution.py b/tests/python/relax/test_op_nn_convolution.py index 9b913138df12..4ee8226bc3f2 100644 --- a/tests/python/relax/test_op_nn_convolution.py +++ b/tests/python/relax/test_op_nn_convolution.py @@ -352,10 +352,10 @@ def test_conv1d_stride_padding_dilation_int64(): w = relax.Var("w", R.Tensor((4, 3, 3), "float32")) conv1d = relax.op.nn.conv1d(x, w, strides=(1,), padding=(1, 1), dilation=(1,)) - assert conv1d.attrs.strides[0].dtype == "int64" - assert conv1d.attrs.padding[0].dtype == "int64" - assert conv1d.attrs.padding[1].dtype == "int64" - assert conv1d.attrs.dilation[0].dtype == "int64" + assert isinstance(conv1d.attrs.strides[0], int) + assert isinstance(conv1d.attrs.padding[0], int) + assert isinstance(conv1d.attrs.padding[1], int) + assert isinstance(conv1d.attrs.dilation[0], int) def test_conv1d_wrong_strides_padding_dilation_length(): @@ -711,9 +711,9 @@ def test_conv1d_transpose_stride_padding_dilation_int64(): w = relax.Var("w", R.Tensor((3, 4, 3), "float32")) conv1d = relax.op.nn.conv1d_transpose(x, w, strides=1, padding=1, dilation=1) - assert conv1d.attrs.strides[0].dtype == "int64" - assert conv1d.attrs.padding[0].dtype == "int64" - assert conv1d.attrs.dilation[0].dtype == "int64" + assert isinstance(conv1d.attrs.strides[0], int) + assert isinstance(conv1d.attrs.padding[0], int) + assert isinstance(conv1d.attrs.dilation[0], int) def test_conv1d_transpose_wrong_strides_padding_dilation_length(): @@ -1122,14 +1122,14 @@ def test_conv2d_stride_padding_dilation_int64(): w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) conv2d = relax.op.nn.conv2d(x, w, strides=(1, 1), padding=(1, 1), dilation=(1, 1)) - assert conv2d.attrs.strides[0].dtype == "int64" - assert conv2d.attrs.strides[1].dtype == "int64" - assert conv2d.attrs.padding[0].dtype == "int64" - assert conv2d.attrs.padding[1].dtype == "int64" - assert conv2d.attrs.padding[2].dtype == "int64" - assert conv2d.attrs.padding[3].dtype == "int64" - assert conv2d.attrs.dilation[0].dtype == "int64" - assert conv2d.attrs.dilation[1].dtype == "int64" + assert isinstance(conv2d.attrs.strides[0], int) + assert isinstance(conv2d.attrs.strides[1], int) + assert isinstance(conv2d.attrs.padding[0], int) + assert isinstance(conv2d.attrs.padding[1], int) + assert isinstance(conv2d.attrs.padding[2], int) + assert isinstance(conv2d.attrs.padding[3], int) + assert isinstance(conv2d.attrs.dilation[0], int) + assert isinstance(conv2d.attrs.dilation[1], int) def test_conv2d_wrong_strides_padding_dilation_length(): @@ -1510,16 +1510,16 @@ def test_conv2d_transpose_stride_padding_dilation_int64(): x, w, strides=(1, 1), padding=(1, 1), output_padding=(1, 2), dilation=(1, 1) ) - assert conv2d_transpose.attrs.strides[0].dtype == "int64" - assert conv2d_transpose.attrs.strides[1].dtype == "int64" - assert conv2d_transpose.attrs.padding[0].dtype == "int64" - assert conv2d_transpose.attrs.padding[1].dtype == "int64" - assert conv2d_transpose.attrs.padding[2].dtype == "int64" - assert conv2d_transpose.attrs.padding[3].dtype == "int64" - assert conv2d_transpose.attrs.output_padding[0].dtype == "int64" - assert conv2d_transpose.attrs.output_padding[1].dtype == "int64" - assert conv2d_transpose.attrs.dilation[0].dtype == "int64" - assert conv2d_transpose.attrs.dilation[1].dtype == "int64" + assert isinstance(conv2d_transpose.attrs.strides[0], int) + assert isinstance(conv2d_transpose.attrs.strides[1], int) + assert isinstance(conv2d_transpose.attrs.padding[0], int) + assert isinstance(conv2d_transpose.attrs.padding[1], int) + assert isinstance(conv2d_transpose.attrs.padding[2], int) + assert isinstance(conv2d_transpose.attrs.padding[3], int) + assert isinstance(conv2d_transpose.attrs.output_padding[0], int) + assert isinstance(conv2d_transpose.attrs.output_padding[1], int) + assert isinstance(conv2d_transpose.attrs.dilation[0], int) + assert isinstance(conv2d_transpose.attrs.dilation[1], int) def test_conv2d_transpose_wrong_strides_padding_dilation_length(): diff --git a/tests/python/relax/test_op_nn_pooling.py b/tests/python/relax/test_op_nn_pooling.py index d4461a122de8..12d099bb334f 100644 --- a/tests/python/relax/test_op_nn_pooling.py +++ b/tests/python/relax/test_op_nn_pooling.py @@ -183,10 +183,10 @@ def test_max_pool1d_stride_padding_dilation_int64(): x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) max_pool1d = relax.op.nn.max_pool1d(x, pool_size=3, strides=1, padding=1, dilation=1) - assert max_pool1d.attrs.strides[0].dtype == "int64" - assert max_pool1d.attrs.padding[0].dtype == "int64" - assert max_pool1d.attrs.padding[1].dtype == "int64" - assert max_pool1d.attrs.dilation[0].dtype == "int64" + assert isinstance(max_pool1d.attrs.strides[0], int) + assert isinstance(max_pool1d.attrs.padding[0], int) + assert isinstance(max_pool1d.attrs.padding[1], int) + assert isinstance(max_pool1d.attrs.dilation[0], int) def test_max_pool1d_wrong_pool_size_strides_padding_dilation_length(): @@ -412,14 +412,14 @@ def test_max_pool2d_stride_padding_dilation_int64(): x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) max_pool2d = relax.op.nn.max_pool2d(x, (3, 3), strides=(1, 1), padding=(1, 1), dilation=(1, 1)) - assert max_pool2d.attrs.strides[0].dtype == "int64" - assert max_pool2d.attrs.strides[1].dtype == "int64" - assert max_pool2d.attrs.padding[0].dtype == "int64" - assert max_pool2d.attrs.padding[1].dtype == "int64" - assert max_pool2d.attrs.padding[2].dtype == "int64" - assert max_pool2d.attrs.padding[3].dtype == "int64" - assert max_pool2d.attrs.dilation[0].dtype == "int64" - assert max_pool2d.attrs.dilation[1].dtype == "int64" + assert isinstance(max_pool2d.attrs.strides[0], int) + assert isinstance(max_pool2d.attrs.strides[1], int) + assert isinstance(max_pool2d.attrs.padding[0], int) + assert isinstance(max_pool2d.attrs.padding[1], int) + assert isinstance(max_pool2d.attrs.padding[2], int) + assert isinstance(max_pool2d.attrs.padding[3], int) + assert isinstance(max_pool2d.attrs.dilation[0], int) + assert isinstance(max_pool2d.attrs.dilation[1], int) def test_max_pool2d_wrong_pool_size_strides_padding_dilation_length(): @@ -660,17 +660,17 @@ def test_max_pool3d_stride_padding_dilation_int64(): x, (3, 3, 3), strides=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1) ) - assert max_pool3d.attrs.strides[0].dtype == "int64" - assert max_pool3d.attrs.strides[1].dtype == "int64" - assert max_pool3d.attrs.strides[2].dtype == "int64" - assert max_pool3d.attrs.padding[0].dtype == "int64" - assert max_pool3d.attrs.padding[1].dtype == "int64" - assert max_pool3d.attrs.padding[2].dtype == "int64" - assert max_pool3d.attrs.padding[3].dtype == "int64" - assert max_pool3d.attrs.padding[4].dtype == "int64" - assert max_pool3d.attrs.dilation[0].dtype == "int64" - assert max_pool3d.attrs.dilation[1].dtype == "int64" - assert max_pool3d.attrs.dilation[2].dtype == "int64" + assert isinstance(max_pool3d.attrs.strides[0], int) + assert isinstance(max_pool3d.attrs.strides[1], int) + assert isinstance(max_pool3d.attrs.strides[2], int) + assert isinstance(max_pool3d.attrs.padding[0], int) + assert isinstance(max_pool3d.attrs.padding[1], int) + assert isinstance(max_pool3d.attrs.padding[2], int) + assert isinstance(max_pool3d.attrs.padding[3], int) + assert isinstance(max_pool3d.attrs.padding[4], int) + assert isinstance(max_pool3d.attrs.dilation[0], int) + assert isinstance(max_pool3d.attrs.dilation[1], int) + assert isinstance(max_pool3d.attrs.dilation[2], int) def test_max_pool3d_wrong_pool_size_strides_padding_dilation_length(): @@ -875,10 +875,10 @@ def test_avg_pool1d_stride_padding_dilation_int64(): x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) avg_pool1d = relax.op.nn.avg_pool1d(x, 3, strides=1, padding=1, dilation=1) - assert avg_pool1d.attrs.strides[0].dtype == "int64" - assert avg_pool1d.attrs.padding[0].dtype == "int64" - assert avg_pool1d.attrs.padding[1].dtype == "int64" - assert avg_pool1d.attrs.dilation[0].dtype == "int64" + assert isinstance(avg_pool1d.attrs.strides[0], int) + assert isinstance(avg_pool1d.attrs.padding[0], int) + assert isinstance(avg_pool1d.attrs.padding[1], int) + assert isinstance(avg_pool1d.attrs.dilation[0], int) def test_avg_pool1d_wrong_pool_size_strides_padding_dilation_length(): @@ -1101,14 +1101,14 @@ def test_avg_pool2d_stride_padding_dilation_int64(): x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) avg_pool2d = relax.op.nn.avg_pool2d(x, (3, 3), strides=(1, 1), padding=(1, 1), dilation=(1, 1)) - assert avg_pool2d.attrs.strides[0].dtype == "int64" - assert avg_pool2d.attrs.strides[1].dtype == "int64" - assert avg_pool2d.attrs.padding[0].dtype == "int64" - assert avg_pool2d.attrs.padding[1].dtype == "int64" - assert avg_pool2d.attrs.padding[2].dtype == "int64" - assert avg_pool2d.attrs.padding[3].dtype == "int64" - assert avg_pool2d.attrs.dilation[0].dtype == "int64" - assert avg_pool2d.attrs.dilation[1].dtype == "int64" + assert isinstance(avg_pool2d.attrs.strides[0], int) + assert isinstance(avg_pool2d.attrs.strides[1], int) + assert isinstance(avg_pool2d.attrs.padding[0], int) + assert isinstance(avg_pool2d.attrs.padding[1], int) + assert isinstance(avg_pool2d.attrs.padding[2], int) + assert isinstance(avg_pool2d.attrs.padding[3], int) + assert isinstance(avg_pool2d.attrs.dilation[0], int) + assert isinstance(avg_pool2d.attrs.dilation[1], int) def test_avg_pool2d_wrong_pool_size_strides_padding_dilation_length(): @@ -1356,15 +1356,15 @@ def test_avg_pool3d_stride_padding_dilation_int64(): x, (3, 3, 3), strides=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1) ) - assert avg_pool3d.attrs.strides[0].dtype == "int64" - assert avg_pool3d.attrs.strides[1].dtype == "int64" - assert avg_pool3d.attrs.strides[2].dtype == "int64" - assert avg_pool3d.attrs.padding[0].dtype == "int64" - assert avg_pool3d.attrs.padding[1].dtype == "int64" - assert avg_pool3d.attrs.padding[2].dtype == "int64" - assert avg_pool3d.attrs.dilation[0].dtype == "int64" - assert avg_pool3d.attrs.dilation[1].dtype == "int64" - assert avg_pool3d.attrs.dilation[2].dtype == "int64" + assert isinstance(avg_pool3d.attrs.strides[0], int) + assert isinstance(avg_pool3d.attrs.strides[1], int) + assert isinstance(avg_pool3d.attrs.strides[2], int) + assert isinstance(avg_pool3d.attrs.padding[0], int) + assert isinstance(avg_pool3d.attrs.padding[1], int) + assert isinstance(avg_pool3d.attrs.padding[2], int) + assert isinstance(avg_pool3d.attrs.dilation[0], int) + assert isinstance(avg_pool3d.attrs.dilation[1], int) + assert isinstance(avg_pool3d.attrs.dilation[2], int) def test_avg_pool3d_wrong_pool_size_strides_padding_dilation_length(): diff --git a/tests/python/relax/test_transform_legalize_ops_grad.py b/tests/python/relax/test_transform_legalize_ops_grad.py index cf9361d7c78a..294cea71de58 100644 --- a/tests/python/relax/test_transform_legalize_ops_grad.py +++ b/tests/python/relax/test_transform_legalize_ops_grad.py @@ -272,17 +272,17 @@ def main(output_grad: R.Tensor((3, 2, 6, 5), "float32"), data: R.Tensor((3, 2, 1 @I.ir_module class Expected: @T.prim_func(private=True) - def avg_pool2d_backward(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(6), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), "float32"), T_pool_grad: T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), "float32")): + def avg_pool2d_backward(output_grad: T.Buffer((T.int64(3), T.int64(2), T.int64(6), T.int64(5)), "float32"), data: T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), "float32"), T_pool_grad: T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), "float32")): T.func_attr({"tir.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3, wh, ww in T.grid(T.int64(3), T.int64(2), T.int64(10), T.int64(10), T.int64(3), T.int64(3)): with T.sblock("T_pool_grad"): v_ax0, v_ax1, v_ax2, v_ax3, v_wh, v_ww = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, wh, ww]) - T.reads(rxplaceholder[v_ax0, v_ax1, T.Div((v_ax2 + T.int64(2)), T.int64(2)) - v_wh, T.Div((v_ax3 + T.int64(1)), T.int64(2)) - v_ww]) + T.reads(output_grad[v_ax0, v_ax1, T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh, T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww]) T.writes(T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3]) with T.init(): T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0) - T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] + T.if_then_else(T.Select(v_ax2 < T.int64(3), T.int64(0), T.Div(v_ax2 - T.int64(3), T.int64(2)) + T.int64(1)) <= T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh and T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh < T.int64(6) and T.Select(v_ax3 < T.int64(4), T.int64(0), T.Div(v_ax3 - T.int64(4), T.int64(2)) + T.int64(1)) <= T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww and T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww < T.int64(5), rxplaceholder[v_ax0, v_ax1, T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh, T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww] / T.Cast("float32", T.max((T.min(T.Div(v_ax2 + T.int64(2), T.int64(2)) * T.int64(2) + T.int64(3) - v_wh * T.int64(2), T.int64(10)) - T.max(T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh - T.int64(1), T.int64(0)) * T.int64(2)) * (T.min(T.Div(v_ax3 + T.int64(1), T.int64(2)) * T.int64(2) + T.int64(4) - v_ww * T.int64(2), T.int64(10)) - T.max(T.Div(v_ax3 + T.int64(1), T.int64(2)) * T.int64(2) - v_ww * T.int64(2) - T.int64(1), T.int64(0))), T.int64(1))), T.float32(0.0)) + T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] + T.if_then_else(T.Select(v_ax2 < T.int64(3), T.int64(0), T.Div(v_ax2 - T.int64(3), T.int64(2)) + T.int64(1)) <= T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh and T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh < T.int64(6) and T.Select(v_ax3 < T.int64(4), T.int64(0), T.Div(v_ax3 - T.int64(4), T.int64(2)) + T.int64(1)) <= T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww and T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww < T.int64(5), output_grad[v_ax0, v_ax1, T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh, T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww] / T.Cast("float32", T.max((T.min(T.Div(v_ax2 + T.int64(2), T.int64(2)) * T.int64(2) + T.int64(3) - T.Cast("int64", v_wh) * T.int64(2), T.int64(10)) - T.max(T.Div(v_ax2 + T.int64(2), T.int64(2)) - T.Cast("int64", v_wh) - T.int64(1), T.int64(0)) * T.int64(2)) * (T.min(T.Div(v_ax3 + T.int64(1), T.int64(2)) * T.int64(2) + T.int64(4) - T.Cast("int64", v_ww) * T.int64(2), T.int64(10)) - T.max(T.Div(v_ax3 + T.int64(1), T.int64(2)) * T.int64(2) - T.Cast("int64", v_ww) * T.int64(2) - T.int64(1), T.int64(0))), T.int64(1))), T.float32(0.0)) @R.function def main(output_grad: R.Tensor((3, 2, 6, 5), dtype="float32"), data: R.Tensor((3, 2, 10, 10), dtype="float32")) -> R.Tensor((3, 2, 10, 10), dtype="float32"):