From 43289b168d4b5b5b7199af0d475afda303c33aa5 Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi Date: Mon, 25 Jul 2022 17:00:01 +0100 Subject: [PATCH 1/4] Pass that removes reshapes post LowerTE Change-Id: Iaf5a5f44776080b0b842af4b563d596134508de1 --- include/tvm/relay/transform.h | 6 + src/relay/backend/aot_executor_codegen.cc | 7 +- src/relay/transforms/remove_reshapes.cc | 116 ++++++++++++ .../contrib/test_cmsisnn/test_conv2d.py | 1 + .../contrib/test_cmsisnn/test_pooling.py | 8 +- .../test_cmsisnn/test_remove_reshapes.py | 169 ++++++++++++++++++ .../contrib/test_ethosu/test_networks.py | 10 +- tests/python/relay/aot/test_crt_aot.py | 2 +- tests/python/relay/aot/test_crt_aot_usmp.py | 36 ++-- 9 files changed, 328 insertions(+), 27 deletions(-) create mode 100644 src/relay/transforms/remove_reshapes.cc create mode 100644 tests/python/contrib/test_cmsisnn/test_remove_reshapes.py diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index f60912fb012e..15fa8459e8d3 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -585,6 +585,12 @@ TVM_DLL Pass CapturePostDfsIndexInSpans(); * expr->memory_scope and annotates expressions by VirtualDevice with required memory_scope */ TVM_DLL Pass AnnotateMemoryScope(CompilationConfig config); + +/*! + * \brief Removes standalone reshapes after lowering the graph. + */ +TVM_DLL Pass RemoveStandaloneReshapes(); + } // namespace transform /*! diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index b380f7b7c8b8..d114198e75f5 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -1096,6 +1096,12 @@ class AOTExecutorCodegen : public MixedModeVisitor { tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment); })(mod); + transform::PassContext pass_ctx = transform::PassContext::Current(); + bool enable_remove_reshapes = + pass_ctx->GetConfig("relay.RemoveReshapes", Bool(true)).value(); + if (enable_remove_reshapes) { + lowered_mod = transform::RemoveReshapes()(lowered_mod); + } auto lowered_main = lowered_mod->Lookup("main"); auto lowered_main_func = GetRef(lowered_main.as()); @@ -1203,7 +1209,6 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Parallel for loops are not supported in AoT codegen. lowered_mod = tir::transform::ConvertForLoopsToSerial()(lowered_mod); - transform::PassContext pass_ctx = transform::PassContext::Current(); bool enable_usmp = pass_ctx->GetConfig(kUSMPEnableOption, Bool(false)).value(); if (enable_usmp) { lowered_mod = PlanMemoryWithUSMP(lowered_mod); diff --git a/src/relay/transforms/remove_reshapes.cc b/src/relay/transforms/remove_reshapes.cc new file mode 100644 index 000000000000..c26dbf801c8a --- /dev/null +++ b/src/relay/transforms/remove_reshapes.cc @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file remove_reshapes.cc + * \brief Relay pass for removing reshapes from lowered graph. + */ + +#include +#include + +#include "../op/call/call.h" +#include "../op/memory/on_device.h" + +namespace tvm { +namespace relay { + +TVM_REGISTER_PASS_CONFIG_OPTION("relay.RemoveReshapes", Bool); +/*! Removes reshapes right after LowerTE. Removes preceding on_device calls + * while removing reshapes. + */ +class RemoveReshapesMutator : public MixedModeMutator { + public: + explicit RemoveReshapesMutator(IRModule& mod) : ir_module_(mod) {} + + using MixedModeMutator::VisitExpr_; + + Expr VisitExpr_(const LetNode* let) final { + Let ret_let; + Var var = Downcast(this->Mutate(let->var)); + auto value = this->Mutate(let->value); + if (auto* on_device_call = value.as()) { + OnDeviceProps on_device_props = GetOnDeviceProps(on_device_call); + if (on_device_props.body.defined() && on_device_props.body->IsInstance()) { + const Call call_lowered = Downcast(on_device_props.body); + if (call_lowered.defined() && call_lowered->op.same_as(CallLoweredOp())) { + let_var_to_call_lowered_.Set(var, call_lowered); + } + } + } + auto body = this->Mutate(let->body); + return WithFields(GetRef(let), var, value, body); + } + + Expr Rewrite_(const CallNode* call, const Expr& post) final { + /* + %1 = call_lowered(@tvmgen_default_non_reshape_function, %input, ...); + let %x: = on_device(%1, ...); + %2 = (%x,); + %3 = call_lowered(@tvmgen_default_fused_reshape, %2, ..., + "relay_attrs"=__dict__="relay.reshape_only"=1, ...); + */ + const CallNode* post_call = post.as(); + CallLoweredProps call_lowered_props = GetCallLoweredProps(post_call); + if (call_lowered_props.lowered_func.defined() && IsReshapeOnly(call_lowered_props)) { + if (!call_lowered_props.arguments.empty() && + call_lowered_props.arguments[0]->IsInstance()) { + Var var = Downcast(call_lowered_props.arguments[0]); + if (var.defined() && let_var_to_call_lowered_.find(var) != let_var_to_call_lowered_.end()) { + return let_var_to_call_lowered_[var]; + } + } + } + + return post; + } + + private: + /*! \brief Map of LetNode's var to previous call_lowered. */ + Map let_var_to_call_lowered_; + /*! \brief Module that contains global reshape functions. */ + IRModule& ir_module_; +}; + +namespace transform { + +Pass RemoveReshapes() { + auto pass_func = [=](IRModule mod, const PassContext& pass_ctx) { + VLOG(1) << "RemoveReshapes before:" << std::endl << PrettyPrint(mod); + RemoveReshapesMutator remove_reshapes_mutator(mod); + Function main_func = Downcast(mod->Lookup("main")); + Expr new_main_body = remove_reshapes_mutator.VisitExpr(main_func->body); + if (!new_main_body.same_as(main_func->body)) { + auto main_var = mod->GetGlobalVar("main"); + auto new_main_func = Function(main_func->params, new_main_body, main_func->ret_type, + main_func->type_params, main_func->attrs); + mod->Update(main_var, new_main_func); + } + Array entry_functions{"main"}; + mod = RemoveUnusedFunctions(entry_functions)(mod); + + VLOG(1) << "RemoveReshapes after:" << std::endl << PrettyPrint(mod); + return mod; + }; + return tvm::transform::CreateModulePass(pass_func, 0, "RemoveReshapes", {}); +} + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py index 623f5c0fc0d7..502743387bfa 100644 --- a/tests/python/contrib/test_cmsisnn/test_conv2d.py +++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py @@ -669,6 +669,7 @@ def test_relay_conv2d_cmsisnn_depthwise_int8( cmsisnn_func = cmsisnn_tir_mod["tvmgen_default_cmsis_nn_main_0"] call_extern = None + # This happens when context buffer is init in case depthM != 1 if isinstance(cmsisnn_func.body, tvm.tir.stmt.Evaluate): call_extern = cmsisnn_func.body.value else: diff --git a/tests/python/contrib/test_cmsisnn/test_pooling.py b/tests/python/contrib/test_cmsisnn/test_pooling.py index e96f397c04da..29140ad2e656 100644 --- a/tests/python/contrib/test_cmsisnn/test_pooling.py +++ b/tests/python/contrib/test_cmsisnn/test_pooling.py @@ -15,14 +15,18 @@ # specific language governing permissions and limitations # under the License. -"""CMSIS-NN integration tests: Conv2D""" +"""CMSIS-NN integration tests: Pooling""" import numpy as np import pytest import tvm from tvm import relay from tvm.relay.op.contrib import cmsisnn -from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data +from tvm.testing.aot import ( + generate_ref_data, + AOTTestModel, + compile_and_run, +) from tvm.micro.testing.aot_test_utils import AOT_USMP_CORSTONE300_RUNNER from .utils import ( make_module, diff --git a/tests/python/contrib/test_cmsisnn/test_remove_reshapes.py b/tests/python/contrib/test_cmsisnn/test_remove_reshapes.py new file mode 100644 index 000000000000..8b33a8a90b76 --- /dev/null +++ b/tests/python/contrib/test_cmsisnn/test_remove_reshapes.py @@ -0,0 +1,169 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""CMSIS-NN integration tests: Reshape removal""" +import numpy as np +import pytest +import tvm +from tvm import relay +from tvm.relay.op.contrib import cmsisnn + +from tvm.testing.aot import ( + generate_ref_data, + AOTTestModel, + compile_models, + run_and_check, +) +from tvm.micro.testing.aot_test_utils import AOT_USMP_CORSTONE300_RUNNER +from .utils import ( + make_module, + get_range_for_dtype_str, + get_same_padding, + make_qnn_relu, + assert_partitioned_function, +) + + +def make_model( + pool_op, + shape=(1, 28, 28, 12), + pool_size=(3, 3), + strides=(2, 2), + padding="VALID", + dtype="int8", + scale=1, + zero_point=-33, + relu_type="RELU", + layout="NHWC", + input_op=None, +): + """Return a model and any parameters it may have, + all parameters are defaulted to known good values + """ + if input_op: + op = input_op + else: + op = relay.var("input", shape=shape, dtype=dtype) + pad_ = (0, 0, 0, 0) + if padding == "SAME": + dilation = (1, 1) + pad_ = get_same_padding((shape[1], shape[2]), pool_size, dilation, strides) + op = relay.nn.pad( + op, + pad_width=[(0, 0), (pad_[0], pad_[2]), (pad_[1], pad_[3]), (0, 0)], + pad_value=zero_point, + pad_mode="constant", + ) + if pool_op.__name__ == relay.nn.avg_pool2d.__name__: + op = relay.cast(op, "int32") + op = pool_op( + op, pool_size=pool_size, strides=strides, padding=pad_, ceil_mode=True, layout=layout + ) + if pool_op.__name__ == relay.nn.avg_pool2d.__name__: + op = relay.cast(op, dtype) + op = make_qnn_relu(op, relu_type, scale, zero_point, dtype) + return op + + +@tvm.testing.requires_cmsisnn +@pytest.mark.parametrize("padding", ["SAME", "VALID"]) +def test_reshape_removal(padding): + """Tests reshape is removed from the network""" + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_USMP_CORSTONE300_RUNNER + + in_shape = (1, 28, 28, 12) + pool_size = (3, 3) + strides = (2, 2) + relu_type = "NONE" + zero_point, scale = (-34, 0.0256) + + max_pool = make_model( + pool_op=relay.nn.max_pool2d, + shape=in_shape, + pool_size=pool_size, + strides=strides, + padding=padding, + scale=scale, + zero_point=zero_point, + relu_type=relu_type, + ) + new_shape = (1, 28, 28, 3) if padding == "VALID" else (1, 30, 30, 3) + reshape = relay.reshape(max_pool, newshape=new_shape) + + model = make_model( + pool_op=relay.nn.avg_pool2d, + shape=new_shape, + pool_size=pool_size, + strides=strides, + padding=padding, + scale=scale, + zero_point=zero_point, + relu_type=relu_type, + input_op=reshape, + ) + orig_mod = make_module(model) + + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) + + # validate pattern matching + assert_partitioned_function(orig_mod, cmsisnn_mod) + + # generate reference output + rng = np.random.default_rng(12345) + in_min, in_max = get_range_for_dtype_str("int8") + inputs = {"input": rng.integers(in_min, high=in_max, size=in_shape, dtype="int8")} + output_list = generate_ref_data(orig_mod["main"], inputs, params=None) + + # validate presence of depthwise convolution + compiled_models = compile_models( + AOTTestModel( + module=cmsisnn_mod, + inputs=inputs, + outputs=output_list, + params=None, + output_tolerance=1, + ), + interface_api, + use_unpacked_api, + pass_config=test_runner.pass_config, + ) + + main_mod = None + for target, mod in compiled_models[0].executor_factory.lowered_ir_mods.items(): + if target.kind.name == "c": + main_mod = mod + + # when padding="SAME", extra padding is introduced which causes Reshape to be fused with the + # Pad. RemoveReshapes pass cannot remove a fused Reshape. Whereas padding="VALID" doesn't need + # an extra Pad layer. In this case, the pass removes the Reshape from the graph. + reshapes_present = any(["reshape" in gv.name_hint for gv in main_mod.get_global_vars()]) + check_reshapes = reshapes_present if padding == "SAME" else not reshapes_present + expected_reshapes = "a" if padding == "SAME" else "No" + assert check_reshapes, "Expeting {} reshape layer(s).".format(expected_reshapes) + + # validate the output + run_and_check( + models=compiled_models, + runner=test_runner, + interface_api=interface_api, + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/contrib/test_ethosu/test_networks.py b/tests/python/contrib/test_ethosu/test_networks.py index 02643f6c1ded..2b4ffd96caef 100644 --- a/tests/python/contrib/test_ethosu/test_networks.py +++ b/tests/python/contrib/test_ethosu/test_networks.py @@ -45,12 +45,12 @@ "accel_type, model_url, workspace_size", [ ("ethos-u65-256", MOBILENET_V1_URL, 1793376), - ("ethos-u65-256", MOBILENET_V2_URL, 2218160), + ("ethos-u65-256", MOBILENET_V2_URL, 2217152), ("ethos-u55-256", MOBILENET_V1_URL, 1793376), - ("ethos-u55-256", MOBILENET_V2_URL, 2218160), - ("ethos-u55-128", MOBILENET_V2_URL, 2218160), - ("ethos-u55-64", MOBILENET_V2_URL, 2218160), - ("ethos-u55-32", MOBILENET_V2_URL, 2218160), + ("ethos-u55-256", MOBILENET_V2_URL, 2217152), + ("ethos-u55-128", MOBILENET_V2_URL, 2217152), + ("ethos-u55-64", MOBILENET_V2_URL, 2217152), + ("ethos-u55-32", MOBILENET_V2_URL, 2217152), ], ) def test_networks_without_usmp(accel_type, model_url, workspace_size): diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 987d425aa63d..edf23ff22781 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -998,7 +998,7 @@ def test_workspace_calculation_cmsis_nn(): ): lib = tvm.relay.build(mod, target, executor=executor, runtime=runtime, params=params) mlf_memory_map = mlf._build_function_memory_map(lib.function_metadata) - assert mlf_memory_map["main"][0]["workspace_size_bytes"] == 14384 + assert mlf_memory_map["main"][0]["workspace_size_bytes"] == 14256 def test_aot_codegen_checks_returns(): diff --git a/tests/python/relay/aot/test_crt_aot_usmp.py b/tests/python/relay/aot/test_crt_aot_usmp.py index 724932183a54..b79350d172ac 100644 --- a/tests/python/relay/aot/test_crt_aot_usmp.py +++ b/tests/python/relay/aot/test_crt_aot_usmp.py @@ -105,24 +105,24 @@ def test_synthetic(interface_api, use_unpacked_api, test_runner): "workspace_byte_alignment,constant_byte_alignment," "main_workspace_size,main_constant_size,usmp_algo", [ - (8, 8, 17280, 948, "greedy_by_conflicts"), - (16, 8, 17280, 948, "greedy_by_conflicts"), - (256, 8, 17792, 948, "greedy_by_conflicts"), - (8, 16, 17280, 956, "greedy_by_conflicts"), - (16, 16, 17280, 956, "greedy_by_conflicts"), - (256, 16, 17792, 956, "greedy_by_conflicts"), - (8, 256, 17280, 1804, "greedy_by_conflicts"), - (16, 256, 17280, 1804, "greedy_by_conflicts"), - (256, 256, 17792, 1804, "greedy_by_conflicts"), - (8, 8, 22032, 948, "greedy_by_size"), - (16, 8, 22032, 948, "greedy_by_size"), - (256, 8, 22976, 948, "greedy_by_size"), - (8, 16, 22032, 956, "greedy_by_size"), - (16, 16, 22032, 956, "greedy_by_size"), - (256, 16, 22976, 956, "greedy_by_size"), - (8, 256, 22032, 1804, "greedy_by_size"), - (16, 256, 22032, 1804, "greedy_by_size"), - (256, 256, 22976, 1804, "greedy_by_size"), + (8, 8, 14208, 948, "greedy_by_conflicts"), + (16, 8, 14208, 948, "greedy_by_conflicts"), + (256, 8, 14720, 948, "greedy_by_conflicts"), + (8, 16, 14208, 956, "greedy_by_conflicts"), + (16, 16, 14208, 956, "greedy_by_conflicts"), + (256, 16, 14720, 956, "greedy_by_conflicts"), + (8, 256, 14208, 1804, "greedy_by_conflicts"), + (16, 256, 14208, 1804, "greedy_by_conflicts"), + (256, 256, 14720, 1804, "greedy_by_conflicts"), + (8, 8, 18576, 948, "greedy_by_size"), + (16, 8, 18576, 948, "greedy_by_size"), + (256, 8, 19392, 948, "greedy_by_size"), + (8, 16, 18576, 956, "greedy_by_size"), + (16, 16, 18576, 956, "greedy_by_size"), + (256, 16, 19392, 956, "greedy_by_size"), + (8, 256, 18576, 1804, "greedy_by_size"), + (16, 256, 18576, 1804, "greedy_by_size"), + (256, 256, 19392, 1804, "greedy_by_size"), (8, 8, 11424, 948, "hill_climb"), (16, 8, 11424, 948, "hill_climb"), (256, 8, 11920, 948, "hill_climb"), From 5d1663ff5210a28c4c546660127760d76c06a7ea Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi Date: Thu, 28 Jul 2022 14:33:52 +0100 Subject: [PATCH 2/4] Renamed pass to RemoveStandaloneReshapes Change-Id: I1f45ee3b15fbe290fdce69832a850d7d85ea1681 --- include/tvm/relay/transform.h | 5 +++- src/relay/backend/aot_executor_codegen.cc | 4 +-- ...hapes.cc => remove_standalone_reshapes.cc} | 28 +++++++++++-------- 3 files changed, 23 insertions(+), 14 deletions(-) rename src/relay/transforms/{remove_reshapes.cc => remove_standalone_reshapes.cc} (73%) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 15fa8459e8d3..49a8e8851dd7 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -587,7 +587,10 @@ TVM_DLL Pass CapturePostDfsIndexInSpans(); TVM_DLL Pass AnnotateMemoryScope(CompilationConfig config); /*! - * \brief Removes standalone reshapes after lowering the graph. + * \brief Remove non-fused reshapes after lowering the graph. + * + * + * \return The pass. */ TVM_DLL Pass RemoveStandaloneReshapes(); diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index d114198e75f5..970ecb321f01 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -1098,9 +1098,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { transform::PassContext pass_ctx = transform::PassContext::Current(); bool enable_remove_reshapes = - pass_ctx->GetConfig("relay.RemoveReshapes", Bool(true)).value(); + pass_ctx->GetConfig("relay.RemoveStandaloneReshapes.enable", Bool(true)).value(); if (enable_remove_reshapes) { - lowered_mod = transform::RemoveReshapes()(lowered_mod); + lowered_mod = transform::RemoveStandaloneReshapes()(lowered_mod); } auto lowered_main = lowered_mod->Lookup("main"); auto lowered_main_func = GetRef(lowered_main.as()); diff --git a/src/relay/transforms/remove_reshapes.cc b/src/relay/transforms/remove_standalone_reshapes.cc similarity index 73% rename from src/relay/transforms/remove_reshapes.cc rename to src/relay/transforms/remove_standalone_reshapes.cc index c26dbf801c8a..34ea910cd0fb 100644 --- a/src/relay/transforms/remove_reshapes.cc +++ b/src/relay/transforms/remove_standalone_reshapes.cc @@ -16,10 +16,14 @@ * specific language governing permissions and limitations * under the License. */ - /*! - * \file remove_reshapes.cc - * \brief Relay pass for removing reshapes from lowered graph. + * \file remove_standalone_reshapes.cc + * \brief This file contains the Relay pass for removing unfused reshapes from lowered graph. + * InferType() cannot be invoked after calling this pass as it removes reshapes from the call + * graph. Many targets only need buffer addresses irrespective of the shapes of them. This makes + * reshapes symbolic once the graph has been lowered. Reshape removal results into smaller code + * size and reduced buffer allocations. It opens up opportunities of operator fusion in the target + * backend. Thus, consequently, it improves the performance of the inference. */ #include @@ -31,16 +35,17 @@ namespace tvm { namespace relay { -TVM_REGISTER_PASS_CONFIG_OPTION("relay.RemoveReshapes", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("relay.RemoveStandaloneReshapes.enable", Bool); /*! Removes reshapes right after LowerTE. Removes preceding on_device calls * while removing reshapes. */ -class RemoveReshapesMutator : public MixedModeMutator { +class RemoveStandaloneReshapesMutator : public MixedModeMutator { public: - explicit RemoveReshapesMutator(IRModule& mod) : ir_module_(mod) {} + explicit RemoveStandaloneReshapesMutator(IRModule& mod) : ir_module_(mod) {} using MixedModeMutator::VisitExpr_; + /*! * \brief Generated map of let variables to preceding CallLowered */ Expr VisitExpr_(const LetNode* let) final { Let ret_let; Var var = Downcast(this->Mutate(let->var)); @@ -58,6 +63,7 @@ class RemoveReshapesMutator : public MixedModeMutator { return WithFields(GetRef(let), var, value, body); } + /*! * \brief Returns preceding CallLowered when call is a CallLowered(Reshape) */ Expr Rewrite_(const CallNode* call, const Expr& post) final { /* %1 = call_lowered(@tvmgen_default_non_reshape_function, %input, ...); @@ -90,10 +96,10 @@ class RemoveReshapesMutator : public MixedModeMutator { namespace transform { -Pass RemoveReshapes() { +Pass RemoveStandaloneReshapes() { auto pass_func = [=](IRModule mod, const PassContext& pass_ctx) { - VLOG(1) << "RemoveReshapes before:" << std::endl << PrettyPrint(mod); - RemoveReshapesMutator remove_reshapes_mutator(mod); + VLOG(1) << "RemoveStandaloneReshapes before:" << std::endl << PrettyPrint(mod); + RemoveStandaloneReshapesMutator remove_reshapes_mutator(mod); Function main_func = Downcast(mod->Lookup("main")); Expr new_main_body = remove_reshapes_mutator.VisitExpr(main_func->body); if (!new_main_body.same_as(main_func->body)) { @@ -105,10 +111,10 @@ Pass RemoveReshapes() { Array entry_functions{"main"}; mod = RemoveUnusedFunctions(entry_functions)(mod); - VLOG(1) << "RemoveReshapes after:" << std::endl << PrettyPrint(mod); + VLOG(1) << "RemoveStandaloneReshapes after:" << std::endl << PrettyPrint(mod); return mod; }; - return tvm::transform::CreateModulePass(pass_func, 0, "RemoveReshapes", {}); + return tvm::transform::CreateModulePass(pass_func, 0, "RemoveStandaloneReshapes", {}); } } // namespace transform From 689d4b38b6d5b25450a70347ab2163ce7fa43c9d Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi Date: Mon, 1 Aug 2022 17:27:12 +0100 Subject: [PATCH 3/4] Unit test for the pass RemoveStandaloneReshapes Change-Id: I81462a552f467d88cf1288acef2f9cbacc3ff532 --- include/tvm/relay/transform.h | 10 +- .../transforms/remove_standalone_reshapes.cc | 10 +- .../test_pass_remove_standalone_reshapes.py | 260 ++++++++++++++++++ 3 files changed, 270 insertions(+), 10 deletions(-) create mode 100644 tests/python/relay/backend/test_pass_remove_standalone_reshapes.py diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 49a8e8851dd7..b37d0f83adf3 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -587,10 +587,12 @@ TVM_DLL Pass CapturePostDfsIndexInSpans(); TVM_DLL Pass AnnotateMemoryScope(CompilationConfig config); /*! - * \brief Remove non-fused reshapes after lowering the graph. - * - * - * \return The pass. + * \brief Removes non-fused reshapes after lowering the graph. + * InferType() cannot be invoked after calling this pass as it removes reshapes from the call + * graph. Many targets only need buffer addresses irrespective of the shapes of them. This makes + * reshapes symbolic once the graph has been lowered. Reshape removal results into smaller code + * size and reduced buffer allocations. It opens up opportunities of operator fusion in the target + * backend. Thus, consequently, it improves the performance of the inference. */ TVM_DLL Pass RemoveStandaloneReshapes(); diff --git a/src/relay/transforms/remove_standalone_reshapes.cc b/src/relay/transforms/remove_standalone_reshapes.cc index 34ea910cd0fb..3d43738b48ef 100644 --- a/src/relay/transforms/remove_standalone_reshapes.cc +++ b/src/relay/transforms/remove_standalone_reshapes.cc @@ -17,13 +17,8 @@ * under the License. */ /*! - * \file remove_standalone_reshapes.cc + * \file src/relay/transforms/remove_standalone_reshapes.cc * \brief This file contains the Relay pass for removing unfused reshapes from lowered graph. - * InferType() cannot be invoked after calling this pass as it removes reshapes from the call - * graph. Many targets only need buffer addresses irrespective of the shapes of them. This makes - * reshapes symbolic once the graph has been lowered. Reshape removal results into smaller code - * size and reduced buffer allocations. It opens up opportunities of operator fusion in the target - * backend. Thus, consequently, it improves the performance of the inference. */ #include @@ -117,6 +112,9 @@ Pass RemoveStandaloneReshapes() { return tvm::transform::CreateModulePass(pass_func, 0, "RemoveStandaloneReshapes", {}); } +TVM_REGISTER_GLOBAL("relay._transform.RemoveStandaloneReshapes") + .set_body_typed(RemoveStandaloneReshapes); + } // namespace transform } // namespace relay } // namespace tvm diff --git a/tests/python/relay/backend/test_pass_remove_standalone_reshapes.py b/tests/python/relay/backend/test_pass_remove_standalone_reshapes.py new file mode 100644 index 000000000000..2113ae7b5c72 --- /dev/null +++ b/tests/python/relay/backend/test_pass_remove_standalone_reshapes.py @@ -0,0 +1,260 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Exercises the RemoveStandaloneReshapes pass. + +import tvm +from tvm import relay +from tvm.relay.expr_functor import ExprMutator +import tvm.testing +from tvm.script import tir as T + + +HOST_DEVICE = tvm.device("cpu") +HOST_TARGET = tvm.target.Target("llvm") + +CPU_DEVICE = tvm.device("cpu") +CPU_TARGET = tvm.target.Target("llvm").with_host(HOST_TARGET) + +CPU = tvm.target.VirtualDevice(CPU_DEVICE, CPU_TARGET) # device_type=1 + + +RemoveStandaloneReshapes = tvm._ffi.get_global_func("relay._transform.RemoveStandaloneReshapes") + + +class MarkReshapeOnlyMutator(ExprMutator): + """A pass for marking call_lowered as ReshapeOnly where reshapes exist unfused""" + + def __init__(self): + ExprMutator.__init__(self) + + def visit_call(self, call): + if isinstance(call.args[0], tvm.ir.GlobalVar) and "reshape" in call.args[0].name_hint: + # attrs = {"relay_attrs" : {"relay.reshape_only" : 1}} + dict_attrs = tvm.ir.make_node("DictAttrs", **{"relay.reshape_only": 1}) + attrs = tvm.ir.make_node( + "relay.attrs.CallLoweredAttrs", **{"metadata": {"relay_attrs": dict_attrs}} + ) + return relay.Call(call.op, call.args, attrs) + return super().visit_call(call) + + +# Reshape should not be removed if its the first layer in the network +def test_first_reshape(): + mod = tvm.ir.IRModule() + + @T.prim_func + def reshape_primfunc(a: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + D = T.match_buffer(d, [128, 128]) + + for i, j in T.grid(128, 128): + D[i, j] = A[i, j] + + metatable = {"VirtualDevice": [CPU]} + reshape_ty = relay.FuncType( + [ + relay.TensorType((128, 128), "float32"), + ], + relay.TensorType((128, 128), "float32"), + ) + + reshape_gv = relay.GlobalVar("reshape", type_annot=reshape_ty) + mod[reshape_gv] = reshape_primfunc + mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32], + virtual_device=meta[VirtualDevice][0]) { + %1 = call_lowered(@reshape, (%x,) ); + let %x_14: Tensor[(128, 128), float32] = on_device(%1, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %x_14 + } + """, + "from_string", + mod, + metatable, + ) + + mod["main"] = MarkReshapeOnlyMutator().visit(mod["main"]) + mod = RemoveStandaloneReshapes()(mod) + reshapes_present = any(["reshape" in gv.name_hint for gv in mod.get_global_vars()]) + assert reshapes_present, "Reshape should have been removed." + return + + +# When reshape layer is the last one in the network +def test_last_reshape(): + mod = tvm.ir.IRModule() + + @T.prim_func + def mul_primfunc(a: T.handle, b: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + D = T.match_buffer(d, [128, 128]) + + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + D[vi, vj] = A[vi, vk] * B[vj, vk] + + @T.prim_func + def reshape_primfunc(a: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + D = T.match_buffer(d, [128, 128]) + + for i, j in T.grid(128, 128): + D[i, j] = A[i, j] + + metatable = {"VirtualDevice": [CPU]} + mul_ty = relay.FuncType( + [ + relay.TensorType((128, 128), "float32"), + relay.TensorType((128, 128), "float32"), + relay.TensorType((128, 128), "float32"), + ], + relay.TensorType((128, 128), "float32"), + ) + + mul_gv = relay.GlobalVar("multiply", type_annot=mul_ty) + mod[mul_gv] = mul_primfunc + reshape_ty = relay.FuncType( + [ + relay.TensorType((128, 128), "float32"), + ], + relay.TensorType((128, 128), "float32"), + ) + + reshape_gv = relay.GlobalVar("reshape", type_annot=reshape_ty) + mod[reshape_gv] = reshape_primfunc + mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32], + %y {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32], + %z {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32], + virtual_device=meta[VirtualDevice][0]) { + %0 = call_lowered(@multiply, (%x, %y, %z)); + let %x_12: Tensor[(128, 128), float32] = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %1 = call_lowered(@reshape, (%x_12,) ); + let %x_14: Tensor[(128, 128), float32] = on_device(%1, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %x_14 + } + """, + "from_string", + mod, + metatable, + ) + + # Expected main: + ##[version = "0.0.5"] + # def @main(%x /* ty=Tensor[(128, 128), float32] */) -> Tensor[(128, 128), float32] { + # %0 = (%x, %y, %z); + # %1 = call_lowered(@multiply, %0); + # let %x_12: Tensor[(128, 128), float32] = on_device(%1, constrain_result=True); + # let %x_14: Tensor[(128, 128), float32] = on_device(%1, constrain_result=True); + # %x_14 + # } + + mod["main"] = MarkReshapeOnlyMutator().visit(mod["main"]) + mod = RemoveStandaloneReshapes()(mod) + reshapes_present = any(["reshape" in gv.name_hint for gv in mod.get_global_vars()]) + assert not reshapes_present, "Reshape should have been removed." + return + + +# When reshape layer is not marked as reshape_only +def test_fused_reshape(): + mod = tvm.ir.IRModule() + + @T.prim_func + def mul_primfunc(a: T.handle, b: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + D = T.match_buffer(d, [128, 128]) + + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + D[vi, vj] = A[vi, vk] * B[vj, vk] + + @T.prim_func + def fused_reshape_primfunc(a: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + D = T.match_buffer(d, [128, 128]) + + for i, j in T.grid(128, 128): + D[i, j] = A[i, j] + + metatable = {"VirtualDevice": [CPU]} + mul_ty = relay.FuncType( + [ + relay.TensorType((128, 128), "float32"), + relay.TensorType((128, 128), "float32"), + relay.TensorType((128, 128), "float32"), + ], + relay.TensorType((128, 128), "float32"), + ) + + mul_gv = relay.GlobalVar("multiply", type_annot=mul_ty) + mod[mul_gv] = mul_primfunc + reshape_ty = relay.FuncType( + [ + relay.TensorType((128, 128), "float32"), + ], + relay.TensorType((128, 128), "float32"), + ) + + reshape_gv = relay.GlobalVar("fused_reshape", type_annot=reshape_ty) + mod[reshape_gv] = fused_reshape_primfunc + mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32], + %y {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32], + %z {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32], + virtual_device=meta[VirtualDevice][0]) { + %0 = call_lowered(@multiply, (%x, %y, %z)); + let %x_12: Tensor[(128, 128), float32] = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %1 = call_lowered(@fused_reshape, (%x_12,) ); + let %x_14: Tensor[(128, 128), float32] = on_device(%1, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %x_14 + } + """, + "from_string", + mod, + metatable, + ) + + # Expected main: + ##[version = "0.0.5"] + # def @main(%x /* ty=Tensor[(128, 128), float32] */) -> Tensor[(128, 128), float32] { + # %0 = (%x, %y, %z); + # %1 = call_lowered(@multiply, %0); + # let %x_12: Tensor[(128, 128), float32] = on_device(%1, constrain_result=True); + # let %x_14: Tensor[(128, 128), float32] = on_device(%1, constrain_result=True); + # %x_14 + # } + + mod = RemoveStandaloneReshapes()(mod) + reshapes_present = any(["reshape" in gv.name_hint for gv in mod.get_global_vars()]) + assert reshapes_present, "Reshape should have been removed." + return + + +if __name__ == "__main__": + tvm.testing.main() From 389cadbe5a0bc2259a22e069d9b98fae8dfb7b34 Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi Date: Thu, 4 Aug 2022 09:46:00 +0100 Subject: [PATCH 4/4] Modified the name of the pass config option for reshape removal Change-Id: I8502bc74eb0914cfcaa86cb809d7c4a9c6e86c70 --- src/relay/backend/aot_executor_codegen.cc | 2 +- src/relay/transforms/remove_standalone_reshapes.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 970ecb321f01..6a9cadb6f770 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -1098,7 +1098,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { transform::PassContext pass_ctx = transform::PassContext::Current(); bool enable_remove_reshapes = - pass_ctx->GetConfig("relay.RemoveStandaloneReshapes.enable", Bool(true)).value(); + pass_ctx->GetConfig("relay.remove_standalone_reshapes.enable", Bool(true)).value(); if (enable_remove_reshapes) { lowered_mod = transform::RemoveStandaloneReshapes()(lowered_mod); } diff --git a/src/relay/transforms/remove_standalone_reshapes.cc b/src/relay/transforms/remove_standalone_reshapes.cc index 3d43738b48ef..28924e8bdfed 100644 --- a/src/relay/transforms/remove_standalone_reshapes.cc +++ b/src/relay/transforms/remove_standalone_reshapes.cc @@ -30,7 +30,7 @@ namespace tvm { namespace relay { -TVM_REGISTER_PASS_CONFIG_OPTION("relay.RemoveStandaloneReshapes.enable", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("relay.remove_standalone_reshapes.enable", Bool); /*! Removes reshapes right after LowerTE. Removes preceding on_device calls * while removing reshapes. */