From f386132165c3449aa172c30481d90ce2e98542fe Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 18 Apr 2024 14:54:55 +0800 Subject: [PATCH] [Relax] Prevent to generate duplicate func in dispatch_sort_scan The current pass would generate multiple PrimFuncs even if they are structural equal, which is because `bb.update_func` will not check whether the new func is already in the list. This PR apply dlight at the end of the dispatching instead of after every function. --- .../tvm/relax/backend/dispatch_sort_scan.py | 57 +++++++++++-------- .../relax/test_backend_dispatch_sort_scan.py | 38 +++++++++++++ 2 files changed, 71 insertions(+), 24 deletions(-) diff --git a/python/tvm/relax/backend/dispatch_sort_scan.py b/python/tvm/relax/backend/dispatch_sort_scan.py index f0e42f401bc2..eb82e49d9a99 100644 --- a/python/tvm/relax/backend/dispatch_sort_scan.py +++ b/python/tvm/relax/backend/dispatch_sort_scan.py @@ -19,10 +19,11 @@ from functools import reduce from operator import mul +from typing import Dict from tvm import DataType, dlight, relax, topi from tvm.contrib.thrust import can_use_thrust -from tvm.ir import Op +from tvm.ir import GlobalVar, Op from tvm.ir.module import IRModule from tvm.ir.transform import PassContext, module_pass from tvm.relax import PyExprMutator, expr_functor @@ -41,8 +42,11 @@ class SortScanDispatcher(PyExprMutator): """ + calls_to_update: Dict[GlobalVar, Target] + def __init__(self, mod): super().__init__(mod) + self.calls_to_update = {} def _get_target(self, sinfo: relax.StructInfo) -> Target: # Get target information from TensorStructInfo @@ -64,22 +68,32 @@ def _get_target(self, sinfo: relax.StructInfo) -> Target: ) return target - def _apply_dlight_gpu_fallback(self, target: Target, tir_call: relax.Call) -> None: - # Apply dlight.gpu.Fallback() on GPU + def apply_dlight_gpu_fallback( + self, + ) -> None: + """Apply DLight rules for all the calls that need to be updated.""" + for gvar, target in self.calls_to_update.items(): + func = self.builder_.get()[gvar] + sch = dlight.base.transform._apply_rules( + func, + target, + rules=[dlight.gpu.Fallback()], + tunable=False, + ) + if sch is not None: + assert len(sch) == 1 + self.builder_.update_func(gvar, sch[0].mod["main"].with_attr("tir.is_scheduled", 1)) + + def _append_calls_to_update(self, tir_call: relax.Call, target: Target) -> None: gvar = tir_call.args[0] - assert isinstance(gvar, relax.GlobalVar) - scan_prim_func = self.builder_.get()[gvar] - sch = dlight.base.transform._apply_rules( - scan_prim_func, - target, - [ - dlight.gpu.Fallback(), - ], - False, - ) - if sch is not None: - assert len(sch) == 1 - self.builder_.update_func(gvar, sch[0].mod["main"].with_attr("tir.is_scheduled", 1)) + assert isinstance(gvar, GlobalVar) + existing_tgt = self.calls_to_update.get(gvar, None) + if existing_tgt is not None and existing_tgt != target: + raise ValueError( + f"Multiple targets detected for function {gvar}. " + f"Existing target: {existing_tgt}, new target: {target}" + ) + self.calls_to_update[gvar] = target def visit_call_(self, call: relax.Call) -> relax.Expr: if not isinstance(call.op, Op): @@ -135,10 +149,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: dtype=call.attrs.dtype, **kwargs, ) - if not is_gpu_target(tgt): - return tir_call - # apply dlight gpu fallback - self._apply_dlight_gpu_fallback(tgt, tir_call) + self._append_calls_to_update(tir_call, tgt) return tir_call if call.op.name in ("relax.cumprod", "relax.cumsum"): tgt = self._get_target(call.struct_info) @@ -161,10 +172,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: call.attrs.exclusive, **kwargs, ) - if not is_gpu_target(tgt): - return tir_call - # apply dlight gpu fallback - self._apply_dlight_gpu_fallback(tgt, tir_call) + self._append_calls_to_update(tir_call, tgt) return tir_call return super().visit_call_(call) @@ -211,4 +219,5 @@ def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule: if isinstance(func, relax.Function): func = sort_scan_dispater.visit_expr(func) sort_scan_dispater.builder_.update_func(gv, func) + sort_scan_dispater.apply_dlight_gpu_fallback() return sort_scan_dispater.builder_.finalize() diff --git a/tests/python/relax/test_backend_dispatch_sort_scan.py b/tests/python/relax/test_backend_dispatch_sort_scan.py index 5a291725d8f7..0fb39dfc9ca1 100644 --- a/tests/python/relax/test_backend_dispatch_sort_scan.py +++ b/tests/python/relax/test_backend_dispatch_sort_scan.py @@ -361,5 +361,43 @@ def foo(x: R.Tensor((2, 3), "float32", "cuda")): assert_structural_equal(mod, expected_mod) +def test_dispatch_topk_gpu(): + @I.ir_module + class Before: + I.module_global_infos({"vdevice": [I.vdevice("vulkan")]}) + + @R.function + def foo(x: R.Tensor((2, 3), "float32", "vulkan")): + with R.dataflow(): + # Two same calls should have only one PrimFunc + lv0 = R.topk(x, k=2, axis=1, largest=True) + lv1 = R.topk(x, k=2, axis=1, largest=True) + gv = (lv0, lv1) + R.output(gv) + return gv + + target = tvm.target.Target("vulkan", host="llvm") + + vdevices = [I.vdevice("vulkan", 0)] + x = relax.Var("x", R.Tensor((2, 3), "float32", vdevices[0])) + bb = relax.BlockBuilder() + with target: + with bb.function("foo", (x,), {"global_symbol": "foo"}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.cuda.topk, x, k=2, axis=1, is_ascend=False, dtype="int32") + lv1 = bb.emit_te(topi.cuda.topk, x, k=2, axis=1, is_ascend=False, dtype="int32") + out = (lv0, lv1) + out = bb.emit_output(out) + bb.emit_func_output(out) + expected_mod = bb.finalize() + expected_mod.update_global_info("vdevice", vdevices) + + with target: + mod = DispatchSortScan()(Before) + expected_mod = dlight.ApplyDefaultSchedule(dlight.gpu.Fallback())(expected_mod) + + assert_structural_equal(mod, expected_mod) + + if __name__ == "__main__": tvm.testing.main()