From 076f94775dd7f3dcfa5db775eeedf10c050f2bf1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 25 May 2023 11:43:50 -0500 Subject: [PATCH 1/2] [TIR] Handle callees on same target, different codegen Prior to this commit, any caller that uses a different `Target` than its callee is lowered to a device-kernel launch. However, if the caller and callee are on the same device, despite using a different target (e.g. `Target("llvm")` and `Target("c")` both use `kDLCPU`), then the kernel launch is unnecessary. This commit updates `LowerDeviceKernelLaunch` to produce a kernel launch only when the callee is on another device, and to produce `T.call_extern` for callees on the same device. --- .../transforms/lower_device_kernel_launch.cc | 45 ++++++++++++++--- ...test_tir_transform_device_kernel_launch.py | 49 +++++++++++++++++++ 2 files changed, 88 insertions(+), 6 deletions(-) diff --git a/src/tir/transforms/lower_device_kernel_launch.cc b/src/tir/transforms/lower_device_kernel_launch.cc index 5ffbf0d7a7fd..0348d1089de2 100644 --- a/src/tir/transforms/lower_device_kernel_launch.cc +++ b/src/tir/transforms/lower_device_kernel_launch.cc @@ -170,14 +170,27 @@ class DeviceKernelMutator : public StmtExprMutator { } PrimFunc UpdateKernelAttributes(const GlobalVar& gvar, PrimFunc func) const { - if (device_kernel_launch_.count(gvar.get())) { + bool is_kernel_launch = device_kernel_launch_.count(gvar.get()); + bool is_call_extern = extern_method_call_.count(gvar.get()); + CHECK(!is_kernel_launch || !is_call_extern) + << "Function " << gvar << " has multiple callees, " + << "and would need to be lowered into a call_extern at some call sites, " + << "and a device kernel launch at others. " + << "This case is not yet supported."; + + if (is_kernel_launch || is_call_extern) { + func = WithAttr(std::move(func), tvm::tir::attr::kIsGlobalFunc, Bool(true)); + } + + if (is_kernel_launch) { const auto& info = device_info_map_.at(gvar.get()); func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv, Integer(tvm::CallingConv::kDeviceKernelLaunch)}, {tvm::tir::attr::kKernelLaunchParams, info.launch_params}, - {tvm::attr::kGlobalSymbol, info.global_symbol}, - {tvm::tir::attr::kIsGlobalFunc, Bool(true)}}); + {tvm::attr::kGlobalSymbol, info.global_symbol}}); + } else if (is_call_extern && !func->GetAttr(tvm::attr::kGlobalSymbol)) { + func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint); } return func; @@ -196,12 +209,31 @@ class DeviceKernelMutator : public StmtExprMutator { << gvar->name_hint << " did not appear within the IRModule"; const KernelInfo& dev_info = it->second; - auto caller_device_type = current_target_.value()->GetTargetDeviceType(); - auto callee_device_type = dev_info.target->GetTargetDeviceType(); - if (caller_device_type == callee_device_type) { + auto caller_target = current_target_.value(); + auto callee_target = dev_info.target; + + bool same_target = caller_target->str() == callee_target->str(); + if (same_target) { + // Calls within the same target may be handled at codegen time + // as internal subroutine calls. return std::move(node); } + bool same_device_type = + caller_target->GetTargetDeviceType() == callee_target->GetTargetDeviceType(); + if (same_device_type) { + // Calls to another target using the same device (e.g. LLVM + // calling a custom TIRToRuntime target) do not require a kernel + // launch, but need to be replaced with call_extern. + extern_method_call_.insert(gvar); + Array args; + args.push_back(StringImm(gvar->name_hint)); + for (const auto& arg : node->args) { + args.push_back(arg); + } + return Call(node->dtype, builtin::call_extern(), args); + } + ICHECK(dev_info.launch_params.defined()) << "CallNode attempted kernel launch to " << gvar->name_hint << " on target " << dev_info.target << ", but subroutine " << gvar->name_hint @@ -243,6 +275,7 @@ class DeviceKernelMutator : public StmtExprMutator { Optional current_target_; std::unordered_map device_info_map_; std::unordered_set device_kernel_launch_; + std::unordered_set extern_method_call_; }; namespace transform { diff --git a/tests/python/unittest/test_tir_transform_device_kernel_launch.py b/tests/python/unittest/test_tir_transform_device_kernel_launch.py index a0f77da3766b..34cde4e4b6ce 100644 --- a/tests/python/unittest/test_tir_transform_device_kernel_launch.py +++ b/tests/python/unittest/test_tir_transform_device_kernel_launch.py @@ -189,5 +189,54 @@ def kernel(A_data: T.handle("float32")): return mod +class TestSameDeviceDifferentTarget(BaseCompare): + """Handle subroutine calls to same device, different codegen + + The device kernel launch is only required when the caller and + callee are on different devices. However, if the caller and + callee use different codegen, then the call cannot be handled as + an internal call by a single codegen. Instead, it should be + lowered to a `T.call_extern`. + """ + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("llvm")}) + mod.kernel(A.data) + + @T.prim_func + def kernel(A_data: T.handle("float32")): + T.func_attr({"target": T.target("c")}) + A = T.decl_buffer(16, dtype="float32", data=A_data) + A[0] = 0.0 + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("llvm")}) + T.call_extern("kernel", A.data, dtype="void") + + @T.prim_func + def kernel(A_data: T.handle("float32")): + T.func_attr( + { + "target": T.target("c"), + "global_symbol": "kernel", + "tir.is_global_func": True, + } + ) + A = T.decl_buffer(16, dtype="float32", data=A_data) + A[0] = 0.0 + + return mod + + if __name__ == "__main__": tvm.testing.main() From 922df815b441eb5cb08efee9bcacde2516da0ead Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 3 Jun 2023 14:39:55 -0500 Subject: [PATCH 2/2] Rename "extern_method_call_" to "extern_function_call_" --- src/tir/transforms/lower_device_kernel_launch.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tir/transforms/lower_device_kernel_launch.cc b/src/tir/transforms/lower_device_kernel_launch.cc index 0348d1089de2..52f06ea45c7c 100644 --- a/src/tir/transforms/lower_device_kernel_launch.cc +++ b/src/tir/transforms/lower_device_kernel_launch.cc @@ -171,7 +171,7 @@ class DeviceKernelMutator : public StmtExprMutator { PrimFunc UpdateKernelAttributes(const GlobalVar& gvar, PrimFunc func) const { bool is_kernel_launch = device_kernel_launch_.count(gvar.get()); - bool is_call_extern = extern_method_call_.count(gvar.get()); + bool is_call_extern = extern_function_call_.count(gvar.get()); CHECK(!is_kernel_launch || !is_call_extern) << "Function " << gvar << " has multiple callees, " << "and would need to be lowered into a call_extern at some call sites, " @@ -225,7 +225,7 @@ class DeviceKernelMutator : public StmtExprMutator { // Calls to another target using the same device (e.g. LLVM // calling a custom TIRToRuntime target) do not require a kernel // launch, but need to be replaced with call_extern. - extern_method_call_.insert(gvar); + extern_function_call_.insert(gvar); Array args; args.push_back(StringImm(gvar->name_hint)); for (const auto& arg : node->args) { @@ -275,7 +275,7 @@ class DeviceKernelMutator : public StmtExprMutator { Optional current_target_; std::unordered_map device_info_map_; std::unordered_set device_kernel_launch_; - std::unordered_set extern_method_call_; + std::unordered_set extern_function_call_; }; namespace transform {