From 0c81811fb192ad3b202e212a7e68965a367effa7 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Thu, 12 May 2022 14:53:31 -0700 Subject: [PATCH 01/12] fix issues for models with more than one inputs --- .../testing/custom_builder_runner.py | 2 +- .../testing/tune_relay_meta_schedule.py | 32 +++++++++++-------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py b/python/tvm/meta_schedule/testing/custom_builder_runner.py index 83bb4aab516b..8019e7f6ac2a 100644 --- a/python/tvm/meta_schedule/testing/custom_builder_runner.py +++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py @@ -167,4 +167,4 @@ def run_module_via_rpc( rt_mod = session.load_module(filename) dev = session.device(dev_type=dev_type, dev_id=0) args = [ndarray.array(arg, dev) for arg in args] - return continuation(rt_mod, dev, *args) + return continuation(rt_mod, dev, args) diff --git a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py index 88de0c336073..bd608bbaed7c 100644 --- a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py +++ b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py @@ -103,10 +103,13 @@ def main(): ARGS.input_shape, cache_dir=ARGS.cache_dir, ) + input_info = {input_name: input_shape} + inputs = [] print(f"Workload: {ARGS.workload}") - print(f" input_name: {input_name}") - print(f" input_shape: {input_shape}") - print(f" input_dtype: {input_dtype}") + for input_name, input_shape in input_info.items(): + print(f" input_name: {input_name}") + print(f" input_shape: {input_shape}") + print(f" input_dtype: {input_dtype}") alloc_repeat = 1 runner = ms.runner.RPCRunner( rpc_config=ARGS.rpc_config, @@ -133,19 +136,21 @@ def main(): params=params, ) graph, rt_mod, params = lib.graph_json, lib.lib, lib.params - if input_dtype.startswith("float"): - input_data = np.random.uniform(size=input_shape).astype(input_dtype) - else: - input_data = np.random.randint(low=0, high=10000, size=input_shape, dtype=input_dtype) + for input_name, input_shape in input_info.items(): + if input_dtype.startswith("float"): + inputs.append(np.random.uniform(size=input_shape).astype(input_dtype)) + else: + inputs.append(np.random.randint(low=0, high=10000, size=input_shape, dtype=input_dtype)) - def f_timer(rt_mod, dev, input_data): + def f_timer(rt_mod, dev, inputs): # pylint: disable=import-outside-toplevel from tvm.contrib.graph_executor import GraphModule # pylint: enable=import-outside-toplevel mod = GraphModule(rt_mod["default"](dev)) - mod.set_input(input_name, input_data) + for index, (input_name, _) in enumerate(input_info.items()): + mod.set_input(input_name, inputs[index]) ftimer = mod.module.time_evaluator( "run", dev, @@ -159,17 +164,18 @@ def f_timer(rt_mod, dev, input_data): rpc_config=ARGS.rpc_config, lib=lib, dev_type=ARGS.target.kind.name, - args=[input_data], + args=inputs, continuation=f_timer, ) - def f_per_layer(rt_mod, dev, input_data): + def f_per_layer(rt_mod, dev, inputs): # pylint: disable=import-outside-toplevel from tvm.contrib.debugger.debug_executor import create # pylint: enable=import-outside-toplevel mod = create(graph, rt_mod, dev) - mod.set_input(input_name, input_data) + for index, (input_name, _) in enumerate(input_info.items()): + mod.set_input(input_name, inputs[index]) graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]] graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000) print("|graph_nodes| = ", len(graph_nodes)) @@ -182,7 +188,7 @@ def f_per_layer(rt_mod, dev, input_data): rpc_config=ARGS.rpc_config, lib=rt_mod, dev_type=ARGS.target.kind.name, - args=[input_data], + args=inputs, continuation=f_per_layer, ) From 16657971417874ca25f11d59e9aaca809d223758 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Thu, 12 May 2022 14:57:27 -0700 Subject: [PATCH 02/12] fix auto schedule --- .../testing/tune_relay_auto_scheduler.py | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py b/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py index 2a2c20868bb7..d9c92cac8346 100644 --- a/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py +++ b/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py @@ -134,10 +134,13 @@ def main(): ARGS.input_shape, cache_dir=ARGS.cache_dir, ) + input_info = {input_name: input_shape} + inputs = [] print(f"Workload: {ARGS.workload}") - print(f" input_name: {input_name}") - print(f" input_shape: {input_shape}") - print(f" input_dtype: {input_dtype}") + for input_name, input_shape in input_info.items(): + print(f" input_name: {input_name}") + print(f" input_shape: {input_shape}") + print(f" input_dtype: {input_dtype}") tasks, task_weights = auto_scheduler.extract_tasks( mod["main"], params, @@ -170,19 +173,21 @@ def main(): params=params, ) graph, rt_mod, params = lib.graph_json, lib.lib, lib.params - if input_dtype.startswith("float"): - input_data = np.random.uniform(size=input_shape).astype(input_dtype) - else: - input_data = np.random.randint(low=0, high=10000, size=input_shape, dtype=input_dtype) + for input_name, input_shape in input_info.items(): + if input_dtype.startswith("float"): + inputs.append(np.random.uniform(size=input_shape).astype(input_dtype)) + else: + inputs.append(np.random.randint(low=0, high=10000, size=input_shape, dtype=input_dtype)) - def f_timer(rt_mod, dev, input_data): + def f_timer(rt_mod, dev, inputs): # pylint: disable=import-outside-toplevel from tvm.contrib.graph_executor import GraphModule # pylint: enable=import-outside-toplevel mod = GraphModule(rt_mod["default"](dev)) - mod.set_input(input_name, input_data) + for index, (input_name, _) in enumerate(input_info.items()): + mod.set_input(input_name, inputs[index]) ftimer = mod.module.time_evaluator( "run", dev, @@ -196,17 +201,18 @@ def f_timer(rt_mod, dev, input_data): rpc_config=ARGS.rpc_config, lib=lib, dev_type=ARGS.target.kind.name, - args=[input_data], + args=inputs, continuation=f_timer, ) - def f_per_layer(rt_mod, dev, input_data): + def f_per_layer(rt_mod, dev, inputs): # pylint: disable=import-outside-toplevel from tvm.contrib.debugger.debug_executor import create # pylint: enable=import-outside-toplevel mod = create(graph, rt_mod, dev) - mod.set_input(input_name, input_data) + for index, (input_name, _) in enumerate(input_info.items()): + mod.set_input(input_name, inputs[index]) graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]] graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000) print("|graph_nodes| = ", len(graph_nodes)) @@ -219,7 +225,7 @@ def f_per_layer(rt_mod, dev, input_data): rpc_config=ARGS.rpc_config, lib=rt_mod, dev_type=ARGS.target.kind.name, - args=[input_data], + args=inputs, continuation=f_per_layer, ) From 0bd1f370ebdb776824dd2b03c83a8d501ffca47a Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Fri, 13 May 2022 16:26:33 -0700 Subject: [PATCH 03/12] only get run_individual_node on local --- python/tvm/contrib/debugger/debug_executor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/contrib/debugger/debug_executor.py b/python/tvm/contrib/debugger/debug_executor.py index f144b3cb4a82..1698d465631c 100644 --- a/python/tvm/contrib/debugger/debug_executor.py +++ b/python/tvm/contrib/debugger/debug_executor.py @@ -113,7 +113,8 @@ def __init__(self, module, device, graph_json_str, dump_root): self._dump_root = dump_root self._dump_path = None self._run_individual = module["run_individual"] - self._run_individual_node = module["run_individual_node"] + if module.type_key != "rpc": + self._run_individual_node = module["run_individual_node"] self._debug_get_output = module["debug_get_output"] self._execute_node = module["execute_node"] self._get_node_output = module["get_node_output"] From 7ac84efb58f1acd349eda02e89e5ec7ecd7fe9c4 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Sun, 15 May 2022 16:31:26 -0700 Subject: [PATCH 04/12] roll back change --- python/tvm/contrib/debugger/debug_executor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/contrib/debugger/debug_executor.py b/python/tvm/contrib/debugger/debug_executor.py index 1698d465631c..f144b3cb4a82 100644 --- a/python/tvm/contrib/debugger/debug_executor.py +++ b/python/tvm/contrib/debugger/debug_executor.py @@ -113,8 +113,7 @@ def __init__(self, module, device, graph_json_str, dump_root): self._dump_root = dump_root self._dump_path = None self._run_individual = module["run_individual"] - if module.type_key != "rpc": - self._run_individual_node = module["run_individual_node"] + self._run_individual_node = module["run_individual_node"] self._debug_get_output = module["debug_get_output"] self._execute_node = module["execute_node"] self._get_node_output = module["get_node_output"] From 790eeadb62223bce8c0bdf6ddb88a5d6f23cd4f7 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Tue, 17 May 2022 17:04:55 -0700 Subject: [PATCH 05/12] address comments --- .../testing/tune_relay_auto_scheduler.py | 22 +++++++++-------- .../testing/tune_relay_meta_schedule.py | 24 ++++++++++--------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py b/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py index d9c92cac8346..ec39d742192c 100644 --- a/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py +++ b/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py @@ -135,7 +135,7 @@ def main(): cache_dir=ARGS.cache_dir, ) input_info = {input_name: input_shape} - inputs = [] + input_data = {} print(f"Workload: {ARGS.workload}") for input_name, input_shape in input_info.items(): print(f" input_name: {input_name}") @@ -175,19 +175,21 @@ def main(): graph, rt_mod, params = lib.graph_json, lib.lib, lib.params for input_name, input_shape in input_info.items(): if input_dtype.startswith("float"): - inputs.append(np.random.uniform(size=input_shape).astype(input_dtype)) + input_data[input_name] = np.random.uniform(size=input_shape).astype(input_dtype) else: - inputs.append(np.random.randint(low=0, high=10000, size=input_shape, dtype=input_dtype)) + input_data[input_name] = np.random.randint( + low=0, high=10000, size=input_shape, dtype=input_dtype + ) - def f_timer(rt_mod, dev, inputs): + def f_timer(rt_mod, dev, input_data): # pylint: disable=import-outside-toplevel from tvm.contrib.graph_executor import GraphModule # pylint: enable=import-outside-toplevel mod = GraphModule(rt_mod["default"](dev)) - for index, (input_name, _) in enumerate(input_info.items()): - mod.set_input(input_name, inputs[index]) + for input_name, input_value in input_data.items(): + mod.set_input(input_name, input_value) ftimer = mod.module.time_evaluator( "run", dev, @@ -201,7 +203,7 @@ def f_timer(rt_mod, dev, inputs): rpc_config=ARGS.rpc_config, lib=lib, dev_type=ARGS.target.kind.name, - args=inputs, + args=input_data, continuation=f_timer, ) @@ -211,8 +213,8 @@ def f_per_layer(rt_mod, dev, inputs): # pylint: enable=import-outside-toplevel mod = create(graph, rt_mod, dev) - for index, (input_name, _) in enumerate(input_info.items()): - mod.set_input(input_name, inputs[index]) + for input_name, input_value in input_data.items(): + mod.set_input(input_name, input_value) graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]] graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000) print("|graph_nodes| = ", len(graph_nodes)) @@ -225,7 +227,7 @@ def f_per_layer(rt_mod, dev, inputs): rpc_config=ARGS.rpc_config, lib=rt_mod, dev_type=ARGS.target.kind.name, - args=inputs, + args=input_data, continuation=f_per_layer, ) diff --git a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py index bd608bbaed7c..bd858e0f2d36 100644 --- a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py +++ b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py @@ -104,7 +104,7 @@ def main(): cache_dir=ARGS.cache_dir, ) input_info = {input_name: input_shape} - inputs = [] + input_data = {} print(f"Workload: {ARGS.workload}") for input_name, input_shape in input_info.items(): print(f" input_name: {input_name}") @@ -138,19 +138,21 @@ def main(): graph, rt_mod, params = lib.graph_json, lib.lib, lib.params for input_name, input_shape in input_info.items(): if input_dtype.startswith("float"): - inputs.append(np.random.uniform(size=input_shape).astype(input_dtype)) + input_data[input_name] = np.random.uniform(size=input_shape).astype(input_dtype) else: - inputs.append(np.random.randint(low=0, high=10000, size=input_shape, dtype=input_dtype)) + input_data[input_name] = np.random.randint( + low=0, high=10000, size=input_shape, dtype=input_dtype + ) - def f_timer(rt_mod, dev, inputs): + def f_timer(rt_mod, dev, input_data): # pylint: disable=import-outside-toplevel from tvm.contrib.graph_executor import GraphModule # pylint: enable=import-outside-toplevel mod = GraphModule(rt_mod["default"](dev)) - for index, (input_name, _) in enumerate(input_info.items()): - mod.set_input(input_name, inputs[index]) + for input_name, input_value in input_data.items(): + mod.set_input(input_name, input_value) ftimer = mod.module.time_evaluator( "run", dev, @@ -164,18 +166,18 @@ def f_timer(rt_mod, dev, inputs): rpc_config=ARGS.rpc_config, lib=lib, dev_type=ARGS.target.kind.name, - args=inputs, + args=input_data, continuation=f_timer, ) - def f_per_layer(rt_mod, dev, inputs): + def f_per_layer(rt_mod, dev, input_data): # pylint: disable=import-outside-toplevel from tvm.contrib.debugger.debug_executor import create # pylint: enable=import-outside-toplevel mod = create(graph, rt_mod, dev) - for index, (input_name, _) in enumerate(input_info.items()): - mod.set_input(input_name, inputs[index]) + for input_name, input_value in input_data.items(): + mod.set_input(input_name, input_value) graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]] graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000) print("|graph_nodes| = ", len(graph_nodes)) @@ -188,7 +190,7 @@ def f_per_layer(rt_mod, dev, inputs): rpc_config=ARGS.rpc_config, lib=rt_mod, dev_type=ARGS.target.kind.name, - args=inputs, + args=input_data, continuation=f_per_layer, ) From add70cc35a0cad90e87fecd76a0eb392e3fa03ad Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Wed, 25 May 2022 12:54:08 -0700 Subject: [PATCH 06/12] add test for run_module_via_rpc --- .../unittest/test_meta_schedule_tune_tir.py | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index a7806ebda28a..a829b5f503cf 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -17,12 +17,20 @@ # pylint: disable=missing-docstring import logging import tempfile +import time +import numpy as np import pytest +import tvm + +from tvm import meta_schedule as ms from tvm.meta_schedule import TuneConfig, tune_tir +from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc from tvm.script import tir as T from tvm.target import Target from tvm.tir import Schedule +from tvm.rpc.tracker import Tracker +from tvm.rpc.server import Server logging.basicConfig() logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) @@ -89,6 +97,63 @@ def test_tune_matmul_cuda(): print(sch.trace) +def test_tune_run_module_via_rpc(): + # pylint: disable=import-outside-toplevel + + tracker = Tracker(host="127.0.0.1", port=9000, port_end=10000, silent=True) + device_key = "$local$device$%d" % tracker.port + server = Server( + port=tracker.port, + port_end=10000, + key=device_key, + silent=True, + tracker_addr=("127.0.0.1", tracker.port), + ) + rpc_config = ms.runner.RPCConfig( + tracker_host=tracker.host, + tracker_port=tracker.port, + tracker_key=device_key, + session_timeout_sec=3600, + ) + # Wait for the processes to start + time.sleep(0.5) + + target = tvm.target.Target("llvm") + rt_mod = tvm.build(matmul, target) + + # construct the input + input_data = [] + input_shape = (128, 128) + input_dtype = "float32" + dev = tvm.cpu() + a_np = np.random.uniform(size=input_shape).astype(input_dtype) + b_np = np.random.uniform(size=input_shape).astype(input_dtype) + c_np = np.zeros(input_shape).astype(input_dtype) + for i in range(128): + for j in range(128): + for k in range(128): + c_np[i, j] = c_np[i, j] + a_np[i, k] * b_np[j, k] + input_data.append(tvm.nd.array(a_np, dev)) + input_data.append(tvm.nd.array(b_np, dev)) + input_data.append(tvm.nd.array(np.zeros(input_shape).astype(input_dtype), dev)) + + def f_timer(rt_mod, dev, input_data): + rt_mod(input_data[0], input_data[1], input_data[2]) + return input_data[2] + + result = run_module_via_rpc( + rpc_config=rpc_config, + lib=rt_mod, + dev_type=target.kind.name, + args=input_data, + continuation=f_timer, + ) + tvm.testing.assert_allclose(result.numpy(), c_np, rtol=1e-3) + tracker.terminate() + server.terminate() + + if __name__ == """__main__""": test_tune_matmul_cpu() test_tune_matmul_cuda() + test_tune_run_module_via_rpc() From b078eaee427863cc3af8055e0903af01db37a24d Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Wed, 25 May 2022 13:50:29 -0700 Subject: [PATCH 07/12] fix for dict input --- .../meta_schedule/testing/custom_builder_runner.py | 6 ++++-- tests/python/unittest/test_meta_schedule_tune_tir.py | 12 ++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py b/python/tvm/meta_schedule/testing/custom_builder_runner.py index 8019e7f6ac2a..08adc80d8ad3 100644 --- a/python/tvm/meta_schedule/testing/custom_builder_runner.py +++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py @@ -166,5 +166,7 @@ def run_module_via_rpc( _, filename = os.path.split(filename) rt_mod = session.load_module(filename) dev = session.device(dev_type=dev_type, dev_id=0) - args = [ndarray.array(arg, dev) for arg in args] - return continuation(rt_mod, dev, args) + np_args = {} + for arg_key, arg_value in args.items(): + np_args[arg_key] = ndarray.array(arg_value, dev) + return continuation(rt_mod, dev, np_args) diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index a829b5f503cf..6e89db6e578c 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -122,7 +122,7 @@ def test_tune_run_module_via_rpc(): rt_mod = tvm.build(matmul, target) # construct the input - input_data = [] + input_data = {} input_shape = (128, 128) input_dtype = "float32" dev = tvm.cpu() @@ -133,13 +133,13 @@ def test_tune_run_module_via_rpc(): for j in range(128): for k in range(128): c_np[i, j] = c_np[i, j] + a_np[i, k] * b_np[j, k] - input_data.append(tvm.nd.array(a_np, dev)) - input_data.append(tvm.nd.array(b_np, dev)) - input_data.append(tvm.nd.array(np.zeros(input_shape).astype(input_dtype), dev)) + input_data["a"] = a_np + input_data["b"] = b_np + input_data["c"] = np.zeros(input_shape).astype(input_dtype) def f_timer(rt_mod, dev, input_data): - rt_mod(input_data[0], input_data[1], input_data[2]) - return input_data[2] + rt_mod(input_data["a"], input_data["b"], input_data["c"]) + return input_data["c"] result = run_module_via_rpc( rpc_config=rpc_config, From 8736d878f35e73ed4ab9f21229b3491e12a44608 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Wed, 25 May 2022 14:00:12 -0700 Subject: [PATCH 08/12] fix mypy linting --- python/tvm/meta_schedule/testing/custom_builder_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py b/python/tvm/meta_schedule/testing/custom_builder_runner.py index 08adc80d8ad3..47efbdd0ce15 100644 --- a/python/tvm/meta_schedule/testing/custom_builder_runner.py +++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py @@ -17,7 +17,7 @@ """Customized builder and runner methods""" # pylint: disable=import-outside-toplevel -from typing import TYPE_CHECKING, Callable, Dict, List +from typing import TYPE_CHECKING, Callable, Dict, List, Any if TYPE_CHECKING: import numpy as np # type: ignore @@ -145,7 +145,7 @@ def run_module_via_rpc( rpc_config: "RPCConfig", lib: "Module", dev_type: str, - args: List["np.ndarray"], + args: Dict[str, Any], continuation: Callable, ): """Execute a tvm.runtime.Module on RPC remote""" From 25a699ff42b22e285a7a0c0497bacdffe9e9684a Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Wed, 25 May 2022 14:01:37 -0700 Subject: [PATCH 09/12] use np.ndarray --- python/tvm/meta_schedule/testing/custom_builder_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py b/python/tvm/meta_schedule/testing/custom_builder_runner.py index 47efbdd0ce15..c59133a4bcea 100644 --- a/python/tvm/meta_schedule/testing/custom_builder_runner.py +++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py @@ -17,7 +17,7 @@ """Customized builder and runner methods""" # pylint: disable=import-outside-toplevel -from typing import TYPE_CHECKING, Callable, Dict, List, Any +from typing import TYPE_CHECKING, Callable, Dict, List if TYPE_CHECKING: import numpy as np # type: ignore @@ -145,7 +145,7 @@ def run_module_via_rpc( rpc_config: "RPCConfig", lib: "Module", dev_type: str, - args: Dict[str, Any], + args: Dict[str, "np.ndarray"], continuation: Callable, ): """Execute a tvm.runtime.Module on RPC remote""" From 1caf40825eed6a26d677bd17a33e68e57ceb5dcf Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Wed, 25 May 2022 14:42:19 -0700 Subject: [PATCH 10/12] fix linting --- python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py b/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py index ec39d742192c..abac49c50c6e 100644 --- a/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py +++ b/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py @@ -207,7 +207,7 @@ def f_timer(rt_mod, dev, input_data): continuation=f_timer, ) - def f_per_layer(rt_mod, dev, inputs): + def f_per_layer(rt_mod, dev, input_data): # pylint: disable=import-outside-toplevel from tvm.contrib.debugger.debug_executor import create From f0ae9ff2fa5f046c311b3aa70bab0681b18544dd Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Wed, 25 May 2022 16:16:37 -0700 Subject: [PATCH 11/12] use LocalRPC --- .../unittest/test_meta_schedule_tune_tir.py | 60 +++++++------------ 1 file changed, 22 insertions(+), 38 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index 6e89db6e578c..0e8c205230e6 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -17,7 +17,6 @@ # pylint: disable=missing-docstring import logging import tempfile -import time import numpy as np import pytest @@ -26,11 +25,10 @@ from tvm import meta_schedule as ms from tvm.meta_schedule import TuneConfig, tune_tir from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc +from tvm.meta_schedule.testing.local_rpc import LocalRPC from tvm.script import tir as T from tvm.target import Target from tvm.tir import Schedule -from tvm.rpc.tracker import Tracker -from tvm.rpc.server import Server logging.basicConfig() logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) @@ -98,26 +96,6 @@ def test_tune_matmul_cuda(): def test_tune_run_module_via_rpc(): - # pylint: disable=import-outside-toplevel - - tracker = Tracker(host="127.0.0.1", port=9000, port_end=10000, silent=True) - device_key = "$local$device$%d" % tracker.port - server = Server( - port=tracker.port, - port_end=10000, - key=device_key, - silent=True, - tracker_addr=("127.0.0.1", tracker.port), - ) - rpc_config = ms.runner.RPCConfig( - tracker_host=tracker.host, - tracker_port=tracker.port, - tracker_key=device_key, - session_timeout_sec=3600, - ) - # Wait for the processes to start - time.sleep(0.5) - target = tvm.target.Target("llvm") rt_mod = tvm.build(matmul, target) @@ -125,7 +103,6 @@ def test_tune_run_module_via_rpc(): input_data = {} input_shape = (128, 128) input_dtype = "float32" - dev = tvm.cpu() a_np = np.random.uniform(size=input_shape).astype(input_dtype) b_np = np.random.uniform(size=input_shape).astype(input_dtype) c_np = np.zeros(input_shape).astype(input_dtype) @@ -137,20 +114,27 @@ def test_tune_run_module_via_rpc(): input_data["b"] = b_np input_data["c"] = np.zeros(input_shape).astype(input_dtype) - def f_timer(rt_mod, dev, input_data): - rt_mod(input_data["a"], input_data["b"], input_data["c"]) - return input_data["c"] - - result = run_module_via_rpc( - rpc_config=rpc_config, - lib=rt_mod, - dev_type=target.kind.name, - args=input_data, - continuation=f_timer, - ) - tvm.testing.assert_allclose(result.numpy(), c_np, rtol=1e-3) - tracker.terminate() - server.terminate() + with LocalRPC() as rpc: + rpc_config = ms.runner.RPCConfig( + tracker_host=rpc.tracker_host, + tracker_port=rpc.tracker_port, + tracker_key=rpc.tracker_key, + session_priority=1, + session_timeout_sec=100, + ) + + def f_timer(rt_mod, dev, input_data): + rt_mod(input_data["a"], input_data["b"], input_data["c"]) + return input_data["c"] + + result = run_module_via_rpc( + rpc_config=rpc_config, + lib=rt_mod, + dev_type=target.kind.name, + args=input_data, + continuation=f_timer, + ) + tvm.testing.assert_allclose(result.numpy(), c_np, rtol=1e-3) if __name__ == """__main__""": From 7922bcc60917efbd52d9e18a78110d6dbc60931a Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Thu, 26 May 2022 10:24:05 -0700 Subject: [PATCH 12/12] better naming --- python/tvm/meta_schedule/testing/custom_builder_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py b/python/tvm/meta_schedule/testing/custom_builder_runner.py index c59133a4bcea..3ba007d9a4d3 100644 --- a/python/tvm/meta_schedule/testing/custom_builder_runner.py +++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py @@ -166,7 +166,7 @@ def run_module_via_rpc( _, filename = os.path.split(filename) rt_mod = session.load_module(filename) dev = session.device(dev_type=dev_type, dev_id=0) - np_args = {} + nd_args = {} for arg_key, arg_value in args.items(): - np_args[arg_key] = ndarray.array(arg_value, dev) - return continuation(rt_mod, dev, np_args) + nd_args[arg_key] = ndarray.array(arg_value, dev) + return continuation(rt_mod, dev, nd_args)