diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 2a4a03bbe8e7..6f331499b042 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -31,23 +31,22 @@ We implement these in python to utilize python's multiprocessing and error handling. """ +import logging +import multiprocessing import os -import time import shutil import tempfile -import multiprocessing -import logging +import time import tvm._ffi -from tvm.runtime import Object, module, ndarray +from tvm.autotvm.env import AutotvmGlobalScope, reset_global_scope +from tvm.contrib import ndk, tar +from tvm.contrib.popen_pool import PopenPoolExecutor, PopenWorker, StatusKind from tvm.driver import build_module from tvm.ir import transform -from tvm.autotvm.env import AutotvmGlobalScope, reset_global_scope -from tvm.contrib import tar, ndk -from tvm.contrib.popen_pool import PopenWorker, PopenPoolExecutor, StatusKind +from tvm.runtime import Object, module, ndarray from tvm.target import Target - from . import _ffi_api from .loop_state import StateObject from .utils import ( @@ -59,8 +58,8 @@ request_remote, ) from .workload_registry import ( - serialize_workload_registry_entry, deserialize_workload_registry_entry, + serialize_workload_registry_entry, ) # pylint: disable=invalid-name @@ -555,8 +554,8 @@ def __init__( device=0, ): # pylint: disable=import-outside-toplevel - from tvm.rpc.tracker import Tracker from tvm.rpc.server import Server + from tvm.rpc.tracker import Tracker self.tracker = Tracker(port=9000, port_end=10000, silent=True) device_key = "$local$device$%d" % self.tracker.port @@ -630,7 +629,7 @@ def _local_build_worker(inp_serialized, build_func, verbose): filename = os.path.join(dirname, "tmp_func." + build_func.output_format) try: - with transform.PassContext(): + with transform.PassContext().current(): func = build_module.build(sch, args, target=task.target) func.export_library(filename, build_func) # pylint: disable=broad-except diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index f582bd1974aa..8fc0da89c4c6 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -31,9 +31,9 @@ import time import traceback import typing +import warnings from collections import namedtuple from random import getrandbits -import warnings import tvm._ffi import tvm.ir.transform @@ -505,10 +505,6 @@ def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option if not config.valid(): raise InstantiationError(config.errors) - opts = build_option or {} - if check_gpu: # Add verify pass to filter out invalid configs in advance. - opts["tir.add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))] - # if target is vta, we need to use vta build if ( hasattr(measure_input.target, "device_name") @@ -519,7 +515,28 @@ def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option func = vta.build(s, args, target_host=task.target_host) else: - with tvm.ir.transform.PassContext(config=opts): + current_pass_context: tvm.ir.transform.PassContext = ( + tvm.ir.transform.PassContext.current() + ) + current_config = dict(current_pass_context.config) + if build_option is not None: + current_config.update(build_option) + + if "tir.add_lower_pass" in current_config: + current_add_lower_pass = list(current_config["tir.add_lower_pass"]) + else: + current_add_lower_pass = [] + if check_gpu: + current_add_lower_pass.append((2, gpu_verify_pass(**check_gpu))) + current_config["tir.add_lower_pass"] = current_add_lower_pass + + with tvm.ir.transform.PassContext( + opt_level=current_pass_context.opt_level, + required_pass=current_pass_context.required_pass, + disabled_pass=current_pass_context.disabled_pass, + instruments=current_pass_context.instruments, + config=current_config, + ): func = build(s, args, target_host=task.target_host, runtime=runtime) return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args) diff --git a/tests/python/integration/test_tuning.py b/tests/python/integration/test_tuning.py index 03f38aa9cc9e..a3dca33e71ee 100644 --- a/tests/python/integration/test_tuning.py +++ b/tests/python/integration/test_tuning.py @@ -28,7 +28,11 @@ import tvm.relay import tvm.testing from tvm import autotvm, te +from tvm.autotvm.measure import measure_methods from tvm.autotvm.tuner import RandomTuner +from tvm.contrib import tar +from tvm.ir.instrument import pass_instrument +from tvm.ir.transform import PassContext from tvm.target import Target @@ -180,6 +184,114 @@ def runner(target, dev): run_test_with_all_multiprocessing(runner, target, dev) +@tvm.testing.parametrize_targets("cuda", "opencl") +def test_tuning_gpu_inherits_pass_context(target, dev): + """Autotvm tuner inherits PassContexts but also adds a gpu verification pass by default. + + Test that using PassContext inherits passes properly but also runs gpu verification pass. + """ + from tvm.tir.analysis import _ffi_api as _analysis_ffi_api + + @pass_instrument + class PassInstrumentChecker: + """Pass Instrument that simply sees if it's been run.""" + + def __init__(self): + self.has_been_run = False + + def run_after_pass(self, mod, info): + self.has_been_run = True + + class GPUVerifyPassMocked: + """Context manager that mocks tir.analysis.verify_gpu_code meant + to verify the pass has been run. This is done by patching the ffi func handles.""" + + FFI_FUNC_HANDLE = "tir.analysis.verify_gpu_code" + FUNC_NAME = "verify_gpu_code" + + def __init__(self) -> None: + self.old_impl = tvm._ffi.get_global_func(self.FFI_FUNC_HANDLE) + self.has_been_run = False + + def gpu_verify_pass_mocked(self): + """Get the replacement for the gpu verification pass.""" + + def _gpu_verify_pass_mocked(*args, **kwargs): + self.has_been_run = True + return self.old_impl(*args, **kwargs) + + return _gpu_verify_pass_mocked + + def __enter__(self): + tvm._ffi.register_func( + self.FFI_FUNC_HANDLE, self.gpu_verify_pass_mocked(), override=True + ) + + # Also overwrite the python bindings + setattr( + _analysis_ffi_api, self.FUNC_NAME, tvm._ffi.get_global_func(self.FFI_FUNC_HANDLE) + ) + + def __exit__(self, *args, **kwargs): + # Restore FFI status back to normal + tvm._ffi.register_func(self.FFI_FUNC_HANDLE, self.old_impl, override=True) + setattr(_analysis_ffi_api, self.FUNC_NAME, self.old_impl) + + class OverwrittenBuildFunc(measure_methods._WrappedBuildFunc): + """BuildFunc that mocks and patches as necessary to test proper passes are run.""" + + def __call__(self, measure_input, tmp_dir, **kwargs): + instrument = PassInstrumentChecker() + mocked_pass_checker = GPUVerifyPassMocked() + with mocked_pass_checker: + with PassContext(instruments=[instrument]): + regular_result = super().__call__(measure_input, tmp_dir, **kwargs) + + # Check instrument has been run, meaning context was inherited by builder + assert instrument.has_been_run + + # But also check the gpu verification pass has been run + # (which was not in the inherited ctx) + assert mocked_pass_checker.has_been_run + + return regular_result + + class MockedLocalBuilder(measure_methods.LocalBuilder): + """As measure_methods.LocalBuilder but overwrites the PassContext for testing.""" + + def __init__( + self, + timeout=10, + n_parallel=None, + build_kwargs=None, + build_func="default", + do_fork=False, + runtime=None, + ): + super().__init__(timeout, n_parallel, build_kwargs, build_func, do_fork, runtime) + self.build_func = OverwrittenBuildFunc(tar.tar, runtime) + + def runner(target, dev): + task, target = get_sample_task(target, None) + logging.info("task config space: %s", task.config_space) + + # Note: we use the MockedLocalBuilder here instead of autotvm.LocalBuilder() + measure_option = autotvm.measure_option(MockedLocalBuilder(), autotvm.LocalRunner()) + + results = [] + + tuner = RandomTuner(task) + tuner.tune( + n_trial=1, + measure_option=measure_option, + callbacks=(lambda _tuner, _inputs, rs: results.extend(rs),), + ) + + assert len(results) == 1 + + run_test_with_all_multiprocessing(runner, target, dev) + + def test_tuning_cpu(): def runner(): ir_mod = tvm.parser.fromtext(