From 5f370390347ac5a093ce9ac06059f0b743eb3e04 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 26 Aug 2024 17:30:05 +0900 Subject: [PATCH 1/3] add test --- tests/python/relax/test_frontend_from_fx.py | 36 +++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 6be3e7b23e9d..be908eb706e1 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3191,6 +3191,42 @@ def main( verify_model(Transpose(), input_info, {}, expected1) +def test_repeat(): + class Tile1(Module): + def forward(self, x: torch.Tensor): + return x.repeat(2) + + class Tile2(Module): + def forward(self, x: torch.Tensor): + return x.repeat(4, 2) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((3,), dtype="float32")) -> R.Tensor((6,), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((6,), dtype="float32") = R.tile(x, 2) + gv: R.Tensor((6,), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main(x: R.Tensor((1, 3), dtype="float32")) -> R.Tensor((4, 6), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2]) + gv: R.Tensor((4, 6), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Tile1(), [([3], "float32")], {}, expected1) + verify_model(Tile2(), [([1, 3], "float32")], {}, expected2) + verify_model(Tile2(), [(torch.Size([1, 3]), "float32")], {}, expected2) + + def test_view(): input_info = [([1, 2, 3, 4], "float32")] From 801a480ac7d2108541e21d456bfddbd055a3c555 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 26 Aug 2024 17:30:39 +0900 Subject: [PATCH 2/3] add support for torch.repeat --- python/tvm/relax/frontend/torch/fx_translator.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 35131d324076..447457ab87ea 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -612,6 +612,15 @@ def _squeeze(self, node: fx.node.Node) -> relax.Var: dim = None return self.block_builder.emit(relax.op.squeeze(x, dim)) + def _repeat(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + print(args) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) + return self.block_builder.emit(relax.op.tile(args[0], args[1:])) + def _tile(self, node: fx.node.Node) -> relax.Var: import torch # type: ignore @@ -1456,6 +1465,7 @@ def create_convert_map(self): "expand": self._expand, "flatten": self._flatten, "permute": self._permute, + "repeat": self._repeat, "reshape": self._reshape, "split": self._split, "tile": self._tile, From 701eea0cdae67fed1f4e0e1a9f68bf890aa588ba Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 26 Aug 2024 17:53:40 +0900 Subject: [PATCH 3/3] remove debug print --- python/tvm/relax/frontend/torch/fx_translator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 447457ab87ea..e82881487a45 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -616,7 +616,6 @@ def _repeat(self, node: fx.node.Node) -> relax.Var: import torch # type: ignore args = self.retrieve_args(node) - print(args) if isinstance(args[1], (torch.Size, tuple, list)): return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) return self.block_builder.emit(relax.op.tile(args[0], args[1:]))