diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index d9391d0d713f..3c08b21f9511 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -19,8 +19,7 @@ import numpy as np # type: ignore import tvm.runtime.ndarray as nd - -from tvm._ffi import register_object, get_global_func +from tvm._ffi import get_global_func, register_object from tvm.ir import IRModule, transform from tvm.relay import Any from tvm.relay import Function as RelayFunc @@ -29,6 +28,7 @@ from . import _ffi_api from .database import Database +from .utils import autotvm_silencer @register_object("meta_schedule.ExtractedTask") @@ -234,7 +234,7 @@ def extract_task_from_relay( if not isinstance(target, Target): target = Target(target) - with target, transform.PassContext( + with autotvm_silencer(), target, transform.PassContext( opt_level=opt_level, config=pass_config, disabled_pass=disabled_pass, diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py b/python/tvm/meta_schedule/testing/custom_builder_runner.py index 87bad5a61caa..83bb4aab516b 100644 --- a/python/tvm/meta_schedule/testing/custom_builder_runner.py +++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py @@ -17,11 +17,12 @@ """Customized builder and runner methods""" # pylint: disable=import-outside-toplevel -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING, Callable, Dict, List if TYPE_CHECKING: + import numpy as np # type: ignore from tvm.ir import IRModule - from tvm.meta_schedule.runner import EvaluatorConfig + from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig from tvm.runtime import Device, Module, NDArray from tvm.target import Target @@ -138,3 +139,32 @@ def run_with_graph_executor( repeated_costs.append(profile_result.results) costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] return costs + + +def run_module_via_rpc( + rpc_config: "RPCConfig", + lib: "Module", + dev_type: str, + args: List["np.ndarray"], + continuation: Callable, +): + """Execute a tvm.runtime.Module on RPC remote""" + # pylint: disable=import-outside-toplevel + import os + import tempfile + + from tvm.contrib.tar import tar + from tvm.runtime import ndarray + + # pylint: enable=import-outside-toplevel + + with tempfile.TemporaryDirectory() as tmp_dir: + filename = os.path.join(tmp_dir, "tvm_tmp_mod." + tar.output_format) + lib.export_library(filename, tar) + session = rpc_config.connect_server() + session.upload(filename) + _, filename = os.path.split(filename) + rt_mod = session.load_module(filename) + dev = session.device(dev_type=dev_type, dev_id=0) + args = [ndarray.array(arg, dev) for arg in args] + return continuation(rt_mod, dev, *args) diff --git a/python/tvm/meta_schedule/testing/relay_workload.py b/python/tvm/meta_schedule/testing/relay_workload.py index 29cc70ad3e05..83a70abb7fc9 100644 --- a/python/tvm/meta_schedule/testing/relay_workload.py +++ b/python/tvm/meta_schedule/testing/relay_workload.py @@ -16,6 +16,7 @@ # under the License. """Workloads in Relay IR""" # pylint: disable=import-outside-toplevel +import logging import multiprocessing import os import pickle @@ -29,6 +30,8 @@ from tvm.runtime import NDArray, load_param_dict, save_param_dict from tvm.target import Target +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + def _get_network( args: Tuple[str, List[int]] @@ -170,7 +173,7 @@ def _load_cache(cache_dir: Optional[str], filename: str) -> Optional[List[Any]]: path = os.path.join(os.path.expanduser(cache_dir), filename) if not os.path.exists(path): return None - print(f"Load from cache: {path}") + logger.info("Loaded from cached: %s", path) with open(path, "rb") as i_f: return pickle.load(i_f) diff --git a/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py b/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py new file mode 100644 index 000000000000..37484226e85b --- /dev/null +++ b/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py @@ -0,0 +1,206 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import argparse +import json +import os + +import numpy as np # type: ignore +import tvm +from tvm import auto_scheduler +from tvm import meta_schedule as ms +from tvm import relay +from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc +from tvm.meta_schedule.testing.relay_workload import get_network + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--workload", + type=str, + required=True, + ) + args.add_argument( + "--input-shape", + type=str, + required=True, + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + required=True, + ) + args.add_argument( + "--rpc-port", + type=int, + required=True, + ) + args.add_argument( + "--rpc-key", + type=str, + required=True, + ) + args.add_argument( + "--rpc-workers", + type=int, + required=True, + ) + args.add_argument( + "--log-dir", + type=str, + required=True, + ) + args.add_argument( + "--cache-dir", + type=str, + default=None, + ) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + parsed.input_shape = json.loads(parsed.input_shape) + parsed.rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=3600, + ) + return parsed + + +ARGS = _parse_args() + + +def main(): + log_file = os.path.join(ARGS.log_dir, f"{ARGS.workload}.json") + + runner = auto_scheduler.RPCRunner( + key=ARGS.rpc_key, + host=ARGS.rpc_host, + port=ARGS.rpc_port, + n_parallel=ARGS.rpc_workers, + number=3, + repeat=1, + min_repeat_ms=100, # TODO + enable_cpu_cache_flush=False, # TODO + ) + + if ARGS.target.kind.name == "llvm": + hardware_params = auto_scheduler.HardwareParams( + num_cores=int(ARGS.target.attrs["num-cores"]), + target=ARGS.target, + ) + elif ARGS.target.kind.name == "cuda": + hardware_params = auto_scheduler.HardwareParams( + num_cores=-1, + vector_unit_bytes=16, + cache_line_bytes=64, + max_shared_memory_per_block=int(ARGS.target.attrs["max_shared_memory_per_block"]), + max_threads_per_block=int(ARGS.target.attrs["max_threads_per_block"]), + # The value `max_local_memory_per_block` is not used in AutoScheduler, + # but is required by the API. + max_local_memory_per_block=12345678, + max_vthread_extent=8, + warp_size=32, + ) + else: + raise NotImplementedError(f"Unsupported target {ARGS.target}") + mod, params, (input_name, input_shape, input_dtype) = get_network( + ARGS.workload, + ARGS.input_shape, + cache_dir=ARGS.cache_dir, + ) + print(f"Workload: {ARGS.workload}") + print(f" input_name: {input_name}") + print(f" input_shape: {input_shape}") + print(f" input_dtype: {input_dtype}") + tasks, task_weights = auto_scheduler.extract_tasks( + mod["main"], + params, + target=ARGS.target, + hardware_params=hardware_params, + ) + for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)): + print(f"==== Task {idx}: {task.desc} (weight {task_weight} key: {task.workload_key}) =====") + print(task.compute_dag) + + tuner = auto_scheduler.TaskScheduler(tasks, task_weights) + tuner.tune( + auto_scheduler.TuningOptions( + num_measure_trials=ARGS.num_trials, + runner=runner, + measure_callbacks=[ + auto_scheduler.RecordToFile(log_file), + ], + ) + ) + + with auto_scheduler.ApplyHistoryBest(log_file): + with tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_auto_scheduler": True}, + ): + lib = relay.build( + mod, + target=ARGS.target, + params=params, + ) + + if input_dtype.startswith("float"): + input_data = np.random.uniform(size=input_shape).astype(input_dtype) + else: + input_data = np.random.randint(low=0, high=10000, size=input_shape, dtype=input_dtype) + + def f_timer(rt_mod, dev, input_data): + # pylint: disable=import-outside-toplevel + from tvm.contrib.graph_executor import GraphModule + + # pylint: enable=import-outside-toplevel + + mod = GraphModule(rt_mod["default"](dev)) + mod.set_input(input_name, input_data) + ftimer = mod.module.time_evaluator( + "run", + dev, + min_repeat_ms=500, + repeat=3, + ) + return list(np.array(ftimer().results)) + + results = run_module_via_rpc( + rpc_config=ARGS.rpc_config, + lib=lib, + dev_type=ARGS.target.kind.name, + args=[input_data], + continuation=f_timer, + ) + + print(results) + + +if __name__ == "__main__": + main() diff --git a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py new file mode 100644 index 000000000000..c353684de52c --- /dev/null +++ b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py @@ -0,0 +1,161 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import argparse +import json +import logging + +import numpy as np # type: ignore +import tvm +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc +from tvm.meta_schedule.testing.relay_workload import get_network + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--workload", + type=str, + required=True, + ) + args.add_argument( + "--input-shape", + type=str, + required=True, + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + required=True, + ) + args.add_argument( + "--rpc-port", + type=int, + required=True, + ) + args.add_argument( + "--rpc-key", + type=str, + required=True, + ) + args.add_argument( + "--rpc-workers", + type=int, + required=True, + ) + args.add_argument( + "--work-dir", + type=str, + required=True, + ) + args.add_argument( + "--cache-dir", + type=str, + default=None, + ) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + parsed.input_shape = json.loads(parsed.input_shape) + parsed.rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=60, + ) + return parsed + + +logging.basicConfig() +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) +ARGS = _parse_args() + + +def main(): + mod, params, (input_name, input_shape, input_dtype) = get_network( + ARGS.workload, + ARGS.input_shape, + cache_dir=ARGS.cache_dir, + ) + alloc_repeat = 1 + runner = ms.runner.RPCRunner( + rpc_config=ARGS.rpc_config, + evaluator_config=ms.runner.EvaluatorConfig( + number=3, + repeat=1, + min_repeat_ms=100, + enable_cpu_cache_flush=False, + ), + alloc_repeat=alloc_repeat, + max_workers=ARGS.rpc_workers, + ) + lib = ms.tune_relay( + mod=mod, + target=ARGS.target, + config=ms.EvolutionarySearchConfig( + num_trials_per_iter=64, + num_trials_total=ARGS.num_trials, + init_min_unmeasured=50, + ), + runner=runner, # type: ignore + work_dir=ARGS.work_dir, + params=params, + ) + if input_dtype.startswith("float"): + input_data = np.random.uniform(size=input_shape).astype(input_dtype) + else: + input_data = np.random.randint(low=0, high=10000, size=input_shape, dtype=input_dtype) + + def f_timer(rt_mod, dev, input_data): + # pylint: disable=import-outside-toplevel + from tvm.contrib.graph_executor import GraphModule + + # pylint: enable=import-outside-toplevel + + mod = GraphModule(rt_mod["default"](dev)) + mod.set_input(input_name, input_data) + ftimer = mod.module.time_evaluator( + "run", + dev, + min_repeat_ms=500, + repeat=3, + ) + return list(np.array(ftimer().results)) + + results = run_module_via_rpc( + rpc_config=ARGS.rpc_config, + lib=lib, + dev_type=ARGS.target.kind.name, + args=[input_data], + continuation=f_timer, + ) + + print(results) + + +if __name__ == "__main__": + main() diff --git a/python/tvm/meta_schedule/testing/run_subgraph_auto_scheduler.py b/python/tvm/meta_schedule/testing/tune_te_auto_scheduler.py similarity index 100% rename from python/tvm/meta_schedule/testing/run_subgraph_auto_scheduler.py rename to python/tvm/meta_schedule/testing/tune_te_auto_scheduler.py diff --git a/python/tvm/meta_schedule/testing/run_subgraph_meta_schedule.py b/python/tvm/meta_schedule/testing/tune_te_meta_schedule.py similarity index 100% rename from python/tvm/meta_schedule/testing/run_subgraph_meta_schedule.py rename to python/tvm/meta_schedule/testing/tune_te_meta_schedule.py index 50ab5b93937d..ceace160ea57 100644 --- a/python/tvm/meta_schedule/testing/run_subgraph_meta_schedule.py +++ b/python/tvm/meta_schedule/testing/tune_te_meta_schedule.py @@ -43,11 +43,6 @@ def _parse_args(): type=int, required=True, ) - args.add_argument( - "--work-dir", - type=str, - required=True, - ) args.add_argument( "--rpc-host", type=str, @@ -68,6 +63,11 @@ def _parse_args(): type=int, required=True, ) + args.add_argument( + "--work-dir", + type=str, + required=True, + ) parsed = args.parse_args() parsed.target = tvm.target.Target(parsed.target) parsed.rpc_config = ms.runner.RPCConfig( diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 27d1fbcf1a91..7d751ea12fcb 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -19,7 +19,8 @@ import json import os import shutil -from typing import Any, List, Optional, Union, Callable +from contextlib import contextmanager +from typing import Any, Callable, List, Optional, Union import psutil # type: ignore import tvm @@ -132,14 +133,17 @@ def __setattr__(self, name, value): @register_func("meta_schedule.cpu_count") def _cpu_count_impl(logical: bool = True) -> int: """Return the number of logical or physical CPUs in the system + Parameters ---------- logical : bool = True If True, return the number of logical CPUs, otherwise return the number of physical CPUs + Returns ------- cpu_count : int The number of logical or physical CPUs in the system + Note ---- The meta schedule search infra intentionally does not adopt the following convention in TVM: @@ -356,3 +360,16 @@ def _to_hex_address(handle: ctypes.c_void_p) -> str: The hexadecimal address of the handle. """ return hex(ctypes.cast(handle, ctypes.c_void_p).value) + + +@contextmanager +def autotvm_silencer(): + """A context manager that silences autotvm warnings.""" + from tvm import autotvm # pylint: disable=import-outside-toplevel + + silent = autotvm.GLOBAL_SCOPE.silent + autotvm.GLOBAL_SCOPE.silent = True + try: + yield + finally: + autotvm.GLOBAL_SCOPE.silent = silent diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index f4595d3b524c..3b7fd0200e1e 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -66,10 +66,8 @@ void TuneContextNode::Initialize() { for (const Postproc& postproc : postprocs) { postproc->InitializeWithTuneContext(GetRef(this)); } - if (mutator_probs.defined()) { - for (const auto& kv : mutator_probs) { - kv.first->InitializeWithTuneContext(GetRef(this)); - } + for (const auto& kv : mutator_probs) { + kv.first->InitializeWithTuneContext(GetRef(this)); } }