diff --git a/python/tvm/contrib/msc/core/codegen/codegen.py b/python/tvm/contrib/msc/core/codegen/codegen.py index 888f1bad4ebe..7e3ddd5e07d4 100644 --- a/python/tvm/contrib/msc/core/codegen/codegen.py +++ b/python/tvm/contrib/msc/core/codegen/codegen.py @@ -224,6 +224,7 @@ def relay_to_relax( trans_config: Optional[Dict[str, str]] = None, build_config: Optional[Dict[str, str]] = None, opt_config: Optional[Dict[str, str]] = None, + build_folder: msc_utils.MSCDirectory = None, ) -> tvm.IRModule: """Change relay IRModule to relax MSCGraph. @@ -239,6 +240,8 @@ def relay_to_relax( The config for build MSCGraph. opt_config: dict The config for optimize the relay before translate. + build_folder: MSCDirectory + The folder for saving scripts and datas. Returns ------- @@ -254,4 +257,4 @@ def relay_to_relax( opt_config=opt_config, ) - return to_relax(graph, weights, codegen_config={"from_relay": True}) + return to_relax(graph, weights, codegen_config={"from_relay": True}, build_folder=build_folder) diff --git a/python/tvm/contrib/msc/core/utils/dataset.py b/python/tvm/contrib/msc/core/utils/dataset.py index e6461d107941..9f706dbf745f 100644 --- a/python/tvm/contrib/msc/core/utils/dataset.py +++ b/python/tvm/contrib/msc/core/utils/dataset.py @@ -20,12 +20,13 @@ import os import shutil import json -from typing import List, Union, Dict, Any +from typing import List, Union, Dict, Any, Tuple import numpy as np import tvm from .arguments import load_dict from .info import cast_array, is_array +from .namespace import MSCFramework def format_datas(datas: Union[List[Any], Dict[str, Any]], names: List[str], style="dict") -> Any: @@ -64,6 +65,51 @@ def format_datas(datas: Union[List[Any], Dict[str, Any]], names: List[str], styl raise TypeError("Unexpected style " + str(style)) +def random_data( + info: Union[List, Tuple, dict], + framework: str = MSCFramework.MSC, + device: str = "cpu", + max_val: int = None, +) -> Any: + """Create random data from info + + Parameters + ---------- + info: list| tuple| dict + The data info. + framework: str + The framework. + device: str + The device. + """ + + if isinstance(info, (tuple, list)): + if len(info) == 1: + info = {"name": "data", "shape": info[0], "dtype": "float32"} + elif len(info) == 2: + info = {"name": "data", "shape": info[0], "dtype": info[1]} + elif len(info) == 3: + info = {"name": info[0], "shape": info[1], "dtype": info[2]} + else: + raise Exception("Unexpected info " + str(info)) + assert isinstance(info, dict) and all( + key in info for key in ["shape", "dtype"] + ), "shape and dtype should be given to create randome data" + if info["dtype"] in ("int32", "int64"): + if max_val is None: + data = np.zeros(info["shape"]).astype(info["dtype"]) + else: + data = np.random.randint(0, high=max_val, size=info["shape"]).astype(info["dtype"]) + elif info["dtype"] == "bool": + data = np.random.rand(*info["shape"]).astype("float32") + data = np.where(data >= 0.5, True, False) + else: + data = np.random.rand(*info["shape"]).astype(info["dtype"]) + if max_val is not None: + data *= max_val + return cast_array(data, framework, device=device) + + class BaseDataLoader(object): """Basic dataset loader for MSC diff --git a/python/tvm/contrib/msc/framework/torch/frontend/translate.py b/python/tvm/contrib/msc/framework/torch/frontend/translate.py index c8c2844c2859..04597bd3419b 100644 --- a/python/tvm/contrib/msc/framework/torch/frontend/translate.py +++ b/python/tvm/contrib/msc/framework/torch/frontend/translate.py @@ -25,6 +25,7 @@ from tvm.contrib.msc.core.ir.graph import MSCGraph from tvm.contrib.msc.core.frontend import from_relax, normalize_inputs from tvm.contrib.msc.core.codegen import relay_to_relax +from tvm.contrib.msc.core import utils as msc_utils def set_weight_alias(graph: MSCGraph) -> MSCGraph: @@ -70,6 +71,7 @@ def from_torch( opt_config: Optional[Dict[str, str]] = None, as_msc: bool = True, custom_convert_map: dict = None, + build_folder: msc_utils.MSCDirectory = None, ) -> Tuple[Union[MSCGraph, tvm.IRModule], Dict[str, tvm.nd.array]]: """Change torch nn.Module to MSCGraph. @@ -93,6 +95,8 @@ def from_torch( Set to to return msc graph, otherwise relax mod custom_convert_map: dict The convert map for plugin + build_folder: MSCDirectory + The folder for saving scripts and datas. Returns ------- @@ -102,9 +106,15 @@ def from_torch( The weights from the IRModule. """ + # try to symbolic_trace if via_relax: - input_info = normalize_inputs(input_info) - graph_model, params = torch.fx.symbolic_trace(model), None + try: + graph_model = torch.fx.symbolic_trace(model) + except: # pylint: disable=bare-except + via_relax = False + + if via_relax: + input_info, params = normalize_inputs(input_info), None with torch.no_grad(): relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map) else: @@ -122,7 +132,9 @@ def from_torch( relay_mod, params = tvm.relay.frontend.from_pytorch( scripted_model, shape_list, custom_convert_map=custom_convert_map ) - relax_mod = relay_to_relax(relay_mod, params, trans_config, build_config, opt_config) + relax_mod = relay_to_relax( + relay_mod, params, trans_config, build_config, opt_config, build_folder=build_folder + ) if not as_msc: return relax_mod, params graph, weights = from_relax(relax_mod, trans_config=trans_config, build_config=build_config) diff --git a/python/tvm/contrib/msc/pipeline/pipeline.py b/python/tvm/contrib/msc/pipeline/pipeline.py index e003f692241c..09fc7727a622 100644 --- a/python/tvm/contrib/msc/pipeline/pipeline.py +++ b/python/tvm/contrib/msc/pipeline/pipeline.py @@ -21,7 +21,6 @@ import json from typing import Any, Union, List, Tuple import traceback -import numpy as np from tvm.contrib.msc.core.tools import get_tool_cls, BaseTool from tvm.contrib.msc.core.utils.namespace import MSCFramework, MSCMap, MSCKey @@ -678,7 +677,7 @@ def _get_loader(self, name: str = MSCStage.PREPARE) -> Any: def get_random(): def _to_data(inp): shape = [1 if isinstance(d, str) else d for d in inp[1]] - return np.random.rand(*shape).astype(inp[2]) + return msc_utils.random_data([shape, inp[2]]) for _ in range(max_batch): yield {i[0]: _to_data(i) for i in self._config["inputs"]} diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 322ee04e0c20..d84993c68d4e 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -783,6 +783,20 @@ def _reshape(self, node: fx.Node) -> relax.Var: dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] return self.block_builder.emit(relax.op.reshape(x, dims)) + def _scatter(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + if len(node.args) == 1: + dim = node.kwargs["dim"] + index = self.env[node.kwargs["index"]] + src = self.env[node.kwargs["src"]] + elif len(node.args) == 4: + dim = node.args[1] + index = self.env[node.args[2]] + src = self.env[node.args[3]] + else: + raise Exception("Unexpected args " + str(node.args)) + return self.block_builder.emit(relax.op.scatter_elements(x, index, src, axis=dim)) + def _split(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] split_size = node.args[1] @@ -801,6 +815,24 @@ def _squeeze(self, node: fx.Node) -> relax.Var: dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) return self.block_builder.emit(relax.op.squeeze(x, dim)) + def _stack(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + in_args = args[0] + assert all( + a.struct_info.shape[axis] == in_args[0].struct_info.shape[axis] for a in in_args[1:] + ), "Expect all dim at {} to be the same, get {}".format( + axis, [a.struct_info.shape for a in args] + ) + cat = self.block_builder.emit(relax.op.concat(in_args, axis=axis)) + s_shape = [] + for idx, s in enumerate(cat.struct_info.shape): + if idx == axis: + s_shape.extend([len(in_args), in_args[0].struct_info.shape[axis]]) + else: + s_shape.append(s) + return self.block_builder.emit(relax.op.reshape(cat, s_shape)) + def _tile(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 9fbc95fa7c00..746010a4dc8a 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -676,9 +676,11 @@ def create_convert_map( "permute": self._permute, "repeat": self._repeat, "reshape": self._reshape, + "scatter": self._scatter, "size": self._size, "split": self._split, "squeeze": self._squeeze, + "stack": self._stack, "tile": self._tile, "transpose": self._transpose, "unsqueeze": lambda node: self.block_builder.emit( diff --git a/src/contrib/msc/framework/torch/torch_opcode.cc b/src/contrib/msc/framework/torch/torch_opcode.cc index 9ae825b804aa..abac3682fbb1 100644 --- a/src/contrib/msc/framework/torch/torch_opcode.cc +++ b/src/contrib/msc/framework/torch/torch_opcode.cc @@ -214,14 +214,28 @@ class TorchConstantCodeGen : public TorchOpCode { protected: void CodeGenInit() final { + const auto& dtype = node()->OutputAt(0)->DTypeName(); + const auto& ref_name = StringUtils::Replace(node()->name, ".", "_"); if (node()->HasAttr("scalar")) { - if (node()->OutputAt(0)->DTypeName() == "int32") { + if (dtype == "int32") { stack_.assign(module_ref(), node()->GetTypeAttr("scalar")); - } else if (node()->OutputAt(0)->DTypeName() == "int64") { + } else if (dtype == "int64") { stack_.assign(module_ref(), node()->GetTypeAttr("scalar")); - } else if (node()->OutputAt(0)->DTypeName() == "float32") { + } else if (dtype == "float32") { stack_.assign(module_ref(), node()->GetTypeAttr("scalar")); } + } else if (dtype == "int32") { + stack_.func_call("register_buffer", "", "self") + .call_arg(DocUtils::ToStr(ref_name)) + .inplace_start("torch.IntTensor") + .call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape)) + .inplace_end(); + } else if (dtype == "int64") { + stack_.func_call("register_buffer", "", "self") + .call_arg(DocUtils::ToStr(ref_name)) + .inplace_start("torch.LongTensor") + .call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape)) + .inplace_end(); } else { stack_.func_call("torch.Tensor", "data") .call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape)) @@ -565,6 +579,39 @@ class TorchSimpleCodeGen : public TorchOpCode { TORCH_OP_CODEGEN_METHODS(TorchSimpleCodeGen); }; +class TorchScatterElementsCodeGen : public TorchOpCode { + TORCH_OP_CODEGEN_METHODS(TorchScatterElementsCodeGen) + + protected: + void CodeGenForward() final { + if (node()->InputAt(1)->DTypeName() == "int32") { + stack_.func_call("to", IdxInput(1), IdxInput(1)).call_arg("torch.int64"); + } + stack_.op_call() + .op_input_arg() + .op_arg("axis", "dim") + .op_input_arg(1, "index") + .op_input_arg(2, "src"); + } +}; + +class TorchScatterNDCodeGen : public TorchOpCode { + TORCH_OP_CODEGEN_METHODS(TorchScatterNDCodeGen) + + protected: + void CodeGenForward() final { + if (node()->InputAt(1)->DTypeName() == "int32") { + stack_.func_call("to", IdxInput(1), IdxInput(1)).call_arg("torch.int64"); + } + // relax add extra dim for indices + if (node()->InputAt(1)->Ndim() == node()->OutputAt(0)->Ndim()) { + stack_.func_call("squeeze", IdxInput(1), IdxInput(1)).call_arg(-1); + } + stack_.assign(DocUtils::ToIndex(IdxInput(0), IdxInput(1)), IdxInput(2)) + .assign(IdxNode(), IdxInput(0)); + } +}; + class TorchSplitCodeGen : public TorchOpCode { TORCH_OP_CODEGEN_METHODS(TorchSplitCodeGen) @@ -719,6 +766,9 @@ const std::shared_ptr>> map->emplace("permute_dims", std::make_shared("", "torch.permute")); map->emplace("repeat", std::make_shared("", "repeat")); map->emplace("reshape", std::make_shared("", "torch.reshape")); + map->emplace("scatter_elements", + std::make_shared("", "torch.scatter")); + map->emplace("scatter_nd", std::make_shared("", "")); map->emplace("split", std::make_shared("", "torch.split")); map->emplace("strided_slice", std::make_shared("", "")); diff --git a/src/contrib/msc/framework/tvm/relax_opcode.cc b/src/contrib/msc/framework/tvm/relax_opcode.cc index 73722f987701..a4be884858dc 100644 --- a/src/contrib/msc/framework/tvm/relax_opcode.cc +++ b/src/contrib/msc/framework/tvm/relax_opcode.cc @@ -568,6 +568,34 @@ class RelaxReshapeCodeGen : public RelaxOpCode { } }; +class RelaxScatterElementsCodeGen : public RelaxOpCode { + RELAX_OP_CODEGEN_METHODS(RelaxScatterElementsCodeGen) + + protected: + void CodeGenBuild() final { stack_.op_call().op_inputs_arg(false).op_arg("axis"); } +}; + +class RelaxScatterNDCodeGen : public RelaxOpCode { + RELAX_OP_CODEGEN_METHODS(RelaxScatterNDCodeGen) + + protected: + void CodeGenBuild() final { + if (config()->from_relay) { + size_t ndim = node()->InputAt(1)->Ndim(); + std::vector axes; + axes.push_back(ndim - 1); + for (size_t i = 0; i < ndim - 1; i++) { + axes.push_back(i); + } + stack_.func_call("relax.op.permute_dims", IdxInput(1)) + .call_arg(IdxInput(1)) + .call_arg(DocUtils::ToList(axes)); + BuilderEmit(IdxInput(1), "permute_" + std::to_string(node()->index)); + } + stack_.op_call().op_inputs_arg(false).op_str_arg("mode", "reduction"); + } +}; + class RelaxResize2dCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxResize2dCodeGen) @@ -626,6 +654,20 @@ class RelaxSplitCodeGen : public RelaxOpCode { } }; +class RelaxStackCodeGen : public RelaxOpCode { + RELAX_OP_CODEGEN_METHODS(RelaxStackCodeGen) + + protected: + void CodeGenBuild() final { + stack_.op_call().op_inputs_arg().op_arg("axis"); + BuilderEmit(IdxNode(), "cat_" + std::to_string(node()->index)); + const auto& out_shape = GetPrims(node()->OutputAt(0)); + stack_.func_call("relax.op.reshape", IdxNode()) + .call_arg(IdxNode()) + .call_arg(DocUtils::ToList(out_shape), "shape"); + } +}; + class RelaxTakeCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxTakeCodeGen) @@ -763,7 +805,11 @@ const std::shared_ptr>> map->emplace("permute_dims", std::make_shared("relax.op.permute_dims")); map->emplace("repeat", std::make_shared("relax.op.repeat")); map->emplace("reshape", std::make_shared("relax.op.reshape")); + map->emplace("scatter_elements", + std::make_shared("relax.op.scatter_elements")); + map->emplace("scatter_nd", std::make_shared("relax.op.scatter_nd")); map->emplace("split", std::make_shared("relax.op.split")); + map->emplace("stack", std::make_shared("relax.op.concat")); map->emplace("strided_slice", std::make_shared("relax.op.strided_slice")); map->emplace("take", std::make_shared("relax.op.take")); diff --git a/tests/python/contrib/test_msc/test_graph_build.py b/tests/python/contrib/test_msc/test_graph_build.py index 76e3147a5507..647879378e0c 100644 --- a/tests/python/contrib/test_msc/test_graph_build.py +++ b/tests/python/contrib/test_msc/test_graph_build.py @@ -20,21 +20,16 @@ import pytest import torch -from torch import fx from torch.nn import Module import tvm.testing -from tvm.relax.frontend.torch import from_fx -from tvm.contrib.msc.core.frontend import translate, normalize_inputs +from tvm.contrib.msc.framework.torch.frontend import translate +from tvm.contrib.msc.core.utils.namespace import MSCFramework from tvm.contrib.msc.core import utils as msc_utils def verify_model(torch_model, input_info, expected): - input_info = normalize_inputs(input_info) - graph_model = fx.symbolic_trace(torch_model) - with torch.no_grad(): - mod = from_fx(graph_model, input_info) - graph, _ = translate.from_relax(mod) + graph, _ = translate.from_torch(torch_model, input_info) inspect = graph.inspect() assert msc_utils.dict_equal(inspect, expected), "Inspect {} mismatch with expected {}".format( inspect, expected @@ -2389,6 +2384,119 @@ def forward(self, data): verify_model(Cat2(), [([1, 3, 10, 10], "float32")], expected2) +@pytest.mark.parametrize("dynamic", [True, False]) +def test_stack(dynamic): + """test graph builder for stack""" + + bz = "bz" if dynamic else 1 + + class Stack(Module): + def forward(self, data, data1, data2): + return torch.stack((data, data1, data2), dim=0) + + input_info = [ + ([bz, 3, 10, 10], "float32"), + ([bz, 3, 10, 10], "float32"), + ([bz, 3, 10, 10], "float32"), + ] + + expected = { + "inputs": [ + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}, + {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}, + {"name": "inp_2", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}, + ], + "outputs": [ + { + "name": "reshape", + "shape": [3, bz, 3, 10, 10], + "dtype": "float32", + "layout": "" if dynamic else "EABCD", + } + ], + "nodes": {"total": 5, "input": 3, "concat": 1, "reshape": 1}, + } + + if dynamic: + expected["prims"] = {"total": 3, "shape": 1, "Int": 1, "Mul": 1} + + verify_model(Stack(), input_info, expected) + + +@pytest.mark.parametrize("dynamic", [True, False]) +def test_scatter(dynamic): + """test graph builder for scatter""" + + bz = "bz" if dynamic else 20 + + class Scatter1(Module): + def __init__(self): + super().__init__() + self.index = msc_utils.random_data([(2, 5), "int64"], MSCFramework.TORCH, max_val=5) + + def forward(self, data, src): + return data.scatter(dim=0, index=self.index, src=src) + + class Scatter2(Module): + def forward(self, data, index, src): + return data.scatter(0, index, src) + + expected1 = { + "inputs": [ + {"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout": ""}, + {"name": "inp_1", "shape": [2, 5], "dtype": "float32", "layout": ""}, + ], + "outputs": [ + {"name": "scatter_elements", "shape": [bz, 20], "dtype": "float32", "layout": ""} + ], + "nodes": {"total": 4, "input": 2, "constant": 1, "scatter_elements": 1}, + } + expected2 = { + "inputs": [ + {"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout": ""}, + {"name": "inp_1", "shape": [2, 5], "dtype": "int64", "layout": ""}, + {"name": "inp_2", "shape": [2, 5], "dtype": "float32", "layout": ""}, + ], + "outputs": [ + {"name": "scatter_elements", "shape": [bz, 20], "dtype": "float32", "layout": ""} + ], + "nodes": {"total": 4, "input": 3, "scatter_elements": 1}, + } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} + + verify_model(Scatter1(), [([bz, 20], "float32"), ([2, 5], "float32")], expected1) + verify_model( + Scatter2(), [([bz, 20], "float32"), ([2, 5], "int64"), ([2, 5], "float32")], expected2 + ) + + +def test_put(): + """test graph builder for index_put""" + + class IndexPut(Module): + def __init__(self): + super().__init__() + self.index = msc_utils.random_data([(5), "int64"], MSCFramework.TORCH, max_val=5) + + def forward(self, data, src): + data[self.index] = src + return data + + expected = { + "inputs": [ + {"name": "input0", "shape": [10, 20], "dtype": "float32", "layout": ""}, + {"name": "input1", "shape": [5, 20], "dtype": "float32", "layout": ""}, + ], + "outputs": [{"name": "scatter_nd", "shape": [10, 20], "dtype": "float32", "layout": ""}], + "nodes": {"total": 4, "input": 2, "constant": 1, "scatter_nd": 1}, + } + + input_info = [([10, 20], "float32"), ([5, 20], "float32")] + verify_model(IndexPut(), input_info, expected) + + @pytest.mark.parametrize("dynamic", [True, False]) def test_attention(dynamic): """test graph builder for attention""" diff --git a/tests/python/contrib/test_msc/test_translate_relax.py b/tests/python/contrib/test_msc/test_translate_relax.py index 64d00bb0922e..27a02844e19d 100644 --- a/tests/python/contrib/test_msc/test_translate_relax.py +++ b/tests/python/contrib/test_msc/test_translate_relax.py @@ -18,27 +18,26 @@ """ Test translate from relax. """ import torch -from torch import fx from torch.nn import Module import numpy as np import tvm.testing -from tvm.relax.frontend.torch import from_fx -from tvm.contrib.msc.core.frontend import translate +from tvm.contrib.msc.framework.torch.frontend import translate as torch_translate + from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen +from tvm.contrib.msc.core.frontend import translate as core_translate +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils def verify_model(torch_model, input_info, opt_config=None): """Compare torch module IR""" - graph_model = fx.symbolic_trace(torch_model) - with torch.no_grad(): - orig_mod = from_fx(graph_model, input_info) - + orig_mod, _ = torch_translate.from_torch(torch_model, input_info, as_msc=False) target = "llvm" dev = tvm.cpu() - args = [tvm.nd.array(np.random.random(size=shape).astype(dtype)) for shape, dtype in input_info] + args = [msc_utils.random_data(i, MSCFramework.TVM) for i in input_info] def _tvm_runtime_to_np(obj): if isinstance(obj, tvm.runtime.NDArray): @@ -60,7 +59,7 @@ def _run_relax(relax_mod): return _tvm_runtime_to_np(res) rt_mod = tvm_codegen.to_relax( - *translate.from_relax(orig_mod, opt_config=opt_config), + *core_translate.from_relax(orig_mod, opt_config=opt_config), codegen_config={"explicit_name": False}, ) @@ -1153,6 +1152,63 @@ def forward(self, data): verify_model(Cat2(), [([1, 3, 10, 10], "float32")]) +def test_stack(): + """test relax translator for stack""" + + class Stack1(Module): + def forward(self, data, data1, data2): + return torch.stack((data, data1, data2), dim=0) + + class Stack2(Module): + def forward(self, data): + const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32) + const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32) + return torch.stack((data, const1, const2), dim=1) + + input_info = [ + ([1, 3, 10, 10], "float32"), + ([1, 3, 10, 10], "float32"), + ([1, 3, 10, 10], "float32"), + ] + verify_model(Stack1(), input_info) + verify_model(Stack2(), [([1, 3, 10, 10], "float32")]) + + +def test_scatter(): + """test relax translator for scatter""" + + class Scatter1(Module): + def __init__(self): + super().__init__() + self.index = msc_utils.random_data([(2, 5), "int64"], MSCFramework.TORCH, max_val=5) + + def forward(self, data, src): + return data.scatter(dim=0, index=self.index, src=src) + + class Scatter2(Module): + def forward(self, data, index, src): + return data.scatter(0, index, src) + + verify_model(Scatter1(), [([20, 20], "float32"), ([2, 5], "float32")]) + verify_model(Scatter2(), [([20, 20], "float32"), ([2, 5], "int64"), ([2, 5], "float32")]) + + +def test_put(): + """test relax translator for index_put""" + + class IndexPut(Module): + def __init__(self): + super().__init__() + self.index = msc_utils.random_data([(5), "int64"], MSCFramework.TORCH, max_val=5) + + def forward(self, data, src): + data[self.index] = src + return data + + input_info = [([10, 20], "float32"), ([5, 20], "float32")] + verify_model(IndexPut(), input_info) + + def test_attention(): """test relax translator for attention""" diff --git a/tests/python/contrib/test_msc/test_translate_relay.py b/tests/python/contrib/test_msc/test_translate_relay.py index ebba339a4a3e..801893e9debd 100644 --- a/tests/python/contrib/test_msc/test_translate_relay.py +++ b/tests/python/contrib/test_msc/test_translate_relay.py @@ -18,8 +18,6 @@ """ Test translate from relay. """ -import numpy as np - import torch from torch import fx from torch.nn import Module @@ -66,7 +64,7 @@ def verify_model(torch_model, input_info, opt_config=None, codegen_config=None, expected = tvm.relax.transform.CanonicalizeBindings()(expected) # graph from relay - datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info] + datas = [msc_utils.random_data(i) for i in input_info] torch_datas = [torch.from_numpy(i) for i in datas] with torch.no_grad(): scripted_model = torch.jit.trace(torch_model, tuple(torch_datas)).eval() # type: ignore diff --git a/tests/python/contrib/test_msc/test_translate_tensorrt.py b/tests/python/contrib/test_msc/test_translate_tensorrt.py index 6d87ca8753dc..e0fd39249a31 100644 --- a/tests/python/contrib/test_msc/test_translate_tensorrt.py +++ b/tests/python/contrib/test_msc/test_translate_tensorrt.py @@ -18,7 +18,6 @@ """ Test translate for TensorrRT. """ import pytest -import numpy as np import torch from torch import fx @@ -91,7 +90,7 @@ def verify_model(torch_model, input_info, **trans_config): """Build model and verify results""" graph_model = fx.symbolic_trace(torch_model) - datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info] + datas = [msc_utils.random_data(i) for i in input_info] torch_datas = [torch.from_numpy(i) for i in datas] with torch.no_grad(): golden = torch_model(*torch_datas) diff --git a/tests/python/contrib/test_msc/test_translate_torch.py b/tests/python/contrib/test_msc/test_translate_torch.py index 55bae682ef20..6ed28c0ac0b7 100644 --- a/tests/python/contrib/test_msc/test_translate_torch.py +++ b/tests/python/contrib/test_msc/test_translate_torch.py @@ -17,24 +17,24 @@ """ Test translate from torch. """ -import numpy as np - import torch from torch.nn import Module import tvm.testing from tvm.contrib.msc.framework.torch.frontend import translate from tvm.contrib.msc.framework.torch import codegen +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils def verify_model(torch_model, input_info, via_relax=True): """Compare torch module results""" - graph, weights = translate.from_torch(torch_model, input_info, via_relax=via_relax) - model = codegen.to_torch(graph, weights) - torch_datas = [torch.from_numpy(np.random.rand(*i[0]).astype(i[1])) for i in input_info] + torch_datas = [msc_utils.random_data(i, MSCFramework.TORCH) for i in input_info] with torch.no_grad(): golden = torch_model(*torch_datas) + graph, weights = translate.from_torch(torch_model, input_info, via_relax=via_relax) + model = codegen.to_torch(graph, weights) with torch.no_grad(): if not graph.get_inputs(): result = model() @@ -1128,6 +1128,67 @@ def forward(self, data): verify_model(Cat2(), [([1, 3, 10, 10], "float32")], via_relax) +def test_stack(): + """test torch translator for stack""" + + class Stack1(Module): + def forward(self, data, data1, data2): + return torch.stack((data, data1, data2), dim=0) + + class Stack2(Module): + def forward(self, data): + const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32) + const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32) + return torch.stack((data, const1, const2), dim=1) + + input_info = [ + ([1, 3, 10, 10], "float32"), + ([1, 3, 10, 10], "float32"), + ([1, 3, 10, 10], "float32"), + ] + for via_relax in [True, False]: + verify_model(Stack1(), input_info, via_relax) + verify_model(Stack2(), [([1, 3, 10, 10], "float32")], via_relax) + + +def test_scatter(): + """test torch translator for scatter""" + + class Scatter1(Module): + def __init__(self): + super().__init__() + self.index = msc_utils.random_data([(2, 5), "int64"], MSCFramework.TORCH, max_val=5) + + def forward(self, data, src): + return data.scatter(dim=0, index=self.index, src=src) + + class Scatter2(Module): + def forward(self, data, index, src): + return data.scatter(0, index, src) + + for via_relax in [True, False]: + verify_model(Scatter1(), [([20, 20], "float32"), ([2, 5], "float32")], via_relax) + verify_model( + Scatter2(), [([20, 20], "float32"), ([2, 5], "int64"), ([2, 5], "float32")], via_relax + ) + + +def test_put(): + """test torch translator for index_put""" + + class IndexPut(Module): + def __init__(self): + super().__init__() + self.index = msc_utils.random_data([(5), "int64"], MSCFramework.TORCH, max_val=5) + + def forward(self, data, src): + data[self.index] = src + return data + + input_info = [([10, 20], "float32"), ([5, 20], "float32")] + verify_model(IndexPut(), input_info, False) + + def test_attention(): """test torch translator for attention""" diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 2cabcba325b2..08331f08612b 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3963,5 +3963,65 @@ def main( verify_model(SymSizeInt1(dim=-2), [([1, 3, 4], "float32")], {}, Expected1) +def test_stack(): + + input_info = [ + ([1, 3, 10, 10], "float32"), + ([1, 3, 10, 10], "float32"), + ([1, 3, 10, 10], "float32"), + ] + + class Stack(Module): + def forward(self, data, data1, data2): + return torch.stack((data, data1, data2), dim=0) + + @tvm.script.ir_module + class expected: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"), + inp_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + inp_2: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((3, 1, 3, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((3, 3, 10, 10), dtype="float32") = R.concat( + (inp_0, inp_1, inp_2), axis=0 + ) + lv1: R.Tensor((3, 1, 3, 10, 10), dtype="float32") = R.reshape( + lv, R.shape([3, 1, 3, 10, 10]) + ) + gv: R.Tensor((3, 1, 3, 10, 10), dtype="float32") = lv1 + R.output(gv) + return gv + + verify_model(Stack(), input_info, {}, expected) + + +def test_scatter(): + input_info = [([20, 20], "float32"), ([2, 5], "int64"), ([2, 5], "float32")] + + class Scatter(Module): + def forward(self, data, index, src): + return data.scatter(dim=0, index=index, src=src) + + @tvm.script.ir_module + class expected: + @R.function + def main( + inp_0: R.Tensor((20, 20), dtype="float32"), + inp_1: R.Tensor((2, 5), dtype="int64"), + inp_2: R.Tensor((2, 5), dtype="float32"), + ) -> R.Tensor((20, 20), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((20, 20), dtype="float32") = R.scatter_elements( + inp_0, inp_1, inp_2, axis=0, reduction="update" + ) + gv: R.Tensor((20, 20), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Scatter(), input_info, {}, expected) + + if __name__ == "__main__": tvm.testing.main()