From 269f20bcd90e681be1eb86ce6e30f4153a3f304b Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 9 Jul 2023 23:36:21 -0700 Subject: [PATCH] [Testing] Return BenchmarkResult in local_run and rpc_run Return `profile_result` in `tvm.testing.local_run` and `tvm.testing.rpc_run` for further use. --- python/tvm/testing/runner.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/tvm/testing/runner.py b/python/tvm/testing/runner.py index 5b677df4bd8f..03533ba167ab 100644 --- a/python/tvm/testing/runner.py +++ b/python/tvm/testing/runner.py @@ -22,6 +22,7 @@ if TYPE_CHECKING: import numpy as np + from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig from tvm.runtime import Device, Module, NDArray @@ -30,6 +31,7 @@ def _args_to_device(args, device): import numpy as np + from tvm.runtime.ndarray import NDArray, empty uploaded_args = [] @@ -109,6 +111,8 @@ def local_run( # pylint: disable=too-many-arguments,too-many-locals ------- args : List[Union[np.ndarray, NDArray, int, float]] The results of running the module. + profile_result : tvm.runtime.BenchmarkResult + The profiling result of running the module. """ import os.path as osp import tempfile @@ -137,13 +141,12 @@ def local_run( # pylint: disable=too-many-arguments,too-many-locals if evaluator_config.enable_cpu_cache_flush else "", )(*args) - print(profile_result) remote_mod(*args) args = _args_to_numpy(args) finally: pass - return args + return args, profile_result def rpc_run( # pylint: disable=too-many-arguments,too-many-locals @@ -188,6 +191,8 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals ------- args : List[Union[np.ndarray, NDArray, int, float]] The results of running the module. + profile_result : tvm.runtime.BenchmarkResult + The profiling result of running the module. """ import os.path as osp @@ -220,7 +225,6 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals if evaluator_config.enable_cpu_cache_flush else "", )(*args) - print(profile_result) remote_mod(*args) args = _args_to_numpy(args) finally: @@ -228,4 +232,4 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals session.remove(remote_path + "." + output_format) session.remove("") - return args + return args, profile_result