From aedd0f7d21e414c38dd2c5b66d15e3a46f433578 Mon Sep 17 00:00:00 2001 From: sung Date: Wed, 8 Jun 2022 10:18:11 -0700 Subject: [PATCH 1/4] refactor three TIR passes - BindTarget, AnnotateEntryFunc, Filter --- include/tvm/tir/transform.h | 19 +++ python/tvm/tir/transform/transform.py | 68 +++++++--- src/driver/driver_api.cc | 45 ++----- src/tir/transforms/helpers.cc | 63 +++++++++ src/tir/transforms/lower_init_block.cc | 2 +- .../convert_pool_allocations_to_offsets.cc | 2 +- .../unittest/test_tir_transform_helpers.py | 123 ++++++++++++++++++ 7 files changed, 265 insertions(+), 57 deletions(-) create mode 100644 src/tir/transforms/helpers.cc create mode 100644 tests/python/unittest/test_tir_transform_helpers.py diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 24c3cfa78f72..f70a8fc27ad9 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -24,6 +24,7 @@ #ifndef TVM_TIR_TRANSFORM_H_ #define TVM_TIR_TRANSFORM_H_ +#include #include #include #include @@ -624,6 +625,24 @@ TVM_DLL Pass ExtractPrimFuncConstants(); */ TVM_DLL Pass RenormalizeSplitPattern(); +/*! + * \brief Annotate a PrimFunc with a given target. + * \return The pass. + */ +TVM_DLL Pass BindTarget(Target target); + +/*! + * \brief Set a PrimFunc as the entry point if it is only function in IRModule. + * \return The pass. + */ +TVM_DLL Pass AnnotateEntryFunc(); + +/*! + * \brief Filter PrimFuncs with a given condition. + * \return The pass. + */ +TVM_DLL Pass Filter(runtime::TypedPackedFunc fcond); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 802fdc576c41..330f51fdf913 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -16,9 +16,10 @@ # under the License. """Wrapping existing transformations.""" # pylint: disable=invalid-name -from typing import Optional +from typing import Optional, Callable from . import _ffi_api from . import function_pass as _fpass +from tvm.target import Target def Apply(ftransform): @@ -43,26 +44,6 @@ def _transform(func, mod, ctx): return _fpass.prim_func_pass(_transform, opt_level=0, name="Apply") # type: ignore -def Filter(fcond): - """Filter functions by the calling convention attribute. - - Parameters - ---------- - fcond : tvm.tir.PrimFunc -> bool - The condition of the filtering. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - # pylint: disable=unused-argument - def _transform(func, mod, ctx): - return func if fcond(func) else None - - return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter") # type: ignore - - def InjectPrefetch(): """Inject prefetch instructions into stmt. @@ -806,3 +787,48 @@ def RenormalizeSplitPattern(): The result pass """ return _ffi_api.RenormalizeSplitPattern() # type: ignore + + +def RenormalizeSplitPattern(): + """Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.RenormalizeSplitPattern() # type: ignore + + +def BindTarget(target: Target): + """Annotate a PrimFunc with a given target. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.BindTarget(target) # type: ignore + + +def AnnotateEntryFunc(): + """Set a PrimFunc as the entry point if it is only function in IRModule. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.AnnotateEntryFunc() # type: ignore + + +def Filter(fcond: Callable): + """Filter out PrimFuncs that does not satisfy the given condition. + `fcond` should be a function that takes a primfunc and returns boolean. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.Filter(fcond) # type: ignore diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 7df1a844acc2..f66153e9b0b0 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -163,32 +163,6 @@ TVM_REGISTER_GLOBAL("driver.get_binds") return out_arr; }); -transform::Pass BindTarget(Target target) { - auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { - return WithAttr(std::move(f), tvm::attr::kTarget, target); - }; - return tir::transform::CreatePrimFuncPass(fpass, 0, "BindTarget", {}); -} - -static transform::Pass AnnotateEntryFunc(bool b) { - auto fpass = [](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { - return WithAttr(std::move(f), tir::attr::kIsEntryFunc, Bool(true)); - }; - return tir::transform::CreatePrimFuncPass(fpass, 0, "AnnotateEntryFunc", {}); -} - -template -transform::Pass Filter(FCond fcond) { - auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { - if (fcond(f)) { - return f; - } else { - return tir::PrimFunc(nullptr); - } - }; - return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {}); -} - Array CreatePassList(bool disable_loop_partition) { transform::PassContext pass_ctx = transform::PassContext::Current(); @@ -560,12 +534,12 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) Array mixed_pass_list; - mixed_pass_list.push_back(BindTarget(target)); + mixed_pass_list.push_back(tir::transform::BindTarget(target)); mixed_pass_list.push_back(tir::transform::VerifyMemory()); if (ShouldAnnotateEntryFunc(mixed_mod)) { - mixed_pass_list.push_back(AnnotateEntryFunc(true)); + mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc()); } bool detect_global_barrier = @@ -602,14 +576,16 @@ TVM_REGISTER_GLOBAL("driver.mixed_mod_passes") transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) { Array host_pass_list; - host_pass_list.push_back(Filter([](const tir::PrimFunc& f) { + + runtime::TypedPackedFunc fcond = [](const tir::PrimFunc& f) { return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != CallingConv::kDeviceKernelLaunch; - })); + }; + host_pass_list.push_back(tir::transform::Filter(fcond)); ICHECK(mixed_mod.defined()) << "This module must be defined"; - host_pass_list.push_back(BindTarget(target_host)); + host_pass_list.push_back(tir::transform::BindTarget(target_host)); host_pass_list.push_back(tir::transform::LowerTVMBuiltin()); host_pass_list.push_back(tir::transform::LowerCustomDatatypes()); @@ -627,12 +603,13 @@ TVM_REGISTER_GLOBAL("driver.host_mod_passes") transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) { Array device_pass_list; - device_pass_list.push_back(Filter([](const tir::PrimFunc& f) { + runtime::TypedPackedFunc fcond = [](const tir::PrimFunc& f) { return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDeviceKernelLaunch; - })); + }; + device_pass_list.push_back(tir::transform::Filter(fcond)); - device_pass_list.push_back(BindTarget(target)); + device_pass_list.push_back(tir::transform::BindTarget(target)); device_pass_list.push_back(tir::transform::LowerWarpMemory()); device_pass_list.push_back(tir::transform::Simplify()); diff --git a/src/tir/transforms/helpers.cc b/src/tir/transforms/helpers.cc new file mode 100644 index 000000000000..d02400aa3c3d --- /dev/null +++ b/src/tir/transforms/helpers.cc @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file helpers.cc + * \brief Passes that serve as helper functions. + */ + +#include +#include + +namespace tvm { +namespace tir { +namespace transform { +transform::Pass BindTarget(Target target) { + auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { + return WithAttr(std::move(f), tvm::attr::kTarget, target); + }; + return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.BindTarget", {}); +} + +transform::Pass AnnotateEntryFunc() { + auto fpass = [](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { + ICHECK(m->functions.size() == 1); + return WithAttr(std::move(f), tir::attr::kIsEntryFunc, Bool(true)); + }; + return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.AnnotateEntryFunc", {}); +} + +transform::Pass Filter(runtime::TypedPackedFunc fcond) { + auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { + if (fcond(f)) { + return f; + } else { + return tir::PrimFunc(nullptr); + } + }; + return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.Filter", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.BindTarget").set_body_typed(BindTarget); +TVM_REGISTER_GLOBAL("tir.transform.AnnotateEntryFunc").set_body_typed(AnnotateEntryFunc); +TVM_REGISTER_GLOBAL("tir.transform.Filter").set_body_typed(Filter); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/lower_init_block.cc b/src/tir/transforms/lower_init_block.cc index d8621ac3b3e6..17b4e3fb22e6 100644 --- a/src/tir/transforms/lower_init_block.cc +++ b/src/tir/transforms/lower_init_block.cc @@ -81,7 +81,7 @@ Pass LowerInitBlock() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { return LowerInitBlock(std::move(f)); }; - return CreatePrimFuncPass(pass_func, 0, "tir.LowerReduction", {}); + return CreatePrimFuncPass(pass_func, 0, "tir.LowerInitBlock", {}); } TVM_REGISTER_GLOBAL("tir.transform.LowerInitBlock").set_body_typed(LowerInitBlock); diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index dc71e3d60891..1611ed596a1c 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -60,7 +60,7 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { PoolInfo pool_info = pool_allocation->pool_info; int byte_pool_offset = pool_allocation->byte_offset->value; int required_pool_size_for_allocation = - byte_pool_offset + CalculateExtentsSize(allocate_node.operator->()); + byte_pool_offset + int(CalculateExtentsSize(allocate_node.operator->())); if (all_pools_sizes_.find(pool_info) == all_pools_sizes_.end()) { all_pools_sizes_[pool_info] = required_pool_size_for_allocation; } else { diff --git a/tests/python/unittest/test_tir_transform_helpers.py b/tests/python/unittest/test_tir_transform_helpers.py new file mode 100644 index 000000000000..01496e0e0fc1 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_helpers.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm.script import tir as T +import tvm.testing + + +def test_annotate_entry_func_single_primfunc(): + @tvm.script.ir_module + class MockModule: + @T.prim_func + def func1(A: T.Buffer[(16,), "float32"]): + for i in T.serial(16): + if i == 5: + if i == 5: + A[i] = 0.0 + + mod = MockModule + assert mod + assert mod["func1"].attrs is None + after = tvm.tir.transform.AnnotateEntryFunc()(mod) + assert ( + after["func1"].attrs + and "tir.is_entry_func" in after["func1"].attrs + and after["func1"].attrs["tir.is_entry_func"] + ) + + +# Test module +@tvm.script.ir_module +class MockModule: + @T.prim_func + def func1(A: T.Buffer[(16,), "float32"]): + for i in T.serial(16): + if i == 5: + if i == 5: + A[i] = 0.0 + + @T.prim_func + def func2(A: T.Buffer[(32,), "float32"]): + for i in T.serial(32): + if i == 15: + if i == 15: + A[i] = 0.0 + + +@pytest.mark.xfail +def test_annotate_entry_func_multiple_primfunc(): + mod = MockModule + assert mod + assert mod["func1"].attrs is None + assert mod["func2"].attrs is None + # This should fail + after = tvm.tir.transform.AnnotateEntryFunc()(mod) + + +def test_bind_target(): + mod = MockModule + assert mod + + target = tvm.target.Target("cuda") + assert mod["func1"].attrs is None + assert mod["func2"].attrs is None + after = tvm.tir.transform.BindTarget(target)(mod) + + assert after["func1"].attrs and "target" in after["func1"].attrs + assert after["func1"].attrs["target"] == target + assert after["func2"].attrs and "target" in after["func2"].attrs + assert after["func2"].attrs["target"] == target + + +def test_filter_primfunc(): + mod = MockModule + assert mod + # Annotate each function for testing + mod["func1"] = mod["func1"].with_attr("temp", "test1") + mod["func2"] = mod["func2"].with_attr("temp", "test2") + + # Test condition that does not filter out anything + def checker_filter_out_none(func: tvm.tir.PrimFunc): + return (func.attrs is not None) and ("temp" in func.attrs) + + after = tvm.tir.transform.Filter(checker_filter_out_none)(mod) + assert len(after.functions) == 2 + # Filtered functions should satisfy the given condition. + assert checker_filter_out_none(after["func1"]) + assert checker_filter_out_none(after["func2"]) + + # Test condition that selectively filters out primfuncs + def checker_filter_out_one(func: tvm.tir.PrimFunc): + return (func.attrs is not None) and ("temp" in func.attrs) and func.attrs["temp"] == "test1" + + after = tvm.tir.transform.Filter(checker_filter_out_one)(mod) + assert len(after.functions) == 1 + # Filtered functions should satisfy the given condition. + assert checker_filter_out_one(after["func1"]) + + # Test condition that filters out everything + def checker_filter_out_both(func: tvm.tir.PrimFunc): + return (func.attrs is not None) and ("invalid_attr" in func.attrs) + + after = tvm.tir.transform.Filter(checker_filter_out_both)(mod) + assert len(after.functions) == 0 + + +if __name__ == "__main__": + tvm.testing.main() From 2dded9ec83d3cb27a0484acb0eea9c3ce05a4017 Mon Sep 17 00:00:00 2001 From: sung Date: Wed, 8 Jun 2022 13:33:05 -0700 Subject: [PATCH 2/4] fix lint --- python/tvm/tir/transform/transform.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 330f51fdf913..b636967c6433 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -17,9 +17,9 @@ """Wrapping existing transformations.""" # pylint: disable=invalid-name from typing import Optional, Callable +from tvm.target import Target from . import _ffi_api from . import function_pass as _fpass -from tvm.target import Target def Apply(ftransform): @@ -789,17 +789,6 @@ def RenormalizeSplitPattern(): return _ffi_api.RenormalizeSplitPattern() # type: ignore -def RenormalizeSplitPattern(): - """Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.RenormalizeSplitPattern() # type: ignore - - def BindTarget(target: Target): """Annotate a PrimFunc with a given target. From 4a7d1342061eeacad2f187237b16338a5eb820c5 Mon Sep 17 00:00:00 2001 From: sung Date: Wed, 8 Jun 2022 14:47:37 -0700 Subject: [PATCH 3/4] fix another lint --- src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index 1611ed596a1c..1161962f1287 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -60,7 +60,7 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { PoolInfo pool_info = pool_allocation->pool_info; int byte_pool_offset = pool_allocation->byte_offset->value; int required_pool_size_for_allocation = - byte_pool_offset + int(CalculateExtentsSize(allocate_node.operator->())); + byte_pool_offset + static_cast(CalculateExtentsSize(allocate_node.operator->())); if (all_pools_sizes_.find(pool_info) == all_pools_sizes_.end()) { all_pools_sizes_[pool_info] = required_pool_size_for_allocation; } else { From 7a0db73299e6751ba66b43b28f8d0764ce6c2076 Mon Sep 17 00:00:00 2001 From: sung Date: Thu, 9 Jun 2022 07:41:58 -0700 Subject: [PATCH 4/4] reflect the feedback --- include/tvm/tir/transform.h | 2 +- python/tvm/tir/transform/transform.py | 8 ++++++-- src/tir/transforms/{helpers.cc => primfunc_utils.cc} | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) rename src/tir/transforms/{helpers.cc => primfunc_utils.cc} (98%) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index f70a8fc27ad9..2af9dbfaf2b8 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -24,8 +24,8 @@ #ifndef TVM_TIR_TRANSFORM_H_ #define TVM_TIR_TRANSFORM_H_ -#include #include +#include #include #include diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index b636967c6433..351368c954ec 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -17,7 +17,7 @@ """Wrapping existing transformations.""" # pylint: disable=invalid-name from typing import Optional, Callable -from tvm.target import Target + from . import _ffi_api from . import function_pass as _fpass @@ -789,8 +789,12 @@ def RenormalizeSplitPattern(): return _ffi_api.RenormalizeSplitPattern() # type: ignore -def BindTarget(target: Target): +def BindTarget(target): """Annotate a PrimFunc with a given target. + Parameters + ------- + target : tvm.target.Target + target Returns ------- diff --git a/src/tir/transforms/helpers.cc b/src/tir/transforms/primfunc_utils.cc similarity index 98% rename from src/tir/transforms/helpers.cc rename to src/tir/transforms/primfunc_utils.cc index d02400aa3c3d..d2bb259f9921 100644 --- a/src/tir/transforms/helpers.cc +++ b/src/tir/transforms/primfunc_utils.cc @@ -18,7 +18,7 @@ */ /*! - * \file helpers.cc + * \file primfunc_utils.cc * \brief Passes that serve as helper functions. */