diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index 5be588a3ae7f..afb198ce1c6e 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -24,6 +24,8 @@ from tarfile import ReadError import argparse import sys +import json + import numpy as np import tvm @@ -33,6 +35,7 @@ from tvm.contrib import graph_executor as executor from tvm.contrib.debugger import debug_executor from tvm.runtime import profiler_vm +from tvm.relay.param_dict import load_param_dict from . import TVMCException from .arguments import TVMCSuppressedArgumentParser from .project import ( @@ -292,6 +295,56 @@ def drive_run(args): result.save(args.outputs) +def get_input_info(graph_str: str, params: Dict[str, tvm.nd.NDArray]): + """Return the 'shape' and 'dtype' dictionaries for the input + tensors of a compiled module. + + .. note:: + We can't simply get the input tensors from a TVM graph + because weight tensors are treated equivalently. Therefore, to + find the input tensors we look at the 'arg_nodes' in the graph + (which are either weights or inputs) and check which ones don't + appear in the params (where the weights are stored). These nodes + are therefore inferred to be input tensors. + + .. note:: + There exists a more recent API to retrieve the input information + directly from the module. However, this isn't supported when using + with RPC due to a lack of support for Array and Map datatypes. + Therefore, this function exists only as a fallback when RPC is in + use. If RPC isn't being used, please use the more recent API. + + Parameters + ---------- + graph_str : str + JSON graph of the module serialized as a string. + params : dict + Parameter dictionary mapping name to value. + + Returns + ------- + shape_dict : dict + Shape dictionary - {input_name: tuple}. + dtype_dict : dict + dtype dictionary - {input_name: dtype}. + """ + + shape_dict = {} + dtype_dict = {} + params_dict = load_param_dict(params) + param_names = [k for (k, v) in params_dict.items()] + graph = json.loads(graph_str) + for node_id in graph["arg_nodes"]: + node = graph["nodes"][node_id] + # If a node is not in the params, infer it to be an input node + name = node["name"] + if name not in param_names: + shape_dict[name] = graph["attrs"]["shape"][1][node_id] + dtype_dict[name] = graph["attrs"]["dltype"][1][node_id] + + return shape_dict, dtype_dict + + def generate_tensor_data(shape: tuple, dtype: str, fill_mode: str): """Generate data to produce a tensor of given shape and dtype. @@ -586,7 +639,14 @@ def run_module( module.load_params(tvmc_package.params) logger.debug("Collecting graph input shape and type:") - shape_dict, dtype_dict = module.get_input_info() + + if isinstance(session, tvm.rpc.client.RPCSession): + # RPC does not support datatypes such as Array and Map, + # fallback to obtaining input information from graph json. + shape_dict, dtype_dict = get_input_info(tvmc_package.graph, tvmc_package.params) + else: + shape_dict, dtype_dict = module.get_input_info() + logger.debug("Graph input shape: %s", shape_dict) logger.debug("Graph input type: %s", dtype_dict) diff --git a/tests/python/driver/tvmc/test_runner.py b/tests/python/driver/tvmc/test_runner.py index 3f4ab11f6ba2..f0d363dc59ac 100644 --- a/tests/python/driver/tvmc/test_runner.py +++ b/tests/python/driver/tvmc/test_runner.py @@ -17,6 +17,7 @@ import pytest import numpy as np +from tvm import rpc from tvm.driver import tvmc from tvm.driver.tvmc.model import TVMCResult from tvm.driver.tvmc.result_utils import get_top_results @@ -103,6 +104,44 @@ def test_run_tflite_module__with_profile__valid_input( assert ( tiger_cat_mobilenet_id in top_5_ids ), "tiger cat is expected in the top-5 for mobilenet v1" - assert type(result.outputs) is dict - assert type(result.times) is BenchmarkResult + assert isinstance(result.outputs, dict) + assert isinstance(result.times, BenchmarkResult) + assert "output_0" in result.outputs.keys() + + +def test_run_tflite_module_with_rpc( + tflite_mobilenet_v1_1_quant, tflite_compile_model, imagenet_cat +): + """ + Test to check that TVMC run is functional when it is being used in + conjunction with an RPC server. + """ + pytest.importorskip("tflite") + + inputs = np.load(imagenet_cat) + input_dict = {"input": inputs["input"].astype("uint8")} + + tflite_compiled_model = tflite_compile_model(tflite_mobilenet_v1_1_quant) + + server = rpc.Server("127.0.0.1", 9099) + result = tvmc.run( + tflite_compiled_model, + inputs=input_dict, + hostname=server.host, + port=server.port, + device="cpu", + ) + + top_5_results = get_top_results(result, 5) + top_5_ids = top_5_results[0] + + # IDs were collected from this reference: + # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/ + # java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt + tiger_cat_mobilenet_id = 283 + + assert ( + tiger_cat_mobilenet_id in top_5_ids + ), "tiger cat is expected in the top-5 for mobilenet v1" + assert isinstance(result.outputs, dict) assert "output_0" in result.outputs.keys()