From 2f21055400ad35fba37e16d1f9fa717fdb5dcacd Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 14 Jun 2022 11:02:41 -0700 Subject: [PATCH] [MetaSchedule] Apply-History-Best Task Filtering This PR enables task filtering in Apply-History-Best, which is used in Relay/Relax integration. Previously, even though a task is ruled out during task extraction, it still shows up in Relay compilation due to the lack of filtering on `Apply-History-Best`. However, TE-to-TIR conversion `te.CreatePrimFunc` doesn't support all cases with hybrid operators involved, which leads to post-tuning failure affecting multiple models. --- .../tvm/meta_schedule/apply_history_best.h | 21 ++++++- include/tvm/meta_schedule/extracted_task.h | 23 ++++++++ .../tvm/meta_schedule/apply_history_best.py | 26 +++++++-- python/tvm/meta_schedule/relay_integration.py | 16 ++++-- python/tvm/meta_schedule/testing/utils.py | 15 ++++- src/meta_schedule/apply_history_best.cc | 15 ++++- src/meta_schedule/extracted_task.cc | 55 ++++++++++++++++++- src/meta_schedule/utils.h | 1 + src/relay/backend/task_extraction.cc | 45 ++------------- src/relay/backend/te_compiler_cache.cc | 17 +++--- .../test_meta_schedule_relay_tir_compute.py | 15 ++++- 11 files changed, 183 insertions(+), 66 deletions(-) diff --git a/include/tvm/meta_schedule/apply_history_best.h b/include/tvm/meta_schedule/apply_history_best.h index 5b1816cef41f..82bb350e1c5e 100644 --- a/include/tvm/meta_schedule/apply_history_best.h +++ b/include/tvm/meta_schedule/apply_history_best.h @@ -29,6 +29,12 @@ #include #include +namespace tvm { +namespace te { +class Tensor; +} // namespace te +} // namespace tvm + namespace tvm { namespace meta_schedule { @@ -38,12 +44,21 @@ namespace meta_schedule { */ class ApplyHistoryBestNode : public runtime::Object { public: + using FTEFilterFunc = + runtime::TypedPackedFunc(const Array&)>; + /*! \brief The database to be queried from */ Database database{nullptr}; + /*! \brief The filtering function for TE computation */ + FTEFilterFunc te_filter_func{nullptr}; /*! \brief The logging function to be used */ PackedFunc logging_func; - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("database", &database); } + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("database", &database); + // `te_filter_func` is not visited + // `logging_func` is not visited + } /*! * \brief Query the best entry from the database * \param task_name The name of the task to be queried @@ -67,9 +82,11 @@ class ApplyHistoryBest : public runtime::ObjectRef { /*! * \brief Constructor * \param database The database to be queried from + * \param te_filter_func The filtering function for TE computation * \param logging_func The logging function to use */ - explicit ApplyHistoryBest(Database database, PackedFunc logging_func); + explicit ApplyHistoryBest(Database database, ApplyHistoryBestNode::FTEFilterFunc te_filter_func, + PackedFunc logging_func); /*! * \brief The current ApplyHistoryBest in the context * \return The ApplyHistoryBest in the current scope. diff --git a/include/tvm/meta_schedule/extracted_task.h b/include/tvm/meta_schedule/extracted_task.h index 898b974d8772..bed1428f8303 100644 --- a/include/tvm/meta_schedule/extracted_task.h +++ b/include/tvm/meta_schedule/extracted_task.h @@ -26,6 +26,15 @@ #include #include +namespace tvm { +namespace tir { +class PrimFunc; +} // namespace tir +namespace te { +class Tensor; +} // namespace te +} // namespace tvm + namespace tvm { namespace meta_schedule { @@ -67,6 +76,20 @@ class ExtractedTask : public runtime::ObjectRef { ExtractedTaskNode); }; +/*! + * \brief The default TE task filter + * \param args The input/output arguments of the TE compute graph + * \return NullOpt if the task is filtered out, otherwise the task in PrimFunc + */ +Optional DefaultTaskFilter(const Array& args); + +/*! + * \brief The default TE task filter, with `te.extern` allowed + * \param args The input/output arguments of the TE compute graph + * \return NullOpt if the task is filtered out, otherwise the task in PrimFunc + */ +Optional DefaultTaskFilterAllowExtern(const Array& args); + } // namespace meta_schedule } // namespace tvm diff --git a/python/tvm/meta_schedule/apply_history_best.py b/python/tvm/meta_schedule/apply_history_best.py index bcde7c97b04d..d618c3a04fa1 100644 --- a/python/tvm/meta_schedule/apply_history_best.py +++ b/python/tvm/meta_schedule/apply_history_best.py @@ -16,12 +16,14 @@ # under the License. """A context manager that injects the best tuning record in the database into compilation""" import logging -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union -from tvm._ffi import register_object +from tvm._ffi import get_global_func, register_object from tvm.ir import IRModule from tvm.runtime import Object from tvm.target import Target +from tvm.te import Tensor +from tvm.tir import PrimFunc from . import _ffi_api from .database import Database @@ -38,13 +40,29 @@ class ApplyHistoryBest(Object): ---------- database : Database The database to be queried from + te_filter_func : Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None + The filtering function for TE computation + If it's a string, it's the name of the filtering function. Built in functions are + - "meta_schedule.DefaultTaskFilter" + - "meta_schedule.DefaultTaskFilterAllowExtern" + If it's None, it's the default filtering function + If it's a callable, it's the filtering function """ database: Database - def __init__(self, database: Database) -> None: + def __init__( + self, + database: Database, + te_filter_func: Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None, + ) -> None: + if isinstance(te_filter_func, str): + te_filter_func = get_global_func(te_filter_func) self.__init_handle_by_constructor__( - _ffi_api.ApplyHistoryBest, database, make_logging_func(logger) # type: ignore # pylint: disable=no-member + _ffi_api.ApplyHistoryBest, # type: ignore # pylint: disable=no-member + database, + te_filter_func, + make_logging_func(logger), ) def query( diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index b55633817413..833f100a0d16 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """MetaSchedule-Relay integration""" -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np # type: ignore from tvm import nd @@ -24,6 +24,7 @@ from tvm.runtime import NDArray from tvm.target import Target from tvm.te import Tensor +from tvm.tir import PrimFunc from .extracted_task import ExtractedTask from .utils import autotvm_silencer @@ -37,7 +38,7 @@ def extract_task_from_relay( opt_level: int = 3, pass_config: Optional[Dict[str, Any]] = None, disabled_pass: Optional[List[str]] = None, - filter_func: Callable[[List[Tensor]], bool] = None, + te_filter_func: Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None, ) -> List[ExtractedTask]: """Extract tuning tasks from a relay program. @@ -55,8 +56,13 @@ def extract_task_from_relay( The pass config of the compiler disabled_pass : Optional[List[str]] The list of disabled passes of the compiler - filter_func : Callable[[List[tvm.te.Tensor]], bool] + te_filter_func : Callable[[List[tvm.te.Tensor]], bool] The filter function to filter out the extracted tasks + If it's a string, it's the name of the filtering function. Built in functions are + - "meta_schedule.DefaultTaskFilter" + - "meta_schedule.DefaultTaskFilterAllowExtern" + If it's None, it's the default filtering function + If it's a callable, it's the filtering function Returns ------- @@ -68,6 +74,8 @@ def extract_task_from_relay( # pylint: enable=import-outside-toplevel + if isinstance(te_filter_func, str): + te_filter_func = get_global_func(te_filter_func) extract_task_func = get_global_func( "relay.backend.MetaScheduleExtractTask", allow_missing=False, @@ -94,4 +102,4 @@ def extract_task_from_relay( config=pass_config, disabled_pass=disabled_pass, ): - return list(extract_task_func(mod, target, relay_params, filter_func)) + return list(extract_task_func(mod, target, relay_params, te_filter_func)) diff --git a/python/tvm/meta_schedule/testing/utils.py b/python/tvm/meta_schedule/testing/utils.py index f353d401a10c..bdd3852e40a3 100644 --- a/python/tvm/meta_schedule/testing/utils.py +++ b/python/tvm/meta_schedule/testing/utils.py @@ -30,6 +30,7 @@ def apply_fixed_schedules( target: Union[str, Target], params: Optional[Dict[str, NDArray]], schedule_fn: Callable[[ms.ExtractedTask, Schedule], bool], + te_filter_func=None, ): """Apply fixed schedules (manually written, without any tunable knobs) as specified by schedule_fn to extracted tasks, and return a database that can be passed to ApplyHistoryBest. @@ -45,6 +46,13 @@ def apply_fixed_schedules( schedule_fn : Callable[[ExtractedTask, Schedule], bool] A callable that is applied for each extracted task and the corresponding default schedule. Returns True if the given schedule should be committed to the database, False otherwise. + te_filter_func : Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None + The filtering function for TE computation + If it's a string, it's the name of the filtering function. Built in functions are + - "meta_schedule.DefaultTaskFilter" + - "meta_schedule.DefaultTaskFilterAllowExtern" + If it's None, it's the default filtering function + If it's a callable, it's the filtering function Returns ------- @@ -52,7 +60,12 @@ def apply_fixed_schedules( The database containing dummy tuning records for manually scheduled traces. """ target = Target(target) if isinstance(target, str) else target - extracted_tasks = ms.extract_task_from_relay(relay_mod, target, params) + extracted_tasks = ms.extract_task_from_relay( + relay_mod, + target, + params, + te_filter_func=te_filter_func, + ) database = ms.database.MemoryDatabase() for task in extracted_tasks: mod = ms.default_config.mod(task.dispatched[0]) diff --git a/src/meta_schedule/apply_history_best.cc b/src/meta_schedule/apply_history_best.cc index 18135811f5f1..e5cc929fd01f 100644 --- a/src/meta_schedule/apply_history_best.cc +++ b/src/meta_schedule/apply_history_best.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "./utils.h" namespace tvm { @@ -87,10 +89,16 @@ void ApplyHistoryBest::ExitWithScope() { /**************** ApplyHistoryBest ****************/ -ApplyHistoryBest::ApplyHistoryBest(Database database, PackedFunc logging_func) { +ApplyHistoryBest::ApplyHistoryBest(Database database, + ApplyHistoryBestNode::FTEFilterFunc te_filter_func, + PackedFunc logging_func) { ObjectPtr n = make_object(); n->database = database; + n->te_filter_func = te_filter_func; n->logging_func = logging_func; + if (te_filter_func == nullptr) { + n->te_filter_func = DefaultTaskFilter; + } data_ = n; } @@ -129,8 +137,9 @@ Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRModu TVM_REGISTER_NODE_TYPE(ApplyHistoryBestNode); TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBest") - .set_body_typed([](Database database, PackedFunc logging_func) -> ApplyHistoryBest { - return ApplyHistoryBest(database, logging_func); + .set_body_typed([](Database database, ApplyHistoryBestNode::FTEFilterFunc te_filter_func, + PackedFunc logging_func) -> ApplyHistoryBest { + return ApplyHistoryBest(database, te_filter_func, logging_func); }); TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBestEnterScope") .set_body_typed(ApplyHistoryBestInternal::EnterScope); diff --git a/src/meta_schedule/extracted_task.cc b/src/meta_schedule/extracted_task.cc index b1044fc87d0f..abd7235acb99 100644 --- a/src/meta_schedule/extracted_task.cc +++ b/src/meta_schedule/extracted_task.cc @@ -17,6 +17,12 @@ * under the License. */ #include +#include +#include +#include + +#include "../te/operation/create_primfunc.h" +#include "./utils.h" namespace tvm { namespace meta_schedule { @@ -32,12 +38,59 @@ ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target, data_ = n; } +Optional DefaultTaskFilterImpl(const Array& args, bool allow_extern_op) { + using namespace ::tvm::te; + std::vector stack; + std::unordered_set visited; + for (const Tensor& v : args) { + for (const PrimExpr& e : v->shape) { + // Dynamic shape is not supported for now + if (!e->IsInstance()) { + return NullOpt; + } + } + if (!visited.count(v.get())) { + visited.insert(v.get()); + stack.push_back(v); + } + } + while (!stack.empty()) { + Tensor tensor = stack.back(); + stack.pop_back(); + if (tensor->op->IsInstance()) { + // do nothing + } else if (tensor->op->IsInstance() || + (allow_extern_op && tensor->op->IsInstance())) { + Array inputs = tensor->op->InputTensors(); + for (const Tensor& v : inputs) { + if (!visited.count(v.get())) { + visited.insert(v.get()); + stack.push_back(v); + } + } + } else { + return NullOpt; + } + } + return te::CreatePrimFunc(args); +} + +Optional DefaultTaskFilter(const Array& args) { + return DefaultTaskFilterImpl(args, false); +} + +Optional DefaultTaskFilterAllowExtern(const Array& args) { + return DefaultTaskFilterImpl(args, true); +} + TVM_REGISTER_NODE_TYPE(ExtractedTaskNode); TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask") .set_body_typed([](String task_name, IRModule mod, Target target, Array dispatched, int weight) -> ExtractedTask { return ExtractedTask(task_name, mod, target, dispatched, weight); }); - +TVM_REGISTER_GLOBAL("meta_schedule.DefaultTaskFilter").set_body_typed(DefaultTaskFilter); +TVM_REGISTER_GLOBAL("meta_schedule.DefaultTaskFilterAllowExtern") + .set_body_typed(DefaultTaskFilterAllowExtern); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index c399696a82d7..76deb62f2376 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index 421a92c245e7..af4b49b4f1da 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -31,48 +31,12 @@ namespace tvm { namespace relay { namespace backend { -bool DefaultTaskFilter(const Array& args) { - using namespace ::tvm::te; - std::vector stack; - std::unordered_set visited; - for (const Tensor& v : args) { - for (const PrimExpr& e : v->shape) { - // Dynamic shape is not supported for now - if (!e->IsInstance()) { - return false; - } - } - if (!visited.count(v.get())) { - visited.insert(v.get()); - stack.push_back(v); - } - } - while (!stack.empty()) { - Tensor tensor = stack.back(); - stack.pop_back(); - if (tensor->op->IsInstance()) { - // do nothing - } else if (tensor->op->IsInstance() || tensor->op->IsInstance()) { - Array inputs = tensor->op->InputTensors(); - for (const Tensor& v : inputs) { - if (!visited.count(v.get())) { - visited.insert(v.get()); - stack.push_back(v); - } - } - } else { - return false; - } - } - return true; -} - Array ExtractTask( IRModule mod, Target target, Map params, - runtime::TypedPackedFunc&)> filter_func) { + runtime::TypedPackedFunc(const Array&)> filter_func) { using meta_schedule::ExtractedTask; if (filter_func == nullptr) { - filter_func = DefaultTaskFilter; + filter_func = tvm::meta_schedule::DefaultTaskFilter; } backend::BindParamsInModule(mod, params); // is_vm=true for backward compatibility @@ -98,11 +62,10 @@ Array ExtractTask( std::string fused_name; std::tie(inputs_outputs, fused_name) = tec::LowerTECompute(relay_func, target, /*return_inputs=*/true); - if (filter_func(inputs_outputs)) { - tir::PrimFunc prim_func = tir::CreatePrimFunc(inputs_outputs); + if (Optional prim_func = filter_func(inputs_outputs)) { GlobalVar prim_fn_var(fused_name); IRModule relay_mod({{prim_fn_var, relay_func}}); - IRModule tir_mod({{prim_fn_var, prim_func}}); + IRModule tir_mod({{prim_fn_var, prim_func.value()}}); ExtractedTask extracted_task(fused_name, relay_mod, target, {tir_mod}, 1); tasks.push_back(extracted_task); cache.emplace(cache_key, extracted_task); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index d219e9bb6787..5b23843c95e6 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -346,15 +346,18 @@ class ScheduleBuilder : public ExprVisitor { } } if (meta_schedule_ctx_) { - IRModule relay_mod({{prim_fn_var, relay_func}}); - IRModule tir_mod({{prim_fn_var, tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs))}}); - if (Optional scheduled_mod = meta_schedule_ctx_.value()->Query( - prim_fn_var->name_hint, relay_mod, target_, Array{tir_mod})) { - ICHECK_EQ(scheduled_mod.value()->functions.count(prim_fn_var), 1); - prim_func = Downcast(scheduled_mod.value()->functions[prim_fn_var]); + Array te_args = Concat(fn_inputs, tensor_outs); + if (Optional tir_func = + meta_schedule_ctx_.value()->te_filter_func(te_args)) { + IRModule relay_mod({{prim_fn_var, relay_func}}); + IRModule tir_mod({{prim_fn_var, tir_func.value()}}); + if (Optional scheduled_mod = meta_schedule_ctx_.value()->Query( + prim_fn_var->name_hint, relay_mod, target_, Array{tir_mod})) { + ICHECK_EQ(scheduled_mod.value()->functions.count(prim_fn_var), 1); + prim_func = Downcast(scheduled_mod.value()->functions[prim_fn_var]); + } } } - // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. if (!schedule.defined() && !prim_func.defined()) { if (anchor_op_.defined()) { diff --git a/tests/python/unittest/test_meta_schedule_relay_tir_compute.py b/tests/python/unittest/test_meta_schedule_relay_tir_compute.py index b208276539cc..058012cb643a 100644 --- a/tests/python/unittest/test_meta_schedule_relay_tir_compute.py +++ b/tests/python/unittest/test_meta_schedule_relay_tir_compute.py @@ -18,7 +18,7 @@ import tvm import tvm.testing import tvm.topi.testing -from tvm import autotvm, relay, te, tir +from tvm import autotvm, relay, te from tvm.meta_schedule import ApplyHistoryBest from tvm.meta_schedule.testing.utils import apply_fixed_schedules from tvm.relay.testing.temp_op_attr import TempOpAttr @@ -147,8 +147,17 @@ def schedule_fn(task, sch): return False with TempOpAttr("nn.conv2d", "FTVMStrategy", _tmp_strategy): - database = apply_fixed_schedules(relay_mod, target, params, schedule_fn) - with ApplyHistoryBest(database): + database = apply_fixed_schedules( + relay_mod, + target, + params, + schedule_fn, + te_filter_func="meta_schedule.DefaultTaskFilterAllowExtern", + ) + with ApplyHistoryBest( + database, + te_filter_func="meta_schedule.DefaultTaskFilterAllowExtern", + ): with tvm.transform.PassContext( opt_level=3, config={"relay.backend.use_meta_schedule": True},