diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 5cb577e82bda..9eed78d270bf 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -319,9 +319,19 @@ class WellFormedChecker : public relax::ExprVisitor, if (auto func_normalize = op_map_normalize_.get(call->op, nullptr); func_normalize != nullptr) { auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_); - auto before_normalize = GetRef(call); - auto after_normalize = func_normalize(dummy_builder, before_normalize); - if (!before_normalize.same_as(after_normalize)) { + Call before_normalize = GetRef(call); + Optional after_normalize = NullOpt; + try { + after_normalize = func_normalize(dummy_builder, before_normalize); + } catch (std::exception& err) { + Malformed( + Diagnostic::Error(call) + << "If an operator defines an operator-specific normalization function (FNormalize), " + << "calls to that operator must be normalized with it. " + << "However, normalization of " << before_normalize << " resulted in the error: \n" + << err.what()); + } + if (after_normalize && !before_normalize.same_as(after_normalize)) { Malformed( Diagnostic::Error(call) << "If an operator defines an operator-specific normalization function (FNormalize), " diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index fe74286a51f1..f51d2cc74f51 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -253,11 +253,70 @@ StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { return call->sinfo_args[0]; } -Expr NormalizeCallTIR(const BlockBuilder&, Call call) { - // Temporary implementation to ensure that at least one op has a - // registered value for FNormalize. This temporary implementation - // is fully implemented in follow-up PR - // https://github.com/apache/tvm/pull/16068. +Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { + // This function is used for normalization of `relax.call_tir`, + // along with the variants `relax.call_tir_with_grad` and + // `relax.call_tir_inplace`. Therefore, all error messages should + // be written in terms of `call->op`, and should not explicitly + // reference the `relax.call_tir` operator.` + CHECK(call->args.size() == 2 || call->args.size() == 3) + << "Operation " << call->op << " expects either two arguments [callee, arg_tuple], " + << "or three arguments [callee, arg_tuple, tir_args], " + << "but " << call << " has " << call->args.size() << " arguments."; + + Expr arg_expr = call->args[1]; + + CHECK(arg_expr->struct_info_.as()) + << "Operation " << call->op << " expects the second argument to be a tuple of relax Expr. " + << "However, the second argument " << arg_expr << " has struct info " + << arg_expr->struct_info_ << "."; + + if (arg_expr.as()) { + return std::move(call); + } + + CHECK(arg_expr.as()) + << "Operation " << call->op << " must hold its arguments as an in-line tuple. " + << "However, " << call << " has arguments " << arg_expr + << ", which is neither an in-line tuple, " + << "nor a variable binding that may be normalized to an in-line tuple."; + + auto unwrap_binding = [&ctx](Expr expr) -> Optional { + if (auto var = expr.as()) { + if (auto bound_value = ctx->LookupBinding(var.value())) { + return bound_value.value(); + } + } + return NullOpt; + }; + + while (auto unwrapped = unwrap_binding(arg_expr)) { + arg_expr = unwrapped.value(); + } + + Tuple new_arg_expr = [&]() { + // Preferred replacement. The argument tuple is provided as a + // variable, but we know the value bound to that variable. + if (auto opt = arg_expr.as()) { + return opt.value(); + } + + // Fallback case. The argument tuple is provided as a variable, + // and we don't know the value bound to that variable. For + // example, if a relax function accepted a tuple as an parameter, + // then provided that same tuple as an argument to call_tir. + Array tuple_elements; + size_t num_fields = Downcast(arg_expr->struct_info_)->fields.size(); + for (size_t i = 0; i < num_fields; i++) { + tuple_elements.push_back(TupleGetItem(arg_expr, i)); + } + return Tuple(tuple_elements); + }(); + + auto new_args = call->args; + new_args.Set(1, new_arg_expr); + call.CopyOnWrite()->args = new_args; + return std::move(call); } @@ -314,6 +373,7 @@ RELAY_REGISTER_OP("relax.call_tir_with_grad") "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " "args if unused") .set_attr("FInferStructInfo", InferStructInfoCallTIR) + .set_attr("FNormalize", NormalizeCallTIR) .set_attr("FPurity", Bool(true)); Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array out_sinfo_list, @@ -353,14 +413,12 @@ TVM_REGISTER_GLOBAL("relax.op.call_tir_with_grad").set_body_typed(MakeCallTIRWit // call_tir_inplace -StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& ctx) { - if (call->sinfo_args.size() != 1) { - ctx->ReportFatal(Diagnostic::Error(call) - << "sinfo_args should have exactly 1 output struct info."); - } - CHECK(call->args[0]->IsInstance()) - << "call_tir expects the first argument to be a GlobalVar referring to a TIR PrimFunc. " - << "However, gets " << call->args[0]; +Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) { + // Apply normalization before error checks. This allows the error + // checks to safely apply `Downcast(call->args[1])`, which + // may result in an error if performed before normalization. + call = Downcast(NormalizeCallTIR(ctx, std::move(call))); + // there must be an inplace index for each output const auto* attrs = call->attrs.as(); size_t num_outputs = 1U; @@ -443,7 +501,7 @@ StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& c } } - return call->sinfo_args[0]; + return std::move(call); } TVM_REGISTER_NODE_TYPE(CallTIRInplaceAttrs); @@ -456,7 +514,8 @@ RELAY_REGISTER_OP("relax.call_tir_inplace") .add_argument("packed_ints", "Expr", "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " "args if unused") - .set_attr("FInferStructInfo", InferStructInfoCallTIRInplace) + .set_attr("FInferStructInfo", InferStructInfoCallTIR) + .set_attr("FNormalize", NormalizeCallTIRInPlace) // Warning: considered pure, but it has the potential to create visible effects! // This should only be used if it has been *checked* that it is safe (no aliases, in-place // arguments will no longer be live) diff --git a/tests/python/relax/test_transform_cse.py b/tests/python/relax/test_transform_cse.py index d69ec61b5c1f..04de5bd499b3 100644 --- a/tests/python/relax/test_transform_cse.py +++ b/tests/python/relax/test_transform_cse.py @@ -290,5 +290,52 @@ def foo(x: R.Tensor((2, 3), dtype="float32")): verify(Before, Expected) +def test_call_tir_tuple_arg(): + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16, 16], "int32"), B: R.Tensor([16, 16], "int32")): + cls = Before + Prod = R.call_tir(cls.product, [A, B], out_sinfo=R.Tensor([16, 16], "int32")) + Sum = R.call_tir(cls.sum, [A, B], out_sinfo=R.Tensor([16, 16], "int32")) + return (Prod, Sum) + + @T.prim_func(private=True) + def product( + A: T.Buffer([16, 16], "int32"), + B: T.Buffer([16, 16], "int32"), + C: T.Buffer([16, 16], "int32"), + ): + for iters in T.grid(*A.shape): + with T.block("compute"): + i, j = T.axis.remap("SS", iters) + C[i, j] = A[i, j] * B[i, j] + + @T.prim_func(private=True) + def sum( + A: T.Buffer([16, 16], "int32"), + B: T.Buffer([16, 16], "int32"), + C: T.Buffer([16, 16], "int32"), + ): + for iters in T.grid(*A.shape): + with T.block("compute"): + i, j = T.axis.remap("SS", iters) + C[i, j] = A[i, j] + B[i, j] + + Expected = Before + + # If EliminateCommonSubexpr produces unnormalized expressions, + # normalization of those expressions may produce additional + # variables bindings. This test case should be agnostic to those + # additional bindings, so DCE is applied after CSE. + After = tvm.ir.transform.Sequential( + [ + EliminateCommonSubexpr(), + tvm.relax.transform.DeadCodeElimination(), + ] + )(Before) + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_operator_specific_normalization.py b/tests/python/relax/test_transform_operator_specific_normalization.py index 07d541ab1ed5..4ee17166452f 100644 --- a/tests/python/relax/test_transform_operator_specific_normalization.py +++ b/tests/python/relax/test_transform_operator_specific_normalization.py @@ -22,7 +22,7 @@ import tvm.relax.testing.transform from tvm import relax -from tvm.script.parser import ir as I, relax as R +from tvm.script.parser import ir as I, relax as R, tir as T import pytest @@ -167,5 +167,211 @@ def main(A: R.Tensor): assert relax.analysis.well_formed(Module) +@pytest.mark.skip_well_formed_check_before_transform +def test_normalize_to_inline_tuple_for_call_tir(custom_op): + """FNormalize in-lines the argument tuple for R.call_tir""" + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16], "float32")): + cls = Before + args = (A,) + return relax.Call( + tvm.ir.Op.get("relax.call_tir"), + [cls.multiply_by_two, args], + sinfo_args=[A.struct_info], + ) + + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * 2.0 + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([16], "float32")): + cls = Expected + args = (A,) + return relax.Call( + tvm.ir.Op.get("relax.call_tir"), + [cls.multiply_by_two, relax.Tuple([A])], + sinfo_args=[A.struct_info], + ) + + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * 2.0 + + After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before) + + assert not tvm.ir.structural_equal(Before, After) + tvm.ir.assert_structural_equal(Expected, After) + + +@pytest.mark.skip_well_formed_check_before_transform +def test_normalize_argument_to_inline_tuple_for_call_tir(custom_op): + """FNormalize in-lines the argument tuple for R.call_tir + + Like `test_normalize_to_inline_tuple_for_call_tir`, but the + argument tuple is provided as a relax function argument. + """ + + @I.ir_module + class Before: + @R.function + def main(args: R.Tuple([R.Tensor([16], "float32")])): + cls = Before + return relax.Call( + tvm.ir.Op.get("relax.call_tir"), + [cls.multiply_by_two, args], + sinfo_args=[args[0].struct_info], + ) + + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * 2.0 + + @I.ir_module + class Expected: + @R.function + def main(args: R.Tuple([R.Tensor([16], "float32")])): + cls = Expected + return relax.Call( + tvm.ir.Op.get("relax.call_tir"), + [cls.multiply_by_two, relax.Tuple([args[0]])], + sinfo_args=[args[0].struct_info], + ) + + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * 2.0 + + After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before) + + assert not tvm.ir.structural_equal(Before, After) + tvm.ir.assert_structural_equal(Expected, After) + + +@pytest.mark.skip_well_formed_check_before_transform +def test_normalize_to_inline_tuple_for_call_tir_inplace(custom_op): + """FNormalize in-lines the argument tuple for R.call_tir_inplace""" + + # The CallTIRInplaceAttrs cannot be constructed from the Python + # API. Therefore, declaring the Expected output first, so that + # the attributes can be used for the non-normalized Before. + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([16], "float32")): + cls = Expected + args = (A,) + return R.call_tir_inplace( + cls.multiply_by_two, + A, + inplace_indices=[0], + out_sinfo=[A.struct_info], + ) + + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer(16, "float32")): + for i in range(16): + A[i] = A[i] * 2.0 + + inplace_attrs = Expected["main"].body.blocks[0].bindings[1].value.attrs + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16], "float32")): + cls = Before + args = (A,) + return relax.Call( + tvm.ir.Op.get("relax.call_tir_inplace"), + [cls.multiply_by_two, args], + attrs=inplace_attrs, + sinfo_args=[A.struct_info], + ) + + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer(16, "float32")): + for i in range(16): + A[i] = A[i] * 2.0 + + After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before) + + assert not tvm.ir.structural_equal(Before, After) + tvm.ir.assert_structural_equal(Expected, After) + + +@pytest.mark.skip_well_formed_check_before_transform +def test_normalize_to_inline_tuple_for_call_tir_with_grad(custom_op): + """FNormalize in-lines the argument tuple for R.call_tir_with_grad""" + + # The CallTIRWithGradAttrs cannot be constructed from the Python + # API. Therefore, declaring the Expected output first, so that + # the attributes can be used for the non-normalized Before. + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([16], "float32")): + cls = Expected + args = (A,) + return R.call_tir_with_grad( + cls.multiply_by_two, + A, + out_sinfo=[A.struct_info], + te_grad_name="f_grad", + ) + + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * 2.0 + + @T.prim_func(private=True) + def f_grad( + A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32"), Grad: T.Buffer(16, "float32") + ): + for i in range(16): + Grad[i] = 2.0 + + with_grad_attrs = Expected["main"].body.blocks[0].bindings[1].value.attrs + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16], "float32")): + cls = Before + args = (A,) + return relax.Call( + tvm.ir.Op.get("relax.call_tir_with_grad"), + [cls.multiply_by_two, args], + attrs=with_grad_attrs, + sinfo_args=[A.struct_info], + ) + + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * 2.0 + + @T.prim_func(private=True) + def f_grad( + A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32"), Grad: T.Buffer(16, "float32") + ): + for i in range(16): + Grad[i] = 2.0 + + After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before) + + assert not tvm.ir.structural_equal(Before, After) + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": tvm.testing.main()