Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions python/tvm/contrib/graph_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__(self, module):
self._get_input = module["get_input"]
self._get_num_outputs = module["get_num_outputs"]
self._get_input_index = module["get_input_index"]
self._get_input_info = module["get_input_info"]
self._get_num_inputs = module["get_num_inputs"]
self._load_params = module["load_params"]
self._share_params = module["share_params"]
Expand Down Expand Up @@ -258,6 +259,32 @@ def get_input_index(self, name):
"""
return self._get_input_index(name)

def get_input_info(self):
"""Return the 'shape' and 'dtype' dictionaries of the graph.

.. note::
We can't simply get the input tensors from a TVM graph
because weight tensors are treated equivalently. Therefore, to
find the input tensors we look at the 'arg_nodes' in the graph
(which are either weights or inputs) and check which ones don't
appear in the params (where the weights are stored). These nodes
are therefore inferred to be input tensors.

Returns
-------
shape_dict : Map
Shape dictionary - {input_name: tuple}.
dtype_dict : Map
dtype dictionary - {input_name: dtype}.
"""
input_info = self._get_input_info()
assert "shape" in input_info
shape_dict = input_info["shape"]
assert "dtype" in input_info
dtype_dict = input_info["dtype"]

return shape_dict, dtype_dict

def get_output(self, index, out=None):
"""Get index-th output to out

Expand Down
71 changes: 14 additions & 57 deletions python/tvm/driver/tvmc/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
Provides support to run compiled networks both locally and remotely.
"""
from contextlib import ExitStack
import json
import logging
import pathlib
from typing import Dict, List, Optional, Union
from typing import Dict, Optional, Union
from tarfile import ReadError
import argparse
import sys
Expand All @@ -32,7 +31,6 @@
from tvm.autotvm.measure import request_remote
from tvm.contrib import graph_executor as runtime
from tvm.contrib.debugger import debug_executor
from tvm.relay.param_dict import load_param_dict
from . import TVMCException
from .arguments import TVMCSuppressedArgumentParser
from .project import (
Expand Down Expand Up @@ -282,53 +280,6 @@ def drive_run(args):
result.save(args.outputs)


def get_input_info(graph_str: str, params: Dict[str, tvm.nd.NDArray]):
"""Return the 'shape' and 'dtype' dictionaries for the input
tensors of a compiled module.

.. note::
We can't simply get the input tensors from a TVM graph
because weight tensors are treated equivalently. Therefore, to
find the input tensors we look at the 'arg_nodes' in the graph
(which are either weights or inputs) and check which ones don't
appear in the params (where the weights are stored). These nodes
are therefore inferred to be input tensors.

Parameters
----------
graph_str : str
JSON graph of the module serialized as a string.
params : dict
Parameter dictionary mapping name to value.

Returns
-------
shape_dict : dict
Shape dictionary - {input_name: tuple}.
dtype_dict : dict
dtype dictionary - {input_name: dtype}.
"""

shape_dict = {}
dtype_dict = {}
params_dict = load_param_dict(params)
param_names = [k for (k, v) in params_dict.items()]
graph = json.loads(graph_str)
for node_id in graph["arg_nodes"]:
node = graph["nodes"][node_id]
# If a node is not in the params, infer it to be an input node
name = node["name"]
if name not in param_names:
shape_dict[name] = graph["attrs"]["shape"][1][node_id]
dtype_dict[name] = graph["attrs"]["dltype"][1][node_id]

logger.debug("Collecting graph input shape and type:")
logger.debug("Graph input shape: %s", shape_dict)
logger.debug("Graph input type: %s", dtype_dict)

return shape_dict, dtype_dict


def generate_tensor_data(shape: tuple, dtype: str, fill_mode: str):
"""Generate data to produce a tensor of given shape and dtype.

Expand Down Expand Up @@ -370,8 +321,8 @@ def generate_tensor_data(shape: tuple, dtype: str, fill_mode: str):


def make_inputs_dict(
shape_dict: Dict[str, List[int]],
dtype_dict: Dict[str, str],
shape_dict: tvm.container.Map,
dtype_dict: tvm.container.Map,
inputs: Optional[Dict[str, np.ndarray]] = None,
fill_mode: str = "random",
):
Expand All @@ -383,9 +334,9 @@ def make_inputs_dict(

Parameters
----------
shape_dict : dict
shape_dict : Map
Shape dictionary - {input_name: tuple}.
dtype_dict : dict
dtype_dict : Map
dtype dictionary - {input_name: dtype}.
inputs : dict, optional
A dictionary that maps input names to numpy values.
Expand Down Expand Up @@ -420,8 +371,10 @@ def make_inputs_dict(
logger.debug("setting input '%s' with user input data", input_name)
inputs_dict[input_name] = inputs[input_name]
else:
shape = shape_dict[input_name]
dtype = dtype_dict[input_name]
# container.ShapleTuple -> tuple
shape = tuple(shape_dict[input_name])
# container.String -> str
dtype = str(dtype_dict[input_name])

logger.debug(
"generating data for input '%s' (shape: %s, dtype: %s), using fill-mode '%s'",
Expand Down Expand Up @@ -580,7 +533,11 @@ def run_module(
logger.debug("Loading params into the runtime module.")
module.load_params(tvmc_package.params)

shape_dict, dtype_dict = get_input_info(tvmc_package.graph, tvmc_package.params)
logger.debug("Collecting graph input shape and type:")
shape_dict, dtype_dict = module.get_input_info()
logger.debug("Graph input shape: %s", shape_dict)
logger.debug("Graph input type: %s", dtype_dict)

inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, fill_mode)

logger.debug("Setting inputs to the module.")
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/runtime/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,14 @@ def __len__(self):

def __getitem__(self, idx):
return getitem_helper(self, _ffi_api.GetShapeTupleElem, len(self), idx)

def __eq__(self, other):
if self.same_as(other):
return True
if len(self) != len(other):
return False
for a, b in zip(self, other):
if a != b:
return False

return True
36 changes: 36 additions & 0 deletions src/runtime/graph_executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include <tvm/runtime/container/map.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
Expand Down Expand Up @@ -97,6 +98,7 @@ void GraphExecutor::Init(const std::string& graph_json, tvm::runtime::Module mod
output_map_[name] = i;
}
}

/*!
* \brief Get the input index given the name of input.
* \param name The name of the input.
Expand All @@ -109,6 +111,29 @@ int GraphExecutor::GetInputIndex(const std::string& name) {
}
return -1;
}

/*!
* \brief Get the input info of Graph by parsing the input nodes.
* \return The shape and dtype tuple.
*/
std::tuple<GraphExecutor::ShapeInfo, GraphExecutor::DtypeInfo> GraphExecutor::GetInputInfo() const {
GraphExecutor::ShapeInfo shape_dict;
GraphExecutor::DtypeInfo dtype_dict;
for (uint32_t nid : input_nodes_) {
CHECK_LE(nid, nodes_.size());
std::string name = nodes_[nid].name;
if (param_names_.find(name) == param_names_.end()) {
CHECK_LE(nid, attrs_.shape.size());
auto shape = attrs_.shape[nid];
shape_dict.Set(name, ShapeTuple(shape));
CHECK_LE(nid, attrs_.dltype.size());
auto dtype = attrs_.dltype[nid];
dtype_dict.Set(name, String(dtype));
}
}
return std::make_tuple(shape_dict, dtype_dict);
}

/*!
* \brief Get the output index given the name of output.
* \param name The name of the output.
Expand Down Expand Up @@ -252,6 +277,7 @@ void GraphExecutor::LoadParams(const std::string& param_blob) {
void GraphExecutor::LoadParams(dmlc::Stream* strm) {
Map<String, NDArray> params = ::tvm::runtime::LoadParams(strm);
for (auto& p : params) {
param_names_.insert(p.first);
int in_idx = GetInputIndex(p.first);
if (in_idx < 0) continue;
uint32_t eid = this->entry_id(input_nodes_[in_idx], 0);
Expand Down Expand Up @@ -614,6 +640,16 @@ PackedFunc GraphExecutor::GetFunction(const std::string& name,
CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string";
*rv = this->GetInputIndex(args[0].operator String());
});
} else if (name == "get_input_info") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
GraphExecutor::ShapeInfo shape_info;
GraphExecutor::DtypeInfo dtype_info;
std::tie(shape_info, dtype_info) = this->GetInputInfo();
Map<String, ObjectRef> input_info;
input_info.Set("shape", shape_info);
input_info.Set("dtype", dtype_info);
*rv = input_info;
});
} else {
return PackedFunc();
}
Expand Down
12 changes: 12 additions & 0 deletions src/runtime/graph_executor/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@

#include <memory>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -71,6 +73,8 @@ class TVM_DLL GraphExecutor : public ModuleNode {
};

public:
using ShapeInfo = Map<String, ObjectRef>;
using DtypeInfo = Map<String, ObjectRef>;
/*!
* \brief Get member function to front-end
* \param name The name of the function.
Expand Down Expand Up @@ -107,6 +111,12 @@ class TVM_DLL GraphExecutor : public ModuleNode {
*/
int GetInputIndex(const std::string& name);

/*!
* \brief Get the input info of Graph by parsing the input nodes.
* \return The shape and dtype tuple.
*/
std::tuple<ShapeInfo, DtypeInfo> GetInputInfo() const;

/*!
* \brief Get the output index given the name of output.
* \param name The name of the output.
Expand Down Expand Up @@ -417,6 +427,8 @@ class TVM_DLL GraphExecutor : public ModuleNode {
std::vector<Node> nodes_;
/*! \brief The argument nodes. */
std::vector<uint32_t> input_nodes_;
/*! \brief The parameter names. */
std::unordered_set<std::string> param_names_;
/*! \brief Map of input names to input indices. */
std::unordered_map<std::string, uint32_t> input_map_;
/*! \brief Map of output names to output indices. */
Expand Down
15 changes: 15 additions & 0 deletions tests/python/relay/test_backend_graph_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,21 @@ def test_graph_executor_api():
assert mod.get_input_index(dname_0) == 0
assert mod.get_input_index("Invalid") == -1

shape_dict, dtype_dict = mod.get_input_info()
assert isinstance(shape_dict, tvm.container.Map)
assert isinstance(dtype_dict, tvm.container.Map)
for data in [data_0, data_1]:
name = data.name_hint
ty = data.type_annotation
# verify shape
assert name in shape_dict
assert isinstance(shape_dict[name], tvm.runtime.container.ShapeTuple)
assert shape_dict[name] == tvm.runtime.container.ShapeTuple([i.value for i in ty.shape])
# verify dtype
assert name in dtype_dict
assert isinstance(dtype_dict[name], tvm.runtime.container.String)
assert dtype_dict[name] == ty.dtype


@tvm.testing.requires_llvm
def test_benchmark():
Expand Down
6 changes: 6 additions & 0 deletions tests/python/unittest/test_runtime_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ def test_shape_tuple():
len(stuple) == len(shape)
for a, b in zip(stuple, shape):
assert a == b
# ShapleTuple vs. list
assert stuple == list(shape)
# ShapleTuple vs. tuple
assert stuple == tuple(shape)
# ShapleTuple vs. ShapeTuple
assert stuple == _container.ShapeTuple(shape)


if __name__ == "__main__":
Expand Down