Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion python/tvm/driver/tvmc/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from tarfile import ReadError
import argparse
import sys
import json

import numpy as np

import tvm
Expand All @@ -33,6 +35,7 @@
from tvm.contrib import graph_executor as executor
from tvm.contrib.debugger import debug_executor
from tvm.runtime import profiler_vm
from tvm.relay.param_dict import load_param_dict
from . import TVMCException
from .arguments import TVMCSuppressedArgumentParser
from .project import (
Expand Down Expand Up @@ -292,6 +295,56 @@ def drive_run(args):
result.save(args.outputs)


def get_input_info(graph_str: str, params: Dict[str, tvm.nd.NDArray]):
"""Return the 'shape' and 'dtype' dictionaries for the input
tensors of a compiled module.

.. note::
We can't simply get the input tensors from a TVM graph
because weight tensors are treated equivalently. Therefore, to
find the input tensors we look at the 'arg_nodes' in the graph
(which are either weights or inputs) and check which ones don't
appear in the params (where the weights are stored). These nodes
are therefore inferred to be input tensors.

.. note::
There exists a more recent API to retrieve the input information
directly from the module. However, this isn't supported when using
with RPC due to a lack of support for Array and Map datatypes.
Therefore, this function exists only as a fallback when RPC is in
use. If RPC isn't being used, please use the more recent API.

Parameters
----------
graph_str : str
JSON graph of the module serialized as a string.
params : dict
Parameter dictionary mapping name to value.

Returns
-------
shape_dict : dict
Shape dictionary - {input_name: tuple}.
dtype_dict : dict
dtype dictionary - {input_name: dtype}.
"""

shape_dict = {}
dtype_dict = {}
params_dict = load_param_dict(params)
param_names = [k for (k, v) in params_dict.items()]
graph = json.loads(graph_str)
for node_id in graph["arg_nodes"]:
node = graph["nodes"][node_id]
# If a node is not in the params, infer it to be an input node
name = node["name"]
if name not in param_names:
shape_dict[name] = graph["attrs"]["shape"][1][node_id]
dtype_dict[name] = graph["attrs"]["dltype"][1][node_id]

return shape_dict, dtype_dict


def generate_tensor_data(shape: tuple, dtype: str, fill_mode: str):
"""Generate data to produce a tensor of given shape and dtype.

Expand Down Expand Up @@ -586,7 +639,14 @@ def run_module(
module.load_params(tvmc_package.params)

logger.debug("Collecting graph input shape and type:")
shape_dict, dtype_dict = module.get_input_info()

if isinstance(session, tvm.rpc.client.RPCSession):
# RPC does not support datatypes such as Array and Map,
# fallback to obtaining input information from graph json.
shape_dict, dtype_dict = get_input_info(tvmc_package.graph, tvmc_package.params)
else:
shape_dict, dtype_dict = module.get_input_info()

logger.debug("Graph input shape: %s", shape_dict)
logger.debug("Graph input type: %s", dtype_dict)

Expand Down
43 changes: 41 additions & 2 deletions tests/python/driver/tvmc/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pytest
import numpy as np

from tvm import rpc
from tvm.driver import tvmc
from tvm.driver.tvmc.model import TVMCResult
from tvm.driver.tvmc.result_utils import get_top_results
Expand Down Expand Up @@ -103,6 +104,44 @@ def test_run_tflite_module__with_profile__valid_input(
assert (
tiger_cat_mobilenet_id in top_5_ids
), "tiger cat is expected in the top-5 for mobilenet v1"
assert type(result.outputs) is dict
assert type(result.times) is BenchmarkResult
assert isinstance(result.outputs, dict)
assert isinstance(result.times, BenchmarkResult)
assert "output_0" in result.outputs.keys()


def test_run_tflite_module_with_rpc(
tflite_mobilenet_v1_1_quant, tflite_compile_model, imagenet_cat
):
"""
Test to check that TVMC run is functional when it is being used in
conjunction with an RPC server.
"""
pytest.importorskip("tflite")

inputs = np.load(imagenet_cat)
input_dict = {"input": inputs["input"].astype("uint8")}

tflite_compiled_model = tflite_compile_model(tflite_mobilenet_v1_1_quant)

server = rpc.Server("127.0.0.1", 9099)
result = tvmc.run(
tflite_compiled_model,
inputs=input_dict,
hostname=server.host,
port=server.port,
device="cpu",
)

top_5_results = get_top_results(result, 5)
top_5_ids = top_5_results[0]

# IDs were collected from this reference:
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/
# java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt
tiger_cat_mobilenet_id = 283

assert (
tiger_cat_mobilenet_id in top_5_ids
), "tiger cat is expected in the top-5 for mobilenet v1"
assert isinstance(result.outputs, dict)
assert "output_0" in result.outputs.keys()