From bfb2272b2261d29a6dbda1cddf77b4a2cc6a3908 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Tue, 31 May 2022 20:24:48 +0000 Subject: [PATCH 1/5] Reuse hexagon launcher in test session --- python/tvm/contrib/hexagon/build.py | 157 ++++++++++-------- python/tvm/contrib/hexagon/pytest_plugin.py | 61 +++++-- python/tvm/contrib/hexagon/session.py | 79 ++++----- .../contrib/test_hexagon/test_launcher.py | 2 - 4 files changed, 175 insertions(+), 124 deletions(-) diff --git a/python/tvm/contrib/hexagon/build.py b/python/tvm/contrib/hexagon/build.py index 43856253cb18..cb3e93f224df 100644 --- a/python/tvm/contrib/hexagon/build.py +++ b/python/tvm/contrib/hexagon/build.py @@ -26,6 +26,9 @@ import socket import stat import subprocess +import random +import string +import tempfile from typing import Union import tvm @@ -34,6 +37,7 @@ HEXAGON_RPC_LIB_DIR = os.environ.get("HEXAGON_RPC_LIB_DIR") +ANDROID_BASH_FILE_NAME = "android_bash.sh" def _get_hexagon_rpc_lib_dir() -> pathlib.Path: @@ -58,7 +62,9 @@ def _get_hexagon_rpc_lib_dir() -> pathlib.Path: def _get_test_directory_name() -> str: """Generate a time-stamped name for use as a test directory name.""" - return datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + date_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + random_str = "".join(random.choice(string.ascii_lowercase) for _ in range(10)) + return f"{date_str}-{random_str}" class HexagonLauncherRPC(metaclass=abc.ABCMeta): @@ -112,7 +118,6 @@ def __init__(self, rpc_info: dict, workspace: Union[str, pathlib.Path] = None): self._rpc_info.update(rpc_info) self._workspace = self._create_workspace(workspace) self._device_key = self.HEXAGON_REMOTE_DEVICE_KEY - self._serial_number = None @abc.abstractmethod def start_server(self): @@ -140,13 +145,18 @@ def _copy_to_remote( ... @abc.abstractmethod - def _create_remote_directory(self, remote_path: Union[str, pathlib.Path]): + def _create_remote_directory(self, remote_path: Union[str, pathlib.Path]) -> pathlib.Path: """Create a directory in the remote location. Parameters ---------- remote_path : str or pathlib.Path Name of the directory to be created. + + Returns + ------- + pathlib.Path : + Absolute path of the remote workspace. """ ... @@ -167,10 +177,9 @@ def _create_workspace(self, workspace: Union[str, pathlib.Path]) -> pathlib.Path if not workspace: base_dir = self._rpc_info["workspace_base"] workspace = os.path.join(base_dir, _get_test_directory_name()) - self._create_remote_directory(workspace) - return pathlib.Path(workspace) + return self._create_remote_directory(workspace) - def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str): + def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str) -> pathlib.Path: """Upload a local file to the remote workspace. Parameters @@ -179,9 +188,16 @@ def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str): Path to the local file to be copied. remote_filename : str Name of the file in the remote workspace. + + Returns + ------- + pathlib.Path : + Uploaded file remote path. """ assert self._workspace - self._copy_to_remote(local_path, os.path.join(str(self._workspace), remote_filename)) + remote_file_path = self._workspace / remote_filename + self._copy_to_remote(local_path, str(remote_file_path)) + return remote_file_path def start_session(self, session_name: str = "hexagon-rpc") -> Session: """Connect to the RPC server. @@ -217,10 +233,7 @@ def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module], sess session and loaded. If the object passed is a string or pathlib.Path, it must - be either a bare file name (without any path components), - or a full path in the remote system. If it is a file name, - the file must already have been uploaded to the remote, - and be placed in the remote workspace. + be a full path in the remote system. session : Session @@ -236,7 +249,10 @@ def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module], sess return session.load_module(module) def get_graph_executor( - self, graph_json: str, module_name: Union[str, pathlib.Path], session: Session + self, + graph_json: str, + module: Union[str, pathlib.Path, tvm.runtime.Module], + session: Session, ): """Create a local GraphModule which consumes a remote libmod. @@ -244,8 +260,14 @@ def get_graph_executor( ---------- graph_json : str The string with the graph JSON. - module_name : str or pathlib.Path - Remote module filename. Same restrictions apply as in load_module(). + module : Union[str, pathlib.Path, tvm.runtime.Module] + + The module to load. If `module` is a + `tvm.runtime.Module`, it will be uploaded to the remote + session and loaded. + + If the object passed is a string or pathlib.Path, it must + be a full path in the remote system. session : Session Remote session. The session must be established (via __enter__) prior to calling this function. @@ -255,13 +277,12 @@ def get_graph_executor( GraphModule : Runtime graph module that can be used to execute the graph. """ - graph_mod = self.load_module(module_name, session) - return tvm.contrib.graph_executor.create(graph_json, graph_mod, session.device) + return session.get_graph_executor(graph_json, module) def get_graph_debug_executor( self, graph_json: str, - module_name: Union[str, pathlib.Path], + module: Union[str, pathlib.Path, tvm.runtime.Module], session: Session, dump_root: Union[str, pathlib.Path] = None, ): @@ -271,39 +292,24 @@ def get_graph_debug_executor( ---------- graph_json : str The string with the graph JSON. - module_name : str or pathlib.Path - Remote module filename. Same restrictions apply as in load_module(). - session : Session - Remote session. The session must be established (via __enter__) - prior to calling this function. - - Returns - ------- - GraphModuleDebug : - Runtime debug graph module that can be used to debug the graph. - """ - graph_mod = self.load_module(module_name, session) - return tvm.contrib.debugger.debug_executor.create( - graph_json, graph_mod, session.device, dump_root=str(dump_root) - ) + module : Union[str, pathlib.Path, tvm.runtime.Module] - def get_aot_executor(self, module_name: Union[str, pathlib.Path], session: Session): - """Create a local AoTModule which consumes a remote libmod. + The module to load. If `module` is a + `tvm.runtime.Module`, it will be uploaded to the remote + session and loaded. - Parameters - ---------- - module_name : str or pathlib.Path - Remote module filename. Same restrictions apply as in load_module(). + If the object passed is a string or pathlib.Path, it must + be a full path in the remote system. session : Session Remote session. The session must be established (via __enter__) prior to calling this function. Returns ------- - aot_module : AotModule - Runtime AOT module that can be used to execute. + GraphModuleDebug : + Runtime debug graph module that can be used to debug the graph. """ - return session.get_aot_executor(module_name) + return session.get_graph_debug_executor(graph_json, module, dump_root=dump_root) class HexagonLauncherAndroid(HexagonLauncherRPC): @@ -311,7 +317,6 @@ class HexagonLauncherAndroid(HexagonLauncherRPC): ANDROID_HEXAGON_TEST_BASE_DIR = pathlib.Path("/data/local/tmp/hexagon_test") ANDROID_HEXAGON_RPC_FILES = [ - "android_bash.sh", "libhexagon_rpc_skel.so", "libtvm_runtime.so", "tvm_rpc_android", @@ -350,39 +355,42 @@ def _copy_to_remote( self._adb_device_sub_cmd + ["push", str(local_path), str(remote_path)] ) - def _create_remote_directory(self, remote_path: Union[str, pathlib.Path]): + def _create_remote_directory(self, remote_path: Union[str, pathlib.Path]) -> pathlib.Path: """Abstract method implementation. See description in HexagonLauncherRPC.""" subprocess.check_call(self._adb_device_sub_cmd + ["shell", "mkdir", "-p", str(remote_path)]) + return pathlib.Path(remote_path) def _copy_binaries(self): """Upload Android server binaries.""" # Create bash script - android_bash_script_path = _get_hexagon_rpc_lib_dir() / "android_bash.sh" - with open(_get_hexagon_rpc_lib_dir() / "android_bash.sh.template", "r") as src_f: - if os.path.exists(android_bash_script_path): - os.remove(android_bash_script_path) - with open(android_bash_script_path, "w") as dest_f: - for line in src_f.readlines(): - if "" in line: - line = line.replace( - "", str(self._rpc_info["rpc_tracker_host"]) - ) - if "" in line: - line = line.replace( - "", str(self._rpc_info["rpc_tracker_port"]) - ) - if "" in line: - line = line.replace("", self._device_key) - if "" in line: - line = line.replace( - "", str(self._rpc_info["rpc_server_port"]) - ) - dest_f.write(line) - - # Make shell script executable - android_bash_stat = os.stat(android_bash_script_path) - os.chmod(android_bash_script_path, android_bash_stat.st_mode | stat.S_IEXEC) + with open(_get_hexagon_rpc_lib_dir() / f"{ANDROID_BASH_FILE_NAME}.template", "r") as src_f: + with tempfile.TemporaryDirectory() as temp_dir: + android_bash_script_path = pathlib.Path(temp_dir) / ANDROID_BASH_FILE_NAME + with open(android_bash_script_path, "w") as dest_f: + for line in src_f.readlines(): + if "" in line: + line = line.replace( + "", str(self._rpc_info["rpc_tracker_host"]) + ) + if "" in line: + line = line.replace( + "", str(self._rpc_info["rpc_tracker_port"]) + ) + if "" in line: + line = line.replace("", self._device_key) + if "" in line: + line = line.replace( + "", str(self._rpc_info["rpc_server_port"]) + ) + dest_f.write(line) + + # Make shell script executable + android_bash_stat = os.stat(android_bash_script_path) + os.chmod(android_bash_script_path, android_bash_stat.st_mode | stat.S_IEXEC) + self._copy_to_remote( + android_bash_script_path, self._workspace / android_bash_script_path.name + ) # Push files lib_dir = _get_hexagon_rpc_lib_dir() @@ -432,7 +440,8 @@ def _run_server_script(self): # Run server and connect to tracker subprocess.Popen( - self._adb_device_sub_cmd + ["shell", f"cd {self._workspace} && ./android_bash.sh"], + self._adb_device_sub_cmd + + ["shell", f"cd {self._workspace} && ./{ANDROID_BASH_FILE_NAME}"], stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.PIPE, @@ -468,7 +477,7 @@ def _terminate_remote(self): self._adb_device_sub_cmd + ["shell", f"kill `cat {self._workspace}/rpc_pid.txt`"] ) - def _cleanup_directory(self): + def cleanup_directory(self): # Remove workspace directory on remote target subprocess.Popen(self._adb_device_sub_cmd + ["shell", f"rm -rf {self._workspace}"]) @@ -481,7 +490,7 @@ def stop_server(self): """Abstract method implementation. See description in HexagonLauncherRPC.""" self._cleanup_port_forwarding() self._terminate_remote() - self._cleanup_directory() + self.cleanup_directory() class HexagonLauncherSimulator(HexagonLauncherRPC): @@ -507,9 +516,10 @@ def _copy_to_remote( """Abstract method implementation. See description in HexagonLauncherRPC.""" subprocess.check_call(["cp", str(local_path), str(remote_path)]) - def _create_remote_directory(self, remote_path: Union[str, pathlib.Path]): + def _create_remote_directory(self, remote_path: Union[str, pathlib.Path]) -> pathlib.Path: """Abstract method implementation. See description in HexagonLauncherRPC.""" subprocess.check_call(["mkdir", "-p", str(remote_path)]) + return pathlib.Path(os.path.abspath(remote_path)) def _copy_libcxx(self, dest_dir: Union[str, pathlib.Path]): """Copy libc++ libraries to the remote workspace.""" @@ -581,6 +591,9 @@ def _start(self): self._server_process = mp.Process(target=lambda *a: _start(self, *a)) self._server_process.start() + def cleanup_directory(self): + pass + def stop_server(self): """Abstract method implementation. See description in HexagonLauncherRPC.""" self._server_process.terminate() diff --git a/python/tvm/contrib/hexagon/pytest_plugin.py b/python/tvm/contrib/hexagon/pytest_plugin.py index 278bd833da95..14f717e972cc 100644 --- a/python/tvm/contrib/hexagon/pytest_plugin.py +++ b/python/tvm/contrib/hexagon/pytest_plugin.py @@ -56,7 +56,7 @@ def _compose(args, decs): requires_hexagon_toolchain = tvm.testing.requires_hexagon(support_required="compile-only") -@tvm.testing.fixture +@pytest.fixture(scope="session") def android_serial_number() -> Optional[str]: serial = os.getenv(ANDROID_SERIAL_NUMBER, default="") # Setting ANDROID_SERIAL_NUMBER to an empty string should be @@ -138,21 +138,28 @@ def tvm_tracker_port(_tracker_info) -> int: return port -@tvm.testing.fixture +@pytest.fixture(scope="session") +def rpc_server_port_for_session() -> int: + return get_free_port() + + +@pytest.fixture() def rpc_server_port() -> int: return get_free_port() -@tvm.testing.fixture +@pytest.fixture(scope="session") def adb_server_socket() -> str: return os.getenv(ADB_SERVER_SOCKET, default="tcp:5037") -@tvm.testing.fixture -def hexagon_launcher( - request, android_serial_number, rpc_server_port, adb_server_socket +@pytest.fixture(scope="session") +def hexagon_server_process( + request, android_serial_number, rpc_server_port_for_session, adb_server_socket ) -> HexagonLauncherRPC: - """Initials and returns hexagon launcher if ANDROID_SERIAL_NUMBER is defined""" + """Initials and returns hexagon launcher if ANDROID_SERIAL_NUMBER is defined. + This launcher is started only once per test session. + """ if android_serial_number is None: yield None else: @@ -165,19 +172,51 @@ def hexagon_launcher( rpc_info = { "rpc_tracker_host": tvm_tracker_host, "rpc_tracker_port": tvm_tracker_port, - "rpc_server_port": rpc_server_port, + "rpc_server_port": rpc_server_port_for_session, "adb_server_socket": adb_server_socket, } launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info) - launcher.start_server() + try: + launcher.start_server() yield launcher finally: launcher.stop_server() -@tvm.testing.fixture -def hexagon_session(hexagon_launcher) -> Session: +@pytest.fixture +def hexagon_launcher( + hexagon_server_process, + rpc_server_port, + tvm_tracker_host, + tvm_tracker_port, + adb_server_socket, + android_serial_number, +) -> HexagonLauncherRPC: + """Initials and returns hexagon launcher which reuses RPC info and Android serial number.""" + if hexagon_server_process._serial_number != "simulator": + rpc_info = hexagon_server_process._rpc_info + serial_number = hexagon_server_process._serial_number + else: + serial_number = android_serial_number + rpc_info = { + "rpc_tracker_host": tvm_tracker_host, + "rpc_tracker_port": tvm_tracker_port, + "rpc_server_port": rpc_server_port, + "adb_server_socket": adb_server_socket, + } + + launcher = HexagonLauncher(serial_number=serial_number, rpc_info=rpc_info) + try: + if hexagon_server_process._serial_number == "simulator": + launcher.start_server() + yield launcher + finally: + launcher.cleanup_directory() + + +@pytest.fixture +def hexagon_session(hexagon_launcher: HexagonLauncherRPC) -> Session: if hexagon_launcher is None: yield None else: diff --git a/python/tvm/contrib/hexagon/session.py b/python/tvm/contrib/hexagon/session.py index f30fe6e47096..1026138df0fd 100644 --- a/python/tvm/contrib/hexagon/session.py +++ b/python/tvm/contrib/hexagon/session.py @@ -93,7 +93,8 @@ def __enter__(self): raise exception def __exit__(self, exc_type, exc_value, exc_traceback): - pass + # close session to the tracker + del self._rpc @property def device(self): @@ -109,7 +110,7 @@ def device(self): return self._device - def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str): + def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str) -> pathlib.Path: """Upload a local file to the remote workspace. Parameters @@ -118,8 +119,13 @@ def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str): Path to the local file to be copied. remote_filename : str Name of the file in the remote workspace. + + Returns + ------- + pathlib.Path : + Uploaded file remote path. """ - self._launcher.upload(local_path, remote_filename) + return self._launcher.upload(local_path, remote_filename) def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module]): """Load TVM module. @@ -136,10 +142,7 @@ def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module]): session and loaded. If the object passed is a string or pathlib.Path, it must - be either a bare file name (without any path components), - or a full path in the remote system. If it is a file name, - the file must already have been uploaded to the remote, - and be placed in the remote workspace. + be a full path in the remote system. Returns ------- @@ -155,16 +158,19 @@ def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module]): binary_name = "test_binary.so" binary_path = temp_dir / binary_name module.save(str(binary_path)) - self.upload(binary_path, binary_name) - module = binary_name + remote_file_path = self.upload(binary_path, binary_name) + else: + remote_file_path = module - assert isinstance(module, (str, pathlib.Path)), "Invalid path type:" + str(type(module)) - return self._rpc.get_function("tvm.hexagon.load_module")(str(module)) + assert isinstance(remote_file_path, (str, pathlib.Path)), "Invalid path type:" + str( + type(remote_file_path) + ) + return self._rpc.get_function("tvm.hexagon.load_module")(str(remote_file_path)) def get_graph_executor( self, graph_json: str, - module_name: Union[str, pathlib.Path], + module_name: Union[str, pathlib.Path, tvm.runtime.Module], ): """Create a local GraphModule which consumes a remote libmod. @@ -173,14 +179,10 @@ def get_graph_executor( Parameters ---------- - - module_name : Union[str, pathlib.Path] - + module_name : Union[str, pathlib.Path, tvm.runtime.Module] The remote module filename, following the same restrictions as `load_module`. - graph_json : str - The string with the graph JSON. Returns @@ -194,33 +196,36 @@ def get_graph_executor( self._set_device_type(graph_mod) return tvm.contrib.graph_executor.create(graph_json, graph_mod, self.device) - def get_aot_executor( + def get_graph_debug_executor( self, - module_name: Union[str, pathlib.Path], + graph_json: str, + module_name: Union[str, pathlib.Path, tvm.runtime.Module], + dump_root: Union[str, pathlib.Path] = None, ): - """Create a local GraphModule which consumes a remote libmod. - - The session must be established (via __enter__) prior to - calling this function. + """Create a local GraphModuleDebug which consumes a remote libmod. Parameters ---------- - - module_name : Union[str, pathlib.Path] - + graph_json : str + The string with the graph JSON. + module_name : Union[str, pathlib.Path, tvm.runtime.Module] The remote module filename, following the same restrictions as `load_module`. + session : Session + Remote session. The session must be established (via __enter__) + prior to calling this function. Returns ------- - GraphModule : - Runtime graph module that can be used to execute the graph. - + GraphModuleDebug : + Runtime debug graph module that can be used to debug the graph. """ - aot_mod = self.load_module(module_name) - self._set_device_type(aot_mod) - return tvm.runtime.executor.AotModule(aot_mod["default"](self.device)) + graph_debug_mod = self.load_module(module_name) + self._set_device_type(graph_debug_mod) + return tvm.contrib.debugger.debug_executor.create( + graph_json, graph_debug_mod, self.device, dump_root=str(dump_root) + ) def get_executor_from_factory(self, module: ExecutorFactoryModule): """Create a local GraphModule which consumes a remote libmod. @@ -286,11 +291,7 @@ def _graph_executor_from_factory( Runtime graph module that can be used to execute the graph. """ - - graph_json = module.get_graph_json() - graph_mod = self.load_module(module.get_lib()) - - return tvm.contrib.graph_executor.create(graph_json, graph_mod, self.device) + return self.get_graph_executor(module.get_graph_json(), module.get_lib()) def _aot_executor_from_factory( self, @@ -354,7 +355,7 @@ def _aot_executor_from_factory( f"Target kind should be from these options: [hexagon, llvm]." ) - self.upload(binary_path, binary_name) + remote_file_path = self.upload(binary_path, binary_name) - aot_mod = self.load_module(binary_name) + aot_mod = self.load_module(str(remote_file_path)) return tvm.runtime.executor.AotModule(aot_mod["default"](self.device)) diff --git a/tests/python/contrib/test_hexagon/test_launcher.py b/tests/python/contrib/test_hexagon/test_launcher.py index ad798925ee88..aae2e598f617 100644 --- a/tests/python/contrib/test_hexagon/test_launcher.py +++ b/tests/python/contrib/test_hexagon/test_launcher.py @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -import sys -import pytest import numpy as np import tvm.testing From 6f86717c90305d45eb0e27d1ce23c30c14320e57 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Mon, 6 Jun 2022 17:29:33 +0000 Subject: [PATCH 2/5] separate random name generation --- python/tvm/contrib/hexagon/build.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/tvm/contrib/hexagon/build.py b/python/tvm/contrib/hexagon/build.py index cb3e93f224df..17431c59a756 100644 --- a/python/tvm/contrib/hexagon/build.py +++ b/python/tvm/contrib/hexagon/build.py @@ -26,8 +26,6 @@ import socket import stat import subprocess -import random -import string import tempfile from typing import Union @@ -62,9 +60,7 @@ def _get_hexagon_rpc_lib_dir() -> pathlib.Path: def _get_test_directory_name() -> str: """Generate a time-stamped name for use as a test directory name.""" - date_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - random_str = "".join(random.choice(string.ascii_lowercase) for _ in range(10)) - return f"{date_str}-{random_str}" + return datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") class HexagonLauncherRPC(metaclass=abc.ABCMeta): From f5d04908c9e4da6bd385fa894a53cfcdd65940eb Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Tue, 7 Jun 2022 19:10:38 +0000 Subject: [PATCH 3/5] revert get_aot_executor --- python/tvm/contrib/hexagon/session.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/python/tvm/contrib/hexagon/session.py b/python/tvm/contrib/hexagon/session.py index 1026138df0fd..0c0bf296df44 100644 --- a/python/tvm/contrib/hexagon/session.py +++ b/python/tvm/contrib/hexagon/session.py @@ -196,6 +196,26 @@ def get_graph_executor( self._set_device_type(graph_mod) return tvm.contrib.graph_executor.create(graph_json, graph_mod, self.device) + def get_aot_executor( + self, + module_file: Union[str, pathlib.Path], + ): + """Create a local GraphModule which consumes a remote libmod. + The session must be established (via __enter__) prior to + calling this function. + Parameters + ---------- + module_file : Union[str, pathlib.Path] + The remote module filename, following the same restrictions + as `load_module`. The filename should be an absolute path. + Returns + ------- + GraphModule : + Runtime graph module that can be used to execute the graph. + """ + aot_mod = self.load_module(module_file) + return tvm.runtime.executor.AotModule(aot_mod["default"](self.device)) + def get_graph_debug_executor( self, graph_json: str, @@ -357,5 +377,4 @@ def _aot_executor_from_factory( remote_file_path = self.upload(binary_path, binary_name) - aot_mod = self.load_module(str(remote_file_path)) - return tvm.runtime.executor.AotModule(aot_mod["default"](self.device)) + return self.get_aot_executor(remote_file_path) From af22dddff03e432851d0c1ec20beaeb8504ef921 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Wed, 8 Jun 2022 00:10:47 +0000 Subject: [PATCH 4/5] Fix launcher for simulator case --- python/tvm/contrib/hexagon/build.py | 11 +++++++++-- python/tvm/contrib/hexagon/pytest_plugin.py | 13 +++++++------ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/python/tvm/contrib/hexagon/build.py b/python/tvm/contrib/hexagon/build.py index 17431c59a756..78d2261e156d 100644 --- a/python/tvm/contrib/hexagon/build.py +++ b/python/tvm/contrib/hexagon/build.py @@ -22,6 +22,7 @@ import multiprocessing as mp import os import pathlib +import shutil import signal import socket import stat @@ -125,6 +126,11 @@ def stop_server(self): """Stop the RPC server""" ... + @abc.abstractmethod + def cleanup_directory(self): + """Cleanup working directory""" + ... + @abc.abstractmethod def _copy_to_remote( self, local_path: Union[str, pathlib.Path], remote_path: Union[str, pathlib.Path] @@ -474,7 +480,7 @@ def _terminate_remote(self): ) def cleanup_directory(self): - # Remove workspace directory on remote target + """Abstract method implementation. See description in HexagonLauncherRPC.""" subprocess.Popen(self._adb_device_sub_cmd + ["shell", f"rm -rf {self._workspace}"]) def start_server(self): @@ -588,7 +594,8 @@ def _start(self): self._server_process.start() def cleanup_directory(self): - pass + """Abstract method implementation. See description in HexagonLauncherRPC.""" + shutil.rmtree(self._workspace) def stop_server(self): """Abstract method implementation. See description in HexagonLauncherRPC.""" diff --git a/python/tvm/contrib/hexagon/pytest_plugin.py b/python/tvm/contrib/hexagon/pytest_plugin.py index 14f717e972cc..7480e7f3ebe8 100644 --- a/python/tvm/contrib/hexagon/pytest_plugin.py +++ b/python/tvm/contrib/hexagon/pytest_plugin.py @@ -160,7 +160,7 @@ def hexagon_server_process( """Initials and returns hexagon launcher if ANDROID_SERIAL_NUMBER is defined. This launcher is started only once per test session. """ - if android_serial_number is None: + if android_serial_number is None or android_serial_number == "simulator": yield None else: # Requesting these fixtures sets up a local tracker, if one @@ -194,11 +194,12 @@ def hexagon_launcher( android_serial_number, ) -> HexagonLauncherRPC: """Initials and returns hexagon launcher which reuses RPC info and Android serial number.""" - if hexagon_server_process._serial_number != "simulator": + if android_serial_number is None: + yield None + + if android_serial_number != "simulator": rpc_info = hexagon_server_process._rpc_info - serial_number = hexagon_server_process._serial_number else: - serial_number = android_serial_number rpc_info = { "rpc_tracker_host": tvm_tracker_host, "rpc_tracker_port": tvm_tracker_port, @@ -206,9 +207,9 @@ def hexagon_launcher( "adb_server_socket": adb_server_socket, } - launcher = HexagonLauncher(serial_number=serial_number, rpc_info=rpc_info) + launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info) try: - if hexagon_server_process._serial_number == "simulator": + if android_serial_number == "simulator": launcher.start_server() yield launcher finally: From c6bec0db9e5b8a7c111435e1347f99eb37933b79 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Wed, 8 Jun 2022 16:56:39 +0000 Subject: [PATCH 5/5] add stop server for simulator --- python/tvm/contrib/hexagon/build.py | 2 -- python/tvm/contrib/hexagon/pytest_plugin.py | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/contrib/hexagon/build.py b/python/tvm/contrib/hexagon/build.py index 78d2261e156d..66b1a8ac75da 100644 --- a/python/tvm/contrib/hexagon/build.py +++ b/python/tvm/contrib/hexagon/build.py @@ -22,7 +22,6 @@ import multiprocessing as mp import os import pathlib -import shutil import signal import socket import stat @@ -595,7 +594,6 @@ def _start(self): def cleanup_directory(self): """Abstract method implementation. See description in HexagonLauncherRPC.""" - shutil.rmtree(self._workspace) def stop_server(self): """Abstract method implementation. See description in HexagonLauncherRPC.""" diff --git a/python/tvm/contrib/hexagon/pytest_plugin.py b/python/tvm/contrib/hexagon/pytest_plugin.py index 7480e7f3ebe8..1841c654b934 100644 --- a/python/tvm/contrib/hexagon/pytest_plugin.py +++ b/python/tvm/contrib/hexagon/pytest_plugin.py @@ -213,6 +213,8 @@ def hexagon_launcher( launcher.start_server() yield launcher finally: + if android_serial_number == "simulator": + launcher.stop_server() launcher.cleanup_directory()