diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 0d518e4ed547..6e3bddf9adf5 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -462,11 +462,50 @@ TVM_DLL Pass RemoveUnusedFunctions(Array entry_functions); TVM_DLL Pass SimplifyExpr(); /*! - * \brief Run any registered RelayToTIR passes registered on the functions in a module. + * \brief Run any custom passes registered under "RelayToTIR" attributes on TargetKinds. + * + * This pass looks for inline, let-bound or global functions which have a "Compiler" attribute. + * If the attribute value corresponds to a TargetKind with a "RelayToTIR" attribute, then the + * 'custom' pass bound to that attribute is run (at most once) on the IRModule as a whole. + * + * If, in addition, the \p config has a Target with a matching TargetKind, that Target is set + * as the 'current' target before the custom pass is executed. In this way it is possible + * for custom passes to pick up target options which may guide how they transform the IRModule. + * (Those targets are referred to as 'extern codegen targets' elsewhere). + * + * A typical custom pass will: + * - Find calls to "Compiler" attributes functions with matching compiler name. + * - Lower those function to TIR PrimFuncs. + * - Bind those functions into the IRModule under the the functions' "global_symbol" attribute. + * - Replace all calls to those functions with 'call_lowered' to the matching global. + * Care should be taken to handle multiple calls to the same function. + * See src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc for an example custom pass. + * + * It is also possible (despite the pass and attribute names!) for the custom pass to proceed + * directly to a runtime::Module, which can be attached to the output IRModules "external_mods" + * attribute (taking care not to clobber any existing modules). In this case the flow is as above, + * except: + * - The runtime::Module must contain a binding for each compiled function under their + * "global_symbol" (ie runtime::Module::ImplementsFunction should return true). + * - A Relay Function must be bound (or re-bound) into the result IRModule, again with the same + * "global_symbol", but with only the "Extern" attribute set to Integer(1). The function body + * should be the original function body. In this way we always have a TVM definition matching + * every global function name. + * + * There are many existing runtime::Modules, ranging from source to object to dynamic libaries to + * entirely custom implementations. Some of those may require additional compilation using + * 'export_library' on the final build artifact. + * + * The OutlineCompilerFunctionsWithExistingGlobalSymbols and MarkCompilerFunctionsAsExtern utility + * passes can be used by custom passes to take care of some of the boilerplate. + * + * TODO(mbs): Rename PreLoweringTargetHooks? + * + * \param config All available targets. * * \return The pass. */ -TVM_DLL Pass RelayToTIRTargetHook(); +TVM_DLL Pass RelayToTIRTargetHook(CompilationConfig config); /*! * \brief A pass for manifesting explicit memory allocations and rewriting diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 395d3aab6757..4879470e7654 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -402,6 +402,16 @@ namespace attr { * See also \p Target::IsExternalCodegenFor */ constexpr const char* kIsExternalCodegen = "is_external_codegen"; + +/*! + * \brief A \p TargetKind attribute of type \p FTVMRelayToTIR. If set, then the target kind name + * also corresponds to an external codegen 'compiler' name, and the bound value is a \p Pass + * to apply before the TVM lowering. + * + * See also \p Target::IsExternalCodegenFor + */ +constexpr const char* kRelayToTIR = "RelayToTIR"; + } // namespace attr /*! diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 60f108aacf66..167afd2c5f78 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -1079,7 +1079,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { // lowering process directly. tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment); }, - config_->host_virtual_device)(mod); + config_)(mod); auto lowered_main = lowered_mod->Lookup("main"); auto lowered_main_func = GetRef(lowered_main.as()); diff --git a/src/relay/backend/contrib/cmsisnn/target.cc b/src/relay/backend/contrib/cmsisnn/target.cc index 99bc0bc7cb20..fd2f18aa9905 100644 --- a/src/relay/backend/contrib/cmsisnn/target.cc +++ b/src/relay/backend/contrib/cmsisnn/target.cc @@ -31,7 +31,7 @@ tvm::transform::Pass RelayToTIR(); runtime::Module TIRToRuntime(IRModule mod, Target target); TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU) - .set_attr("RelayToTIR", RelayToTIR()) + .set_attr(tvm::attr::kRelayToTIR, RelayToTIR()) .set_attr("TIRToRuntime", TIRToRuntime); } // namespace cmsisnn diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index 19b8c579cd8b..fd1c39bb9283 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -227,6 +227,14 @@ class CSourceCodegen : public CSourceModuleCodegenBase { Array variables = std::get<0>(res); String func_name = std::get<1>(res); + Optional opt_target = Target::Current(); + if (opt_target.defined() && opt_target.value()->kind->name == "ccompiler") { + Optional header = opt_target.value()->GetAttr("header"); + if (header.defined() && !header.value().empty()) { + code_stream_ << header.value().c_str() << "\n"; + } + } + // Create headers code_stream_ << "#include \n"; code_stream_ << "#include \n"; @@ -293,6 +301,10 @@ runtime::Module CCompiler(const ObjectRef& ref) { TVM_REGISTER_GLOBAL("relay.ext.ccompiler").set_body_typed(CCompiler); +TVM_REGISTER_TARGET_KIND("ccompiler", kDLCPU) + .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)) + .add_attr_option("header", String("")); // value is prepended to every output CModule + } // namespace contrib } // namespace relay } // namespace tvm diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc index 47c80b47c579..afa17750d8a8 100644 --- a/src/relay/backend/contrib/ethosu/codegen.cc +++ b/src/relay/backend/contrib/ethosu/codegen.cc @@ -320,7 +320,7 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) { TVM_REGISTER_TARGET_KIND("ethos-u", kDLCPU) .set_attr("use_device_api", Bool(true)) - .set_attr("RelayToTIR", RelayToTIR()) + .set_attr(tvm::attr::kRelayToTIR, RelayToTIR()) .set_attr("TIRToRuntime", TIRToRuntime); } // namespace ethosu diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index c498baa6d11d..eb6cf1cce420 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -28,12 +28,37 @@ #include #include "../../../op/call/call.h" +#include "tvm/tir/function.h" namespace tvm { namespace relay { namespace contrib { namespace example_target_hooks { +namespace { + +/*! + * \brief An example mutator for a "RelayToTIR" custom pass. Replaces every call to a Relay + * Function with "external_symbol" attribute of "replace_add_with_subtract" with a call to a + * TIR PrimFunc implementing subtraction. + * + * Illustrates six aspects a custom 'lowering' style pass may need to account for: + * - Lowerable functions can appear inline as call ops, bound to let-bound variables, or as + * global functions. + * - Let-bound lowerable functions should be inlined on-the-fly since after processing the + * let-binding is no longer required. + * - There may be multiple calls to the same lowerable function. All calls need to be + * rewritten, even though the function itself need be rewritten only once. + * - GlobalVars must be shared between all calls and the new definition itself. + * - Calls to lowered functions must use the "call_lowered" calling convention. + * - The Target::Current() may hold an instance of the TargetKind from which the custom Pass + * was extracted. + * + * Though not illustrated here, it is also valid for a "RelayToTIR" custom pass to add + * runtime::Modules to the output IRModule's "external_mods" attribute. In this case the + * IRModule must be left with an 'extern' Function definition with the matching "external_symbol" + * name. + */ class ConvertAddToSubtract : public MixedModeMutator { public: explicit ConvertAddToSubtract(IRModule ir_module, Target host_target) @@ -56,51 +81,102 @@ class ConvertAddToSubtract : public MixedModeMutator { return tir::BufferLoad(buffer, {index}); } - void ReplaceAddWithSubtractPrimFunc(const GlobalVar& new_global_var, const Function& func) { - tir::Buffer x_buffer = tir::decl_buffer({8}, DataType::Float(32), "x"); - tir::Buffer y_buffer = tir::decl_buffer({8}, DataType::Float(32), "y"); - tir::Buffer out_buffer = tir::decl_buffer({8}, DataType::Float(32)); + GlobalVar ReplaceAddWithSubtractPrimFunc(const Function& func) { + auto func_name = func->GetAttr(::tvm::attr::kGlobalSymbol); + ICHECK(func_name.defined()); - tir::Var x_var("x", DataType::Handle()); - tir::Var y_var("y", DataType::Handle()); - tir::Var out_var("out", DataType::Handle()); + // -------------------------------------------------------------------------------------------- + // Cases: + // - Inline function: + // - First encounter: create global var, rewrite to PrimFunc, add binding, replace call. + // - Thereafter (via object sharing): discover global var already in module, replace call + // - Global function: + // - Assume func_name == global_var->name_hint + // - First encounter: create global var, rewrite to PrimFunc, update binding, replace call + // - Thereafter (via global var): discover global var already in module, replace call + // -------------------------------------------------------------------------------------------- - Map dict_attrs; - dict_attrs.Set("global_symbol", new_global_var->name_hint); - dict_attrs.Set("tir.noalias", Bool(true)); + // If necessary, introduce a new global var to map the function to and copy the source type + // over for InferType. + GlobalVar global_var; + bool need_rewriting; + if (ir_module_->ContainGlobalVar(func_name.value())) { + global_var = ir_module_->GetGlobalVar(func_name.value()); + // Only rewrite to a PrimFunc if the global definition is still a Relay function. + need_rewriting = ir_module_->Lookup(global_var)->IsInstance(); + } else { + global_var = GlobalVar(func_name.value()); + global_var->checked_type_ = func->checked_type(); + need_rewriting = true; + } - te::Var index("index", DataType::Int(32)); - tir::Sub indexed_sub = tir::Sub(LoadIndex(x_buffer, index), LoadIndex(y_buffer, index)); - tir::Stmt math_body = tir::BufferStore(out_buffer, indexed_sub, {index}); - tir::Stmt math_loop = tir::For(index, 0, 8, tir::ForKind::kSerial, math_body); + // For illustration only, check if the current target matches the example_target_hook kind, + // and if so extract the example attribute value. + int64_t example_attribute_value = 0; + Optional opt_current_target = Target::Current(); + if (opt_current_target.defined() && + opt_current_target.value()->kind->name == "example_target_hook") { + example_attribute_value = + opt_current_target.value()->GetAttr("example_attribute").value()->value; + } - Map buffer_map = { - {x_var, x_buffer}, - {y_var, y_buffer}, - {out_var, out_buffer}, - }; + if (need_rewriting) { + // The called function is still in Relay form. Convert to TIR. + tir::Buffer x_buffer = tir::decl_buffer({8}, DataType::Float(32), "x"); + tir::Buffer y_buffer = tir::decl_buffer({8}, DataType::Float(32), "y"); + tir::Buffer out_buffer = tir::decl_buffer({8}, DataType::Float(32)); - tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(), - buffer_map, {}, DictAttrs(dict_attrs)); + tir::Var x_var("x", DataType::Handle()); + tir::Var y_var("y", DataType::Handle()); + tir::Var out_var("out", DataType::Handle()); - // Switch to TIRToRuntime hook for testing - Bool tir_to_runtime = func->GetAttr("tir_to_runtime").value_or(Bool(false)); - if (tir_to_runtime) { - replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, custom_target_); - } else { - replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_); + Map dict_attrs; + dict_attrs.Set("global_symbol", global_var->name_hint); + dict_attrs.Set("tir.noalias", Bool(true)); + + te::Var index("index", DataType::Int(32)); + tir::Sub indexed_sub = tir::Sub(LoadIndex(x_buffer, index), LoadIndex(y_buffer, index)); + if (example_attribute_value > 0) { + // For illustration only, fold the example attribute into the result. + indexed_sub = tir::Sub(indexed_sub, FloatImm(DataType::Float(32), + static_cast(example_attribute_value))); + } + + tir::Stmt math_body = tir::BufferStore(out_buffer, indexed_sub, {index}); + tir::Stmt math_loop = tir::For(index, 0, 8, tir::ForKind::kSerial, math_body); + + Map buffer_map = { + {x_var, x_buffer}, + {y_var, y_buffer}, + {out_var, out_buffer}, + }; + + tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(), + buffer_map, {}, DictAttrs(dict_attrs)); + + // Switch to TIRToRuntime hook for testing + Bool tir_to_runtime = func->GetAttr("tir_to_runtime").value_or(Bool(false)); + if (tir_to_runtime) { + replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, custom_target_); + } else { + replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_); + } + + ir_module_->Update(global_var, replacement_func); // Will Add if global_var is new. } - ir_module_->Add(new_global_var, replacement_func); + return global_var; } + using MixedModeMutator::VisitExpr_; + Expr VisitExpr_(const LetNode* op) final { auto pre_visit = [this](const LetNode* op) { Expr var = this->VisitExpr(op->var); Expr value = this->VisitExpr(op->value); - // Outlineable function no longer needs let binding - if (this->CanLowerExpr(value)) { + if (AsLowerableFunction(value)) { + // Inline on-the-fly if the let-bound value is lowerable. this->memo_[var] = value; } }; @@ -110,8 +186,8 @@ class ConvertAddToSubtract : public MixedModeMutator { Expr body = this->VisitExpr(op->body); auto expr = GetRef(op); - // Drop the let binding - if (this->CanLowerExpr(value)) { + if (AsLowerableFunction(value)) { + // The let binding is no longer needed since inlined on-the-fly above. this->memo_[expr] = this->VisitExpr(op->body); } else { Var var = Downcast(this->VisitExpr(op->var)); @@ -126,39 +202,49 @@ class ConvertAddToSubtract : public MixedModeMutator { return memo_[GetRef(op)]; } - bool CanLowerExpr(const Expr& expr) { - const auto* func = expr.as(); - if (func == nullptr) { - return false; - } - auto func_name = func->GetAttr(::tvm::attr::kGlobalSymbol); - if (!func_name.defined()) { - return false; + const FunctionNode* AsLowerableFunction(const Expr& expr) { + if (const auto* function_node = expr.as()) { + auto func_name = function_node->GetAttr(::tvm::attr::kGlobalSymbol); + if (!func_name.defined()) { + return nullptr; + } + if (func_name != "replace_add_with_subtract") { + return nullptr; + } + return function_node; + } else if (const auto* global_var_node = expr.as()) { + return AsLowerableFunction(ir_module_->Lookup(GetRef(global_var_node))); + } else { + return nullptr; } - if (func_name != "replace_add_with_subtract") { - return false; + } + + const GlobalVarNode* AsAlreadyLoweredFunction(const Expr& expr) { + if (const auto* global_var_node = expr.as()) { + if (ir_module_->Lookup(GetRef(global_var_node)).as()) { + return global_var_node; + } } - return true; + return nullptr; } Expr Rewrite_(const CallNode* pre, const Expr& post) override { - if (const CallNode* call = post.as()) { - if (CanLowerExpr(call->op)) { - auto* func = call->op.as(); - auto func_name = func->GetAttr(::tvm::attr::kGlobalSymbol); - - // Introduce a new global var to map the function to and copy the source type - // over for InferType - GlobalVar new_global_var(func_name.value()); - new_global_var->checked_type_ = func->checked_type(); - ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef(func)); - + if (const auto* call = post.as()) { + GlobalVar new_op; + if (const auto* function_node = AsLowerableFunction(call->op)) { + // Add or replace the function with a PrimFunc. + new_op = ReplaceAddWithSubtractPrimFunc(GetRef(function_node)); + } else if (const auto* global_var_node = AsAlreadyLoweredFunction(call->op)) { + // The function has already been rewritten, so we just need to update the call. + new_op = GetRef(global_var_node); + } + if (new_op.defined()) { // Since we are replacing the Relay function with a call to a TIR function, we must use // the call_lowered op. CallLoweredAttrs attrs; attrs.metadata.Set("relay_attrs", call->attrs); ICHECK(call->type_args.empty()) << "lowered functions cannot be polymorphic"; - return CallLowered(std::move(new_global_var), call->args, std::move(attrs), call->span); + return CallLowered(std::move(new_op), call->args, std::move(attrs), call->span); } } @@ -171,10 +257,12 @@ class ConvertAddToSubtract : public MixedModeMutator { Target custom_target_; }; +} // namespace + transform::Pass RelayToTIR() { runtime::TypedPackedFunc pass_func = [=](IRModule ir_module, transform::PassContext pass_context) { - auto relay_to_tir = ConvertAddToSubtract(ir_module, Target("c")); + ConvertAddToSubtract relay_to_tir(std::move(ir_module), Target("c")); return relay_to_tir.Mutate(); }; return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIR", {}); diff --git a/src/relay/backend/contrib/example_target_hooks/target.cc b/src/relay/backend/contrib/example_target_hooks/target.cc index 6f1914eac4c3..19bfa8c68298 100644 --- a/src/relay/backend/contrib/example_target_hooks/target.cc +++ b/src/relay/backend/contrib/example_target_hooks/target.cc @@ -34,7 +34,8 @@ runtime::Module TIRToRuntime(IRModule mod, Target target); TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU) .set_attr("use_device_api", Bool(true)) - .set_attr("RelayToTIR", relay::contrib::example_target_hooks::RelayToTIR()) - .set_attr("TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime); + .set_attr(attr::kRelayToTIR, relay::contrib::example_target_hooks::RelayToTIR()) + .set_attr("TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime) + .add_attr_option("example_attribute", Integer(0)); } // namespace tvm diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 2734439cddbd..7dba23803f8c 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -232,7 +232,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorfunction_metadata_); }, - config_->host_virtual_device)(mod); + config_)(mod); Optional main_func_info = lowered_mod->GetAttr("main_func_info"); diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 65ef29651695..9661040eab30 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -945,14 +945,13 @@ class Interpreter : public ExprFunctor, * rewritten \p mod and target-specific modules containing bindings for all TIR primitive * functions needed by the rewritten module. */ -IRModule Prepare(IRModule mod, CompilationConfig config) { - VirtualDevice host_virtual_device = config->host_virtual_device; +IRModule Prepare(IRModule mod, const CompilationConfig& config) { // Run minimal transforms on module to establish invariants needed by interpreter. transform::Sequential seq( {transform::SimplifyInference(), qnn::transform::Legalize(), // Figure out which devices should be used to execute. // TODO(mbs): Should ignore all existing annotations when constant folding - transform::PlanDevices(std::move(config)), + transform::PlanDevices(config), // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' // attribute. transform::FuseOps(/*fuse_opt_level=*/0), @@ -962,8 +961,7 @@ IRModule Prepare(IRModule mod, CompilationConfig config) { transform::EtaExpand( /*expand_constructor=*/true, /*expand_global_var=*/false), transform::InferType(), - tec::LowerTEPass(/*module_name=*/"intrp", [](BaseFunc func) { /* no-op */ }, - std::move(host_virtual_device))}); + tec::LowerTEPass(/*module_name=*/"intrp", [](BaseFunc func) { /* no-op */ }, config)}); transform::PassContext pass_ctx = transform::PassContext::Current(); With ctx(pass_ctx); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 76dbfef5386d..73b44f7361a5 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -299,11 +299,10 @@ class TECompilerImpl : public TECompilerNode { // the module's globals. Furthermore, the external codegen tool must bind the compiled // function to the "global_symbol" attribute on the source_func. So do not use GetUniqueName // here. - auto target = Target("ext_dev"); auto global_var = GlobalVar(opt_global_symbol.value()); global_var->checked_type_ = key->source_func->checked_type(); ir_module->Add(global_var, key->source_func); - value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule{nullptr}, + value->cached_func = CachedFunc(key->target, global_var, {}, {}, te::Schedule{nullptr}, tir::PrimFunc{nullptr}, {}, ir_module); // Collect these here as it's removed in LowerExternalFunctions() device_contexts_.Set(value->cached_func->prim_fn_var, opt_compiler.value()); @@ -531,14 +530,14 @@ using AnalysisRemapping = std::unordered_maptarget); + CCacheKey shape_key(func, config_->host_virtual_device->target); CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); // Capture the shape function's global var and parameters 'states' in call @@ -733,7 +732,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // Special case: device_copies are left as calls to primitive operators // (thus undoing FuseOps) so that each backend can handle them directly. - // TODO(mbs): device_copy cleanup. Would be better for FuseOps to just leave device_copy alone. + // TODO(mbs): device_copy cleanup. Would be better for FuseOps to just leave device_copy + // alone. if (const auto* function_node = primitive_func.as()) { DeviceCopyProps device_copy_props = GetDeviceCopyProps(function_node->body); if (device_copy_props.body.defined()) { @@ -771,10 +771,18 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // Typical case: call to fused primitive Relay Function. // Find the desired target device. Target target; - if (primitive_func->GetAttr(attr::kCompiler).defined()) { - // The generic 'external device' target. - // TODO(mbs): Retire once replaced unified BYOC compiler and target machinery - target = Target("ext_dev"); + Optional opt_compiler = primitive_func->GetAttr(attr::kCompiler); + if (opt_compiler.defined()) { + // This function needs to be compiled with external codegen. + Optional opt_target = config_->FindPrimitiveTargetForKind(opt_compiler.value()); + if (opt_target.defined()) { + // The target is what's supplied by the compilation config for kind matching the + // "Compiler" name. + target = opt_target.value(); + } else { + // Legacy fallback. + target = Target("ext_dev"); + } } else { // The target corresponding to the call_node expression's annotation. VirtualDevice virtual_device = GetVirtualDevice(GetRef(call_node)); @@ -791,6 +799,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { IRModule module_; ProcessFn process_fn_; + /*! \brief All available targets. */ + CompilationConfig config_; // Map from in-scope let-bound variables to Functions known to be primitive, or PrimFuncs which // have already been lowered. We'll rewrite these to the fresh global vars bound to the lowered // primitive function as we go. Those vars will be bound in the target device-type specific @@ -799,21 +809,15 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { std::unordered_map primitive_functions_; String module_name_; TECompiler compiler_; - /*! - * \brief The \p VirtualDevice for the host, which is where all shape-related data and computation - * must live. - */ - VirtualDevice host_virtual_device_; // Cache ops that need to be frequently used later to reduce lookup overhead. const Op& debug_op_; }; Pass LowerTensorExpr(const String& module_name, TECompiler compiler, ProcessFn process_fn, - VirtualDevice host_virtual_device) { + CompilationConfig config) { runtime::TypedPackedFunc pass_func = [=](Function func, IRModule module, PassContext ctx) { - LowerTensorExprMutator lower_te(module, process_fn, module_name, compiler, - host_virtual_device); + LowerTensorExprMutator lower_te(module, process_fn, config, module_name, compiler); return Downcast(lower_te.Mutate(func)); }; return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {}); @@ -1043,7 +1047,7 @@ void UpdateFunctionMetadata(BaseFunc func, } IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn process_fn, - VirtualDevice host_virtual_device) { + CompilationConfig config) { TECompiler compiler(module); // TODO(mbs): This is all unnecessarily convoluted. Better would be to accumulate the rewritten @@ -1058,8 +1062,8 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr // GlobalVar, and calls updated (sticking with regular Relay Call). // - Calls to functions tagged with "Primitive" are compiled to PrimFuncs, and calls updated // (using call_lowered convention). - IRModule updated_module = LowerTensorExpr(module_name, compiler, std::move(process_fn), - std::move(host_virtual_device))(module); + IRModule updated_module = + LowerTensorExpr(module_name, compiler, std::move(process_fn), std::move(config))(module); // The Functions tagged with "Compiler" are now residing in the cache ready to be // compiled by LowerExternalFunctions. However we still need a record of them in the @@ -1159,15 +1163,14 @@ Map GetPerTargetModules(IRModule mod) { return per_target_modules; } -Pass LowerTEPass(const String& module_name, ProcessFn process_fn, - VirtualDevice host_virtual_device) { +Pass LowerTEPass(String module_name, ProcessFn process_fn, CompilationConfig complilation_config) { runtime::TypedPackedFunc pass_func = [=](IRModule module, PassContext ctx) { - return LowerTE(module, module_name, process_fn, host_virtual_device); + return LowerTE(module, module_name, process_fn, complilation_config); }; return tvm::transform::Sequential( - {tvm::relay::transform::RelayToTIRTargetHook(), + {tvm::relay::transform::RelayToTIRTargetHook(complilation_config), tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {"InferType"}), InferType(), tvm::tir::transform::ExtractPrimFuncConstants()}); } diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 0b2288d6a156..8312a20cb862 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -189,7 +189,8 @@ IRModule LowerTE( const IRModule& module, backend::StaticMemoryPlan memory_plan, const String& module_name, ProcessFn process_fn = [](BaseFunc f) {}); -/*! \brief Pass to lower an IRModule's primitive functions to TIR. +/*! + * \brief Pass to lower an IRModule's primitive functions to TIR. * * This is the "back half" of the Relay compiler which lowers "primitive functions" * to TE expressions, schedules them, and then to TIR. It annotates all functions @@ -198,11 +199,11 @@ IRModule LowerTE( * \param module_name The name of this module * \param process_fn Callback allowing one-level up code generators to process * each function that we lower - * \param host_virtual_device \p VirtualDevice for host data and computations - * \returns The pass which lowers primative functions to TIR + * \param config All available targets. + * \returns The pass which lowers primitive functions to TIR */ -transform::Pass LowerTEPass(const String& module_name, ProcessFn process_fn, - VirtualDevice host_virtual_device); +transform::Pass LowerTEPass(String module_name, ProcessFn process_fn, CompilationConfig config); + } // namespace tec } // namespace relay } // namespace tvm diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 5a62ac66f736..e0b742a84090 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -523,11 +523,13 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { op_index = itr->second; } - // Capture the dictionary of attributes from the original primitive function so that they - // can contribute to the hash of the compiled primitive. This way we can distinguish primitives - // with the same body expression but different attributes which may arbitrarily influence code - // generation. - op_attrs[op_index] = attrs->dict; + if (attrs.defined() && attrs->dict.defined()) { + // Capture the dictionary of attributes from the original primitive function so that they + // can contribute to the hash of the compiled primitive. This way we can distinguish + // primitives with the same body expression but different attributes which may arbitrarily + // influence code generation. + op_attrs[op_index] = attrs->dict; + } Emit(Instruction::InvokePacked(op_index, argument_registers.size(), output_tuple->fields.size(), argument_registers)); @@ -981,25 +983,25 @@ void VMCompiler::LowerImpl(IRModule mod) { } } -transform::Sequential VMCompiler::MemoryOpt(const VirtualDevice& host_virtual_device) { +transform::Sequential VMCompiler::MemoryOpt(const CompilationConfig& config) { Array pass_seqs; // Remove unused functions Array entry_functions{"main"}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); // Manifest the allocations. - pass_seqs.push_back(transform::ManifestAlloc(host_virtual_device)); + pass_seqs.push_back(transform::ManifestAlloc(config->host_virtual_device)); // Compute away possibly introduced constant computation. pass_seqs.push_back(transform::FoldConstant()); // Fuse & lower any new shape functions and device_copies. - pass_seqs.push_back(FuseAndLowerOperators(host_virtual_device)); + pass_seqs.push_back(FuseAndLowerOperators(config)); // Manifest the allocations needed for the shape functions. - pass_seqs.push_back(transform::ManifestAlloc(host_virtual_device)); + pass_seqs.push_back(transform::ManifestAlloc(config->host_virtual_device)); // Fuse & lower any new allocations. - pass_seqs.push_back(FuseAndLowerOperators(host_virtual_device)); + pass_seqs.push_back(FuseAndLowerOperators(config)); // TODO(mbrookhart, jroesch, masahi): this pass is very slow, and is // incomplete to provide memory resuse optimizations. Disable it until we can @@ -1011,10 +1013,10 @@ transform::Sequential VMCompiler::MemoryOpt(const VirtualDevice& host_virtual_de pass_seqs.push_back(transform::FoldConstant()); // Fuse & lower yet again - pass_seqs.push_back(FuseAndLowerOperators(host_virtual_device)); + pass_seqs.push_back(FuseAndLowerOperators(config)); // Create allocations for math introduced by dynamic region math. - pass_seqs.push_back(transform::ManifestAlloc(host_virtual_device)); + pass_seqs.push_back(transform::ManifestAlloc(config->host_virtual_device)); // Compute away possibly introduced constant computation. pass_seqs.push_back(transform::FoldConstant()); @@ -1030,7 +1032,7 @@ transform::Sequential VMCompiler::MemoryOpt(const VirtualDevice& host_virtual_de return transform::Sequential(std::move(pass_seqs)); } -transform::Sequential VMCompiler::FuseAndLowerOperators(const VirtualDevice& host_virtual_device) { +transform::Sequential VMCompiler::FuseAndLowerOperators(const CompilationConfig& config) { Array pass_seqs; // Hoist operators to "primitive" Functions. pass_seqs.push_back(FuseOps()); @@ -1043,7 +1045,7 @@ transform::Sequential VMCompiler::FuseAndLowerOperators(const VirtualDevice& hos backend::UpdateConstants(func, ¶ms_); } }, - host_virtual_device)); + config)); // Since lowered functions are bound in the IRModule, we can now eliminate any unused // let-bound functions. pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false)); @@ -1094,7 +1096,7 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { backend::UpdateConstants(func, ¶ms_); } }, - config_->host_virtual_device)); + config_)); // Since lowered functions are bound in the IRModule, we can now eliminate any unused // let-bound functions. @@ -1111,7 +1113,7 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { // external codegen. pass_seqs.push_back(transform::Inline()); - pass_seqs.push_back(MemoryOpt(config_->host_virtual_device)); + pass_seqs.push_back(MemoryOpt(config_)); pass_seqs.push_back(transform::InferType()); transform::Sequential seq(pass_seqs); diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index a65bdc5ab3cb..163ec399013b 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -146,10 +146,10 @@ class VMCompiler : public runtime::ModuleNode { IRModule OptimizeModuleImpl(IRModule mod); /*! \brief Returns the passes which layout memory. */ - transform::Sequential MemoryOpt(const VirtualDevice& host_virtual_device); + transform::Sequential MemoryOpt(const CompilationConfig& config); /*! \brief Returns the passes which fuse then lower Relay primitive operators. */ - transform::Sequential FuseAndLowerOperators(const VirtualDevice& host_virtual_device); + transform::Sequential FuseAndLowerOperators(const CompilationConfig& config); /*! * \brief Populate the global function names in a map where the value is used diff --git a/src/relay/transforms/dead_code.cc b/src/relay/transforms/dead_code.cc index ca1e04ae59fa..45cb8271b074 100644 --- a/src/relay/transforms/dead_code.cc +++ b/src/relay/transforms/dead_code.cc @@ -534,6 +534,7 @@ namespace transform { // Declared in relay/transform.h Pass DeadCodeElimination(bool inline_once, bool ignore_impurity) { auto pass_func = [=](IRModule mod, PassContext pc) -> IRModule { + VLOG(1) << "Before:" << std::endl << PrettyPrint(mod); // Which let bindings are pure and can be safely elided? std::unordered_map var_to_purity; if (!ignore_impurity) { @@ -566,6 +567,7 @@ Pass DeadCodeElimination(bool inline_once, bool ignore_impurity) { result->Add(kv.first, kv.second); } } + VLOG(1) << "After:" << std::endl << PrettyPrint(result); return result; }; diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index a6e26364bbc4..c55b6778093e 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -69,6 +69,7 @@ class Inliner : ExprMutator { for (auto arg : vanilla_call->args) { new_args.push_back(VisitExpr(arg)); } + // TODO(mbs): Does not handle multiple calls to the same global function. cur_node_->RemoveCallTo(gv); return MakeNewExpr(gv, new_args, GetRef(call_node)); } diff --git a/src/relay/transforms/target_hooks.cc b/src/relay/transforms/target_hooks.cc index 0022baf881ba..00953a1907e1 100644 --- a/src/relay/transforms/target_hooks.cc +++ b/src/relay/transforms/target_hooks.cc @@ -30,61 +30,143 @@ namespace tvm { namespace relay { namespace transform { -class TargetHookVisitor : public tvm::relay::MixedModeVisitor { - /*! \brief Collected pass list for all nodes */ - std::vector pass_list_; - /*! \brief Attribute map for all registered targets */ - TargetKindAttrMap target_attr_map_; - using tvm::relay::MixedModeVisitor::VisitExpr_; +namespace { + +/*! + * \brief A pass extracted from a target kind's "RelayToTIR" attribute, along with any + * 'external codegen' Target instance with matching kind name which should be current when + * the pass is applied. + */ +struct CustomPass { + std::string target_kind_name; + Pass pass; + Optional opt_target; + CustomPass(std::string target_kind_name, Pass pass, Optional opt_target) + : target_kind_name(std::move(target_kind_name)), + pass(std::move(pass)), + opt_target(std::move(opt_target)) {} +}; + +/*! + * \brief Collect all the \p CustomPasses needed according to the "Compiler" attributes on + * inlined or global functions. + */ +class TargetHookVisitor : public MixedModeVisitor { public: - TargetHookVisitor() : target_attr_map_(tvm::TargetKind::GetAttrMap("RelayToTIR")) {} + TargetHookVisitor(IRModule mod, CompilationConfig config) + : mod_(std::move(mod)), + config_(std::move(config)), + target_attr_map_(tvm::TargetKind::GetAttrMap(tvm::attr::kRelayToTIR)) {} - std::vector Visit(const IRModule& ir_mod) { - for (const auto& it : ir_mod->functions) { - if (const auto* function_node = it.second.as()) { + std::vector Visit() { + ICHECK(custom_passes_.empty()); + // To ensure the passes are run in a deterministic order we'll search for functions in + // lexicographic order. + std::vector> functions; + for (const auto& kv : mod_->functions) { + functions.emplace_back(kv.first->name_hint, kv.second); + } + std::sort(functions.begin(), functions.end()); + for (const auto& kv : functions) { + if (const auto* function_node = kv.second.as()) { + // May be a top-level function with a "Compiler" attribute. + MaybeAddPassForFunction(function_node); + } + if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { + // May have calls to inlined "Compiler" functions in body. VisitExpr(GetRef(function_node)); } } - return pass_list_; + return std::move(custom_passes_); } - void VisitExpr_(const LetNode* op) final { - auto pre_visit = [this](const LetNode* op) { - this->VisitExpr(op->var); - this->VisitExpr(op->value); + private: + using tvm::relay::MixedModeVisitor::VisitExpr_; + + void VisitExpr_(const LetNode* let_node) final { + auto pre_visit = [this](const LetNode* inner_let_node) { + this->VisitExpr(inner_let_node->var); + this->VisitExpr(inner_let_node->value); }; - auto post_visit = [this](const LetNode* op) { - this->VisitExpr(op->body); - this->visit_counter_[op] += 1; + auto post_visit = [this](const LetNode* inner_let_node) { + this->VisitExpr(inner_let_node->body); + this->visit_counter_[inner_let_node] += 1; }; - ExpandANormalForm(op, pre_visit, post_visit); + ExpandANormalForm(let_node, pre_visit, post_visit); + } + + void VisitExpr_(const FunctionNode* function_node) override { + ExprVisitor::VisitExpr_(function_node); + MaybeAddPassForFunction(function_node); } - void VisitExpr_(const FunctionNode* func) override { - ExprVisitor::VisitExpr_(func); - if (!func->GetAttr(attr::kCompiler).defined()) { + /*! + * \brief If \p function_node has a "Compiler" attribute, checks if we should include a + * matching custom pass. Otherwise no-op. + */ + void MaybeAddPassForFunction(const FunctionNode* function_node) { + Optional opt_compiler = function_node->GetAttr(attr::kCompiler); + if (!opt_compiler) { + // No external codegen required. return; } - String code_gen_name = func->GetAttr(attr::kCompiler).value(); - Optional target_kind = tvm::TargetKind::Get(code_gen_name); - if (!target_kind || !target_attr_map_.count(target_kind.value())) { + // First cross-over: use "Compiler" attribute name as target kind. + std::string kind_name = opt_compiler.value(); + Optional opt_target_kind = tvm::TargetKind::Get(kind_name); + if (!opt_target_kind || !target_attr_map_.count(opt_target_kind.value())) { + // Target kind does not exist or have the "RelayToTIR" attribute, no custom pass to consider. return; } - Pass custom_target_pass = target_attr_map_[target_kind.value()]; - if (std::find(pass_list_.begin(), pass_list_.end(), custom_target_pass) == pass_list_.end()) { - pass_list_.push_back(custom_target_pass); + if (!seen_kinds_.emplace(kind_name).second) { + // Already accounted for custom pass. + return; } + // Second (optional) cross-over: find unique Target instance in overall available targets with + // the same kind so that it can be made available when custom pass is invoked. + Optional opt_target = config_->FindPrimitiveTargetForKind(opt_compiler.value()); + Pass custom_target_pass = target_attr_map_[opt_target_kind.value()]; + custom_passes_.emplace_back(std::move(kind_name), std::move(custom_target_pass), + std::move(opt_target)); } + + /*! \brief IRModule we are visiting. */ + IRModule mod_; + /*! \brief All available targets. */ + CompilationConfig config_; + /*! \brief Cached attribute map for all registered targets */ + TargetKindAttrMap target_attr_map_; + /*! \brief Which target kind names have already contributed to the custom passes list. */ + std::unordered_set seen_kinds_; + /*! + * \brief All the custom passes to run, paired with their corresponding target instances, if any. + */ + std::vector custom_passes_; }; -Pass RelayToTIRTargetHook() { - auto pass_func = [=](IRModule mod, const PassContext& pass_ctx) { - auto target_hook_visitor = TargetHookVisitor(); - std::vector pass_list = target_hook_visitor.Visit(mod); - Sequential run_hooks(pass_list); +} // namespace - return run_hooks(mod); +Pass RelayToTIRTargetHook(CompilationConfig config) { + auto pass_func = [config = std::move(config)](IRModule mod, const PassContext& pass_ctx) { + VLOG(1) << "Before:" << std::endl << PrettyPrint(mod); + TargetHookVisitor target_hook_visitor(mod, config); + std::vector custom_passes = target_hook_visitor.Visit(); + for (const auto& custom_pass : custom_passes) { + if (custom_pass.opt_target.defined()) { + VLOG(0) << "Invoking custom pass for target " + << custom_pass.opt_target.value()->ToDebugString(); + // Push the target on the stack. + With with_target(custom_pass.opt_target.value()); + // Invoke the pass with target in scope. + mod = custom_pass.pass(mod); + } else { + // Invoke the pass. + VLOG(0) << "Invoking custom pass for target kind '" << custom_pass.target_kind_name << "'"; + mod = custom_pass.pass(mod); + } + } + VLOG(1) << "After:" << std::endl << PrettyPrint(mod); + return mod; }; return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIRTargetHook", {}); } diff --git a/src/target/target.cc b/src/target/target.cc index 75126ed11c70..3cdfa0cc0d5e 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -495,8 +495,12 @@ Target::Target(TargetKind kind, Optional host, String tag, Array attr_map = TargetKind::GetAttrMap(::tvm::attr::kIsExternalCodegen); - return attr_map.get(get()->kind, Bool(false)); + TargetKindAttrMap is_external_codegen_map = + TargetKind::GetAttrMap(tvm::attr::kIsExternalCodegen); + TargetKindAttrMap relay_to_tir_map = + TargetKind::GetAttrMap(tvm::attr::kRelayToTIR); + return is_external_codegen_map.get(get()->kind, Bool(false)) || + relay_to_tir_map.count(get()->kind); } bool Target::IsExternalCodegenFor(const Target& that) const { diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index b657ac0c5783..2c85e47e7fb8 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -144,16 +145,21 @@ TVM_REGISTER_TARGET_KIND("test_external_codegen_1", kDLCUDA) TVM_REGISTER_TARGET_KIND("test_external_codegen_2", kDLMetal) .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); +TVM_REGISTER_TARGET_KIND("test_external_codegen_3", kDLCPU) + .set_attr(tvm::attr::kRelayToTIR, tvm::relay::transform::InferType()); + TEST(Target, ExternalCodegen) { Target regular("cuda"); Target external0("test_external_codegen_0"); Target external1("test_external_codegen_1"); Target external2("test_external_codegen_2"); + Target external3("test_external_codegen_3"); ASSERT_FALSE(regular.IsExternalCodegen()); ASSERT_TRUE(external0.IsExternalCodegen()); ASSERT_TRUE(external1.IsExternalCodegen()); ASSERT_TRUE(external2.IsExternalCodegen()); + ASSERT_TRUE(external3.IsExternalCodegen()); ASSERT_TRUE(external0.IsExternalCodegenFor(regular)); ASSERT_FALSE(regular.IsExternalCodegenFor(external0)); diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 41123a254825..dbc5147e2030 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -6653,4 +6653,4 @@ def verify_LinearRegressor(a_shape, c_shape, i_shape, targets=1, batch=1): if __name__ == "__main__": - pytest.main([__file__]) + tvm.testing.main() diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py index a017762ce35d..690ddcac8d51 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level2.py +++ b/tests/python/relay/dyn/test_dynamic_op_level2.py @@ -208,6 +208,4 @@ def verify_pad_default_fill(dshape, pad_width, dtype): if __name__ == "__main__": - test_dyn_pad() - test_dyn_upsampling_infer_type_const() - test_dyn_upsampling_run() + tvm.testing.main() diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index c5a9041b15fe..4f451a125184 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -31,6 +31,8 @@ set_external_func_attr, parametrize_external_codegen_checks, parametrize_external_json_codegen_checks, + check_graph_executor_result, + check_vm_result, ) @@ -180,6 +182,58 @@ def test_extern_gcc(check_result): check_result(mod, inputs, (2, 2), (y_data * y_data) - (x_data + x_data)) +# TODO(mbs): The check_aot_executor_result does not support the list-of-targets, mostly because +# tvm.testing.aot.compile_and_run requires the target to be a kind name string, and +# tvm.testing.aot.compile_models requires a single Target object. However, code outside of +# tvm.testing.aot is ready for this more general form. +@pytest.mark.parametrize("check_result", [check_graph_executor_result, check_vm_result]) +def test_extern_gcc_with_target_instance(check_result): + shape = (8, 8) + dtype = "int32" + + def make_mod(): + x0 = relay.var("x0", shape=shape, dtype=dtype) + y0 = relay.var("y0", shape=shape, dtype=dtype) + z = x0 + y0 + f = relay.Function([x0, y0], z) + f = set_external_func_attr(f, "ccompiler", "ccompiler_0") + x = relay.var("x", shape=shape, dtype=dtype) + y = relay.var("y", shape=shape, dtype=dtype) + call = relay.Call(f, [x, y]) + return tvm.IRModule.from_expr(call) + + host_target = tvm.target.Target("llvm") + generic_target = tvm.target.Target("llvm", host=host_target) + # The header attribute is just whitespace, so compilation is as usual. + good_extern_codegen_target = tvm.target.Target( + {"kind": "ccompiler", "header": "// Good"}, host=host_target + ) + # The header attribute is ill-formed, so compilation is expected to fail. + bogus_extern_codegen_target = tvm.target.Target( + {"kind": "ccompiler", "header": "Bogus"}, host=host_target + ) + + mod = make_mod() + + x_data = np.random.rand(*shape).astype(dtype) + y_data = np.random.rand(*shape).astype(dtype) + expected_result = x_data + y_data + inputs = {"x": x_data, "y": y_data} + + check_result( + mod, inputs, shape, expected_result, target=[generic_target, good_extern_codegen_target] + ) + + with pytest.raises(RuntimeError): + check_result( + mod, + inputs, + shape, + expected_result, + target=[generic_target, bogus_extern_codegen_target], + ) + + @pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now") def test_extern_gcc_consts(): @tvm._ffi.register_func("relay.ext.ccompiler.constant_updater") diff --git a/tests/python/relay/test_target_hooks.py b/tests/python/relay/test_target_hooks.py index 22b3b8cb3063..046b2c7e541d 100644 --- a/tests/python/relay/test_target_hooks.py +++ b/tests/python/relay/test_target_hooks.py @@ -18,19 +18,25 @@ import sys import numpy as np import pytest +import logging +import tvm import tvm.testing from tvm import relay, IRModule from utils.external_codegen import ( + parametrize_external_codegen_checks, set_external_func_attr, check_aot_executor_result, check_graph_executor_result, + check_vm_result, ) +logging.basicConfig(level=logging.INFO) -@pytest.mark.parametrize("check_result", [check_aot_executor_result, check_graph_executor_result]) -def test_tir_external_generation(check_result): + +@parametrize_external_codegen_checks +def test_tir_external_generation_inline_without_target_instance(check_result): shape = (8,) x_data = np.random.randint(255, size=shape).astype("float32") y_data = np.random.randint(255, size=shape).astype("float32") @@ -50,6 +56,49 @@ def test_tir_external_generation(check_result): check_result(func, inputs, (8,), x_data - y_data) +# TODO(mbs): The check_aot_executor_result does not support list-of-targets, mostly because +# tvm.testing.aot.compile_and_run requires the target to be a kind name string, and +# tvm.testing.aot.compile_models requires a single Target object. However, code outside of +# tvm.testing.aot is ready for this more general form. +@pytest.mark.parametrize("check_result", [check_graph_executor_result, check_vm_result]) +def test_tir_external_generation_outline_with_target_instance(check_result): + shape = (8,) + x_data = np.random.randint(255, size=shape).astype("float32") + y_data = np.random.randint(255, size=shape).astype("float32") + inputs = {"x": x_data, "y": y_data} + # Compile with an instance of the hooked target kind to demonstrate plumbing target attributes + # into custom passes. + host_target = tvm.target.Target("llvm") + generic_target = tvm.target.Target("llvm", host=host_target) + extern_codegen_target = tvm.target.Target( + "example_target_hook -example_attribute=42", host=host_target + ) + mod = tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(8), float32], %y: Tensor[(8), float32]) -> Tensor[(8), float32] { + @replace_add_with_subtract(%x, %y) * 2.0f + } + + def @replace_add_with_subtract(%x: Tensor[(8), float32], %y: Tensor[(8), float32], + Inline=1, + Primitive=1, + Compiler="example_target_hook", + global_symbol="replace_add_with_subtract") -> Tensor[(8), float32] { + %x + %y // will be rewritten to TIR implementing %x - %y - 42.0f by custom pass + } + """ + ) + + check_result( + mod, + inputs, + (8,), + (x_data - y_data - 42.0) * 2.0, + target=[generic_target, extern_codegen_target], + ) + + @pytest.mark.parametrize("check_result", [check_aot_executor_result, check_graph_executor_result]) def test_runtime_module_generation(check_result): shape = (8,) diff --git a/tests/python/relay/utils/external_codegen.py b/tests/python/relay/utils/external_codegen.py index 6d3d917ff5a2..8e5ab803de7a 100644 --- a/tests/python/relay/utils/external_codegen.py +++ b/tests/python/relay/utils/external_codegen.py @@ -22,7 +22,7 @@ import pytest import tvm -from tvm import relay, runtime +from tvm import relay, runtime, testing from tvm.contrib import utils