diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index 6337c6e6fec5..ac6803ca9842 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -158,6 +158,7 @@ def __init__(self, module): self._get_input = module["get_input"] self._get_num_outputs = module["get_num_outputs"] self._get_input_index = module["get_input_index"] + self._get_input_info = module["get_input_info"] self._get_num_inputs = module["get_num_inputs"] self._load_params = module["load_params"] self._share_params = module["share_params"] @@ -258,6 +259,32 @@ def get_input_index(self, name): """ return self._get_input_index(name) + def get_input_info(self): + """Return the 'shape' and 'dtype' dictionaries of the graph. + + .. note:: + We can't simply get the input tensors from a TVM graph + because weight tensors are treated equivalently. Therefore, to + find the input tensors we look at the 'arg_nodes' in the graph + (which are either weights or inputs) and check which ones don't + appear in the params (where the weights are stored). These nodes + are therefore inferred to be input tensors. + + Returns + ------- + shape_dict : Map + Shape dictionary - {input_name: tuple}. + dtype_dict : Map + dtype dictionary - {input_name: dtype}. + """ + input_info = self._get_input_info() + assert "shape" in input_info + shape_dict = input_info["shape"] + assert "dtype" in input_info + dtype_dict = input_info["dtype"] + + return shape_dict, dtype_dict + def get_output(self, index, out=None): """Get index-th output to out diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index a2343962af95..4f1be94f6523 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -18,10 +18,9 @@ Provides support to run compiled networks both locally and remotely. """ from contextlib import ExitStack -import json import logging import pathlib -from typing import Dict, List, Optional, Union +from typing import Dict, Optional, Union from tarfile import ReadError import argparse import sys @@ -32,7 +31,6 @@ from tvm.autotvm.measure import request_remote from tvm.contrib import graph_executor as runtime from tvm.contrib.debugger import debug_executor -from tvm.relay.param_dict import load_param_dict from . import TVMCException from .arguments import TVMCSuppressedArgumentParser from .project import ( @@ -282,53 +280,6 @@ def drive_run(args): result.save(args.outputs) -def get_input_info(graph_str: str, params: Dict[str, tvm.nd.NDArray]): - """Return the 'shape' and 'dtype' dictionaries for the input - tensors of a compiled module. - - .. note:: - We can't simply get the input tensors from a TVM graph - because weight tensors are treated equivalently. Therefore, to - find the input tensors we look at the 'arg_nodes' in the graph - (which are either weights or inputs) and check which ones don't - appear in the params (where the weights are stored). These nodes - are therefore inferred to be input tensors. - - Parameters - ---------- - graph_str : str - JSON graph of the module serialized as a string. - params : dict - Parameter dictionary mapping name to value. - - Returns - ------- - shape_dict : dict - Shape dictionary - {input_name: tuple}. - dtype_dict : dict - dtype dictionary - {input_name: dtype}. - """ - - shape_dict = {} - dtype_dict = {} - params_dict = load_param_dict(params) - param_names = [k for (k, v) in params_dict.items()] - graph = json.loads(graph_str) - for node_id in graph["arg_nodes"]: - node = graph["nodes"][node_id] - # If a node is not in the params, infer it to be an input node - name = node["name"] - if name not in param_names: - shape_dict[name] = graph["attrs"]["shape"][1][node_id] - dtype_dict[name] = graph["attrs"]["dltype"][1][node_id] - - logger.debug("Collecting graph input shape and type:") - logger.debug("Graph input shape: %s", shape_dict) - logger.debug("Graph input type: %s", dtype_dict) - - return shape_dict, dtype_dict - - def generate_tensor_data(shape: tuple, dtype: str, fill_mode: str): """Generate data to produce a tensor of given shape and dtype. @@ -370,8 +321,8 @@ def generate_tensor_data(shape: tuple, dtype: str, fill_mode: str): def make_inputs_dict( - shape_dict: Dict[str, List[int]], - dtype_dict: Dict[str, str], + shape_dict: tvm.container.Map, + dtype_dict: tvm.container.Map, inputs: Optional[Dict[str, np.ndarray]] = None, fill_mode: str = "random", ): @@ -383,9 +334,9 @@ def make_inputs_dict( Parameters ---------- - shape_dict : dict + shape_dict : Map Shape dictionary - {input_name: tuple}. - dtype_dict : dict + dtype_dict : Map dtype dictionary - {input_name: dtype}. inputs : dict, optional A dictionary that maps input names to numpy values. @@ -420,8 +371,10 @@ def make_inputs_dict( logger.debug("setting input '%s' with user input data", input_name) inputs_dict[input_name] = inputs[input_name] else: - shape = shape_dict[input_name] - dtype = dtype_dict[input_name] + # container.ShapleTuple -> tuple + shape = tuple(shape_dict[input_name]) + # container.String -> str + dtype = str(dtype_dict[input_name]) logger.debug( "generating data for input '%s' (shape: %s, dtype: %s), using fill-mode '%s'", @@ -580,7 +533,11 @@ def run_module( logger.debug("Loading params into the runtime module.") module.load_params(tvmc_package.params) - shape_dict, dtype_dict = get_input_info(tvmc_package.graph, tvmc_package.params) + logger.debug("Collecting graph input shape and type:") + shape_dict, dtype_dict = module.get_input_info() + logger.debug("Graph input shape: %s", shape_dict) + logger.debug("Graph input type: %s", dtype_dict) + inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, fill_mode) logger.debug("Setting inputs to the module.") diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py index 7f83693292ba..cae408b6121d 100644 --- a/python/tvm/runtime/container.py +++ b/python/tvm/runtime/container.py @@ -161,3 +161,14 @@ def __len__(self): def __getitem__(self, idx): return getitem_helper(self, _ffi_api.GetShapeTupleElem, len(self), idx) + + def __eq__(self, other): + if self.same_as(other): + return True + if len(self) != len(other): + return False + for a, b in zip(self, other): + if a != b: + return False + + return True diff --git a/src/runtime/graph_executor/graph_executor.cc b/src/runtime/graph_executor/graph_executor.cc index d12c24250f43..f713671317b8 100644 --- a/src/runtime/graph_executor/graph_executor.cc +++ b/src/runtime/graph_executor/graph_executor.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -97,6 +98,7 @@ void GraphExecutor::Init(const std::string& graph_json, tvm::runtime::Module mod output_map_[name] = i; } } + /*! * \brief Get the input index given the name of input. * \param name The name of the input. @@ -109,6 +111,29 @@ int GraphExecutor::GetInputIndex(const std::string& name) { } return -1; } + +/*! + * \brief Get the input info of Graph by parsing the input nodes. + * \return The shape and dtype tuple. + */ +std::tuple GraphExecutor::GetInputInfo() const { + GraphExecutor::ShapeInfo shape_dict; + GraphExecutor::DtypeInfo dtype_dict; + for (uint32_t nid : input_nodes_) { + CHECK_LE(nid, nodes_.size()); + std::string name = nodes_[nid].name; + if (param_names_.find(name) == param_names_.end()) { + CHECK_LE(nid, attrs_.shape.size()); + auto shape = attrs_.shape[nid]; + shape_dict.Set(name, ShapeTuple(shape)); + CHECK_LE(nid, attrs_.dltype.size()); + auto dtype = attrs_.dltype[nid]; + dtype_dict.Set(name, String(dtype)); + } + } + return std::make_tuple(shape_dict, dtype_dict); +} + /*! * \brief Get the output index given the name of output. * \param name The name of the output. @@ -252,6 +277,7 @@ void GraphExecutor::LoadParams(const std::string& param_blob) { void GraphExecutor::LoadParams(dmlc::Stream* strm) { Map params = ::tvm::runtime::LoadParams(strm); for (auto& p : params) { + param_names_.insert(p.first); int in_idx = GetInputIndex(p.first); if (in_idx < 0) continue; uint32_t eid = this->entry_id(input_nodes_[in_idx], 0); @@ -614,6 +640,16 @@ PackedFunc GraphExecutor::GetFunction(const std::string& name, CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string"; *rv = this->GetInputIndex(args[0].operator String()); }); + } else if (name == "get_input_info") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + GraphExecutor::ShapeInfo shape_info; + GraphExecutor::DtypeInfo dtype_info; + std::tie(shape_info, dtype_info) = this->GetInputInfo(); + Map input_info; + input_info.Set("shape", shape_info); + input_info.Set("dtype", dtype_info); + *rv = input_info; + }); } else { return PackedFunc(); } diff --git a/src/runtime/graph_executor/graph_executor.h b/src/runtime/graph_executor/graph_executor.h index 87e8aa3cee34..25b01a253c7d 100644 --- a/src/runtime/graph_executor/graph_executor.h +++ b/src/runtime/graph_executor/graph_executor.h @@ -33,7 +33,9 @@ #include #include +#include #include +#include #include #include @@ -71,6 +73,8 @@ class TVM_DLL GraphExecutor : public ModuleNode { }; public: + using ShapeInfo = Map; + using DtypeInfo = Map; /*! * \brief Get member function to front-end * \param name The name of the function. @@ -107,6 +111,12 @@ class TVM_DLL GraphExecutor : public ModuleNode { */ int GetInputIndex(const std::string& name); + /*! + * \brief Get the input info of Graph by parsing the input nodes. + * \return The shape and dtype tuple. + */ + std::tuple GetInputInfo() const; + /*! * \brief Get the output index given the name of output. * \param name The name of the output. @@ -417,6 +427,8 @@ class TVM_DLL GraphExecutor : public ModuleNode { std::vector nodes_; /*! \brief The argument nodes. */ std::vector input_nodes_; + /*! \brief The parameter names. */ + std::unordered_set param_names_; /*! \brief Map of input names to input indices. */ std::unordered_map input_map_; /*! \brief Map of output names to output indices. */ diff --git a/tests/python/relay/test_backend_graph_executor.py b/tests/python/relay/test_backend_graph_executor.py index c04ae0039658..e817e588a516 100644 --- a/tests/python/relay/test_backend_graph_executor.py +++ b/tests/python/relay/test_backend_graph_executor.py @@ -325,6 +325,21 @@ def test_graph_executor_api(): assert mod.get_input_index(dname_0) == 0 assert mod.get_input_index("Invalid") == -1 + shape_dict, dtype_dict = mod.get_input_info() + assert isinstance(shape_dict, tvm.container.Map) + assert isinstance(dtype_dict, tvm.container.Map) + for data in [data_0, data_1]: + name = data.name_hint + ty = data.type_annotation + # verify shape + assert name in shape_dict + assert isinstance(shape_dict[name], tvm.runtime.container.ShapeTuple) + assert shape_dict[name] == tvm.runtime.container.ShapeTuple([i.value for i in ty.shape]) + # verify dtype + assert name in dtype_dict + assert isinstance(dtype_dict[name], tvm.runtime.container.String) + assert dtype_dict[name] == ty.dtype + @tvm.testing.requires_llvm def test_benchmark(): diff --git a/tests/python/unittest/test_runtime_container.py b/tests/python/unittest/test_runtime_container.py index 4c72f2c6083b..8c302e920577 100644 --- a/tests/python/unittest/test_runtime_container.py +++ b/tests/python/unittest/test_runtime_container.py @@ -83,6 +83,12 @@ def test_shape_tuple(): len(stuple) == len(shape) for a, b in zip(stuple, shape): assert a == b + # ShapleTuple vs. list + assert stuple == list(shape) + # ShapleTuple vs. tuple + assert stuple == tuple(shape) + # ShapleTuple vs. ShapeTuple + assert stuple == _container.ShapeTuple(shape) if __name__ == "__main__":