From 0ad90ea734cc7b0f3b239796d633250bd3650a25 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Thu, 3 Apr 2025 21:31:32 +0800 Subject: [PATCH 1/5] Update fx_translator.py --- python/tvm/relax/frontend/torch/fx_translator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 3ddf919c2ed1..f803b453ca2b 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -785,6 +785,7 @@ def create_convert_map( "argmin": self._argmax_argmin(relax.op.argmin), "where": self._where, # tensor manipulation + "argsort": self._argsort, "cat": self._cat, "chunk": self._chunk, "concat": self._cat, @@ -803,11 +804,13 @@ def create_convert_map( "scatter": self._scatter, "select": self._select, "size": self._size, + "sort": self._sort, "split": self._split, "squeeze": self._squeeze, "stack": self._stack, "take": self._take, "tile": self._tile, + "topk": self._topk, "transpose": self._transpose, "unsqueeze": lambda node: self.block_builder.emit( relax.op.expand_dims(self.env[node.args[0]], node.args[1]) From 439200f6e8704fbbd9faba1403b7af4d347f3571 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Thu, 3 Apr 2025 21:34:35 +0800 Subject: [PATCH 2/5] Update base_fx_graph_translator.py --- .../torch/base_fx_graph_translator.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 890f925079e0..73283a64053f 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -959,6 +959,12 @@ def _where(self, node: fx.Node) -> relax.Var: ########## Manipulation ########## + def _argsort(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) + descending = node.args[2] if len(node.args) > 2 else node.kwargs.get("descending", False) + return self.block_builder.emit(relax.op.argsort(x, dim, descending)) + def _cat(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) @@ -1071,6 +1077,12 @@ def _scatter(self, node: fx.Node) -> relax.Var: raise Exception("Unexpected args " + str(node.args)) return self.block_builder.emit(relax.op.scatter_elements(x, index, src, axis=dim)) + def _sort(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) + descending = node.args[2] if len(node.args) > 2 else node.kwargs.get("descending", False) + return self.block_builder.emit(relax.op.sort(x, dim, descending)) + def _split(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] split_size = node.args[1] @@ -1121,6 +1133,22 @@ def _tile(self, node: fx.Node) -> relax.Var: dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] return self.block_builder.emit(relax.op.tile(x, dims)) + def _topk(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + k = args[1] if len(args) > 1 else node.kwargs.get("k", 1) + dim = args[2] if len(args) > 2 else node.kwargs.get("dim", -1) + largest = args[3] if len(args) > 3 else node.kwargs.get("largest", True) + sorted = args[4] if len(args) > 4 else node.kwargs.get("sorted", True) + + if not sorted: + msg = "Currently supports only sorted output for topk operator." + raise AssertionError(msg) + + return self.block_builder.emit( + relax.op.topk(x, k=k, axis=dim, largest=largest, ret_type="both", dtype="int64") + ) + def _transpose(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) full_idx = list(range(len(self.shape_of(args[0])))) From 1df7d403b786254ce0647cac68bfc615964822cc Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Thu, 3 Apr 2025 21:35:41 +0800 Subject: [PATCH 3/5] Update test_frontend_from_fx.py --- tests/python/relax/test_frontend_from_fx.py | 64 +++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index d913baf13a0d..6c439ab56c78 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4368,5 +4368,69 @@ def main( ) +def test_argsort(): + class Argsort(Module): + def forward(self, x): + return torch.argsort(x, dim=1, descending=True) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tensor((5, 3), dtype="int32"): + with R.dataflow(): + lv: R.Tensor((5, 3), dtype="int32") = R.argsort(inp_0, axis=1, descending=True) + gv: R.Tensor((5, 3), dtype="int32") = lv + R.output(gv) + return gv + + verify_model(Argsort(), [([5, 3], "float32")], {}, Expected) + + +def test_sort(): + class Sort(Module): + def forward(self, x): + return torch.sort(x, dim=1, descending=True) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tensor((5, 3), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((5, 3), dtype="float32") = R.sort(inp_0, axis=1, descending=True) + gv: R.Tensor((5, 3), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sort(), [([5, 3], "float32")], {}, Expected) + + +def test_topk(): + class Topk(Module): + def forward(self, x): + return torch.topk(x, k=2, dim=1, largest=True, sorted=True) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")): + with R.dataflow(): + lv: R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")) = R.topk(inp_0, k=2, + axis=1, + ret_type="both", + largest=True, + dtype="int64") + gv: R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")) = lv + R.output(gv) + return gv + + verify_model(Topk(), [([5, 3], "float32")], {}, Expected) + + if __name__ == "__main__": tvm.testing.main() From 8b65653a2816dece52f7ab221443e53e684c1a44 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Fri, 4 Apr 2025 23:15:44 +0800 Subject: [PATCH 4/5] Update test_frontend_from_fx.py --- tests/python/relax/test_frontend_from_fx.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 6c439ab56c78..2c5560b577c4 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4417,14 +4417,12 @@ def forward(self, x): class Expected: @R.function def main( - inp_0: R.Tensor((5, 3), dtype="float32"), + inp_0: R.Tensor((5, 3), dtype="float32"), ) -> R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")): with R.dataflow(): - lv: R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")) = R.topk(inp_0, k=2, - axis=1, - ret_type="both", - largest=True, - dtype="int64") + lv: R.Tuple( + R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64") + ) = R.topk(inp_0, k=2, axis=1, ret_type="both", largest=True, dtype="int64") gv: R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")) = lv R.output(gv) return gv From 89a372d93494a76989cadc4387c9095960123cc1 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sat, 5 Apr 2025 01:18:09 +0800 Subject: [PATCH 5/5] fix lint --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 73283a64053f..cca03e95e62c 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1139,9 +1139,9 @@ def _topk(self, node: fx.Node) -> relax.Var: k = args[1] if len(args) > 1 else node.kwargs.get("k", 1) dim = args[2] if len(args) > 2 else node.kwargs.get("dim", -1) largest = args[3] if len(args) > 3 else node.kwargs.get("largest", True) - sorted = args[4] if len(args) > 4 else node.kwargs.get("sorted", True) + _sorted = args[4] if len(args) > 4 else node.kwargs.get("_sorted", True) - if not sorted: + if not _sorted: msg = "Currently supports only sorted output for topk operator." raise AssertionError(msg)