From 4fab5ce11be63704f635f9104ffe8b252250d7a1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 24 Mar 2023 15:51:46 -0500 Subject: [PATCH 01/19] [Driver] Single-module lowering flow in driver_api.cc Prior to this commit, a build that used multiple targets needed to provide `tvm::build` with a `Map` specifying which target should be used to compile each `IRModule`. As a result, lowering passes could not introduce new targets based on a PrimFunc's content (e.g. a `with T.target()` frame to delegate out to another device), nor simplify based on cross-device subroutines (e.g. simplify a host-side conditional based on the known output of a device-side internal subroutine). This commit makes the `tvm::attr::kTarget` attribute (`"target"`) be the single source of truth for where a `PrimFunc` will be executed. Other existing methods for specifying the target (the `target` parameter for `tvm.build`, the keys in a `Map`, the parameter to the pass `tir::transform::BindTarget`) are still accepted as inputs, and may provide a default value for `tvm::attr::kTarget` if the attribute is missing, but may not overwrite the target attribute. This is part of a series of commits to simplify the handling of multi-target builds. --- include/tvm/driver/driver_api.h | 3 +- src/driver/driver_api.cc | 215 +++++++++++------- .../example_target_hooks/relay_to_tir.cc | 2 +- 3 files changed, 133 insertions(+), 87 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index fffcab49667c..14ea5119e0e5 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -54,7 +54,8 @@ using tvm::transform::Pass; * \param target The device Target. * \return The composite Pass for the fused module. // */ -TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target); +TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, + Optional target = NullOpt); /*! * \brief Configures and returns the composite Pass for the device Target after device/host from diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index d46fab716814..d9f098441d6c 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -276,17 +276,6 @@ Array CreatePassList(bool disable_loop_partition) { return pass_list; } -IRModule LowerWithPassList(IRModule mod, Array pass_list) { - auto optimize = tvm::transform::Sequential(pass_list); - mod = optimize(std::move(mod)); - return mod; -} - -IRModule ApplyPasses(IRModule mod, transform::Sequential seq) { - mod = seq(std::move(mod)); - return mod; -} - // Convert te schedule to IRModule IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, @@ -339,7 +328,8 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") IRModule LowerModule(IRModule mod, bool simple_mode) { Array pass_list = CreatePassList(simple_mode); - return LowerWithPassList(std::move(mod), pass_list); + tvm::transform::Sequential optimize(pass_list, "tvm.lower"); + return optimize(std::move(mod)); } TVM_REGISTER_GLOBAL("driver.lower_module").set_body_typed([](IRModule mod, bool simple_mode) { @@ -356,10 +346,7 @@ IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_ f = WithAttr(std::move(f), "tir.noalias", Bool(true)); } IRModule mod = IRModule(Map({{GlobalVar(name), f}})); - - // Get the pass list - Array pass_list = CreatePassList(simple_mode); - return LowerWithPassList(std::move(mod), pass_list); + return LowerModule(mod, simple_mode); } TVM_REGISTER_GLOBAL("driver.lower_primfunc") @@ -381,9 +368,7 @@ IRModule LowerSchedule(te::Schedule sch, const Array& args, const std const std::unordered_map& binds, GlobalVarSupply global_var_supply, bool simple_mode) { IRModule mod = ScheduleToModule(std::move(sch), args, name, binds, global_var_supply); - // Get the legacy TE pass list - Array pass_list = CreatePassList(simple_mode); - return LowerWithPassList(mod, pass_list); + return LowerModule(mod, simple_mode); } TVM_REGISTER_GLOBAL("driver.lower_schedule") @@ -400,35 +385,42 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule") simple_mode); }); -/** - * This function takes the input module that contains both the device and host opts. - * Then, it applies transformation on the original module before splitting into separate modules for - * device and host. Then it also applies transformations on the new splitted modules. - */ -std::pair SplitMixedModule(IRModule mod_mixed, const Target& target_arg, - const Target& target_host_arg) { - Target target = target_arg, target_host = target_host_arg; - CheckAndUpdateHostConsistency(&target, &target_host); - - ICHECK(mod_mixed.defined()) << "This module must be defined"; +IRModule MergeModules(const Map& inputs) { + if (inputs.size() == 1) { + auto [target, mod] = *inputs.begin(); + return tir::transform::BindTarget(target)(mod); + } - mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target)); + // Take the attrs from the first module so the eventual modules have them. + IRModule first_module = (*inputs.begin()).second; + IRModule merged = IRModule(Map(), {}, {}, {}, first_module->attrs); - IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host)); + for (auto [target, mod] : inputs) { + mod = tir::transform::BindTarget(target)(mod); + merged->Update(mod); + } - IRModule device_mod = ApplyPasses(mod_mixed, DeviceModulePassManager(mod_mixed, target)); + return merged; +} - auto keys = target->GetKeys(); +Map SplitModule(const IRModule& module) { + Map split; - CheckAndUpdateHostConsistency(&target, &target_host); + for (auto [gvar, base_func] : module->functions) { + auto target_str = base_func->GetAttr(tvm::attr::kTarget).value()->str(); + if (auto it = split.find(target_str); it != split.end()) { + (*it).second->Add(gvar, base_func); + } else { + split.Set(target_str, IRModule({{gvar, base_func}}, {}, {}, {}, module->attrs)); + } + } - bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end(); - if (target_is_gpu && device_mod->functions.size() == 0) { - DLOG(WARNING) << "Specified target " << target->str() - << " but cannot find device code. Did you forget to bind?"; + Map out; + for (auto [str, mod] : split) { + out.Set(Target(str), mod); } - return {host_mod, device_mod}; + return out; } runtime::Module TIRToRuntime(const Map& inputs_arg, @@ -457,52 +449,74 @@ runtime::Module TIRToRuntime(const Map& inputs_arg, // Update target host for all targets CheckAndUpdateHostConsistency(&inputs, &target_host); - // Take the attrs from the first module so the eventual modules have them. - // Ideally this would just be one unified module all the way through; - IRModule first_module = (*inputs.begin()).second; - IRModule mhost_all = IRModule(Map(), {}, {}, {}, first_module->attrs); - - ICHECK(mhost_all.defined()) << "The host module must be defined"; - - for (const auto& it : inputs) { - if (it.second.defined()) { - const Target& target = it.first; - const IRModule& ir_module = it.second; - auto pair = SplitMixedModule(ir_module, target, target_host); - auto& host_mod = pair.first; - auto& device_mod = pair.second; - - ICHECK(host_mod.defined()) << "The split host module must be defined"; - - ICHECK(mhost_all.defined()) << "The host module must be defined"; - - // We don't want library modules going back into host codegen - // unless they're supposed to. Here if we overrode the target host - // to allow lowering previously we check that it's meant to be placed - // back into the host Module. - bool overrides_host_target = - target->GetTargetDeviceType() == target_host->GetTargetDeviceType(); - bool non_host_target_kind = target->kind != target_host->kind; - if (overrides_host_target && non_host_target_kind) { - device_modules.push_back(codegen::Build(host_mod, it.first)); - } else { - mhost_all->Update(host_mod); + auto has_gpu_function = [](const IRModule& mod) -> bool { + for (const auto& [gvar, func] : mod->functions) { + if (auto target = func->GetAttr(tvm::attr::kTarget)) { + if (target.value()->HasKey("gpu")) { + return true; + } } + } + return false; + }; + + IRModule merged = MergeModules(inputs); + + bool contains_gpu_function_pre = has_gpu_function(merged); + merged = MixedModulePassManager(merged)(merged); + bool contains_gpu_function_post = has_gpu_function(merged); + if (contains_gpu_function_pre && !contains_gpu_function_post) { + DLOG(WARNING) << "Specified GPU targets, " + << "but cannot find device code. Did you forget to bind?"; + } + + Map split = SplitModule(merged); - if (device_mod->functions.size() != 0) { - device_modules.push_back(codegen::Build(device_mod, it.first)); + Map built; + for (const auto& [target, mod] : split) { + built.Set(target, codegen::Build(mod, target)); + } + + auto host_target = [&]() -> Target { + Array targets_with_entry_func; + Array cpu_targets; + for (const auto& [target, mod] : split) { + bool contains_entry_func = std::any_of( + mod->functions.begin(), mod->functions.end(), + [](const auto& kv) { return kv.second->HasNonzeroAttr(tvm::tir::attr::kIsEntryFunc); }); + if (contains_entry_func) { + targets_with_entry_func.push_back(target); + } + + if (target->HasKey("cpu")) { + cpu_targets.push_back(target); } } - } - runtime::Module mhost = codegen::Build(mhost_all, target_host); - for (const auto& it : device_modules) { - if (it.operator->()) { - mhost.Import(it); + if (targets_with_entry_func.size()) { + ICHECK_EQ(targets_with_entry_func.size(), 1) + << "Expected at most one function " + << "annotated with tvm::tir::attr::kIsEntryFunc " + << "(\"" << tvm::tir::attr::kIsEntryFunc << "\"), " + << "but found: " << targets_with_entry_func; + return targets_with_entry_func[0]; + } else if (cpu_targets.size() == 1) { + return cpu_targets[0]; + } else { + LOG(FATAL) << "Could not determine which target is the host. " + << "No function was annotated with tvm::tir::attr::kIsEntryFunc (\"" + << tvm::tir::attr::kIsEntryFunc << "\"), " + << "and " << cpu_targets.size() << " targets have the 'cpu' key"; } - } + }(); - return mhost; + auto runtime_module = built[host_target]; + for (const auto& [target, mod] : built) { + if (!mod.same_as(runtime_module)) { + runtime_module.Import(mod); + } + } + return runtime_module; } TVM_REGISTER_GLOBAL("driver.tir_to_runtime") @@ -543,18 +557,20 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, return TIRToRuntime(inputs, target_host); } -transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) { +transform::Sequential MixedModulePassManager(IRModule mixed_mod, Optional target) { transform::PassContext pass_ctx = transform::PassContext::Current(); Array mixed_pass_list; + if (target) { + mixed_pass_list.push_back(tir::transform::BindTarget(target.value())); + } + // VerifyVTCMLimit must occur before LowerVtcmAlloc mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target)); // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc()); - mixed_pass_list.push_back(tir::transform::BindTarget(target)); - mixed_pass_list.push_back(tir::transform::VerifyMemory()); mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc()); @@ -600,7 +616,28 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch()); - return transform::Sequential(mixed_pass_list); + // Only applies to the device functions, identified by inspection of + // each function's tvm::attr::kTarget attribute. + mixed_pass_list.push_back(tir::transform::LowerWarpMemory()); + + // Only applies to the host functions, identified by inspection of + // each function's tvm::attr::kTarget attribute. + mixed_pass_list.push_back(tir::transform::LowerTVMBuiltin()); + + // Apply to both host and device functions + mixed_pass_list.push_back(tir::transform::Simplify()); + mixed_pass_list.push_back(tir::transform::LowerCustomDatatypes()); + mixed_pass_list.push_back(tir::transform::LowerIntrin()); + mixed_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); + + // Only applies to the host functions, identified by inspection of + // each function's tvm::attr::kTarget attribute. + mixed_pass_list.push_back(tir::transform::CombineContextCall()); + if (pass_ctx->GetConfig("tir.enable_debug", Bool(false)).value()) { + mixed_pass_list.push_back(tir::transform::InstallDebugSpans()); + } + + return transform::Sequential(mixed_pass_list, "tvm.build"); } TVM_REGISTER_GLOBAL("driver.mixed_mod_passes") @@ -609,6 +646,10 @@ TVM_REGISTER_GLOBAL("driver.mixed_mod_passes") }); transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) { + LOG(WARNING) << "Use of driver.host_mod_passes is deprecated. " + << "All lowering passes are now included " + << "as part of driver.mixed_mod_passes."; + transform::PassContext pass_ctx = transform::PassContext::Current(); bool enable_debug = pass_ctx->GetConfig("tir.enable_debug", Bool(false)).value(); @@ -634,7 +675,7 @@ transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_ho host_pass_list.push_back(tir::transform::InstallDebugSpans()); } - return transform::Sequential(host_pass_list); + return transform::Sequential(host_pass_list, "tir.host_mod_passes"); } TVM_REGISTER_GLOBAL("driver.host_mod_passes") @@ -643,6 +684,10 @@ TVM_REGISTER_GLOBAL("driver.host_mod_passes") }); transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) { + LOG(WARNING) << "Use of driver.device_mod_passes is deprecated. " + << "All lowering passes are now included " + << "as part of driver.mixed_mod_passes."; + Array device_pass_list; runtime::TypedPackedFunc fcond = [](const tir::PrimFunc& f) { return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == @@ -658,7 +703,7 @@ transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); device_pass_list.push_back(tir::transform::LowerIntrin()); - return transform::Sequential(device_pass_list); + return transform::Sequential(device_pass_list, "tir.device_mod_passes"); } TVM_REGISTER_GLOBAL("driver.device_mod_passes") diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index 2b037181653c..90c0fd41dc7c 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -64,7 +64,7 @@ class ConvertAddToSubtract : public MixedModeMutator { explicit ConvertAddToSubtract(IRModule ir_module, Target host_target) : ir_module_(ir_module), host_target_(host_target), - custom_target_(Target("example_target_hook")) {} + custom_target_(Target(Target("example_target_hook"), Target("example_target_hook"))) {} IRModule Mutate() { GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); From b43209a08b151a35e43f29c9a4695c84ef8d8499 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 6 Jun 2023 09:19:49 -0500 Subject: [PATCH 02/19] Clarify behavior of emit_fwd_func_decl_ --- src/target/source/codegen_c_host.h | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index aeba685f7422..e597d3ac9ab6 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -89,7 +89,15 @@ class CodeGenCHost : public CodeGenC { Array function_names_; /*! \brief whether to emit asserts in the resulting C code */ bool emit_asserts_; - /*! \brief whether to emit forwared function declarations in the resulting C code */ + /*! \brief whether to emit forwared function declarations in the resulting C code + * + * Determines the behavior when encountering an unknown symbol as + * the callee in a `CallNode` whose operation is + * `builtin::call_extern`. If true, the unknown symbol will be + * forward-declared as a function, derived from the TIR types of + * CallNode's argument/return value. If false, the forward + * declaration is omitted. + */ bool emit_fwd_func_decl_; FunctionInfo GetFunctionInfo(const CallNode* op, bool has_resource_handle); From 6ac4259d577cd70cea1134fe5cfb382d2c922105 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 6 Jun 2023 09:21:15 -0500 Subject: [PATCH 03/19] Annotate entire body as target region if no subregions found Otherwise, in cases of a custom codegen, the device specification may be dropped entirely. --- src/tir/transforms/annotate_device_regions.cc | 32 ++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/src/tir/transforms/annotate_device_regions.cc b/src/tir/transforms/annotate_device_regions.cc index a81af7d7805b..56835e2b73d4 100644 --- a/src/tir/transforms/annotate_device_regions.cc +++ b/src/tir/transforms/annotate_device_regions.cc @@ -34,16 +34,37 @@ namespace tir { class DeviceRegionAnnotater : public StmtMutator { public: + static Stmt Apply(Target host_target, Target device_target, Stmt body) { + DeviceRegionAnnotater mutator(device_target); + body = mutator(body); + + bool same_host_and_device = host_target->str() == device_target->str(); + + // If no region was found that must be on the device, but the + // device and host differ (e.g. `T.target('c', host='llvm')`), + // then the entire region should be annotated. This preserves the + // host-side handling of DLTensor arguments, while ensuring that + // any device targets are used for the codegen. + if (!mutator.found_target_region_ && !same_host_and_device) { + body = AttrStmt(device_target, tvm::attr::kTarget, 0, body); + } + + return body; + } + + private: explicit DeviceRegionAnnotater(Target device_target) : device_target_(device_target) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == tvm::attr::kTarget) { // If a target attribute already exists, use it as-is. + found_target_region_ = true; return GetRef(op); } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope || op->attr_key == attr::device_scope) { // These attributes are only allowed in device-side code, so // they should be annotated with the function's default target. + found_target_region_ = true; Stmt body = GetRef(op); return AttrStmt(device_target_, tvm::attr::kTarget, 0, body); } else { @@ -52,8 +73,8 @@ class DeviceRegionAnnotater : public StmtMutator { } } - private: Target device_target_; + bool found_target_region_{false}; }; namespace transform { @@ -64,9 +85,12 @@ Pass AnnotateDeviceRegions() { ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target attribute"; Target target = opt_target.value(); - if (target->GetHost()) { - DeviceRegionAnnotater mutator(target.WithoutHost()); - func.CopyOnWrite()->body = mutator(func->body); + if (auto opt_host = target->GetHost()) { + auto new_body = + DeviceRegionAnnotater::Apply(opt_host.value(), target.WithoutHost(), func->body); + if (!new_body.same_as(func->body)) { + func.CopyOnWrite()->body = new_body; + } } return func; }; From 0f1d17f40655c2934bedfecb3a108c5b53cfa634 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 12 Jun 2023 10:05:34 -0500 Subject: [PATCH 04/19] Better detection of host target --- src/driver/driver_api.cc | 20 +++++++++++++++---- src/tir/transforms/split_host_device.cc | 3 ++- .../test_tir_transform_split_host_device.py | 4 ++++ 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index d9f098441d6c..b154c89c7d35 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -478,17 +478,29 @@ runtime::Module TIRToRuntime(const Map& inputs_arg, } auto host_target = [&]() -> Target { + // All targets that contain a kIsEntryFunc=True function Array targets_with_entry_func; + + // All targets that can run on the CPU and contain at least one + // function without kIsEntryFunc=False. Array cpu_targets; for (const auto& [target, mod] : split) { - bool contains_entry_func = std::any_of( - mod->functions.begin(), mod->functions.end(), - [](const auto& kv) { return kv.second->HasNonzeroAttr(tvm::tir::attr::kIsEntryFunc); }); + bool contains_entry_func = false; + bool may_contain_entry_func = false; + for (const auto& [gvar, func] : mod->functions) { + Optional is_entry_func = func->attrs.GetAttr(tvm::tir::attr::kIsEntryFunc); + if (is_entry_func.defined() && is_entry_func.value()->value) { + contains_entry_func = true; + } else if (!is_entry_func.defined()) { + may_contain_entry_func = true; + } + } + if (contains_entry_func) { targets_with_entry_func.push_back(target); } - if (target->HasKey("cpu")) { + if (may_contain_entry_func && target->HasKey("cpu")) { cpu_targets.push_back(target); } } diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 9b1dbf1a6618..faaab47f1517 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -95,7 +95,8 @@ class HostDeviceSplitter : public StmtMutator { PrimFunc device_func(params, body, kernel_ret_type); device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target}, {tir::attr::kNoAlias, Bool(true)}, - {tir::attr::kIsGlobalFunc, Bool(true)}}); + {tir::attr::kIsGlobalFunc, Bool(true)}, + {tir::attr::kIsEntryFunc, Bool(false)}}); (*device_mod_)->Add(kernel_symbol_global, device_func); Array args = params.Map([](const Var& var) -> PrimExpr { return var; }); diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py b/tests/python/unittest/test_tir_transform_split_host_device.py index b61fcc66014e..50fb750f28a0 100644 --- a/tests/python/unittest/test_tir_transform_split_host_device.py +++ b/tests/python/unittest/test_tir_transform_split_host_device.py @@ -122,6 +122,7 @@ def main_kernel(n: T.int32): "target": T.target("cuda"), "tir.noalias": T.bool(True), "tir.is_global_func": True, + "tir.is_entry_func": False, } ) T.evaluate(n) @@ -159,6 +160,7 @@ def main_kernel(n: T.int32) -> T.int32: "target": T.target("llvm"), "tir.noalias": T.bool(True), "tir.is_global_func": True, + "tir.is_entry_func": False, } ) T.evaluate(n) @@ -200,6 +202,7 @@ def main_kernel(n: T.int32): "target": T.target("cuda"), "tir.noalias": T.bool(True), "tir.is_global_func": True, + "tir.is_entry_func": False, } ) T.evaluate(n) @@ -261,6 +264,7 @@ def main_kernel_1(n: T.int32): "target": T.target("cuda"), "tir.noalias": T.bool(True), "tir.is_global_func": True, + "tir.is_entry_func": False, } ) T.evaluate(n) From 406e9739602eb5d677c81defab43914c79474bf3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 15 Jun 2023 08:55:09 -0500 Subject: [PATCH 05/19] Allow "ext_dev" to act as host. Currently, the Target does two independent tasks: (1) defining which device owns the buffers that are passed as input to a PrimFunc, and (2) defining which codegen will be used for a PrimFunc. Prior to this commit, the "ext_dev" target was required to define the device ownership, but did not provide the `"target.build.ext_dev"` function that is required for codegen. This worked, because `SplitHostDevice` would remove the `"ext_dev"` target without making a device-side function. With the single-module lowering flow, the separate device-side function is required to support UMA codegen. To resolve this issue, `"ext_dev"` now provides a codegen function, which is identical to the LLVM codegen. This may be improved in the future by allowing the buffer device and the codegen to be specified independently. --- apps/extension/tests/test_ext.py | 2 +- src/target/llvm/llvm_module.cc | 16 ++++++++++------ src/target/target_kind.cc | 2 +- src/tir/transforms/lower_intrin.cc | 4 ++++ 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/apps/extension/tests/test_ext.py b/apps/extension/tests/test_ext.py index 994a673298f1..d387263a06a8 100644 --- a/apps/extension/tests/test_ext.py +++ b/apps/extension/tests/test_ext.py @@ -39,7 +39,7 @@ def test_ext_dev(): def check_llvm(): if not tvm.testing.device_enabled("llvm"): return - f = tvm.build(s, [A, B], "ext_dev", "llvm") + f = tvm.build(s, [A, B], "ext_dev", "ext_dev") dev = tvm.ext_dev(0) # launch the kernel. a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 85750fbf146e..ae465c823c44 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -444,12 +444,16 @@ void* LLVMModuleNode::GetFunctionAddr(const std::string& name, } } -TVM_REGISTER_GLOBAL("target.build.llvm") - .set_body_typed([](IRModule mod, Target target) -> runtime::Module { - auto n = make_object(); - n->Init(mod, target); - return runtime::Module(n); - }); +namespace { +runtime::Module BuildLLVM(IRModule mod, Target target) { + auto n = make_object(); + n->Init(mod, target); + return runtime::Module(n); +} +} // namespace + +TVM_REGISTER_GLOBAL("target.build.llvm").set_body_typed(BuildLLVM); +TVM_REGISTER_GLOBAL("target.build.ext_dev").set_body_typed(BuildLLVM); TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") .set_body_typed([](std::string target_str, std::string module_name) -> runtime::Module { diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 56066fcfb6ab..984f84975f34 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -450,7 +450,7 @@ TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon) TVM_REGISTER_TARGET_KIND("stackvm", kDLCPU) // line break .set_default_keys({"cpu"}); -TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev); +TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev).set_default_keys({"cpu"}); TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU); diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 212ccf6e5616..fbc5d4fda92d 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -44,6 +44,10 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "") : IRMutatorWithAnalyzer(analyzer) { + if (target == "ext_dev") { + target = "llvm"; + } + std::vector patterns; patterns.push_back(target + ".FLowerIntrinsic"); patterns.push_back(target + ".FLegalize"); From 77ade9d71ccae546575ce9a94322ad5ce515b8f2 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 16 Jun 2023 14:38:53 -0500 Subject: [PATCH 06/19] Improved device region annotation Previously, if no device-specific attribute is found, assume that the entire function should be executed on the device. Now, identify host-specific Call (e.g. `builtin::call_packed()`) and ensure these remain on the host. --- src/tir/transforms/annotate_device_regions.cc | 92 +++++++++++++++++-- ...t_tir_transform_annotate_device_regions.py | 71 ++++++++++++++ 2 files changed, 155 insertions(+), 8 deletions(-) diff --git a/src/tir/transforms/annotate_device_regions.cc b/src/tir/transforms/annotate_device_regions.cc index 56835e2b73d4..087fba20586e 100644 --- a/src/tir/transforms/annotate_device_regions.cc +++ b/src/tir/transforms/annotate_device_regions.cc @@ -29,23 +29,32 @@ #include #include +#include +#include +#include + namespace tvm { namespace tir { -class DeviceRegionAnnotater : public StmtMutator { +class DeviceRegionAnnotater : public StmtExprMutator { + using Parent = StmtExprMutator; + public: static Stmt Apply(Target host_target, Target device_target, Stmt body) { + bool same_host_and_device = host_target->str() == device_target->str(); + if (same_host_and_device) { + return body; + } + DeviceRegionAnnotater mutator(device_target); body = mutator(body); - bool same_host_and_device = host_target->str() == device_target->str(); - // If no region was found that must be on the device, but the // device and host differ (e.g. `T.target('c', host='llvm')`), // then the entire region should be annotated. This preserves the // host-side handling of DLTensor arguments, while ensuring that // any device targets are used for the codegen. - if (!mutator.found_target_region_ && !same_host_and_device) { + if (mutator.current_region_ == Region::Either && !same_host_and_device) { body = AttrStmt(device_target, tvm::attr::kTarget, 0, body); } @@ -58,23 +67,90 @@ class DeviceRegionAnnotater : public StmtMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == tvm::attr::kTarget) { // If a target attribute already exists, use it as-is. - found_target_region_ = true; + current_region_ = Region::Device; return GetRef(op); } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope || op->attr_key == attr::device_scope) { // These attributes are only allowed in device-side code, so // they should be annotated with the function's default target. - found_target_region_ = true; + current_region_ = Region::Device; Stmt body = GetRef(op); return AttrStmt(device_target_, tvm::attr::kTarget, 0, body); } else { // All other annotations are ignored - return StmtMutator::VisitStmt_(op); + return Parent::VisitStmt_(op); + } + } + + Stmt VisitStmt_(const SeqStmtNode* op) final { + std::vector regions; + Array seq = op->seq.Map([&](Stmt stmt) { + current_region_ = Region::Either; + stmt = VisitStmt(stmt); + regions.push_back(current_region_); + return stmt; + }); + + bool has_host_function = std::any_of(regions.begin(), regions.end(), + [](const auto& reg) { return reg == Region::Host; }); + if (has_host_function) { + current_region_ = Region::Host; + + Array new_seq; + Array device_seq; + auto finish_device_seq = [&]() { + if (device_seq.size()) { + new_seq.push_back( + AttrStmt(device_target_, tvm::attr::kTarget, 0, SeqStmt::Flatten(device_seq))); + device_seq.clear(); + } + }; + + for (size_t i = 0; i < seq.size(); i++) { + if (regions[i] == Region::Host) { + finish_device_seq(); + new_seq.push_back(seq[i]); + } else { + device_seq.push_back(seq[i]); + } + } + finish_device_seq(); + + return SeqStmt::Flatten(new_seq); + } else if (seq.same_as(op->seq)) { + return GetRef(op); + } else { + return SeqStmt(seq); } } + PrimExpr VisitExpr_(const CallNode* op) final { + // TODO(Lunderberg): Make a new attribute in builtin.cc to label + // host-only operations. + bool is_host_only_op = + op->op.same_as(builtin::tvm_call_packed()) || op->op.same_as(builtin::tvm_call_cpacked()) || + op->op.same_as(builtin::tvm_call_packed_lowered()) || + op->op.same_as(builtin::tvm_call_cpacked_lowered()) || + op->op.same_as(builtin::tvm_struct_get()) || op->op.same_as(builtin::tvm_struct_set()) || + op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::call_pure_extern()) || + op->op.same_as(builtin::tvm_throw_last_error()) || + op->op.same_as(builtin::tvm_stack_alloca()) || + op->op.same_as(builtin::tvm_stack_make_shape()) || + op->op.same_as(builtin::tvm_stack_make_array()); + if (is_host_only_op) { + current_region_ = Region::Host; + } + return Parent::VisitExpr_(op); + } + Target device_target_; - bool found_target_region_{false}; + + enum class Region { + Either, + Host, + Device, + }; + Region current_region_{Region::Either}; }; namespace transform { diff --git a/tests/python/unittest/test_tir_transform_annotate_device_regions.py b/tests/python/unittest/test_tir_transform_annotate_device_regions.py index efa43027e9c6..7b869ddf7694 100644 --- a/tests/python/unittest/test_tir_transform_annotate_device_regions.py +++ b/tests/python/unittest/test_tir_transform_annotate_device_regions.py @@ -54,5 +54,76 @@ def expected(A: T.Buffer(1, "float32")): A[0] = 0.0 +class TestAnnotateEntireBody(BaseCompare): + """Annotation inserted to wrap entire function + + Function is assumed to belong on the device. + """ + + def before(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + A[0] = 0.0 + + def expected(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.attr(T.target("cuda"), "target", 0) + A[0] = 0.0 + + +class TestNoAnnotationForSameHostDevice(BaseCompare): + """No annotation is needed if host/device are the same""" + + def before(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("llvm", host="llvm")}) + A[0] = 0.0 + + expected = before + + +class TestAnnotationAvoidsHostConstructs(BaseCompare): + """Device annotation does not contain host-only functions + + Calls that must be on the host side (e.g. T.call_packed) remain on + the host. + """ + + def before(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.call_packed("dummy_function", dtype="void") + A[0] = 0.0 + T.call_packed("dummy_function", dtype="void") + + def expected(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.call_packed("dummy_function", dtype="void") + with T.attr(T.target("cuda"), "target", 0): + A[0] = 0.0 + T.call_packed("dummy_function", dtype="void") + + +class TestAnnotationNoRepetition(BaseCompare): + """Device annotation does not contain host-only functions + + When placing everything that isn't a host-specific function into + target block, sequential device statements should be in the same + block. + """ + + def before(A: T.Buffer(2, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.call_packed("dummy_function", dtype="void") + A[0] = 0.0 + A[1] = 1.0 + T.call_packed("dummy_function", dtype="void") + + def expected(A: T.Buffer(2, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.call_packed("dummy_function", dtype="void") + with T.attr(T.target("cuda"), "target", 0): + A[0] = 0.0 + A[1] = 1.0 + T.call_packed("dummy_function", dtype="void") + + if __name__ == "__main__": tvm.testing.main() From b507879e8276e7c61b8c86942006c2eba7737748 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 21 Jun 2023 12:18:09 -0500 Subject: [PATCH 07/19] Allow call_extern from kDLCPU to kDLExtDev --- src/tir/transforms/lower_device_kernel_launch.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/lower_device_kernel_launch.cc b/src/tir/transforms/lower_device_kernel_launch.cc index 932116485fa1..a33376bd69ee 100644 --- a/src/tir/transforms/lower_device_kernel_launch.cc +++ b/src/tir/transforms/lower_device_kernel_launch.cc @@ -260,7 +260,9 @@ class DeviceKernelMutator : public StmtExprMutator { bool same_device_type = caller_target->GetTargetDeviceType() == callee_target->GetTargetDeviceType(); - if (same_device_type) { + bool linkable_module = (caller_target->GetTargetDeviceType() == kDLCPU) && + (callee_target->GetTargetDeviceType() == kDLExtDev); + if (same_device_type || linkable_module) { // Calls to another target using the same device (e.g. LLVM // calling a custom TIRToRuntime target) do not require a kernel // launch, but need to be replaced with call_extern. From 5a82fe0aeb7e44371f1fcfa494610b1244174ba7 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 21 Jun 2023 12:18:44 -0500 Subject: [PATCH 08/19] Save .so instead of .o in tutorials Since the ext_dev target may be compiled separately. --- vta/scripts/tune_resnet.py | 6 +++--- vta/tutorials/matrix_multiply.py | 6 +++--- vta/tutorials/optimize/convolution_opt.py | 6 +++--- vta/tutorials/optimize/matrix_multiply_opt.py | 6 +++--- vta/tutorials/vta_get_started.py | 6 +++--- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/vta/scripts/tune_resnet.py b/vta/scripts/tune_resnet.py index 2c284f05a0de..3f5c693b78a0 100644 --- a/vta/scripts/tune_resnet.py +++ b/vta/scripts/tune_resnet.py @@ -344,9 +344,9 @@ def tune_tasks( # Export library temp = utils.tempdir() - lib.save(temp.relpath("graphlib.o")) - remote.upload(temp.relpath("graphlib.o")) - lib = remote.load_module("graphlib.o") + lib.export_library(temp.relpath("graphlib.so")) + remote.upload(temp.relpath("graphlib.so")) + lib = remote.load_module("graphlib.so") # If detailed runtime info is needed build with debug runtime if opt.debug_profile: diff --git a/vta/tutorials/matrix_multiply.py b/vta/tutorials/matrix_multiply.py index 0d1167854458..1d1dd98dfaf3 100644 --- a/vta/tutorials/matrix_multiply.py +++ b/vta/tutorials/matrix_multiply.py @@ -392,13 +392,13 @@ # Write the compiled module into an object file. temp = utils.tempdir() -my_gemm.save(temp.relpath("gemm.o")) +my_gemm.export_library(temp.relpath("gemm.so")) # Send the executable over RPC -remote.upload(temp.relpath("gemm.o")) +remote.upload(temp.relpath("gemm.so")) # Load the compiled module -f = remote.load_module("gemm.o") +f = remote.load_module("gemm.so") ###################################################################### # Running the Function diff --git a/vta/tutorials/optimize/convolution_opt.py b/vta/tutorials/optimize/convolution_opt.py index 521a73ab510d..3c757fdc0c2b 100644 --- a/vta/tutorials/optimize/convolution_opt.py +++ b/vta/tutorials/optimize/convolution_opt.py @@ -374,9 +374,9 @@ s, [data, kernel, res], tvm.target.Target("ext_dev", host=env.target_host), name="my_conv" ) temp = utils.tempdir() -my_conv.save(temp.relpath("conv2d.o")) -remote.upload(temp.relpath("conv2d.o")) -f = remote.load_module("conv2d.o") +my_conv.export_library(temp.relpath("conv2d.so")) +remote.upload(temp.relpath("conv2d.so")) +f = remote.load_module("conv2d.so") # Get the remote device context ctx = remote.ext_dev(0) diff --git a/vta/tutorials/optimize/matrix_multiply_opt.py b/vta/tutorials/optimize/matrix_multiply_opt.py index b470475b16e7..ea70b5260c56 100644 --- a/vta/tutorials/optimize/matrix_multiply_opt.py +++ b/vta/tutorials/optimize/matrix_multiply_opt.py @@ -314,9 +314,9 @@ s, [data, weight, res], tvm.target.Target("ext_dev", host=env.target_host), name="my_gemm" ) temp = utils.tempdir() -my_gemm.save(temp.relpath("gemm.o")) -remote.upload(temp.relpath("gemm.o")) -f = remote.load_module("gemm.o") +my_gemm.export_library(temp.relpath("gemm.so")) +remote.upload(temp.relpath("gemm.so")) +f = remote.load_module("gemm.so") # Get the remote device context ctx = remote.ext_dev(0) diff --git a/vta/tutorials/vta_get_started.py b/vta/tutorials/vta_get_started.py index 3482258dece8..6edb34184fb4 100644 --- a/vta/tutorials/vta_get_started.py +++ b/vta/tutorials/vta_get_started.py @@ -327,17 +327,17 @@ # Write the compiled module into an object file. temp = utils.tempdir() -my_vadd.save(temp.relpath("vadd.o")) +my_vadd.export_library(temp.relpath("vadd.so")) # Send the executable over RPC -remote.upload(temp.relpath("vadd.o")) +remote.upload(temp.relpath("vadd.so")) ###################################################################### # Loading the Module # ~~~~~~~~~~~~~~~~~~ # We can load the compiled module from the file system to run the code. -f = remote.load_module("vadd.o") +f = remote.load_module("vadd.so") ###################################################################### # Running the Function From 4b75c77e349769325d45c5657b609653c5dcf863 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 5 Jul 2023 10:08:45 -0500 Subject: [PATCH 09/19] Mark VTA's CPU-side allocations with disable_lower_builtin Because VTA's codegen is allowed to call host-side, it uses the `"cpu"` tag. Therefore, the allocations that are already handled with `VTABufferCPUPtr` should opt-out of using the device API from `LowerTVMBuiltin`. --- vta/python/vta/transform.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index b1135c0eb007..5a26ecf83af7 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -202,7 +202,14 @@ def _post_order(op): ), op.body, ) - alloc = tvm.tir.Allocate(buffer_var, op.dtype, op.extents, op.condition, let_stmt) + alloc = tvm.tir.Allocate( + buffer_var, + op.dtype, + op.extents, + op.condition, + let_stmt, + annotations={"disable_lower_builtin": True}, + ) del var_remap[buffer_var] bufs_to_delete = [ old_buf for old_buf in buf_remap if old_buf.data.same_as(buffer_var) From 358d45909dc32374ed560c7b88c4ed49b8044ccf Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 6 Jul 2023 13:03:40 -0500 Subject: [PATCH 10/19] Use int8_t* for StringImm literals. --- src/tir/op/op.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index fd14f4892154..4181546bb8ac 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -96,6 +96,11 @@ Type GetType(const PrimExpr& expr) { return PointerType(PrimType(address->dtype)); } } + + if (expr.as()) { + return PointerType(PrimType(DataType::Int(8))); + } + // Default: return the type indicated by the dtype. runtime::DataType dtype = expr.dtype(); return GetTypeFromRuntimeDataType(dtype); From e9a4c32d7356ec292e65eaf607aefb5ebfb1d574 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 7 Jul 2023 13:41:23 -0500 Subject: [PATCH 11/19] Allow call_extern in device-side functions May be used by kernels to call device-specific intrinsics (e.g. for cmsis-nn) --- src/tir/transforms/annotate_device_regions.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/tir/transforms/annotate_device_regions.cc b/src/tir/transforms/annotate_device_regions.cc index 087fba20586e..4c5209e4b3b3 100644 --- a/src/tir/transforms/annotate_device_regions.cc +++ b/src/tir/transforms/annotate_device_regions.cc @@ -132,7 +132,6 @@ class DeviceRegionAnnotater : public StmtExprMutator { op->op.same_as(builtin::tvm_call_packed_lowered()) || op->op.same_as(builtin::tvm_call_cpacked_lowered()) || op->op.same_as(builtin::tvm_struct_get()) || op->op.same_as(builtin::tvm_struct_set()) || - op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::call_pure_extern()) || op->op.same_as(builtin::tvm_throw_last_error()) || op->op.same_as(builtin::tvm_stack_alloca()) || op->op.same_as(builtin::tvm_stack_make_shape()) || From 85b1d5e080b60623a50358a728754e7310c5aa6c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 12 Jul 2023 09:04:35 -0500 Subject: [PATCH 12/19] [EthosU] Use ethos-u as both device and host --- python/tvm/relay/backend/contrib/ethosu/tir/compiler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index d47b3d4a7de6..8f6232347059 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -209,7 +209,9 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function: primfunc = tir_mod["main"] primfunc = primfunc.with_attr("global_symbol", func.attrs["global_symbol"]) primfunc = primfunc.with_attr("ethos-u.constants", const_dict) - primfunc = primfunc.with_attr("target", tvm.target.Target(compiler_name)) + primfunc = primfunc.with_attr( + "target", tvm.target.Target(compiler_name, host=compiler_name) + ) return primfunc def __call__(self, *args, **kwargs): From 775742d34611e0c38028dc5252d7facae4f2e6dd Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 15 May 2023 16:30:33 -0500 Subject: [PATCH 13/19] [TIR] End-to-end tests for PrimFunc-to-PrimFunc subroutines The functionality tested in this commit was added across several recent PRs, each of which tested their features in isolation. This PR adds unit tests to validate the end-to-end behavior of TIR subroutine calls. PRs building up to this point: - TVMScript - https://github.com/apache/tvm/pull/14889 - https://github.com/apache/tvm/pull/14915 - https://github.com/apache/tvm/pull/14919 - https://github.com/apache/tvm/pull/14941 - Functionality improvements of existing TIR passes - https://github.com/apache/tvm/pull/14913 - https://github.com/apache/tvm/pull/14914 - https://github.com/apache/tvm/pull/14918 - https://github.com/apache/tvm/pull/14951 - Changes to the TIR lowering flow - https://github.com/apache/tvm/pull/14942 - https://github.com/apache/tvm/pull/14985 - Codegen updates - https://github.com/apache/tvm/pull/14958 - https://github.com/apache/tvm/pull/14901 - Compatibility updates/fixes - https://github.com/apache/tvm/pull/14892 - https://github.com/apache/tvm/pull/14950 - https://github.com/apache/tvm/pull/14943 - https://github.com/apache/tvm/pull/14944 - https://github.com/apache/tvm/pull/14945 - https://github.com/apache/tvm/pull/14952 - https://github.com/apache/tvm/pull/14982 - https://github.com/apache/tvm/pull/14949 --- .../unittest/test_tir_subroutine_call.py | 275 ++++++++++++++++++ 1 file changed, 275 insertions(+) create mode 100755 tests/python/unittest/test_tir_subroutine_call.py diff --git a/tests/python/unittest/test_tir_subroutine_call.py b/tests/python/unittest/test_tir_subroutine_call.py new file mode 100755 index 000000000000..a5d84abf268f --- /dev/null +++ b/tests/python/unittest/test_tir_subroutine_call.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring + +import pytest +import numpy as np + +import tvm +import tvm.testing + +from tvm.script import tir as T, ir as I + + +@tvm.testing.parametrize_targets("llvm") +def test_call_noop(target, dev): + """TIR functions on the CPU may call other functions + + The simplest test case, where the subroutine is a no-op. + """ + + @I.ir_module + class module: + @T.prim_func + def subroutine(): + T.evaluate(0) + + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main"}) + module.subroutine() + A[0] = 42.0 + + built = tvm.build(module, target=target) + + arr = tvm.nd.empty([1], dtype="float32", device=dev) + built(arr) + + assert arr.numpy()[0] == 42.0 + + +@tvm.testing.parametrize_targets("llvm") +def test_call_noop_defined_below(target, dev): + """Calling a subroutine does not depend on the definition order + + All GlobalVar instances are in-scope for subroutine calls. + """ + + @I.ir_module + class module: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main"}) + module.subroutine() + A[0] = 42.0 + + @T.prim_func + def subroutine(): + T.evaluate(0) + + built = tvm.build(module, target=target) + + arr = tvm.nd.empty([1], dtype="float32", device=dev) + built(arr) + + assert arr.numpy()[0] == 42.0 + + +@tvm.testing.parametrize_targets("llvm") +def test_subroutine_call_with_pointer_param(target, dev): + """TIR functions on the CPU may call other functions + + Buffers may be exposed to subroutines through data pointers. + """ + + @I.ir_module + class module: + @T.prim_func + def main(A: T.Buffer(2, "float32")): + T.func_attr({"global_symbol": "main"}) + module.subroutine(A.data) + module.subroutine(T.address_of(A[1])) + + @T.prim_func + def subroutine(A_data: T.handle("float32")): + A = T.decl_buffer(shape=[1], dtype="float32", data=A_data) + A[0] = 42.0 + + built = tvm.build(module, target=target) + + arr = tvm.nd.empty([2], dtype="float32", device=dev) + built(arr) + + assert arr.numpy()[0] == 42.0 + assert arr.numpy()[1] == 42.0 + + +@pytest.mark.xfail(reason="Depends on LLVM version") +@tvm.testing.parametrize_targets("llvm") +def test_failed_subroutine_call_for_incorrect_type(target, dev): + """Calls into a subroutine must have correct argument types + + This currently relies on the `llvm::verifyModule` function during + codegen. In the future, this should be moved to a dedicated check + of TIR validity. + """ + + @I.ir_module + class module: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main"}) + module.subroutine(A.data) + + @T.prim_func + def subroutine(A_data: T.handle("int32")): + A = T.decl_buffer(shape=[1], dtype="int32", data=A_data) + A[0] = -1 + + lowered = tvm.lower(module) + with pytest.raises(tvm.TVMError): + tvm.build(lowered) + + +@tvm.testing.parametrize_targets("llvm") +def test_subroutine_call_with_scalar_param(target, dev): + """Subroutines may also accept scalar parameters""" + + @I.ir_module + class module: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main"}) + module.subroutine(A.data, 42.0) + + @T.prim_func + def subroutine(A_data: T.handle("float32"), val: T.float32): + A = T.decl_buffer([1], "float32", data=A_data) + A[0] = 2 * val + + built = tvm.build(module, target=target) + + arr = tvm.nd.empty([1], dtype="float32", device=dev) + built(arr) + + assert arr.numpy()[0] == 84.0 + + +@tvm.testing.parametrize_targets("llvm") +def test_internal_subroutine_is_not_exposed_externally(target, dev): + """An internal subroutine may not be called externally + + An internal subroutine is any subroutine without a "global_symbol" + attribute. These are not exposed in the runtime::Module and do + not have an externally linkable symbol. + """ + + @I.ir_module + class module: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main"}) + module.subroutine(A.data, 42.0) + + @T.prim_func + def subroutine(A_data: T.handle("float32"), val: T.float32): + A = T.decl_buffer([1], "float32", data=A_data) + A[0] = 2 * val + + built = tvm.build(module, target=target) + with pytest.raises(AttributeError): + built["subroutine"] + + +@tvm.testing.parametrize_targets("llvm") +def test_call_to_externally_visible_subroutine(target, dev): + """Subroutines may be exposed externally. + + A subroutine may be exposed externally. Externally-exposed + subroutines may be called by an external API, or may be called by + other functions in the same IRModule. + + The current implementation lowers internal subroutine calls to + `T.tvm_call_cpacked`. This avoids the overhead of the global + registry lookup used by `T.tvm_call_packed`, but still requires + the overhead of packing/unpacking the `PackedFunc` interface, and + is limited to callers whose target supports the `PackedFunc` + interface. + """ + + @I.ir_module + class module: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main"}) + module.subroutine(A.data, 42.0) + + @T.prim_func + def subroutine(A_data: T.handle("float32"), val: T.float32): + T.func_attr({"global_symbol": "subroutine"}) + A = T.Buffer([1], "float32", data=A_data) + A[0] = 2 * val + + built = tvm.build(module, target=target) + + arr = tvm.nd.empty([1], dtype="float32", device=dev) + built["main"](arr) + assert arr.numpy()[0] == 84.0 + + arr = np.zeros(shape=[1], dtype="float32") + built["subroutine"](arr.ctypes._data, 100.0) + assert arr[0] == 200.0 + + +is_external_subroutine = tvm.testing.parameter(by_dict={"external": True, "internal": False}) + + +@tvm.testing.parametrize_targets("llvm", "cuda") +def test_call_to_device_subroutine(target, dev, is_external_subroutine): + """Subroutines may be exposed externally. + + This feature is currently limited to host-side subroutine calls of + externally-exposed subroutines. + """ + is_gpu = "gpu" in tvm.target.Target(target).keys + + if is_gpu and not is_external_subroutine: + pytest.xfail(reason="Not yet implemented.") + + if is_external_subroutine: + func_attr = {"global_symbol": "subroutine"} + else: + func_attr = {} + + @I.ir_module + class module: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main"}) + module.subroutine(A.data, 42.0) + + @T.prim_func + def subroutine(A_data: T.handle("float32"), val: T.float32): + T.func_attr(func_attr) + A = T.Buffer([1], "float32", data=A_data) + iterator = T.meta_var( + T.thread_binding(0, 1, thread="threadIdx.x") if is_gpu else range(1) + ) + for i in iterator: + A[0] = 2 * val + + built = tvm.build(module, target=target) + + arr = tvm.nd.empty([1], dtype="float32", device=dev) + built["main"](arr) + assert arr.numpy()[0] == 84.0 + + +if __name__ == "__main__": + tvm.testing.main() From 210b674e9952099957e42361fc9738acbd1da8aa Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 16 Jun 2023 15:30:03 -0500 Subject: [PATCH 14/19] Updated example of PrintFuncPrefix Now that the function return type is handled by `CodeGenC`, updating the docstring to a usage other than the return type. --- src/target/source/codegen_c.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 2921a56ef3a1..ce74733159a1 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -138,7 +138,7 @@ class CodeGenC : public ExprFunctor, * \brief Print the function header before the argument list * \param os The output stream * - * Example: stream << "void"; + * Example: os << "extern \"C\""; */ virtual void PrintFuncPrefix(std::ostream& os); // NOLINT(*) /*! From 3015ab0502ea76ea8f772c5d6a2d703d091feab5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 16 Jun 2023 15:31:20 -0500 Subject: [PATCH 15/19] Handle missing global symbol in ExtractFuncInfo --- src/target/build_common.h | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/target/build_common.h b/src/target/build_common.h index 7c9ad8cb3c68..a79e935f0bde 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -55,8 +55,14 @@ inline std::unordered_map ExtractFuncInfo(co info.launch_param_tags.push_back(tag); } } - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - fmap[static_cast(global_symbol.value())] = info; + + if (auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol)) { + info.name = global_symbol.value(); + } else { + info.name = kv.first->name_hint; + } + + fmap[info.name] = info; } return fmap; } From 92b74eeb848d26fd1f89227c6912489e82b3e137 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 16 Jun 2023 15:31:37 -0500 Subject: [PATCH 16/19] Annotate internal functions with __device__ instead of __global__ Calling a function annotated with `__global__` can be done from the GPU (see https://stackoverflow.com/a/39448797), but requires a different calling convention. --- src/target/source/codegen_cuda.cc | 8 +++++++- .../unittest/test_tir_transform_inject_ptx_async_copy.py | 4 ++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 6c0234819199..f89be0b1600f 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -49,7 +49,7 @@ void CodeGenCUDA::Init(bool output_ssa) { ICHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state); } -void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ "; } +void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" "; } class ThreadIdxExtractor : public tir::StmtVisitor { private: @@ -76,6 +76,12 @@ class ThreadIdxExtractor : public tir::StmtVisitor { }; void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) { + if (f->GetAttr(tvm::attr::kGlobalSymbol)) { + os << " __global__"; + } else { + os << " __device__"; + } + ThreadIdxExtractor extractor; extractor(f->body); arith::Analyzer analyzer; diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index b39fca72c871..0abff0abcc80 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -204,8 +204,8 @@ def test_inject_async_copy_shared_dyn(): #define int64_t long long #define uint64_t unsigned long long #endif -extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C); -extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) { +extern "C" void __global__ __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C); +extern "C" void __global__ __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) { __shared__ float A_shared[64]; __shared__ float B_shared[64]; A_shared[((int)threadIdx.x)] = 0.000000e+00f; From a8815f2f9e50fbb9a4b68018acb03c0bd48f3797 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 16 Jun 2023 15:46:17 -0500 Subject: [PATCH 17/19] AnnotateDeviceRegions, use host to call exposed device functions Externally exposed function is lowered into a PackedFunc call, and calling a PackedFunc requires the caller to be on the host. In the future, this can be improved by having a pass that identifies internal callers of an externally-exposed callee, rewriting to extract an internal method that is called by both externally-exposed functions. --- src/tir/transforms/annotate_device_regions.cc | 55 +++++++++++++------ 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/src/tir/transforms/annotate_device_regions.cc b/src/tir/transforms/annotate_device_regions.cc index 4c5209e4b3b3..20f02ecc44db 100644 --- a/src/tir/transforms/annotate_device_regions.cc +++ b/src/tir/transforms/annotate_device_regions.cc @@ -40,13 +40,14 @@ class DeviceRegionAnnotater : public StmtExprMutator { using Parent = StmtExprMutator; public: - static Stmt Apply(Target host_target, Target device_target, Stmt body) { + static Stmt Apply(Target host_target, Target device_target, + const std::unordered_set& externally_exposed, Stmt body) { bool same_host_and_device = host_target->str() == device_target->str(); if (same_host_and_device) { return body; } - DeviceRegionAnnotater mutator(device_target); + DeviceRegionAnnotater mutator(device_target, externally_exposed); body = mutator(body); // If no region was found that must be on the device, but the @@ -62,7 +63,9 @@ class DeviceRegionAnnotater : public StmtExprMutator { } private: - explicit DeviceRegionAnnotater(Target device_target) : device_target_(device_target) {} + explicit DeviceRegionAnnotater(Target device_target, + const std::unordered_set& externally_exposed) + : device_target_(device_target), externally_exposed_(externally_exposed) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == tvm::attr::kTarget) { @@ -135,7 +138,8 @@ class DeviceRegionAnnotater : public StmtExprMutator { op->op.same_as(builtin::tvm_throw_last_error()) || op->op.same_as(builtin::tvm_stack_alloca()) || op->op.same_as(builtin::tvm_stack_make_shape()) || - op->op.same_as(builtin::tvm_stack_make_array()); + op->op.same_as(builtin::tvm_stack_make_array()) || + externally_exposed_.count(op->op.as()); if (is_host_only_op) { current_region_ = Region::Host; } @@ -143,6 +147,7 @@ class DeviceRegionAnnotater : public StmtExprMutator { } Target device_target_; + const std::unordered_set& externally_exposed_; enum class Region { Either, @@ -155,22 +160,40 @@ class DeviceRegionAnnotater : public StmtExprMutator { namespace transform { Pass AnnotateDeviceRegions() { - auto pass_func = [](PrimFunc func, IRModule mod, PassContext ctx) -> PrimFunc { - auto opt_target = func->GetAttr(tvm::attr::kTarget); - ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target attribute"; - Target target = opt_target.value(); - - if (auto opt_host = target->GetHost()) { - auto new_body = - DeviceRegionAnnotater::Apply(opt_host.value(), target.WithoutHost(), func->body); - if (!new_body.same_as(func->body)) { - func.CopyOnWrite()->body = new_body; + auto pass_func = [](IRModule mod, PassContext ctx) -> IRModule { + std::unordered_set externally_exposed; + for (const auto& [gvar, base_func] : mod->functions) { + if (base_func->GetAttr(tvm::attr::kGlobalSymbol)) { + externally_exposed.insert(gvar.get()); } } - return func; + + IRModule updates; + + for (const auto& [gvar, base_func] : mod->functions) { + auto func = Downcast(base_func); + auto opt_target = func->GetAttr(tvm::attr::kTarget); + ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target attribute"; + Target target = opt_target.value(); + + if (auto opt_host = target->GetHost()) { + auto new_body = DeviceRegionAnnotater::Apply(opt_host.value(), target.WithoutHost(), + externally_exposed, func->body); + if (!new_body.same_as(func->body)) { + func.CopyOnWrite()->body = new_body; + updates->Add(gvar, func); + } + } + } + + if (updates->functions.size()) { + mod.CopyOnWrite()->Update(updates); + } + + return mod; }; - return CreatePrimFuncPass(pass_func, 0, "tir.AnnotateDeviceRegions", {}); + return tvm::transform::CreateModulePass(pass_func, 0, "tir.AnnotateDeviceRegions", {}); } TVM_REGISTER_GLOBAL("tir.transform.AnnotateDeviceRegions").set_body_typed(AnnotateDeviceRegions); From c4e4fd4f661435c658aecb7aaba03a736fe073fc Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 16 Jun 2023 15:50:45 -0500 Subject: [PATCH 18/19] Relax cuda requirement for CallingConv::kDeviceKernelLaunch Previously, all functions required the `tvm::attr::kCallingConv` attribute to be set to `CallingConv::kDeviceKernelLaunch` (2). Now, this is only required for externally-exposed functions. --- src/target/opt/build_cuda_on.cc | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index e0f53e350992..179d7335e13d 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -136,8 +136,14 @@ runtime::Module BuildCUDA(IRModule mod, Target target) { ICHECK(base_func->IsInstance()) << "CodeGenCUDA: Can only take PrimFunc"; auto prim_func = Downcast(base_func); auto calling_conv = prim_func->GetAttr(tvm::attr::kCallingConv); - ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) - << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + bool is_device_kernel_launch = calling_conv == CallingConv::kDeviceKernelLaunch; + auto global_symbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); + bool is_internal_function = !global_symbol.defined(); + ICHECK(is_device_kernel_launch || is_internal_function) + << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch " + << "for externally exposed functions, " + << "but function " << gvar << " has calling_conv " << calling_conv << " and global symbol " + << global_symbol; functions.Set(gvar, prim_func); } From 3d6f97e1497d5821585e46d1096be82b0a99f445 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 16 Jun 2023 15:54:18 -0500 Subject: [PATCH 19/19] Enable unit test for internal GPU subroutine --- tests/python/unittest/test_tir_subroutine_call.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/python/unittest/test_tir_subroutine_call.py b/tests/python/unittest/test_tir_subroutine_call.py index a5d84abf268f..58a70d47ac0c 100755 --- a/tests/python/unittest/test_tir_subroutine_call.py +++ b/tests/python/unittest/test_tir_subroutine_call.py @@ -239,9 +239,6 @@ def test_call_to_device_subroutine(target, dev, is_external_subroutine): """ is_gpu = "gpu" in tvm.target.Target(target).keys - if is_gpu and not is_external_subroutine: - pytest.xfail(reason="Not yet implemented.") - if is_external_subroutine: func_attr = {"global_symbol": "subroutine"} else: