diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 4907a0bf2bd4..8a25684db7e4 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -697,7 +697,7 @@ def PartitionGraph(): return _ffi_api.PartitionGraph() -def AnnotateTarget(targets): +def AnnotateTarget(targets, include_non_call_ops=True): """Annotate ops in an experession with a provied compiler/target and then use it for codegen. @@ -705,6 +705,9 @@ def AnnotateTarget(targets): ---------- targets : str or List[str] The list of target compilers used for codegen. + include_non_call_ops : boolean + If True then non-call ops also will be annotated with targets + If False then non-call ops will not be processed Returns ------- @@ -714,7 +717,9 @@ def AnnotateTarget(targets): """ if isinstance(targets, str): targets = [targets] - return _ffi_api.AnnotateTarget([tvm.runtime.container.String(t) for t in targets]) + return _ffi_api.AnnotateTarget( + [tvm.runtime.container.String(t) for t in targets], include_non_call_ops + ) def DynamicToStatic(): diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index d5f1e4cc1752..76585cf1272f 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -39,13 +39,19 @@ static const PackedFunc* make_begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); static const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end"); - -// A helper class to insert annotation boundaries for a program region that will -// be handled by a specific compiler. +static const char default_target[] = "default"; +// A helper class to insert annotation boundaries for all the ops of a program +// region that will be handled by a specific compiler. class AnnotateTargetRewriter : public ExprRewriter { public: explicit AnnotateTargetRewriter(Array targets) : targets_(std::move(targets)) {} + protected: + /*! \brief The target backends for annotation. */ + Array targets_; + /*! \brief Maintain the decision of the target for each op expr. */ + std::unordered_map op_expr_to_target_; + /*! * \brief This function annotates a compiler end and a compiler begin to all arguments. * @@ -61,20 +67,27 @@ class AnnotateTargetRewriter : public ExprRewriter { std::pair> AnnotateArgs(const Array& args, const std::string& target = "") { std::string ref_target = ""; + Array compiler_begins; Array compiler_ends; for (auto arg : args) { - std::string arg_target = "default"; + std::string arg_target = default_target; const CallNode* call = arg.as(); if (call && call->op == CompilerBeginOp()) { // Argument is already compiler begin node meaning that this is not the first time // running this pass, so we simply remove it and will add a new one later. ICHECK_EQ(call->args.size(), 1U); + // Do not alter existing annotation if not default + if (default_target != call->attrs.as()->compiler) { + compiler_begins.push_back(arg); + } else { + // Remove default + compiler_ends.push_back(call->args[0]); + } const CallNode* end = call->args[0].as(); - if (end->op == CompilerEndOp()) { + if (end && end->op == CompilerEndOp()) { arg_target = end->attrs.as()->compiler; } - compiler_ends.push_back(call->args[0]); } else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) { arg_target = op_expr_to_target_[arg]; // If an argument is a call node and has no argument, then it should be tensor ops such as @@ -93,18 +106,20 @@ class AnnotateTargetRewriter : public ExprRewriter { if (ref_target == "") { ref_target = arg_target; } else if (ref_target != arg_target) { - ref_target = "default"; + ref_target = default_target; } } // Determine compiler begin target. std::string op_target = (target == "") ? ref_target : target; - Array compiler_begins; - for (const auto& end : compiler_ends) { - compiler_begins.push_back(InsertAnnotation(end, op_target, make_begin_op)); + if (ref_target != "") { + for (const auto& end : compiler_ends) { + compiler_begins.push_back(InsertAnnotation(end, op_target, make_begin_op)); + } + } else { + return {op_target, args}; } - return {op_target, compiler_begins}; } @@ -128,14 +143,31 @@ class AnnotateTargetRewriter : public ExprRewriter { * \return An annotated and target-propagated relay expression. */ Expr new_expr = expr; - if (op_expr_to_target_.find(expr) != op_expr_to_target_.end() && FreeVars(expr).size() != 0) { - new_expr = InsertAnnotation(expr, op_expr_to_target_[expr], make_end_op); - op_expr_to_target_[new_expr] = op_expr_to_target_[expr]; + const CallNode* call = expr.as(); + if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) { + // Check whether expr has args, if not - do not insert compiler_end. + if (expr->IsInstance() || expr->IsInstance() || + expr->IsInstance() || expr->IsInstance() || + expr->IsInstance() || (call && !call->args.empty())) { + std::string target = op_expr_to_target_[new_expr]; + new_expr = InsertAnnotation(new_expr, target, make_end_op); + op_expr_to_target_[new_expr] = target; + } + } else if (call && call->op == CompilerEndOp()) { + if (default_target == call->attrs.as()->compiler) { + ICHECK_EQ(call->args.size(), 1U); + new_expr = call->args[0]; + std::string target = op_expr_to_target_[new_expr]; + new_expr = InsertAnnotation(new_expr, target, make_end_op); + op_expr_to_target_[new_expr] = target; + } } + return std::move(new_expr); } - Expr Rewrite_(const CallNode* pre, const Expr& post) final { + public: + Expr Rewrite_(const CallNode* pre, const Expr& post) override { // Supported targets for this node. The order implies the priority. std::vector supported_targets; @@ -146,13 +178,19 @@ class AnnotateTargetRewriter : public ExprRewriter { // Bypass compiler begin due to lack of target information. It will be processed // when the following op handling arguments. ICHECK_EQ(pre->args.size(), 1U); - return post.as()->args[0]; + // Preserve annotations + return post; } else if (op_node && pre->op == CompilerEndOp()) { // Override compiler end with the new target. ICHECK_EQ(pre->args.size(), 1U); auto input_expr = post.as()->args[0]; + // Already annotated. Recover target + if (op_expr_to_target_.find(input_expr) == op_expr_to_target_.end()) { + op_expr_to_target_[input_expr] = post.as()->attrs.as()->compiler; + } ICHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end()); - return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op); + // Preserve annotated nodes + return post; } // Check prior to peeking first argument if (pre->args.size()) { @@ -161,8 +199,9 @@ class AnnotateTargetRewriter : public ExprRewriter { const CallNode* first_arg_call = pre->args[0].as(); if (first_arg_call && first_arg_call->op == CompilerBeginOp()) { std::string arg_target = first_arg_call->attrs.as()->compiler; - if (arg_target != "default") { - supported_targets.push_back(arg_target); + if (arg_target != default_target) { + // annotated already + return post; } } } @@ -188,7 +227,6 @@ class AnnotateTargetRewriter : public ExprRewriter { // if it is in the target list. Function func = Downcast(pre->op); ICHECK(func.defined()); - if (auto comp_name = func->GetAttr(attr::kComposite)) { std::string comp_name_str = comp_name.value(); size_t i = comp_name_str.find('.'); @@ -203,16 +241,18 @@ class AnnotateTargetRewriter : public ExprRewriter { } } } - supported_targets.push_back("default"); // Make default as the last option. - + supported_targets.push_back(default_target); // Make default as the last option. + // Visit and mutate arguments after the target of this op has been determined. + Call post_call = Downcast(post); + if (pre->op->IsInstance()) { + auto new_call = RewriteVarCall(post_call); + if (nullptr != new_call) return GetRef(new_call->get()); + } // TODO(@comaniac, @zhiics): Now we simply assign this node to the target with // the highest priority, but we should preserve all supported targets so that // we can make a better decision. std::string target = supported_targets[0]; - // Visit and mutate arguments after the target of this op has been determined. - Call post_call = Downcast(post); - // Add annotations to each arg. auto target_n_args = AnnotateArgs(post_call->args, target); Array compiler_begins = std::get<1>(target_n_args); @@ -221,11 +261,12 @@ class AnnotateTargetRewriter : public ExprRewriter { // Update the target map. op_expr_to_target_[new_call] = target; - return std::move(new_call); } - Expr Rewrite_(const TupleNode* op, const Expr& post) final { + virtual std::unique_ptr RewriteVarCall(const Call& post_call) { return nullptr; } + + Expr Rewrite_(const TupleNode* op, const Expr& post) override { auto expr = Downcast(post); auto target_n_args = AnnotateArgs(expr->fields); @@ -234,7 +275,7 @@ class AnnotateTargetRewriter : public ExprRewriter { return std::move(new_expr); } - Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final { + Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) override { auto expr = Downcast(post); auto target_n_args = AnnotateArgs(Array({expr->tuple})); @@ -243,7 +284,7 @@ class AnnotateTargetRewriter : public ExprRewriter { return std::move(new_expr); } - Expr Rewrite_(const FunctionNode* fn, const Expr& post) final { + Expr Rewrite_(const FunctionNode* fn, const Expr& post) override { Function func; Expr new_body; // don't step into composite functions @@ -257,7 +298,7 @@ class AnnotateTargetRewriter : public ExprRewriter { return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs); } - Expr Rewrite_(const LetNode* op, const Expr& post) final { + Expr Rewrite_(const LetNode* op, const Expr& post) override { auto let = Downcast(post); Expr new_expr; @@ -274,7 +315,7 @@ class AnnotateTargetRewriter : public ExprRewriter { return std::move(new_expr); } - Expr Rewrite_(const IfNode* op, const Expr& post) final { + Expr Rewrite_(const IfNode* op, const Expr& post) override { auto expr = Downcast(post); Expr new_cond = InsertCompilerEndAndPropogateTarget(expr->cond); Expr new_true_branch = InsertCompilerEndAndPropogateTarget(expr->true_branch); @@ -284,7 +325,7 @@ class AnnotateTargetRewriter : public ExprRewriter { return std::move(new_expr); } - Expr Rewrite_(const RefCreateNode* op, const Expr& post) final { + Expr Rewrite_(const RefCreateNode* op, const Expr& post) override { auto expr = Downcast(post); auto target_n_args = AnnotateArgs(Array({expr->value})); @@ -293,7 +334,7 @@ class AnnotateTargetRewriter : public ExprRewriter { return std::move(new_expr); } - Expr Rewrite_(const RefReadNode* op, const Expr& post) final { + Expr Rewrite_(const RefReadNode* op, const Expr& post) override { auto expr = Downcast(post); auto target_n_args = AnnotateArgs(Array({expr->ref})); @@ -302,7 +343,7 @@ class AnnotateTargetRewriter : public ExprRewriter { return std::move(new_expr); } - Expr Rewrite_(const RefWriteNode* op, const Expr& post) final { + Expr Rewrite_(const RefWriteNode* op, const Expr& post) override { auto expr = Downcast(post); auto target_n_args = AnnotateArgs(Array({expr->ref, expr->value})); @@ -310,27 +351,85 @@ class AnnotateTargetRewriter : public ExprRewriter { op_expr_to_target_[new_expr] = std::get<0>(target_n_args); return std::move(new_expr); } +}; - private: - /*! \brief The target backends for annotation. */ - Array targets_; - /*! \brief Maintain the decision of the target for each op expr. */ - std::unordered_map op_expr_to_target_; +// A helper class to insert annotation boundaries for call ops and function nodes +// in a program region that will be handled by a specific compiler. +class CallOpsTargetRewriter : public AnnotateTargetRewriter { + public: + explicit CallOpsTargetRewriter(Array targets) + : AnnotateTargetRewriter(std::move(targets)) {} + + std::unique_ptr RewriteVarCall(const Call& post_call) override { + Array ends; + for (auto arg : post_call->args) { + ends.push_back(InsertCompilerEndAndPropogateTarget(arg)); + } + auto new_call = std::make_unique(post_call->op, ends, post_call->attrs); + (*new_call)->checked_type_ = post_call->checked_type_; + return new_call; + } + + Expr Rewrite_(const TupleNode* op, const Expr& post) override { + auto expr = Downcast(post); + Array new_fields; + for (auto f : expr->fields) { + new_fields.push_back(InsertCompilerEndAndPropogateTarget(f)); + } + return std::move(Tuple(new_fields)); + } + + Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) override { + auto expr = Downcast(post); + return std::move(TupleGetItem(InsertCompilerEndAndPropogateTarget(expr->tuple), expr->index)); + } + + Expr Rewrite_(const IfNode* op, const Expr& post) override { + auto expr = Downcast(post); + Expr new_cond = InsertCompilerEndAndPropogateTarget(expr->cond); + Expr new_true_branch = InsertCompilerEndAndPropogateTarget(expr->true_branch); + Expr new_false_branch = InsertCompilerEndAndPropogateTarget(expr->false_branch); + + auto new_expr = If(new_cond, new_true_branch, new_false_branch); + return std::move(new_expr); + } + + Expr Rewrite_(const RefCreateNode* op, const Expr& post) override { + auto expr = Downcast(post); + auto new_expr = RefCreate(InsertCompilerEndAndPropogateTarget(expr->value)); + return std::move(new_expr); + } + + Expr Rewrite_(const RefReadNode* op, const Expr& post) override { + auto expr = Downcast(post); + auto new_expr = RefRead(InsertCompilerEndAndPropogateTarget(expr->ref)); + return std::move(new_expr); + } + + Expr Rewrite_(const RefWriteNode* op, const Expr& post) override { + auto expr = Downcast(post); + auto new_expr = RefWrite(InsertCompilerEndAndPropogateTarget(expr->ref), + InsertCompilerEndAndPropogateTarget(expr->value)); + return std::move(new_expr); + } }; -Expr AnnotateTarget(const Expr& expr, const Array& targets) { - auto rewriter = AnnotateTargetRewriter(targets); - return PostOrderRewrite(expr, &rewriter); +Expr AnnotateTarget(const Expr& expr, const Array& targets, + bool include_non_call_ops) { + auto r = include_non_call_ops ? std::make_unique(targets) + : std::make_unique(targets); + return PostOrderRewrite(expr, r.get()); } } // namespace annotate_target namespace transform { -Pass AnnotateTarget(const Array& targets) { +Pass AnnotateTarget(const Array& targets, bool include_non_call_ops) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::annotate_target::AnnotateTarget(f, targets)); + return Downcast( + relay::annotate_target::AnnotateTarget(f, targets, include_non_call_ops)); }; auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", {"InferType"}); return transform::Sequential({func_pass, InferType()}, "AnnotateTarget"); diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 325826d183da..4f35066a8384 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -212,9 +212,10 @@ def after(): mod = tvm.IRModule.from_expr(f) return mod - result = transform.AnnotateTarget("test")(before()) - expected = transform.InferType()(after()) - assert tvm.ir.structural_equal(expected, result) + for annotate_non_call_ops in [False, True]: + result = transform.AnnotateTarget("test", annotate_non_call_ops)(before()) + expected = transform.InferType()(after()) + assert tvm.ir.structural_equal(expected, result) def test_type_propagation(): @@ -232,8 +233,57 @@ def before(): mod = tvm.IRModule.from_expr(f) return mod - # If the type isn't propogated, then the relu checker function will fail to get the dtype. - assert transform.AnnotateTarget(target)(before()) + for annotate_non_call_ops in [False, True]: + # If the type isn't propogated, then the relu checker function will fail to get the dtype. + assert transform.AnnotateTarget(target, annotate_non_call_ops)(before()) + + +def test_ref_create_read_write(): + target = "relu" + + @tvm.ir.register_op_attr("nn.relu", "target." + target) + def annotate(expr): + return True + + def before(): + ref = relay.expr.RefCreate(relay.const(1.0)) + r = relay.expr.RefWrite(ref, relay.nn.relu(relay.expr.RefRead(ref))) + return tvm.IRModule.from_expr(r) + + def after(annotate_non_call_ops): + co = relay.const(1.0) + if annotate_non_call_ops: + co = relay.annotation.compiler_begin(co, "default") + + ref = relay.expr.RefCreate(co) + ref1 = ref + if annotate_non_call_ops: + ref = relay.annotation.compiler_end(ref, "default") + ref = relay.annotation.compiler_begin(ref, "default") + ref1 = relay.annotation.compiler_end(ref1, "default") + ref1 = relay.annotation.compiler_begin(ref1, "default") + + read = relay.expr.RefRead(ref1) + if annotate_non_call_ops: + read = relay.annotation.compiler_end(read, "default") + + beg = relay.annotation.compiler_begin(read, target) + relu = relay.nn.relu(beg) + end = relay.annotation.compiler_end(relu, target) + + if annotate_non_call_ops: + end = relay.annotation.compiler_begin(end, "default") + + r = relay.expr.RefWrite(ref, end) + + if annotate_non_call_ops: + r = relay.annotation.compiler_end(r, "default") + return tvm.IRModule.from_expr(r) + + for annotate_non_call_ops in [True, False, True]: + result = transform.AnnotateTarget(target, annotate_non_call_ops)(before()) + expected = transform.InferType()(after(annotate_non_call_ops)) + assert tvm.ir.structural_equal(expected, result) def test_tuple(): @@ -259,7 +309,7 @@ def before(): mod = tvm.IRModule.from_expr(f) return mod - def after(): + def after(annotate_non_call_ops): x = relay.var("x", shape=(10, 5)) y = relay.var("y", shape=(10, 5)) cb_1 = relay.annotation.compiler_begin(x, target) @@ -268,10 +318,15 @@ def after(): a_2 = relay.nn.relu(cb_2) ce_1 = relay.annotation.compiler_end(a_1, target) ce_2 = relay.annotation.compiler_end(a_2, target) - cb_3 = relay.annotation.compiler_begin(ce_1, target) - cb_4 = relay.annotation.compiler_begin(ce_2, target) - tup = relay.Tuple([cb_3, cb_4]) - ce_3 = relay.annotation.compiler_end(tup, target) + + if annotate_non_call_ops: + cb_3 = relay.annotation.compiler_begin(ce_1, target) + cb_4 = relay.annotation.compiler_begin(ce_2, target) + tup = relay.Tuple([cb_3, cb_4]) + ce_3 = relay.annotation.compiler_end(tup, target) + else: + ce_3 = relay.Tuple([ce_1, ce_2]) + cb_3 = relay.annotation.compiler_begin(ce_3, target) out = relay.op._make.concatenate(cb_3, 1) ce_4 = relay.annotation.compiler_end(out, target) @@ -279,9 +334,10 @@ def after(): mod = tvm.IRModule.from_expr(f) return mod - result = transform.AnnotateTarget(target)(before()) - expected = transform.InferType()(after()) - assert tvm.ir.structural_equal(expected, result) + for annotate_non_call_ops in [False, True]: + result = transform.AnnotateTarget(target, annotate_non_call_ops)(before()) + expected = transform.InferType()(after(annotate_non_call_ops)) + assert tvm.ir.structural_equal(expected, result) def test_composite_function(): @@ -329,6 +385,48 @@ def after(): assert tvm.ir.structural_equal(expected, result) +def test_double_target(): + @tvm.ir.register_op_attr("nn.relu", "target.double.A") + def relu(expr): # pylint: disable=unused-variable + return True + + def before(): + x = relay.var("x", shape=(10, 5)) + a_1 = relay.nn.relu(x) + mod = tvm.IRModule.from_expr(a_1) + return mod + + for annotate_non_call_ops in [True, False]: + mod = before() + mod1 = transform.AnnotateTarget("double.A", annotate_non_call_ops)(mod) + mod2 = transform.AnnotateTarget("double.A", annotate_non_call_ops)(mod1) + assert tvm.ir.structural_equal(mod1, mod2) + + +def test_different_targets(): + @tvm.ir.register_op_attr("nn.relu", "target.different.A") + def relu(expr): # pylint: disable=unused-variable + return True + + @tvm.ir.register_op_attr("add", "target.different.B") + def relu(expr): # pylint: disable=unused-variable + return True + + def before(): + x = relay.var("x", shape=(10, 5)) + a_1 = relay.nn.relu(x) + b_1 = relay.add(a_1, a_1) + mod = tvm.IRModule.from_expr(b_1) + return mod + + for annotate_non_call_ops in [True, False]: + mod = before() + mod1 = transform.AnnotateTarget("different.A", annotate_non_call_ops)(mod) + mod1 = transform.AnnotateTarget("different.B", annotate_non_call_ops)(mod1) + mod2 = transform.AnnotateTarget(["different.A", "different.B"], annotate_non_call_ops)(mod) + assert tvm.ir.structural_equal(mod1, mod2) + + def test_multiple_runs(): @tvm.ir.register_op_attr("nn.relu", "target.A") def relu(expr): # pylint: disable=unused-variable @@ -349,10 +447,62 @@ def before(): mod = tvm.IRModule.from_expr(f) return mod - mod = transform.AnnotateTarget("A")(before()) - mod = transform.AnnotateTarget("B")(mod) - expected = transform.AnnotateTarget(["A", "B"])(before()) - assert tvm.ir.structural_equal(expected, mod) + for annotate_non_call_ops in [True, False]: + mod = transform.AnnotateTarget("A", annotate_non_call_ops)(before()) + mod = transform.AnnotateTarget("B", annotate_non_call_ops)(mod) + expected = transform.AnnotateTarget(["A", "B"], annotate_non_call_ops)(before()) + assert tvm.ir.structural_equal(expected, mod) + + +def test_ends_with_tuple(): + trgt = "clip" + + @tvm.ir.register_op_attr("clip", "target." + trgt) + def relu(expr): # pylint: disable=unused-variable + return True + + def get_model(get_item): + """Return a model""" + a = relay.var("a", shape=(1, 16, 16, 4), dtype="uint8") + z = relay.op.clip(a, 0, 255) + b = relay.op.clip(z, 0, 15) + c = relay.op.clip(z, 16, 31) + t = relay.Tuple((c, b)) + tgi = relay.TupleGetItem(t, 1) if get_item else t + foo = relay.Function([a], tgi) + return tvm.IRModule.from_expr(tgi) + + def get_expected(annotate_non_call_ops, get_item): + a_ = relay.var("a", shape=(1, 16, 16, 4), dtype="uint8") + a = relay.annotation.compiler_begin(a_, trgt) + z = relay.op.clip(a, 0, 255) + z1 = relay.annotation.compiler_end(z, trgt) + z1 = relay.annotation.compiler_begin(z1, trgt) + b = relay.op.clip(z1, 0, 15) + b = relay.annotation.compiler_end(b, trgt) + b = relay.annotation.compiler_begin(b, trgt) if annotate_non_call_ops else b + z2 = relay.annotation.compiler_end(z, trgt) + z2 = relay.annotation.compiler_begin(z2, trgt) + c = relay.op.clip(z2, 16, 31) + c = relay.annotation.compiler_end(c, trgt) + c = relay.annotation.compiler_begin(c, trgt) if annotate_non_call_ops else c + t = relay.Tuple((c, b)) + t = relay.annotation.compiler_end(t, trgt) if annotate_non_call_ops else t + if get_item: + t = relay.annotation.compiler_begin(t, trgt) if annotate_non_call_ops else t + tgi = relay.TupleGetItem(t, 1) + tgi = relay.annotation.compiler_end(tgi, trgt) if annotate_non_call_ops else tgi + else: + tgi = t + foo = relay.Function([a_], tgi) + return tvm.IRModule.from_expr(foo) + + for get_item in [True, False]: + for annotate_non_call_ops in [False, True]: + mod = get_model(get_item) + mod = transform.AnnotateTarget("clip", annotate_non_call_ops)(mod) + expected = transform.InferType()(get_expected(annotate_non_call_ops, get_item)) + assert tvm.ir.structural_equal(expected, mod) def test_if_else(): @@ -421,9 +571,10 @@ def after(): mod = tvm.IRModule.from_expr(func) return mod - result = transform.AnnotateTarget(target)(before()) expected = transform.InferType()(after()) - assert tvm.ir.structural_equal(expected, result) + for annotate_non_call_ops in [True, False]: + result = transform.AnnotateTarget(target, annotate_non_call_ops)(before()) + assert tvm.ir.structural_equal(expected, result) def test_while_let(): @@ -462,7 +613,7 @@ def before(): mod = tvm.IRModule.from_expr(func_2) return mod - def after(): + def after(annotate_non_call_ops): var1 = relay.var("var1", shape=(2,)) var2 = relay.var("var2", shape=(), dtype="int32") var3 = relay.var("var3", shape=(2,)) @@ -481,23 +632,39 @@ def after(): cb_4 = relay.annotation.compiler_begin(relay.const(1, dtype="int32"), target) add_op_1 = relay.add(cb_3, cb_4) ce_2 = relay.annotation.compiler_end(add_op_1, target) - cb_5 = relay.annotation.compiler_begin(ce_2, "default") + + cb_5 = relay.annotation.compiler_begin(ce_2, "default") if annotate_non_call_ops else ce_2 + cb_6 = relay.annotation.compiler_begin(var3, target) cb_7 = relay.annotation.compiler_begin(var1, target) add_op_2 = relay.add(cb_6, cb_7) ce_3 = relay.annotation.compiler_end(add_op_2, target) - cb_8 = relay.annotation.compiler_begin(ce_3, "default") + + cb_8 = relay.annotation.compiler_begin(ce_3, "default") if annotate_non_call_ops else ce_3 + true_branch = loop(cb_5, cb_8) # while loop - ce_4 = relay.annotation.compiler_end(true_branch, "default") + ce_4 = ( + relay.annotation.compiler_end(true_branch, "default") + if annotate_non_call_ops + else true_branch + ) if_condition = relay.If(ce_1, ce_4, var3) - - cb_9 = relay.annotation.compiler_begin(relay.const(0, dtype="int32"), "default") + const_1 = relay.const(0, dtype="int32") + cb_9 = ( + relay.annotation.compiler_begin(const_1, "default") + if annotate_non_call_ops + else const_1 + ) cb_10 = relay.annotation.compiler_begin(var1, target) zeros_like = relay.zeros_like(cb_10) ce_5 = relay.annotation.compiler_end(zeros_like, target) - cb_11 = relay.annotation.compiler_begin(ce_5, "default") + cb_11 = relay.annotation.compiler_begin(ce_5, "default") if annotate_non_call_ops else ce_5 while_condition = loop(cb_9, cb_11) - ce_6 = relay.annotation.compiler_end(while_condition, "default") + ce_6 = ( + relay.annotation.compiler_end(while_condition, "default") + if annotate_non_call_ops + else while_condition + ) func_1 = relay.Function([var2, var3], if_condition) ret = relay.Let(loop, func_1, ce_6) @@ -505,9 +672,10 @@ def after(): mod = tvm.IRModule.from_expr(func_2) return mod - result = transform.AnnotateTarget(target)(before()) - expected = transform.InferType()(after()) - assert tvm.ir.structural_equal(expected, result) + for annotate_non_call_ops in [False, True]: + result = transform.AnnotateTarget(target, annotate_non_call_ops)(before()) + expected = transform.InferType()(after(annotate_non_call_ops)) + assert tvm.ir.structural_equal(expected, result) def test_if_free_vars(): @@ -570,9 +738,10 @@ def after(): mod = tvm.IRModule.from_expr(func) return mod - result = transform.AnnotateTarget(target)(before()) - expected = transform.InferType()(after()) - assert tvm.ir.structural_equal(expected, result) + for annotate_non_call_ops in [True, False, True]: + result = transform.AnnotateTarget(target)(before()) + expected = transform.InferType()(after()) + assert tvm.ir.structural_equal(expected, result) def test_free_vars_zeros(): @@ -607,3 +776,7 @@ def after(): test_while_let() test_if_free_vars() test_free_vars_zeros() + test_different_targets() + test_double_target() + test_ends_with_tuple() + test_ref_create_read_write()