diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index ee166e867916..9541232a6a38 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -483,7 +483,4 @@ def is_auto_scheduler_enabled(): return PassContext.current().config.get( "relay.backend.use_auto_scheduler", False, - ) or PassContext.current().config.get( - "relay.backend.use_meta_schedule", - False, ) diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 26cf446b1090..eb40b32e7c29 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -30,10 +30,10 @@ search_strategy, space_generator, ) -from .profiler import Profiler from .apply_history_best import ApplyHistoryBest from .extracted_task import ExtractedTask -from .relay_integration import extract_task_from_relay +from .profiler import Profiler +from .relay_integration import extract_task_from_relay, is_meta_schedule_enabled from .search_strategy import MeasureCandidate from .tune import TuneConfig, tune_extracted_tasks, tune_relay, tune_te, tune_tir from .tune_context import TuneContext diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index 833f100a0d16..84a6c559562a 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -103,3 +103,17 @@ def extract_task_from_relay( disabled_pass=disabled_pass, ): return list(extract_task_func(mod, target, relay_params, te_filter_func)) + + +def is_meta_schedule_enabled() -> bool: + """Return whether the meta-schedule is enabled. + + Returns + ------- + enabled: bool + Whether the meta schedule is enabled + """ + return transform.PassContext.current().config.get( + "relay.backend.use_meta_schedule", + False, + ) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 6ccb449d0e08..4c5af610d709 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -15,16 +15,19 @@ # specific language governing permissions and limitations # under the License. """Definition of ARM CPU operator strategy.""" +import logging + # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import import re -import logging from tvm import relay, topi + +from ....auto_scheduler import is_auto_scheduler_enabled +from ....meta_schedule import is_meta_schedule_enabled from ....target import arm_isa from ....topi.generic import conv2d as conv2d_generic -from ....auto_scheduler import is_auto_scheduler_enabled -from .generic import * from .. import op as _op +from .generic import * logger = logging.getLogger("strategy") @@ -477,7 +480,9 @@ def schedule_dense_arm_cpu(attrs, inputs, out_type, target): logger.warning("dense is not optimized for arm cpu.") strategy.add_implementation( wrap_compute_dense( - topi.nn.dense, need_auto_scheduler_layout=is_auto_scheduler_enabled() + topi.nn.dense, + need_auto_scheduler_layout=is_auto_scheduler_enabled(), + need_meta_schedule_layout=is_meta_schedule_enabled(), ), wrap_topi_schedule(topi.generic.schedule_dense), name="dense.generic", diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 4a7cff5f3f33..072b958da213 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -20,11 +20,12 @@ from tvm.auto_scheduler import is_auto_scheduler_enabled from tvm.contrib import nvcc from tvm.contrib.thrust import can_use_thrust +from tvm.meta_schedule import is_meta_schedule_enabled from tvm.te import SpecializedCondition -from .. import op as _op from ....target import Target from ....tir import IntImm +from .. import op as _op from .generic import * @@ -251,7 +252,17 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): ) # register auto-scheduler implementations - if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler: + if ( + is_auto_scheduler_enabled() or is_meta_schedule_enabled() + ) and judge_winograd_auto_scheduler: + strategy.add_implementation( + wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc), + naive_schedule, # this implementation should never be picked by autotvm + name="conv2d_nhwc.winograd", + plevel=15, + ) + # register meta-schedule implementations + if is_meta_schedule_enabled() and judge_winograd_auto_scheduler: strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc), naive_schedule, # this implementation should never be picked by autotvm @@ -534,7 +545,14 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty name="conv2d_nhwc_winograd_direct_without_weight_transform.cuda", ) - if is_auto_scheduler_enabled(): + if is_auto_scheduler_enabled() or is_meta_schedule_enabled(): + strategy.add_implementation( + wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc_without_weight_transform), + naive_schedule, # this implementation should never be picked by autotvm + name="conv2d_nhwc_winograd_without_weight_transform", + plevel=15, + ) + if is_meta_schedule_enabled(): strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc_without_weight_transform), naive_schedule, # this implementation should never be picked by autotvm @@ -805,7 +823,13 @@ def matmul_strategy_cuda(attrs, inputs, out_type, target): """Matmul cuda strategy.""" strategy = _op.OpStrategy() - if is_auto_scheduler_enabled(): + if is_auto_scheduler_enabled() or is_meta_schedule_enabled(): + strategy.add_implementation( + wrap_compute_matmul(topi.nn.matmul), + naive_schedule, + name="matmul.cuda", + ) + elif is_meta_schedule_enabled(): strategy.add_implementation( wrap_compute_matmul(topi.nn.matmul), naive_schedule, diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 2bb009dbc8f7..4ff7490b89ac 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -21,7 +21,12 @@ from tvm import _ffi, ir, te, topi from tvm.target import generic_func, override_native_generic_func -from tvm.topi.utils import get_const_float, get_const_int, get_const_tuple, get_float_tuple +from tvm.topi.utils import ( + get_const_float, + get_const_int, + get_const_tuple, + get_float_tuple, +) from .. import op as _op @@ -211,6 +216,9 @@ def schedule_bitpack(attrs, outs, target): get_auto_scheduler_rewritten_layout = _ffi.get_global_func( "relay.attrs.get_auto_scheduler_rewritten_layout" ) +get_meta_schedule_original_shape = _ffi.get_global_func( + "relay.attrs.get_meta_schedule_original_shape" +) # conv2d def wrap_compute_conv2d( @@ -219,6 +227,7 @@ def wrap_compute_conv2d( need_out_layout=False, has_groups=False, need_auto_scheduler_layout=False, + need_meta_schedule_layout=False, ): """Wrap conv2d topi compute""" @@ -240,6 +249,9 @@ def _compute_conv2d(attrs, inputs, out_type): args.append(out_dtype) if need_auto_scheduler_layout: args.append(get_auto_scheduler_rewritten_layout(attrs)) + elif need_meta_schedule_layout: + args.append("") + args.append(get_meta_schedule_original_shape(attrs)) return [topi_compute(*args)] return _compute_conv2d @@ -530,7 +542,12 @@ def conv3d_transpose_strategy(attrs, inputs, out_type, target): # conv3d -def wrap_compute_conv3d(topi_compute, need_layout=False, need_auto_scheduler_layout=False): +def wrap_compute_conv3d( + topi_compute, + need_layout=False, + need_auto_scheduler_layout=False, + need_meta_schedule_layout=False, +): """wrap conv3d topi compute""" def _compute_conv3d(attrs, inputs, out_type): @@ -552,6 +569,9 @@ def _compute_conv3d(attrs, inputs, out_type): args.append(out_dtype) if need_auto_scheduler_layout: args.append(get_auto_scheduler_rewritten_layout(attrs)) + elif need_meta_schedule_layout: + args.append("") + args.append(get_meta_schedule_original_shape(attrs)) return [topi_compute(*args)] return _compute_conv3d @@ -782,7 +802,11 @@ def copy_if_identical(tensor_a, tensor_b): # matmul -def wrap_compute_matmul(topi_compute, need_auto_scheduler_layout=False): +def wrap_compute_matmul( + topi_compute, + need_auto_scheduler_layout=False, + need_meta_schedule_layout=False, +): """wrap matmul topi compute""" def _compute_matmul(attrs, inputs, out_type): @@ -799,6 +823,9 @@ def _compute_matmul(attrs, inputs, out_type): ] if need_auto_scheduler_layout: args.append(get_auto_scheduler_rewritten_layout(attrs)) + elif need_meta_schedule_layout: + args.append("") + args.append(get_meta_schedule_original_shape(attrs)) args[1] = copy_if_identical(inputs[0], inputs[1]) return [topi_compute(*args)] @@ -819,7 +846,11 @@ def matmul_strategy(attrs, inputs, out_type, target): # dense -def wrap_compute_dense(topi_compute, need_auto_scheduler_layout=False): +def wrap_compute_dense( + topi_compute, + need_auto_scheduler_layout=False, + need_meta_schedule_layout=False, +): """wrap dense topi compute""" def _compute_dense(attrs, inputs, out_type): @@ -829,6 +860,9 @@ def _compute_dense(attrs, inputs, out_type): args = [inputs[0], inputs[1], None, out_dtype] if need_auto_scheduler_layout: args.append(get_auto_scheduler_rewritten_layout(attrs)) + elif need_meta_schedule_layout: + args.append("") + args.append(get_meta_schedule_original_shape(attrs)) args[1] = copy_if_identical(inputs[0], inputs[1]) return [topi_compute(*args)] @@ -862,7 +896,13 @@ def dense_pack_strategy(attrs, inputs, out_type, target): # batch_matmul -def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False, need_out_dtype=False): +def wrap_compute_batch_matmul( + topi_compute, + *, + need_auto_scheduler_layout=False, + need_meta_schedule_layout=False, + need_out_dtype=False, +): """wrap batch_matmul topi compute""" def _compute_batch_matmul(attrs, inputs, out_type): @@ -872,6 +912,9 @@ def _compute_batch_matmul(attrs, inputs, out_type): args.append(attrs.transpose_b) if need_auto_scheduler_layout: args.append(get_auto_scheduler_rewritten_layout(attrs)) + elif need_meta_schedule_layout: + args.append("") + args.append(get_meta_schedule_original_shape(attrs)) args[1] = copy_if_identical(inputs[0], inputs[1]) return [topi_compute(*args)] diff --git a/python/tvm/relay/op/strategy/mali.py b/python/tvm/relay/op/strategy/mali.py index e5f4b4e58562..dca684835ba4 100644 --- a/python/tvm/relay/op/strategy/mali.py +++ b/python/tvm/relay/op/strategy/mali.py @@ -17,10 +17,13 @@ """Definition of mali operator strategy.""" # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import import re + from tvm import topi from tvm.auto_scheduler import is_auto_scheduler_enabled -from .generic import * +from tvm.meta_schedule import is_meta_schedule_enabled + from .. import op as _op +from .generic import * @conv2d_strategy.register("mali") @@ -72,15 +75,15 @@ def conv2d_strategy_mali(attrs, inputs, out_type, target): ) elif layout == "NHWC": assert kernel_layout == "HWIO" - if not is_auto_scheduler_enabled(): - strategy.add_implementation( - wrap_compute_conv2d(topi.mali.conv2d_nhwc_spatial_pack), - wrap_topi_schedule(topi.mali.schedule_conv2d_nhwc_spatial_pack), - name="conv2d_nhwc_spatial_pack.mali", - ) - else: + need_auto_scheduler_layout = is_auto_scheduler_enabled() + need_meta_schedule_layout = is_meta_schedule_enabled() + if need_auto_scheduler_layout or need_meta_schedule_layout: strategy.add_implementation( - wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True), + wrap_compute_conv2d( + topi.nn.conv2d_nhwc, + need_auto_scheduler_layout=need_auto_scheduler_layout, + need_meta_schedule_layout=need_meta_schedule_layout, + ), naive_schedule, name="conv2d_nhwc.mali", ) @@ -98,14 +101,36 @@ def conv2d_strategy_mali(attrs, inputs, out_type, target): and dilation_w == 1 ) if is_winograd_applicable: - strategy.add_implementation( - wrap_compute_conv2d( - topi.nn.conv2d_winograd_nhwc, need_auto_scheduler_layout=True - ), - naive_schedule, # this implementation should never be picked by autotvm - name="conv2d_nhwc.winograd", - plevel=15, - ) + if need_meta_schedule_layout: + strategy.add_implementation( + wrap_compute_conv2d( + topi.nn.conv2d_winograd_nhwc, + need_auto_scheduler_layout=False, + need_meta_schedule_layout=True, + ), + naive_schedule, # this implementation should never be picked by autotvm + name="conv2d_nhwc.winograd", + plevel=15, + ) + elif need_auto_scheduler_layout: + strategy.add_implementation( + wrap_compute_conv2d( + topi.nn.conv2d_winograd_nhwc, + need_auto_scheduler_layout=True, + need_meta_schedule_layout=False, + ), + naive_schedule, # this implementation should never be picked by autotvm + name="conv2d_nhwc.winograd", + plevel=15, + ) + else: + raise RuntimeError("Both AutoScheduler and MetaSchedule are not enabled") + else: + strategy.add_implementation( + wrap_compute_conv2d(topi.mali.conv2d_nhwc_spatial_pack), + wrap_topi_schedule(topi.mali.schedule_conv2d_nhwc_spatial_pack), + name="conv2d_nhwc_spatial_pack.mali", + ) else: raise RuntimeError("Unsupported conv2d layout {} for mali".format(layout)) @@ -119,18 +144,24 @@ def conv2d_strategy_mali(attrs, inputs, out_type, target): ) elif layout == "NHWC": assert kernel_layout == "HWOI" - if not is_auto_scheduler_enabled(): + if is_auto_scheduler_enabled(): strategy.add_implementation( - wrap_compute_conv2d(topi.mali.depthwise_conv2d_nhwc), - wrap_topi_schedule(topi.mali.schedule_depthwise_conv2d_nhwc), + wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), + naive_schedule, name="depthwise_conv2d_nhwc.mali", ) - else: + elif is_meta_schedule_enabled(): strategy.add_implementation( wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), naive_schedule, name="depthwise_conv2d_nhwc.mali", ) + else: + strategy.add_implementation( + wrap_compute_conv2d(topi.mali.depthwise_conv2d_nhwc), + wrap_topi_schedule(topi.mali.schedule_depthwise_conv2d_nhwc), + name="depthwise_conv2d_nhwc.mali", + ) else: raise RuntimeError("Unsupported depthwise_conv2d layout {} for mali".format(layout)) else: # group_conv2d @@ -158,19 +189,23 @@ def conv2d_winograd_without_weight_transfrom_strategy_mali(attrs, inputs, out_ty name="conv2d_nchw_winograd.mali", ) elif layout == "NHWC": - if not is_auto_scheduler_enabled(): + need_auto_scheduler_layout = is_auto_scheduler_enabled() + need_meta_schedule_layout = is_meta_schedule_enabled() + if need_auto_scheduler_layout or need_meta_schedule_layout: + strategy.add_implementation( + wrap_compute_conv2d( + topi.nn.conv2d_winograd_nhwc_without_weight_transform, + need_auto_scheduler_layout=need_auto_scheduler_layout, + need_meta_schedule_layout=need_meta_schedule_layout, + ), + naive_schedule, # this implementation should never be picked by autotvm + name="conv2d_nhwc_winograd_without_weight_transform", + plevel=15, + ) + else: raise RuntimeError( "Winograd conv2d NHWC is not enabled for mali without auto_scheduler." ) - strategy.add_implementation( - wrap_compute_conv2d( - topi.nn.conv2d_winograd_nhwc_without_weight_transform, - need_auto_scheduler_layout=True, - ), - naive_schedule, # this implementation should never be picked by autotvm - name="conv2d_nhwc_winograd_without_weight_transform", - plevel=15, - ) else: raise RuntimeError( "Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout) @@ -182,16 +217,22 @@ def conv2d_winograd_without_weight_transfrom_strategy_mali(attrs, inputs, out_ty def dense_strategy_mali(attrs, inputs, out_type, target): """dense mali strategy""" strategy = _op.OpStrategy() - if not is_auto_scheduler_enabled(): + if is_auto_scheduler_enabled(): strategy.add_implementation( - wrap_compute_dense(topi.mali.dense), - wrap_topi_schedule(topi.mali.schedule_dense), + wrap_compute_dense(topi.nn.dense, need_auto_scheduler_layout=True), + naive_schedule, name="dense.mali", ) - else: + elif is_meta_schedule_enabled(): strategy.add_implementation( - wrap_compute_dense(topi.nn.dense, need_auto_scheduler_layout=True), + wrap_compute_dense(topi.nn.dense, need_meta_schedule_layout=True), naive_schedule, name="dense.mali", ) + else: + strategy.add_implementation( + wrap_compute_dense(topi.mali.dense), + wrap_topi_schedule(topi.mali.schedule_dense), + name="dense.mali", + ) return strategy diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index a032fd00bf34..abbc9d9a4c57 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -17,16 +17,18 @@ """Definition of x86 operator strategy.""" # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import import logging - import re -from tvm import topi, tir -from tvm.topi.x86.utils import target_has_vnni + +from tvm import tir, topi from tvm.auto_scheduler import is_auto_scheduler_enabled -from tvm.te import SpecializedCondition +from tvm.meta_schedule import is_meta_schedule_enabled from tvm.relay.ty import is_dynamic from tvm.target import Target -from .generic import * +from tvm.te import SpecializedCondition +from tvm.topi.x86.utils import target_has_vnni + from .. import op as _op +from .generic import * logger = logging.getLogger("strategy") @@ -111,6 +113,9 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): if dilation_h < 1 or dilation_w < 1: raise ValueError("dilation should be positive value") + need_auto_scheduler_layout = is_auto_scheduler_enabled() + need_meta_schedule_layout = is_meta_schedule_enabled() + if groups == 1: if layout == "NCHW": assert kernel_layout == "OIHW" @@ -137,7 +142,7 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target) elif layout == "NHWC": assert kernel_layout == "HWIO" - if not is_auto_scheduler_enabled(): + if (not need_auto_scheduler_layout) and (not need_meta_schedule_layout): logger.warning("conv2d NHWC layout is not optimized for x86 with autotvm.") if "dnnl" in target.libs: strategy.add_implementation( @@ -147,7 +152,11 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): ) else: strategy.add_implementation( - wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True), + wrap_compute_conv2d( + topi.nn.conv2d_nhwc, + need_auto_scheduler_layout=need_auto_scheduler_layout, + need_meta_schedule_layout=need_meta_schedule_layout, + ), wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc), name="conv2d_nhwc.x86", ) @@ -171,10 +180,14 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): ) # register auto-scheduler implementations - if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler: + if ( + need_auto_scheduler_layout or need_meta_schedule_layout + ) and judge_winograd_auto_scheduler: strategy.add_implementation( wrap_compute_conv2d( - topi.nn.conv2d_winograd_nhwc, need_auto_scheduler_layout=True + topi.nn.conv2d_winograd_nhwc, + need_auto_scheduler_layout=need_auto_scheduler_layout, + need_meta_schedule_layout=need_meta_schedule_layout, ), naive_schedule, # this implementation should never be picked by autotvm name="conv2d_nhwc.winograd", @@ -182,7 +195,7 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): ) elif layout == "HWCN": assert kernel_layout == "HWIO" - if not is_auto_scheduler_enabled(): + if (not need_auto_scheduler_layout) or (not need_meta_schedule_layout): logger.warning("conv2d HWCN layout is not optimized for x86 with autotvm.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_hwcn), @@ -216,7 +229,7 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): return depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target) elif layout == "NHWC": assert kernel_layout == "HWOI" - if not is_auto_scheduler_enabled(): + if (not need_auto_scheduler_layout) and (not need_meta_schedule_layout): logger.warning( "depthwise_conv2d NHWC layout is not optimized for x86 with autotvm." ) @@ -237,7 +250,7 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): ) elif layout == "NHWC": assert kernel_layout == "HWIO" - if not is_auto_scheduler_enabled(): + if (not need_auto_scheduler_layout) and (not need_meta_schedule_layout): logger.warning("group_conv2d is not optimized for x86 with autotvm.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.group_conv2d_nhwc, has_groups=True), @@ -328,7 +341,9 @@ def conv3d_strategy_cpu(attrs, inputs, out_type, target): """conv3d generic strategy""" strategy = _op.OpStrategy() layout = attrs.data_layout - if is_auto_scheduler_enabled(): + need_auto_scheduler_layout = is_auto_scheduler_enabled() + need_meta_schedule_layout = is_meta_schedule_enabled() + if need_auto_scheduler_layout or need_meta_schedule_layout: # Use auto-scheduler. We should provide clear compute definition without autotvm templates # or packed layouts. if layout == "NCDHW": @@ -339,7 +354,11 @@ def conv3d_strategy_cpu(attrs, inputs, out_type, target): ) elif layout == "NDHWC": strategy.add_implementation( - wrap_compute_conv3d(topi.nn.conv3d_ndhwc, need_auto_scheduler_layout=True), + wrap_compute_conv3d( + topi.nn.conv3d_ndhwc, + need_auto_scheduler_layout=need_auto_scheduler_layout, + need_meta_schedule_layout=need_meta_schedule_layout, + ), naive_schedule, name="conv3d_ndhwc.x86", ) @@ -456,9 +475,15 @@ def matmul_strategy_cpu(attrs, inputs, out_type, target): if length_before == length_after: logger.warning("Currently dnnl only support the data type to be float32. Skip.") - if is_auto_scheduler_enabled(): + need_auto_scheduler_layout = is_auto_scheduler_enabled() + need_meta_schedule_layout = is_meta_schedule_enabled() + if need_auto_scheduler_layout or need_meta_schedule_layout: strategy.add_implementation( - wrap_compute_matmul(topi.nn.matmul, need_auto_scheduler_layout=True), + wrap_compute_matmul( + topi.nn.matmul, + need_auto_scheduler_layout=need_auto_scheduler_layout, + need_meta_schedule_layout=need_meta_schedule_layout, + ), naive_schedule, name="matmul.generic", plevel=11, @@ -499,9 +524,16 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): plevel=10, ) - if is_auto_scheduler_enabled(): + need_auto_scheduler_layout = is_auto_scheduler_enabled() + need_meta_schedule_layout = is_meta_schedule_enabled() + + if need_auto_scheduler_layout or need_meta_schedule_layout: strategy.add_implementation( - wrap_compute_dense(topi.nn.dense, need_auto_scheduler_layout=True), + wrap_compute_dense( + topi.nn.dense, + need_auto_scheduler_layout=need_auto_scheduler_layout, + need_meta_schedule_layout=need_meta_schedule_layout, + ), naive_schedule, name="dense.generic", plevel=11, @@ -568,6 +600,9 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): strategy = _op.OpStrategy() mcpu = Target.current().mcpu + need_auto_scheduler_layout = is_auto_scheduler_enabled() + need_meta_schedule_layout = is_meta_schedule_enabled() + if ( not attrs.transpose_a and attrs.transpose_b @@ -583,10 +618,13 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): name="batch_matmul_vnni.x86", plevel=10, ) - elif is_dynamic(out_type) or is_auto_scheduler_enabled(): + elif is_dynamic(out_type) or need_auto_scheduler_layout or need_meta_schedule_layout: strategy.add_implementation( wrap_compute_batch_matmul( - topi.nn.batch_matmul, need_auto_scheduler_layout=True, need_out_dtype=True + topi.nn.batch_matmul, + need_out_dtype=True, + need_auto_scheduler_layout=need_auto_scheduler_layout, + need_meta_schedule_layout=need_meta_schedule_layout, ), wrap_topi_schedule(topi.generic.nn.schedule_batch_matmul), name="batch_matmul.generic", @@ -733,15 +771,31 @@ def conv2d_winograd_without_weight_transfrom_strategy_cpu(attrs, inputs, out_typ assert strides == (1, 1), "Do not support strides now" assert groups == 1, "Do not supoort arbitrary group number" strategy = _op.OpStrategy() + need_auto_scheduler_layout = is_auto_scheduler_enabled() + need_meta_schedule_layout = is_meta_schedule_enabled() if layout == "NHWC": - strategy.add_implementation( - wrap_compute_conv2d( - topi.nn.conv2d_winograd_nhwc_without_weight_transform, - need_auto_scheduler_layout=True, - ), - naive_schedule, - name="ansor.winograd", - ) + if need_meta_schedule_layout: + strategy.add_implementation( + wrap_compute_conv2d( + topi.nn.conv2d_winograd_nhwc_without_weight_transform, + need_auto_scheduler_layout=False, + need_meta_schedule_layout=True, + ), + naive_schedule, + name="ansor.winograd", + ) + elif need_auto_scheduler_layout: + strategy.add_implementation( + wrap_compute_conv2d( + topi.nn.conv2d_winograd_nhwc_without_weight_transform, + need_auto_scheduler_layout=True, + need_meta_schedule_layout=False, + ), + naive_schedule, + name="ansor.winograd", + ) + else: + raise RuntimeError("Both AutoScheduler and MetaSchedule are not enabled") else: raise RuntimeError( "Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout)