From b5d2122e6af8e9d67df8e50a43a2fdac19ea7d2f Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Mon, 23 May 2022 13:07:18 -0700 Subject: [PATCH 1/3] [Relay] Plumb external codegen target via Target.current() for all external codegen paths (See https://discuss.tvm.apache.org/t/byoc-supporting-cutlass-byoc-with-collage/12796/6 for context, which in turn is part of Collage (https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md). We want both old-style (via relay.ext.$toolchain) and new-style (via "RelayToTIR" Pass attribute on target kind) external codegen to be able to access the current 'external codegen' Target instance via Target.current(). - For old-style, plumb the true Target through TEComplier and push it on the context stack before calling relay.ext.$toolchain. - For new-style, pass the CompilationConfig to the RelayToTIRTargetHook pass, make the jump from "Compiler" attribute value to Target via the new CompilationConfig::FindPrimitiveTargetForKind method, and push on the stack before invoking the custom "RelayToTIR" pass. While working on this discovered RelayToTIRTargetHook was incompatible with the VM's compilation flow since RelayToTIRTargetHook assumes all "Compiler" attributed functions are inlined. Generalize it to support both inline and global function styles. Extend Target::IsExternalCodegen to recognize target kinds with "RelayToTIR" attributes as external. Update target hooks unit test to exercise new support for outline-style, picking up the current target, and compiling via the VM. --- include/tvm/relay/transform.h | 4 +- include/tvm/target/target_kind.h | 10 + src/relay/backend/aot_executor_codegen.cc | 2 +- src/relay/backend/contrib/cmsisnn/target.cc | 2 +- .../backend/contrib/codegen_c/codegen.cc | 12 ++ src/relay/backend/contrib/ethosu/codegen.cc | 2 +- .../example_target_hooks/relay_to_tir.cc | 201 +++++++++++++----- .../contrib/example_target_hooks/target.cc | 5 +- src/relay/backend/graph_executor_codegen.cc | 2 +- src/relay/backend/interpreter.cc | 8 +- src/relay/backend/te_compiler.cc | 57 ++--- src/relay/backend/te_compiler.h | 11 +- src/relay/backend/vm/compiler.cc | 34 +-- src/relay/backend/vm/compiler.h | 4 +- src/relay/transforms/dead_code.cc | 2 + src/relay/transforms/inline.cc | 1 + src/relay/transforms/target_hooks.cc | 134 +++++++++--- src/target/target.cc | 8 +- tests/cpp/target_test.cc | 6 + tests/python/frontend/onnx/test_forward.py | 2 +- .../relay/dyn/test_dynamic_op_level2.py | 4 +- tests/python/relay/test_external_codegen.py | 54 +++++ tests/python/relay/test_target_hooks.py | 53 ++++- tests/python/relay/utils/external_codegen.py | 2 +- 24 files changed, 460 insertions(+), 160 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 0d518e4ed547..e42bf4009de1 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -464,9 +464,11 @@ TVM_DLL Pass SimplifyExpr(); /*! * \brief Run any registered RelayToTIR passes registered on the functions in a module. * + * \param config All available targets. + * * \return The pass. */ -TVM_DLL Pass RelayToTIRTargetHook(); +TVM_DLL Pass RelayToTIRTargetHook(CompilationConfig config); /*! * \brief A pass for manifesting explicit memory allocations and rewriting diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 395d3aab6757..4879470e7654 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -402,6 +402,16 @@ namespace attr { * See also \p Target::IsExternalCodegenFor */ constexpr const char* kIsExternalCodegen = "is_external_codegen"; + +/*! + * \brief A \p TargetKind attribute of type \p FTVMRelayToTIR. If set, then the target kind name + * also corresponds to an external codegen 'compiler' name, and the bound value is a \p Pass + * to apply before the TVM lowering. + * + * See also \p Target::IsExternalCodegenFor + */ +constexpr const char* kRelayToTIR = "RelayToTIR"; + } // namespace attr /*! diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 60f108aacf66..167afd2c5f78 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -1079,7 +1079,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { // lowering process directly. tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment); }, - config_->host_virtual_device)(mod); + config_)(mod); auto lowered_main = lowered_mod->Lookup("main"); auto lowered_main_func = GetRef(lowered_main.as()); diff --git a/src/relay/backend/contrib/cmsisnn/target.cc b/src/relay/backend/contrib/cmsisnn/target.cc index 99bc0bc7cb20..fd2f18aa9905 100644 --- a/src/relay/backend/contrib/cmsisnn/target.cc +++ b/src/relay/backend/contrib/cmsisnn/target.cc @@ -31,7 +31,7 @@ tvm::transform::Pass RelayToTIR(); runtime::Module TIRToRuntime(IRModule mod, Target target); TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU) - .set_attr("RelayToTIR", RelayToTIR()) + .set_attr(tvm::attr::kRelayToTIR, RelayToTIR()) .set_attr("TIRToRuntime", TIRToRuntime); } // namespace cmsisnn diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index 19b8c579cd8b..fd1c39bb9283 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -227,6 +227,14 @@ class CSourceCodegen : public CSourceModuleCodegenBase { Array variables = std::get<0>(res); String func_name = std::get<1>(res); + Optional opt_target = Target::Current(); + if (opt_target.defined() && opt_target.value()->kind->name == "ccompiler") { + Optional header = opt_target.value()->GetAttr("header"); + if (header.defined() && !header.value().empty()) { + code_stream_ << header.value().c_str() << "\n"; + } + } + // Create headers code_stream_ << "#include \n"; code_stream_ << "#include \n"; @@ -293,6 +301,10 @@ runtime::Module CCompiler(const ObjectRef& ref) { TVM_REGISTER_GLOBAL("relay.ext.ccompiler").set_body_typed(CCompiler); +TVM_REGISTER_TARGET_KIND("ccompiler", kDLCPU) + .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)) + .add_attr_option("header", String("")); // value is prepended to every output CModule + } // namespace contrib } // namespace relay } // namespace tvm diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc index 47c80b47c579..afa17750d8a8 100644 --- a/src/relay/backend/contrib/ethosu/codegen.cc +++ b/src/relay/backend/contrib/ethosu/codegen.cc @@ -320,7 +320,7 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) { TVM_REGISTER_TARGET_KIND("ethos-u", kDLCPU) .set_attr("use_device_api", Bool(true)) - .set_attr("RelayToTIR", RelayToTIR()) + .set_attr(tvm::attr::kRelayToTIR, RelayToTIR()) .set_attr("TIRToRuntime", TIRToRuntime); } // namespace ethosu diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index c498baa6d11d..10113dbf1a7d 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -28,12 +28,35 @@ #include #include "../../../op/call/call.h" +#include "tvm/tir/function.h" namespace tvm { namespace relay { namespace contrib { namespace example_target_hooks { +namespace { + +/*! + * \brief An example mutator for a "RelayToTIR" custom pass. Replaces every call to a Relay + * Function with "external_symbol" attribute of "replace_add_with_subtract" with a call to a + * TIR PrimFunc implementing subtraction. + * + * Illustrates six aspects a custom 'lowering' style pass may need to account for: + * - Lowerable functions can appear inline as calls ops, bound to let-bound variables, or as + * global functions. + * - Let-bound lowerable functions should be inlined on-the-fly since after processing the + * let-binding is no longer required. + * - There may be multiple calls to the same lowerable function. All calls need to be + * rewritten, even though the function itself need be rewritten only once. + * - GlobalVars must be shared between all calls and the new definition itself. + * - Calls to lowered functions must use the "call_lowered" calling convention. + * - The Target::Current() may hold an instance of the TargetKind from which the custom Pass + * was extracted. + * + * Though not illustrated here, it is also valid for a "RelayToTIR" custom pass to add + * runtime::Modules to the output IRModule's "external_mods" attribute. + */ class ConvertAddToSubtract : public MixedModeMutator { public: explicit ConvertAddToSubtract(IRModule ir_module, Target host_target) @@ -56,42 +79,96 @@ class ConvertAddToSubtract : public MixedModeMutator { return tir::BufferLoad(buffer, {index}); } - void ReplaceAddWithSubtractPrimFunc(const GlobalVar& new_global_var, const Function& func) { - tir::Buffer x_buffer = tir::decl_buffer({8}, DataType::Float(32), "x"); - tir::Buffer y_buffer = tir::decl_buffer({8}, DataType::Float(32), "y"); - tir::Buffer out_buffer = tir::decl_buffer({8}, DataType::Float(32)); + GlobalVar ReplaceAddWithSubtractPrimFunc(const Function& func) { + auto func_name = func->GetAttr(::tvm::attr::kGlobalSymbol); + ICHECK(func_name.defined()); + + // -------------------------------------------------------------------------------------------- + // Cases: + // - Inline function: + // - First encounter: create global var, rewrite to PrimFunc, add binding, replace call. + // - Thereafter (via object sharing): discover global var already in module, replace call + // - Global function: + // - func_name == global_var->name_hint + // - First encounter: rewrite to PrimFunc and update binding, replace call + // - Thereafter (via global var): Just replace call + // - func_name != global_var->name_hint + // - First encounter: create global var, rewrite to PrimFunc, add binding, replace call + // (The original Relay function should also be tagged as 'extern', ie given attribute + // "ExternalSymbol".) + // - Thereafter (via global var): discover global var already in module, replace call + // -------------------------------------------------------------------------------------------- + + // If necessary, introduce a new global var to map the function to and copy the source type + // over for InferType. + GlobalVar global_var; + bool need_rewriting; + if (ir_module_->ContainGlobalVar(func_name.value())) { + global_var = ir_module_->GetGlobalVar(func_name.value()); + // Only rewrite to a PrimFunc if the global definition is still a Relay function. + need_rewriting = ir_module_->Lookup(global_var)->IsInstance(); + } else { + global_var = GlobalVar(func_name.value()); + global_var->checked_type_ = func->checked_type(); + need_rewriting = true; + } + + // For illustration only, check if the current target matches the example_target_hook kind, + // and if so extract the example attribute value. + int64_t example_attribute_value = 0; + Optional opt_current_target = Target::Current(); + if (opt_current_target.defined() && + opt_current_target.value()->kind->name == "example_target_hook") { + example_attribute_value = + opt_current_target.value()->GetAttr("example_attribute").value()->value; + } - tir::Var x_var("x", DataType::Handle()); - tir::Var y_var("y", DataType::Handle()); - tir::Var out_var("out", DataType::Handle()); + if (need_rewriting) { + // The called function is still in Relay form. Convert to TIR. + tir::Buffer x_buffer = tir::decl_buffer({8}, DataType::Float(32), "x"); + tir::Buffer y_buffer = tir::decl_buffer({8}, DataType::Float(32), "y"); + tir::Buffer out_buffer = tir::decl_buffer({8}, DataType::Float(32)); - Map dict_attrs; - dict_attrs.Set("global_symbol", new_global_var->name_hint); - dict_attrs.Set("tir.noalias", Bool(true)); + tir::Var x_var("x", DataType::Handle()); + tir::Var y_var("y", DataType::Handle()); + tir::Var out_var("out", DataType::Handle()); - te::Var index("index", DataType::Int(32)); - tir::Sub indexed_sub = tir::Sub(LoadIndex(x_buffer, index), LoadIndex(y_buffer, index)); - tir::Stmt math_body = tir::BufferStore(out_buffer, indexed_sub, {index}); - tir::Stmt math_loop = tir::For(index, 0, 8, tir::ForKind::kSerial, math_body); + Map dict_attrs; + dict_attrs.Set("global_symbol", global_var->name_hint); + dict_attrs.Set("tir.noalias", Bool(true)); - Map buffer_map = { - {x_var, x_buffer}, - {y_var, y_buffer}, - {out_var, out_buffer}, - }; + te::Var index("index", DataType::Int(32)); + tir::Sub indexed_sub = tir::Sub(LoadIndex(x_buffer, index), LoadIndex(y_buffer, index)); + if (example_attribute_value > 0) { + // For illustration only, fold the example attribute into the result. + indexed_sub = tir::Sub(indexed_sub, FloatImm(DataType::Float(32), + static_cast(example_attribute_value))); + } - tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(), - buffer_map, {}, DictAttrs(dict_attrs)); + tir::Stmt math_body = tir::BufferStore(out_buffer, indexed_sub, {index}); + tir::Stmt math_loop = tir::For(index, 0, 8, tir::ForKind::kSerial, math_body); - // Switch to TIRToRuntime hook for testing - Bool tir_to_runtime = func->GetAttr("tir_to_runtime").value_or(Bool(false)); - if (tir_to_runtime) { - replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, custom_target_); - } else { - replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_); + Map buffer_map = { + {x_var, x_buffer}, + {y_var, y_buffer}, + {out_var, out_buffer}, + }; + + tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(), + buffer_map, {}, DictAttrs(dict_attrs)); + + // Switch to TIRToRuntime hook for testing + Bool tir_to_runtime = func->GetAttr("tir_to_runtime").value_or(Bool(false)); + if (tir_to_runtime) { + replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, custom_target_); + } else { + replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_); + } + + ir_module_->Update(global_var, replacement_func); // Will Add if global_var is new. } - ir_module_->Add(new_global_var, replacement_func); + return global_var; } Expr VisitExpr_(const LetNode* op) final { @@ -99,8 +176,8 @@ class ConvertAddToSubtract : public MixedModeMutator { Expr var = this->VisitExpr(op->var); Expr value = this->VisitExpr(op->value); - // Outlineable function no longer needs let binding - if (this->CanLowerExpr(value)) { + if (AsLowerableFunction(value)) { + // Inline on-the-fly if the let-bound value is lowerable. this->memo_[var] = value; } }; @@ -110,8 +187,8 @@ class ConvertAddToSubtract : public MixedModeMutator { Expr body = this->VisitExpr(op->body); auto expr = GetRef(op); - // Drop the let binding - if (this->CanLowerExpr(value)) { + if (AsLowerableFunction(value)) { + // The let binding is no longer needed since inlined on-the-fly above. this->memo_[expr] = this->VisitExpr(op->body); } else { Var var = Downcast(this->VisitExpr(op->var)); @@ -126,39 +203,49 @@ class ConvertAddToSubtract : public MixedModeMutator { return memo_[GetRef(op)]; } - bool CanLowerExpr(const Expr& expr) { - const auto* func = expr.as(); - if (func == nullptr) { - return false; - } - auto func_name = func->GetAttr(::tvm::attr::kGlobalSymbol); - if (!func_name.defined()) { - return false; + const FunctionNode* AsLowerableFunction(const Expr& expr) { + if (const auto* function_node = expr.as()) { + auto func_name = function_node->GetAttr(::tvm::attr::kGlobalSymbol); + if (!func_name.defined()) { + return nullptr; + } + if (func_name != "replace_add_with_subtract") { + return nullptr; + } + return function_node; + } else if (const auto* global_var_node = expr.as()) { + return AsLowerableFunction(ir_module_->Lookup(GetRef(global_var_node))); + } else { + return nullptr; } - if (func_name != "replace_add_with_subtract") { - return false; + } + + const GlobalVarNode* AsAlreadyLoweredFunction(const Expr& expr) { + if (const auto* global_var_node = expr.as()) { + if (ir_module_->Lookup(GetRef(global_var_node)).as()) { + return global_var_node; + } } - return true; + return nullptr; } Expr Rewrite_(const CallNode* pre, const Expr& post) override { - if (const CallNode* call = post.as()) { - if (CanLowerExpr(call->op)) { - auto* func = call->op.as(); - auto func_name = func->GetAttr(::tvm::attr::kGlobalSymbol); - - // Introduce a new global var to map the function to and copy the source type - // over for InferType - GlobalVar new_global_var(func_name.value()); - new_global_var->checked_type_ = func->checked_type(); - ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef(func)); - + if (const auto* call = post.as()) { + GlobalVar new_op; + if (const auto* function_node = AsLowerableFunction(call->op)) { + // Add or replace the function with a PrimFunc. + new_op = ReplaceAddWithSubtractPrimFunc(GetRef(function_node)); + } else if (const auto* global_var_node = AsAlreadyLoweredFunction(call->op)) { + // The function has already been rewritten, so we just need to update the call. + new_op = GetRef(global_var_node); + } + if (new_op.defined()) { // Since we are replacing the Relay function with a call to a TIR function, we must use // the call_lowered op. CallLoweredAttrs attrs; attrs.metadata.Set("relay_attrs", call->attrs); ICHECK(call->type_args.empty()) << "lowered functions cannot be polymorphic"; - return CallLowered(std::move(new_global_var), call->args, std::move(attrs), call->span); + return CallLowered(std::move(new_op), call->args, std::move(attrs), call->span); } } @@ -171,10 +258,12 @@ class ConvertAddToSubtract : public MixedModeMutator { Target custom_target_; }; +} // namespace + transform::Pass RelayToTIR() { runtime::TypedPackedFunc pass_func = [=](IRModule ir_module, transform::PassContext pass_context) { - auto relay_to_tir = ConvertAddToSubtract(ir_module, Target("c")); + ConvertAddToSubtract relay_to_tir(std::move(ir_module), Target("c")); return relay_to_tir.Mutate(); }; return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIR", {}); diff --git a/src/relay/backend/contrib/example_target_hooks/target.cc b/src/relay/backend/contrib/example_target_hooks/target.cc index 6f1914eac4c3..19bfa8c68298 100644 --- a/src/relay/backend/contrib/example_target_hooks/target.cc +++ b/src/relay/backend/contrib/example_target_hooks/target.cc @@ -34,7 +34,8 @@ runtime::Module TIRToRuntime(IRModule mod, Target target); TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU) .set_attr("use_device_api", Bool(true)) - .set_attr("RelayToTIR", relay::contrib::example_target_hooks::RelayToTIR()) - .set_attr("TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime); + .set_attr(attr::kRelayToTIR, relay::contrib::example_target_hooks::RelayToTIR()) + .set_attr("TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime) + .add_attr_option("example_attribute", Integer(0)); } // namespace tvm diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 2734439cddbd..7dba23803f8c 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -232,7 +232,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorfunction_metadata_); }, - config_->host_virtual_device)(mod); + config_)(mod); Optional main_func_info = lowered_mod->GetAttr("main_func_info"); diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 65ef29651695..9661040eab30 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -945,14 +945,13 @@ class Interpreter : public ExprFunctor, * rewritten \p mod and target-specific modules containing bindings for all TIR primitive * functions needed by the rewritten module. */ -IRModule Prepare(IRModule mod, CompilationConfig config) { - VirtualDevice host_virtual_device = config->host_virtual_device; +IRModule Prepare(IRModule mod, const CompilationConfig& config) { // Run minimal transforms on module to establish invariants needed by interpreter. transform::Sequential seq( {transform::SimplifyInference(), qnn::transform::Legalize(), // Figure out which devices should be used to execute. // TODO(mbs): Should ignore all existing annotations when constant folding - transform::PlanDevices(std::move(config)), + transform::PlanDevices(config), // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' // attribute. transform::FuseOps(/*fuse_opt_level=*/0), @@ -962,8 +961,7 @@ IRModule Prepare(IRModule mod, CompilationConfig config) { transform::EtaExpand( /*expand_constructor=*/true, /*expand_global_var=*/false), transform::InferType(), - tec::LowerTEPass(/*module_name=*/"intrp", [](BaseFunc func) { /* no-op */ }, - std::move(host_virtual_device))}); + tec::LowerTEPass(/*module_name=*/"intrp", [](BaseFunc func) { /* no-op */ }, config)}); transform::PassContext pass_ctx = transform::PassContext::Current(); With ctx(pass_ctx); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 76dbfef5386d..73b44f7361a5 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -299,11 +299,10 @@ class TECompilerImpl : public TECompilerNode { // the module's globals. Furthermore, the external codegen tool must bind the compiled // function to the "global_symbol" attribute on the source_func. So do not use GetUniqueName // here. - auto target = Target("ext_dev"); auto global_var = GlobalVar(opt_global_symbol.value()); global_var->checked_type_ = key->source_func->checked_type(); ir_module->Add(global_var, key->source_func); - value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule{nullptr}, + value->cached_func = CachedFunc(key->target, global_var, {}, {}, te::Schedule{nullptr}, tir::PrimFunc{nullptr}, {}, ir_module); // Collect these here as it's removed in LowerExternalFunctions() device_contexts_.Set(value->cached_func->prim_fn_var, opt_compiler.value()); @@ -531,14 +530,14 @@ using AnalysisRemapping = std::unordered_maptarget); + CCacheKey shape_key(func, config_->host_virtual_device->target); CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); // Capture the shape function's global var and parameters 'states' in call @@ -733,7 +732,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // Special case: device_copies are left as calls to primitive operators // (thus undoing FuseOps) so that each backend can handle them directly. - // TODO(mbs): device_copy cleanup. Would be better for FuseOps to just leave device_copy alone. + // TODO(mbs): device_copy cleanup. Would be better for FuseOps to just leave device_copy + // alone. if (const auto* function_node = primitive_func.as()) { DeviceCopyProps device_copy_props = GetDeviceCopyProps(function_node->body); if (device_copy_props.body.defined()) { @@ -771,10 +771,18 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // Typical case: call to fused primitive Relay Function. // Find the desired target device. Target target; - if (primitive_func->GetAttr(attr::kCompiler).defined()) { - // The generic 'external device' target. - // TODO(mbs): Retire once replaced unified BYOC compiler and target machinery - target = Target("ext_dev"); + Optional opt_compiler = primitive_func->GetAttr(attr::kCompiler); + if (opt_compiler.defined()) { + // This function needs to be compiled with external codegen. + Optional opt_target = config_->FindPrimitiveTargetForKind(opt_compiler.value()); + if (opt_target.defined()) { + // The target is what's supplied by the compilation config for kind matching the + // "Compiler" name. + target = opt_target.value(); + } else { + // Legacy fallback. + target = Target("ext_dev"); + } } else { // The target corresponding to the call_node expression's annotation. VirtualDevice virtual_device = GetVirtualDevice(GetRef(call_node)); @@ -791,6 +799,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { IRModule module_; ProcessFn process_fn_; + /*! \brief All available targets. */ + CompilationConfig config_; // Map from in-scope let-bound variables to Functions known to be primitive, or PrimFuncs which // have already been lowered. We'll rewrite these to the fresh global vars bound to the lowered // primitive function as we go. Those vars will be bound in the target device-type specific @@ -799,21 +809,15 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { std::unordered_map primitive_functions_; String module_name_; TECompiler compiler_; - /*! - * \brief The \p VirtualDevice for the host, which is where all shape-related data and computation - * must live. - */ - VirtualDevice host_virtual_device_; // Cache ops that need to be frequently used later to reduce lookup overhead. const Op& debug_op_; }; Pass LowerTensorExpr(const String& module_name, TECompiler compiler, ProcessFn process_fn, - VirtualDevice host_virtual_device) { + CompilationConfig config) { runtime::TypedPackedFunc pass_func = [=](Function func, IRModule module, PassContext ctx) { - LowerTensorExprMutator lower_te(module, process_fn, module_name, compiler, - host_virtual_device); + LowerTensorExprMutator lower_te(module, process_fn, config, module_name, compiler); return Downcast(lower_te.Mutate(func)); }; return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {}); @@ -1043,7 +1047,7 @@ void UpdateFunctionMetadata(BaseFunc func, } IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn process_fn, - VirtualDevice host_virtual_device) { + CompilationConfig config) { TECompiler compiler(module); // TODO(mbs): This is all unnecessarily convoluted. Better would be to accumulate the rewritten @@ -1058,8 +1062,8 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr // GlobalVar, and calls updated (sticking with regular Relay Call). // - Calls to functions tagged with "Primitive" are compiled to PrimFuncs, and calls updated // (using call_lowered convention). - IRModule updated_module = LowerTensorExpr(module_name, compiler, std::move(process_fn), - std::move(host_virtual_device))(module); + IRModule updated_module = + LowerTensorExpr(module_name, compiler, std::move(process_fn), std::move(config))(module); // The Functions tagged with "Compiler" are now residing in the cache ready to be // compiled by LowerExternalFunctions. However we still need a record of them in the @@ -1159,15 +1163,14 @@ Map GetPerTargetModules(IRModule mod) { return per_target_modules; } -Pass LowerTEPass(const String& module_name, ProcessFn process_fn, - VirtualDevice host_virtual_device) { +Pass LowerTEPass(String module_name, ProcessFn process_fn, CompilationConfig complilation_config) { runtime::TypedPackedFunc pass_func = [=](IRModule module, PassContext ctx) { - return LowerTE(module, module_name, process_fn, host_virtual_device); + return LowerTE(module, module_name, process_fn, complilation_config); }; return tvm::transform::Sequential( - {tvm::relay::transform::RelayToTIRTargetHook(), + {tvm::relay::transform::RelayToTIRTargetHook(complilation_config), tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {"InferType"}), InferType(), tvm::tir::transform::ExtractPrimFuncConstants()}); } diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 0b2288d6a156..8312a20cb862 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -189,7 +189,8 @@ IRModule LowerTE( const IRModule& module, backend::StaticMemoryPlan memory_plan, const String& module_name, ProcessFn process_fn = [](BaseFunc f) {}); -/*! \brief Pass to lower an IRModule's primitive functions to TIR. +/*! + * \brief Pass to lower an IRModule's primitive functions to TIR. * * This is the "back half" of the Relay compiler which lowers "primitive functions" * to TE expressions, schedules them, and then to TIR. It annotates all functions @@ -198,11 +199,11 @@ IRModule LowerTE( * \param module_name The name of this module * \param process_fn Callback allowing one-level up code generators to process * each function that we lower - * \param host_virtual_device \p VirtualDevice for host data and computations - * \returns The pass which lowers primative functions to TIR + * \param config All available targets. + * \returns The pass which lowers primitive functions to TIR */ -transform::Pass LowerTEPass(const String& module_name, ProcessFn process_fn, - VirtualDevice host_virtual_device); +transform::Pass LowerTEPass(String module_name, ProcessFn process_fn, CompilationConfig config); + } // namespace tec } // namespace relay } // namespace tvm diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 5a62ac66f736..e0b742a84090 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -523,11 +523,13 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { op_index = itr->second; } - // Capture the dictionary of attributes from the original primitive function so that they - // can contribute to the hash of the compiled primitive. This way we can distinguish primitives - // with the same body expression but different attributes which may arbitrarily influence code - // generation. - op_attrs[op_index] = attrs->dict; + if (attrs.defined() && attrs->dict.defined()) { + // Capture the dictionary of attributes from the original primitive function so that they + // can contribute to the hash of the compiled primitive. This way we can distinguish + // primitives with the same body expression but different attributes which may arbitrarily + // influence code generation. + op_attrs[op_index] = attrs->dict; + } Emit(Instruction::InvokePacked(op_index, argument_registers.size(), output_tuple->fields.size(), argument_registers)); @@ -981,25 +983,25 @@ void VMCompiler::LowerImpl(IRModule mod) { } } -transform::Sequential VMCompiler::MemoryOpt(const VirtualDevice& host_virtual_device) { +transform::Sequential VMCompiler::MemoryOpt(const CompilationConfig& config) { Array pass_seqs; // Remove unused functions Array entry_functions{"main"}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); // Manifest the allocations. - pass_seqs.push_back(transform::ManifestAlloc(host_virtual_device)); + pass_seqs.push_back(transform::ManifestAlloc(config->host_virtual_device)); // Compute away possibly introduced constant computation. pass_seqs.push_back(transform::FoldConstant()); // Fuse & lower any new shape functions and device_copies. - pass_seqs.push_back(FuseAndLowerOperators(host_virtual_device)); + pass_seqs.push_back(FuseAndLowerOperators(config)); // Manifest the allocations needed for the shape functions. - pass_seqs.push_back(transform::ManifestAlloc(host_virtual_device)); + pass_seqs.push_back(transform::ManifestAlloc(config->host_virtual_device)); // Fuse & lower any new allocations. - pass_seqs.push_back(FuseAndLowerOperators(host_virtual_device)); + pass_seqs.push_back(FuseAndLowerOperators(config)); // TODO(mbrookhart, jroesch, masahi): this pass is very slow, and is // incomplete to provide memory resuse optimizations. Disable it until we can @@ -1011,10 +1013,10 @@ transform::Sequential VMCompiler::MemoryOpt(const VirtualDevice& host_virtual_de pass_seqs.push_back(transform::FoldConstant()); // Fuse & lower yet again - pass_seqs.push_back(FuseAndLowerOperators(host_virtual_device)); + pass_seqs.push_back(FuseAndLowerOperators(config)); // Create allocations for math introduced by dynamic region math. - pass_seqs.push_back(transform::ManifestAlloc(host_virtual_device)); + pass_seqs.push_back(transform::ManifestAlloc(config->host_virtual_device)); // Compute away possibly introduced constant computation. pass_seqs.push_back(transform::FoldConstant()); @@ -1030,7 +1032,7 @@ transform::Sequential VMCompiler::MemoryOpt(const VirtualDevice& host_virtual_de return transform::Sequential(std::move(pass_seqs)); } -transform::Sequential VMCompiler::FuseAndLowerOperators(const VirtualDevice& host_virtual_device) { +transform::Sequential VMCompiler::FuseAndLowerOperators(const CompilationConfig& config) { Array pass_seqs; // Hoist operators to "primitive" Functions. pass_seqs.push_back(FuseOps()); @@ -1043,7 +1045,7 @@ transform::Sequential VMCompiler::FuseAndLowerOperators(const VirtualDevice& hos backend::UpdateConstants(func, ¶ms_); } }, - host_virtual_device)); + config)); // Since lowered functions are bound in the IRModule, we can now eliminate any unused // let-bound functions. pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false)); @@ -1094,7 +1096,7 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { backend::UpdateConstants(func, ¶ms_); } }, - config_->host_virtual_device)); + config_)); // Since lowered functions are bound in the IRModule, we can now eliminate any unused // let-bound functions. @@ -1111,7 +1113,7 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { // external codegen. pass_seqs.push_back(transform::Inline()); - pass_seqs.push_back(MemoryOpt(config_->host_virtual_device)); + pass_seqs.push_back(MemoryOpt(config_)); pass_seqs.push_back(transform::InferType()); transform::Sequential seq(pass_seqs); diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index a65bdc5ab3cb..163ec399013b 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -146,10 +146,10 @@ class VMCompiler : public runtime::ModuleNode { IRModule OptimizeModuleImpl(IRModule mod); /*! \brief Returns the passes which layout memory. */ - transform::Sequential MemoryOpt(const VirtualDevice& host_virtual_device); + transform::Sequential MemoryOpt(const CompilationConfig& config); /*! \brief Returns the passes which fuse then lower Relay primitive operators. */ - transform::Sequential FuseAndLowerOperators(const VirtualDevice& host_virtual_device); + transform::Sequential FuseAndLowerOperators(const CompilationConfig& config); /*! * \brief Populate the global function names in a map where the value is used diff --git a/src/relay/transforms/dead_code.cc b/src/relay/transforms/dead_code.cc index ca1e04ae59fa..45cb8271b074 100644 --- a/src/relay/transforms/dead_code.cc +++ b/src/relay/transforms/dead_code.cc @@ -534,6 +534,7 @@ namespace transform { // Declared in relay/transform.h Pass DeadCodeElimination(bool inline_once, bool ignore_impurity) { auto pass_func = [=](IRModule mod, PassContext pc) -> IRModule { + VLOG(1) << "Before:" << std::endl << PrettyPrint(mod); // Which let bindings are pure and can be safely elided? std::unordered_map var_to_purity; if (!ignore_impurity) { @@ -566,6 +567,7 @@ Pass DeadCodeElimination(bool inline_once, bool ignore_impurity) { result->Add(kv.first, kv.second); } } + VLOG(1) << "After:" << std::endl << PrettyPrint(result); return result; }; diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index a6e26364bbc4..c55b6778093e 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -69,6 +69,7 @@ class Inliner : ExprMutator { for (auto arg : vanilla_call->args) { new_args.push_back(VisitExpr(arg)); } + // TODO(mbs): Does not handle multiple calls to the same global function. cur_node_->RemoveCallTo(gv); return MakeNewExpr(gv, new_args, GetRef(call_node)); } diff --git a/src/relay/transforms/target_hooks.cc b/src/relay/transforms/target_hooks.cc index 0022baf881ba..53077a7e73e4 100644 --- a/src/relay/transforms/target_hooks.cc +++ b/src/relay/transforms/target_hooks.cc @@ -30,61 +30,129 @@ namespace tvm { namespace relay { namespace transform { -class TargetHookVisitor : public tvm::relay::MixedModeVisitor { - /*! \brief Collected pass list for all nodes */ - std::vector pass_list_; - /*! \brief Attribute map for all registered targets */ - TargetKindAttrMap target_attr_map_; - using tvm::relay::MixedModeVisitor::VisitExpr_; +namespace { + +/*! + * \brief A pass extracted from a target kind's "RelayToTIR" attribute, along with any + * 'external codegen' Target instance with matching kind name. + */ +struct CustomPass { + std::string target_kind_name; + Pass pass; + Optional opt_target; + + CustomPass(std::string target_kind_name, Pass pass, Optional opt_target) + : target_kind_name(std::move(target_kind_name)), + pass(std::move(pass)), + opt_target(std::move(opt_target)) {} +}; +class TargetHookVisitor : public MixedModeVisitor { public: - TargetHookVisitor() : target_attr_map_(tvm::TargetKind::GetAttrMap("RelayToTIR")) {} + TargetHookVisitor(IRModule mod, CompilationConfig config) + : mod_(std::move(mod)), + config_(std::move(config)), + target_attr_map_(tvm::TargetKind::GetAttrMap(tvm::attr::kRelayToTIR)) {} - std::vector Visit(const IRModule& ir_mod) { - for (const auto& it : ir_mod->functions) { + std::vector Visit() { + ICHECK(custom_passes_.empty()); + for (const auto& it : mod_->functions) { if (const auto* function_node = it.second.as()) { + // May be a top-level function with "Compiler" attribute. + MaybeAddPassForFunction(function_node); + // May have calls to inlined "Compiler" functions in body. VisitExpr(GetRef(function_node)); } } - return pass_list_; + return std::move(custom_passes_); } - void VisitExpr_(const LetNode* op) final { - auto pre_visit = [this](const LetNode* op) { - this->VisitExpr(op->var); - this->VisitExpr(op->value); + private: + using tvm::relay::MixedModeVisitor::VisitExpr_; + + void VisitExpr_(const LetNode* let_node) final { + auto pre_visit = [this](const LetNode* inner_let_node) { + this->VisitExpr(inner_let_node->var); + this->VisitExpr(inner_let_node->value); }; - auto post_visit = [this](const LetNode* op) { - this->VisitExpr(op->body); - this->visit_counter_[op] += 1; + auto post_visit = [this](const LetNode* inner_let_node) { + this->VisitExpr(inner_let_node->body); + this->visit_counter_[inner_let_node] += 1; }; - ExpandANormalForm(op, pre_visit, post_visit); + ExpandANormalForm(let_node, pre_visit, post_visit); + } + + void VisitExpr_(const FunctionNode* function_node) override { + ExprVisitor::VisitExpr_(function_node); + MaybeAddPassForFunction(function_node); } - void VisitExpr_(const FunctionNode* func) override { - ExprVisitor::VisitExpr_(func); - if (!func->GetAttr(attr::kCompiler).defined()) { + /*! + * \brief If \p function_node has a "Compiler" attribute, check if we should include a + * matching custom pass. Otherwise no-op. + */ + void MaybeAddPassForFunction(const FunctionNode* function_node) { + Optional opt_compiler = function_node->GetAttr(attr::kCompiler); + if (!opt_compiler) { + // No external codegen required. return; } - String code_gen_name = func->GetAttr(attr::kCompiler).value(); - Optional target_kind = tvm::TargetKind::Get(code_gen_name); - if (!target_kind || !target_attr_map_.count(target_kind.value())) { + // First cross-over: use "Compiler" attribute name as target kind. + std::string kind_name = opt_compiler.value(); + Optional opt_target_kind = tvm::TargetKind::Get(kind_name); + if (!opt_target_kind || !target_attr_map_.count(opt_target_kind.value())) { + // Target kind does not exist or have the "RelayToTIR" attribute, no custom pass to consider. return; } - Pass custom_target_pass = target_attr_map_[target_kind.value()]; - if (std::find(pass_list_.begin(), pass_list_.end(), custom_target_pass) == pass_list_.end()) { - pass_list_.push_back(custom_target_pass); + if (!seen_kinds_.emplace(kind_name).second) { + // Already accounted for custom pass. + return; } + // Second (optional) cross-over: find unique Target instance in overall available targets with + // the same kind so that it can be made available when custom pass is invoked. + Optional opt_target = config_->FindPrimitiveTargetForKind(opt_compiler.value()); + Pass custom_target_pass = target_attr_map_[opt_target_kind.value()]; + custom_passes_.emplace_back(std::move(kind_name), std::move(custom_target_pass), + std::move(opt_target)); } + + /*! \brief IRModule we are visiting. */ + IRModule mod_; + /*! \brief All available targets. */ + CompilationConfig config_; + /*! \brief Cached attribute map for all registered targets */ + TargetKindAttrMap target_attr_map_; + /*! \brief Which target kind names have already contributed to the custom passes list. */ + std::unordered_set seen_kinds_; + /*! + * \brief All the custom passes to run, paired with their corresponding target instances, if any. + */ + std::vector custom_passes_; }; -Pass RelayToTIRTargetHook() { - auto pass_func = [=](IRModule mod, const PassContext& pass_ctx) { - auto target_hook_visitor = TargetHookVisitor(); - std::vector pass_list = target_hook_visitor.Visit(mod); - Sequential run_hooks(pass_list); +} // namespace - return run_hooks(mod); +Pass RelayToTIRTargetHook(CompilationConfig config) { + auto pass_func = [config = std::move(config)](IRModule mod, const PassContext& pass_ctx) { + VLOG(1) << "Before:" << std::endl << PrettyPrint(mod); + TargetHookVisitor target_hook_visitor(mod, config); + std::vector custom_passes = target_hook_visitor.Visit(); + for (const auto& custom_pass : custom_passes) { + if (custom_pass.opt_target.defined()) { + VLOG(0) << "Invoking custom pass for target " + << custom_pass.opt_target.value()->ToDebugString(); + // Push the target on the stack. + With with_target(custom_pass.opt_target.value()); + // Invoke the pass with target in scope. + mod = custom_pass.pass(mod); + } else { + // Invoke the pass. + VLOG(0) << "Invoking custom pass for target kind '" << custom_pass.target_kind_name << "'"; + mod = custom_pass.pass(mod); + } + } + VLOG(1) << "After:" << std::endl << PrettyPrint(mod); + return mod; }; return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIRTargetHook", {}); } diff --git a/src/target/target.cc b/src/target/target.cc index 75126ed11c70..3cdfa0cc0d5e 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -495,8 +495,12 @@ Target::Target(TargetKind kind, Optional host, String tag, Array attr_map = TargetKind::GetAttrMap(::tvm::attr::kIsExternalCodegen); - return attr_map.get(get()->kind, Bool(false)); + TargetKindAttrMap is_external_codegen_map = + TargetKind::GetAttrMap(tvm::attr::kIsExternalCodegen); + TargetKindAttrMap relay_to_tir_map = + TargetKind::GetAttrMap(tvm::attr::kRelayToTIR); + return is_external_codegen_map.get(get()->kind, Bool(false)) || + relay_to_tir_map.count(get()->kind); } bool Target::IsExternalCodegenFor(const Target& that) const { diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index b657ac0c5783..2c85e47e7fb8 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -144,16 +145,21 @@ TVM_REGISTER_TARGET_KIND("test_external_codegen_1", kDLCUDA) TVM_REGISTER_TARGET_KIND("test_external_codegen_2", kDLMetal) .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); +TVM_REGISTER_TARGET_KIND("test_external_codegen_3", kDLCPU) + .set_attr(tvm::attr::kRelayToTIR, tvm::relay::transform::InferType()); + TEST(Target, ExternalCodegen) { Target regular("cuda"); Target external0("test_external_codegen_0"); Target external1("test_external_codegen_1"); Target external2("test_external_codegen_2"); + Target external3("test_external_codegen_3"); ASSERT_FALSE(regular.IsExternalCodegen()); ASSERT_TRUE(external0.IsExternalCodegen()); ASSERT_TRUE(external1.IsExternalCodegen()); ASSERT_TRUE(external2.IsExternalCodegen()); + ASSERT_TRUE(external3.IsExternalCodegen()); ASSERT_TRUE(external0.IsExternalCodegenFor(regular)); ASSERT_FALSE(regular.IsExternalCodegenFor(external0)); diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 41123a254825..dbc5147e2030 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -6653,4 +6653,4 @@ def verify_LinearRegressor(a_shape, c_shape, i_shape, targets=1, batch=1): if __name__ == "__main__": - pytest.main([__file__]) + tvm.testing.main() diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py index a017762ce35d..690ddcac8d51 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level2.py +++ b/tests/python/relay/dyn/test_dynamic_op_level2.py @@ -208,6 +208,4 @@ def verify_pad_default_fill(dshape, pad_width, dtype): if __name__ == "__main__": - test_dyn_pad() - test_dyn_upsampling_infer_type_const() - test_dyn_upsampling_run() + tvm.testing.main() diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index c5a9041b15fe..4f451a125184 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -31,6 +31,8 @@ set_external_func_attr, parametrize_external_codegen_checks, parametrize_external_json_codegen_checks, + check_graph_executor_result, + check_vm_result, ) @@ -180,6 +182,58 @@ def test_extern_gcc(check_result): check_result(mod, inputs, (2, 2), (y_data * y_data) - (x_data + x_data)) +# TODO(mbs): The check_aot_executor_result does not support the list-of-targets, mostly because +# tvm.testing.aot.compile_and_run requires the target to be a kind name string, and +# tvm.testing.aot.compile_models requires a single Target object. However, code outside of +# tvm.testing.aot is ready for this more general form. +@pytest.mark.parametrize("check_result", [check_graph_executor_result, check_vm_result]) +def test_extern_gcc_with_target_instance(check_result): + shape = (8, 8) + dtype = "int32" + + def make_mod(): + x0 = relay.var("x0", shape=shape, dtype=dtype) + y0 = relay.var("y0", shape=shape, dtype=dtype) + z = x0 + y0 + f = relay.Function([x0, y0], z) + f = set_external_func_attr(f, "ccompiler", "ccompiler_0") + x = relay.var("x", shape=shape, dtype=dtype) + y = relay.var("y", shape=shape, dtype=dtype) + call = relay.Call(f, [x, y]) + return tvm.IRModule.from_expr(call) + + host_target = tvm.target.Target("llvm") + generic_target = tvm.target.Target("llvm", host=host_target) + # The header attribute is just whitespace, so compilation is as usual. + good_extern_codegen_target = tvm.target.Target( + {"kind": "ccompiler", "header": "// Good"}, host=host_target + ) + # The header attribute is ill-formed, so compilation is expected to fail. + bogus_extern_codegen_target = tvm.target.Target( + {"kind": "ccompiler", "header": "Bogus"}, host=host_target + ) + + mod = make_mod() + + x_data = np.random.rand(*shape).astype(dtype) + y_data = np.random.rand(*shape).astype(dtype) + expected_result = x_data + y_data + inputs = {"x": x_data, "y": y_data} + + check_result( + mod, inputs, shape, expected_result, target=[generic_target, good_extern_codegen_target] + ) + + with pytest.raises(RuntimeError): + check_result( + mod, + inputs, + shape, + expected_result, + target=[generic_target, bogus_extern_codegen_target], + ) + + @pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now") def test_extern_gcc_consts(): @tvm._ffi.register_func("relay.ext.ccompiler.constant_updater") diff --git a/tests/python/relay/test_target_hooks.py b/tests/python/relay/test_target_hooks.py index 22b3b8cb3063..cf61b1d55d3d 100644 --- a/tests/python/relay/test_target_hooks.py +++ b/tests/python/relay/test_target_hooks.py @@ -18,19 +18,25 @@ import sys import numpy as np import pytest +import logging +import tvm import tvm.testing from tvm import relay, IRModule from utils.external_codegen import ( + parametrize_external_codegen_checks, set_external_func_attr, check_aot_executor_result, check_graph_executor_result, + check_vm_result, ) +logging.basicConfig(level=logging.INFO) -@pytest.mark.parametrize("check_result", [check_aot_executor_result, check_graph_executor_result]) -def test_tir_external_generation(check_result): + +@parametrize_external_codegen_checks +def test_tir_external_generation_inline_without_target_instance(check_result): shape = (8,) x_data = np.random.randint(255, size=shape).astype("float32") y_data = np.random.randint(255, size=shape).astype("float32") @@ -50,6 +56,49 @@ def test_tir_external_generation(check_result): check_result(func, inputs, (8,), x_data - y_data) +# TODO(mbs): The check_aot_executor_result does not support the list-of-targets, mostly because +# tvm.testing.aot.compile_and_run requires the target to be a kind name string, and +# tvm.testing.aot.compile_models requires a single Target object. However, code outside of +# tvm.testing.aot is ready for this more general form. +@pytest.mark.parametrize("check_result", [check_graph_executor_result, check_vm_result]) +def test_tir_external_generation_outline_with_target_instance(check_result): + shape = (8,) + x_data = np.random.randint(255, size=shape).astype("float32") + y_data = np.random.randint(255, size=shape).astype("float32") + inputs = {"x": x_data, "y": y_data} + # Compile with an instance of the hooked target kind to demonstrate plumbing target attributes + # into custom passes. + host_target = tvm.target.Target("llvm") + generic_target = tvm.target.Target("llvm", host=host_target) + extern_codegen_target = tvm.target.Target( + "example_target_hook -example_attribute=42", host=host_target + ) + mod = tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(8), float32], %y: Tensor[(8), float32]) -> Tensor[(8), float32] { + @replace_add_with_subtract(%x, %y) * 2.0f + } + + def @replace_add_with_subtract(%x: Tensor[(8), float32], %y: Tensor[(8), float32], + Inline=1, + Primitive=1, + Compiler="example_target_hook", + global_symbol="replace_add_with_subtract") -> Tensor[(8), float32] { + %x + %y // will be rewritten to TIR implementing %x - %y - 42.0f by custom pass + } + """ + ) + + check_result( + mod, + inputs, + (8,), + (x_data - y_data - 42.0) * 2.0, + target=[generic_target, extern_codegen_target], + ) + + @pytest.mark.parametrize("check_result", [check_aot_executor_result, check_graph_executor_result]) def test_runtime_module_generation(check_result): shape = (8,) diff --git a/tests/python/relay/utils/external_codegen.py b/tests/python/relay/utils/external_codegen.py index 6d3d917ff5a2..8e5ab803de7a 100644 --- a/tests/python/relay/utils/external_codegen.py +++ b/tests/python/relay/utils/external_codegen.py @@ -22,7 +22,7 @@ import pytest import tvm -from tvm import relay, runtime +from tvm import relay, runtime, testing from tvm.contrib import utils From fefc3d4d90035a988602a28948ca39177bdf6bfe Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Thu, 26 May 2022 13:50:21 -0700 Subject: [PATCH 2/3] - A bit of polishing en passant. --- .../example_target_hooks/relay_to_tir.cc | 19 +++++++-------- src/relay/transforms/target_hooks.cc | 24 +++++++++++++++---- tests/python/relay/test_target_hooks.py | 2 +- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index 10113dbf1a7d..eb6cf1cce420 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -43,7 +43,7 @@ namespace { * TIR PrimFunc implementing subtraction. * * Illustrates six aspects a custom 'lowering' style pass may need to account for: - * - Lowerable functions can appear inline as calls ops, bound to let-bound variables, or as + * - Lowerable functions can appear inline as call ops, bound to let-bound variables, or as * global functions. * - Let-bound lowerable functions should be inlined on-the-fly since after processing the * let-binding is no longer required. @@ -55,7 +55,9 @@ namespace { * was extracted. * * Though not illustrated here, it is also valid for a "RelayToTIR" custom pass to add - * runtime::Modules to the output IRModule's "external_mods" attribute. + * runtime::Modules to the output IRModule's "external_mods" attribute. In this case the + * IRModule must be left with an 'extern' Function definition with the matching "external_symbol" + * name. */ class ConvertAddToSubtract : public MixedModeMutator { public: @@ -89,14 +91,9 @@ class ConvertAddToSubtract : public MixedModeMutator { // - First encounter: create global var, rewrite to PrimFunc, add binding, replace call. // - Thereafter (via object sharing): discover global var already in module, replace call // - Global function: - // - func_name == global_var->name_hint - // - First encounter: rewrite to PrimFunc and update binding, replace call - // - Thereafter (via global var): Just replace call - // - func_name != global_var->name_hint - // - First encounter: create global var, rewrite to PrimFunc, add binding, replace call - // (The original Relay function should also be tagged as 'extern', ie given attribute - // "ExternalSymbol".) - // - Thereafter (via global var): discover global var already in module, replace call + // - Assume func_name == global_var->name_hint + // - First encounter: create global var, rewrite to PrimFunc, update binding, replace call + // - Thereafter (via global var): discover global var already in module, replace call // -------------------------------------------------------------------------------------------- // If necessary, introduce a new global var to map the function to and copy the source type @@ -171,6 +168,8 @@ class ConvertAddToSubtract : public MixedModeMutator { return global_var; } + using MixedModeMutator::VisitExpr_; + Expr VisitExpr_(const LetNode* op) final { auto pre_visit = [this](const LetNode* op) { Expr var = this->VisitExpr(op->var); diff --git a/src/relay/transforms/target_hooks.cc b/src/relay/transforms/target_hooks.cc index 53077a7e73e4..00953a1907e1 100644 --- a/src/relay/transforms/target_hooks.cc +++ b/src/relay/transforms/target_hooks.cc @@ -34,7 +34,8 @@ namespace { /*! * \brief A pass extracted from a target kind's "RelayToTIR" attribute, along with any - * 'external codegen' Target instance with matching kind name. + * 'external codegen' Target instance with matching kind name which should be current when + * the pass is applied. */ struct CustomPass { std::string target_kind_name; @@ -47,6 +48,10 @@ struct CustomPass { opt_target(std::move(opt_target)) {} }; +/*! + * \brief Collect all the \p CustomPasses needed according to the "Compiler" attributes on + * inlined or global functions. + */ class TargetHookVisitor : public MixedModeVisitor { public: TargetHookVisitor(IRModule mod, CompilationConfig config) @@ -56,10 +61,19 @@ class TargetHookVisitor : public MixedModeVisitor { std::vector Visit() { ICHECK(custom_passes_.empty()); - for (const auto& it : mod_->functions) { - if (const auto* function_node = it.second.as()) { - // May be a top-level function with "Compiler" attribute. + // To ensure the passes are run in a deterministic order we'll search for functions in + // lexicographic order. + std::vector> functions; + for (const auto& kv : mod_->functions) { + functions.emplace_back(kv.first->name_hint, kv.second); + } + std::sort(functions.begin(), functions.end()); + for (const auto& kv : functions) { + if (const auto* function_node = kv.second.as()) { + // May be a top-level function with a "Compiler" attribute. MaybeAddPassForFunction(function_node); + } + if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { // May have calls to inlined "Compiler" functions in body. VisitExpr(GetRef(function_node)); } @@ -88,7 +102,7 @@ class TargetHookVisitor : public MixedModeVisitor { } /*! - * \brief If \p function_node has a "Compiler" attribute, check if we should include a + * \brief If \p function_node has a "Compiler" attribute, checks if we should include a * matching custom pass. Otherwise no-op. */ void MaybeAddPassForFunction(const FunctionNode* function_node) { diff --git a/tests/python/relay/test_target_hooks.py b/tests/python/relay/test_target_hooks.py index cf61b1d55d3d..046b2c7e541d 100644 --- a/tests/python/relay/test_target_hooks.py +++ b/tests/python/relay/test_target_hooks.py @@ -56,7 +56,7 @@ def test_tir_external_generation_inline_without_target_instance(check_result): check_result(func, inputs, (8,), x_data - y_data) -# TODO(mbs): The check_aot_executor_result does not support the list-of-targets, mostly because +# TODO(mbs): The check_aot_executor_result does not support list-of-targets, mostly because # tvm.testing.aot.compile_and_run requires the target to be a kind name string, and # tvm.testing.aot.compile_models requires a single Target object. However, code outside of # tvm.testing.aot is ready for this more general form. From 673c5ce978a0f22d250dcd23c23eb4f25318b1b9 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Fri, 27 May 2022 10:48:59 -0700 Subject: [PATCH 3/3] - Add comment as per Josh's suggestion Can't repro tests/python/contrib/test_ethosu/cascader/test_scheduler.py::test_compute_cycles_annotation failure, flake? --- include/tvm/relay/transform.h | 39 ++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index e42bf4009de1..6e3bddf9adf5 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -462,7 +462,44 @@ TVM_DLL Pass RemoveUnusedFunctions(Array entry_functions); TVM_DLL Pass SimplifyExpr(); /*! - * \brief Run any registered RelayToTIR passes registered on the functions in a module. + * \brief Run any custom passes registered under "RelayToTIR" attributes on TargetKinds. + * + * This pass looks for inline, let-bound or global functions which have a "Compiler" attribute. + * If the attribute value corresponds to a TargetKind with a "RelayToTIR" attribute, then the + * 'custom' pass bound to that attribute is run (at most once) on the IRModule as a whole. + * + * If, in addition, the \p config has a Target with a matching TargetKind, that Target is set + * as the 'current' target before the custom pass is executed. In this way it is possible + * for custom passes to pick up target options which may guide how they transform the IRModule. + * (Those targets are referred to as 'extern codegen targets' elsewhere). + * + * A typical custom pass will: + * - Find calls to "Compiler" attributes functions with matching compiler name. + * - Lower those function to TIR PrimFuncs. + * - Bind those functions into the IRModule under the the functions' "global_symbol" attribute. + * - Replace all calls to those functions with 'call_lowered' to the matching global. + * Care should be taken to handle multiple calls to the same function. + * See src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc for an example custom pass. + * + * It is also possible (despite the pass and attribute names!) for the custom pass to proceed + * directly to a runtime::Module, which can be attached to the output IRModules "external_mods" + * attribute (taking care not to clobber any existing modules). In this case the flow is as above, + * except: + * - The runtime::Module must contain a binding for each compiled function under their + * "global_symbol" (ie runtime::Module::ImplementsFunction should return true). + * - A Relay Function must be bound (or re-bound) into the result IRModule, again with the same + * "global_symbol", but with only the "Extern" attribute set to Integer(1). The function body + * should be the original function body. In this way we always have a TVM definition matching + * every global function name. + * + * There are many existing runtime::Modules, ranging from source to object to dynamic libaries to + * entirely custom implementations. Some of those may require additional compilation using + * 'export_library' on the final build artifact. + * + * The OutlineCompilerFunctionsWithExistingGlobalSymbols and MarkCompilerFunctionsAsExtern utility + * passes can be used by custom passes to take care of some of the boilerplate. + * + * TODO(mbs): Rename PreLoweringTargetHooks? * * \param config All available targets. *