From 019dd965afbf365904c82accb988dcae02a2c57b Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Tue, 4 Nov 2025 19:07:13 -0500 Subject: [PATCH 1/2] finish1 --- .../frontend/torch/base_fx_graph_translator.py | 14 ++++++++++++++ .../relax/test_frontend_from_exported_program.py | 8 ++++---- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index aedef8acf84c..3a67ebed5389 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1725,6 +1725,20 @@ def _squeeze(self, node: fx.Node) -> relax.Var: # Support both "dim" and "dims" parameters if dim is None: dim = node.kwargs.get("dims", None) + + # If dims is a list, filter out axes where dimension is not 1 + # This is needed because PyTorch decomposition may pass all axes + if isinstance(dim, (list, tuple)) and len(dim) > 0: + shape = self.shape_of(x) + # Filter to only include axes where the dimension is 1 + valid_dims = [] + for d in dim: + axis = d if d >= 0 else len(shape) + d + if axis < len(shape) and shape[axis] == 1: + valid_dims.append(d) + # If no valid dims, use None to squeeze all size-1 dimensions + dim = valid_dims if valid_dims else None + return self.block_builder.emit(relax.op.squeeze(x, dim)) def _stack(self, node: fx.Node) -> relax.Var: diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 8a9fe66a0fad..a6001b85a5cb 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4545,18 +4545,18 @@ def forward(self, input): class Expected2: @R.function def main( - inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") + input: R.Tensor((3, 1, 4, 1), dtype="float32") ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")): with R.dataflow(): - lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(inp_0, axis=None) + lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(input, axis=[1, 3]) gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,) R.output(gv) return gv example_args = (torch.randn(3, 1, 4, 1, dtype=torch.float32),) - verify_model(Squeeze1(), example_args, {}, Expected1) - verify_model(Squeeze2(), example_args, {}, Expected2) + verify_model(Squeeze1(), example_args, {}, Expected1, run_ep_decomposition=True) + verify_model(Squeeze2(), example_args, {}, Expected2, run_ep_decomposition=True) def test_stack(): From d31664b8d3debb6ed9e4b4c153215798acd3127c Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Tue, 4 Nov 2025 23:01:53 -0500 Subject: [PATCH 2/2] finish2 --- .../torch/base_fx_graph_translator.py | 4 +- .../torch/exported_program_translator.py | 20 +- .../test_frontend_from_exported_program.py | 221 ++++++++---------- 3 files changed, 114 insertions(+), 131 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 3a67ebed5389..03e3b8d557d0 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1725,7 +1725,7 @@ def _squeeze(self, node: fx.Node) -> relax.Var: # Support both "dim" and "dims" parameters if dim is None: dim = node.kwargs.get("dims", None) - + # If dims is a list, filter out axes where dimension is not 1 # This is needed because PyTorch decomposition may pass all axes if isinstance(dim, (list, tuple)) and len(dim) > 0: @@ -1738,7 +1738,7 @@ def _squeeze(self, node: fx.Node) -> relax.Var: valid_dims.append(d) # If no valid dims, use None to squeeze all size-1 dimensions dim = valid_dims if valid_dims else None - + return self.block_builder.emit(relax.op.squeeze(x, dim)) def _stack(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 3be255a29a65..4f3132b8d8f2 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -701,11 +701,23 @@ def _select(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.take(x, index, dim)) def _slice(self, node: fx.Node) -> relax.Var: + import sys + x = self.env[node.args[0]] - axes = [node.args[1]] - begin = [node.args[2]] - end = [node.args[3]] - stride = [node.args[4] if len(node.args) > 4 else 1] + dim = node.args[1] if len(node.args) > 1 else 0 + start = node.args[2] if len(node.args) > 2 else None + end_val = node.args[3] if len(node.args) > 3 else None + step = node.args[4] if len(node.args) > 4 else 1 + + if start is None: + start = 0 + if end_val is None: + end_val = sys.maxsize + + axes = [dim] + begin = [start] + end = [end_val] + stride = [step] return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride)) def _unflatten(self, node: fx.Node) -> relax.Var: diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index a6001b85a5cb..44248c1c59f4 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4111,7 +4111,7 @@ def main( return gv example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) - verify_model(Reshape(), example_args, {}, expected1) + verify_model(Reshape(), example_args, {}, expected1, run_ep_decomposition=True) def test_reshape_as(): @@ -4137,7 +4137,7 @@ def main( torch.randn(1, 2, 3, 4, dtype=torch.float32), torch.randn(2, 12, dtype=torch.float32), ) - verify_model(ReshapeAs(), example_args, {}, expected1) + verify_model(ReshapeAs(), example_args, {}, expected1, run_ep_decomposition=True) def test_roll(): @@ -4160,25 +4160,14 @@ class Expected1: def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")): with R.dataflow(): lv: R.Tensor((8,), dtype="int64") = R.reshape(x, R.shape([8])) - lv1: R.Tensor((7,), dtype="int64") = R.strided_slice( - lv, - axes=[0], - begin=[R.prim_value(0)], - end=[R.prim_value(7)], - strides=[R.prim_value(1)], - assume_inbound=False, + lv1: R.Tensor((8,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(8), R.prim_value(1), dtype="int64" ) - lv2: R.Tensor((1,), dtype="int64") = R.strided_slice( - lv, - axes=[0], - begin=[R.prim_value(7)], - end=[R.prim_value(8)], - strides=[R.prim_value(1)], - assume_inbound=False, - ) - lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1), axis=0) - lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3, R.shape([4, 2])) - gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv4,) + lv2: R.Tensor((8,), dtype="int64") = R.add(lv1, R.const(7, "int64")) + lv3: R.Tensor((8,), dtype="int64") = R.mod(lv2, R.const(8, "int64")) + lv4: R.Tensor((8,), dtype="int64") = R.take(lv, lv3, axis=0, mode="fast") + lv5: R.Tensor((4, 2), dtype="int64") = R.reshape(lv4, R.shape([4, 2])) + gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv5,) R.output(gv) return gv @@ -4188,24 +4177,13 @@ class Expected2: @R.function def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")): with R.dataflow(): - lv: R.Tensor((1, 2), dtype="int64") = R.strided_slice( - x, - axes=[0], - begin=[R.prim_value(0)], - end=[R.prim_value(1)], - strides=[R.prim_value(1)], - assume_inbound=False, - ) - lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice( - x, - axes=[0], - begin=[R.prim_value(1)], - end=[R.prim_value(4)], - strides=[R.prim_value(1)], - assume_inbound=False, + lv: R.Tensor((4,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(4), R.prim_value(1), dtype="int64" ) - lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) - gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv2,) + lv1: R.Tensor((4,), dtype="int64") = R.add(lv, R.const(1, "int64")) + lv2: R.Tensor((4,), dtype="int64") = R.mod(lv1, R.const(4, "int64")) + lv3: R.Tensor((4, 2), dtype="int64") = R.take(x, lv2, axis=0, mode="fast") + gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv3,) R.output(gv) return gv @@ -4216,43 +4194,20 @@ class Expected3: def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")): with R.dataflow(): # First roll along dim=0 with shift=2 - lv: R.Tensor((2, 2), dtype="int64") = R.strided_slice( - x, - axes=[0], - begin=[R.prim_value(0)], - end=[R.prim_value(2)], - strides=[R.prim_value(1)], - assume_inbound=False, - ) - lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice( - x, - axes=[0], - begin=[R.prim_value(2)], - end=[R.prim_value(4)], - strides=[R.prim_value(1)], - assume_inbound=False, + lv: R.Tensor((4,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(4), R.prim_value(1), dtype="int64" ) - lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) - + lv1: R.Tensor((4,), dtype="int64") = R.add(lv, R.const(2, "int64")) + lv2: R.Tensor((4,), dtype="int64") = R.mod(lv1, R.const(4, "int64")) + lv3: R.Tensor((4, 2), dtype="int64") = R.take(x, lv2, axis=0, mode="fast") # Second roll along dim=1 with shift=1 - lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice( - lv2, - axes=[1], - begin=[R.prim_value(0)], - end=[R.prim_value(1)], - strides=[R.prim_value(1)], - assume_inbound=False, - ) - lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice( - lv2, - axes=[1], - begin=[R.prim_value(1)], - end=[R.prim_value(2)], - strides=[R.prim_value(1)], - assume_inbound=False, + lv4: R.Tensor((2,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(2), R.prim_value(1), dtype="int64" ) - lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3), axis=1) - gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv5,) + lv5: R.Tensor((2,), dtype="int64") = R.add(lv4, R.const(1, "int64")) + lv6: R.Tensor((2,), dtype="int64") = R.mod(lv5, R.const(2, "int64")) + lv7: R.Tensor((4, 2), dtype="int64") = R.take(lv3, lv6, axis=1, mode="fast") + gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv7,) R.output(gv) return gv @@ -4260,9 +4215,9 @@ def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype=" example_input = torch.randint(0, 10, (4, 2), dtype=torch.int64) # Run verification for each case - verify_model(Roll1(), (example_input,), {}, Expected1) - verify_model(Roll2(), (example_input,), {}, Expected2) - verify_model(Roll3(), (example_input,), {}, Expected3) + verify_model(Roll1(), (example_input,), {}, Expected1, run_ep_decomposition=True) + verify_model(Roll2(), (example_input,), {}, Expected2, run_ep_decomposition=True) + verify_model(Roll3(), (example_input,), {}, Expected3, run_ep_decomposition=True) def test_select_slice(): @@ -4342,10 +4297,10 @@ def main( return gv example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) - verify_model(Slice1(), example_args, {}, expected1) + verify_model(Slice1(), example_args, {}, expected1, run_ep_decomposition=True) example_args = (torch.randn(8, 16, dtype=torch.float32),) - verify_model(Slice2(), example_args, {}, expected2) + verify_model(Slice2(), example_args, {}, expected2, run_ep_decomposition=True) def test_slice_scatter(): @@ -4387,10 +4342,10 @@ def main( return gv example_args = (torch.randn(8, 8, 10, 10, dtype=torch.float32), torch.randn(8, 3, 10, 10)) - verify_model(SliceScatter1(), example_args, {}, expected1) + verify_model(SliceScatter1(), example_args, {}, expected1, run_ep_decomposition=True) example_args = (torch.randn(8, 16, dtype=torch.float32), torch.randn(6, 16)) - verify_model(SliceScatter2(), example_args, {}, expected2) + verify_model(SliceScatter2(), example_args, {}, expected2, run_ep_decomposition=True) def test_split(): @@ -4402,7 +4357,7 @@ def forward(self, input): class Expected: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple( R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"), @@ -4414,7 +4369,7 @@ def main( R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"), - ) = R.split(input_1, indices_or_sections=3, axis=1) + ) = R.split(input, indices_or_sections=[1, 2], axis=1) lv1: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[0] lv2: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[1] lv3: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[2] @@ -4434,7 +4389,7 @@ def forward(self, data): class expected1: @R.function def main( - input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + data: R.Tensor((3, 3, 10, 10), dtype="float32") ) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), @@ -4442,30 +4397,38 @@ def main( ): # block 0 with R.dataflow(): - lv: R.Tuple( - R.Tensor((1, 3, 10, 10), dtype="float32"), - R.Tensor((1, 3, 10, 10), dtype="float32"), - R.Tensor((1, 3, 10, 10), dtype="float32"), - ) = R.split(input_1, indices_or_sections=3, axis=0) - lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] - lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) - lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1] - lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[0]) - lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[2] - lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[0]) - lv7: R.Tuple( - R.Tensor((3, 10, 10), dtype="float32"), - R.Tensor((3, 10, 10), dtype="float32"), - R.Tensor((3, 10, 10), dtype="float32"), - ) = (lv2, lv4, lv6) - lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] - lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] - lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice( + data, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice( + data, + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice( + data, + (R.prim_value(0),), + (R.prim_value(2),), + (R.prim_value(3),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv3: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv, axis=[0]) + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) + lv5: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv2, axis=[0]) gv: R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), - ) = (lv8, lv9, lv10) + ) = (lv3, lv4, lv5) R.output(gv) return gv @@ -4477,7 +4440,7 @@ def forward(self, data): class expected2: @R.function def main( - input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + data: R.Tensor((3, 3, 10, 10), dtype="float32") ) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), @@ -4485,39 +4448,47 @@ def main( ): # block 0 with R.dataflow(): - lv: R.Tuple( - R.Tensor((3, 1, 10, 10), dtype="float32"), - R.Tensor((3, 1, 10, 10), dtype="float32"), - R.Tensor((3, 1, 10, 10), dtype="float32"), - ) = R.split(input_1, indices_or_sections=3, axis=1) - lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0] - lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) - lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1] - lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[1]) - lv5: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[2] - lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[1]) - lv7: R.Tuple( - R.Tensor((3, 10, 10), dtype="float32"), - R.Tensor((3, 10, 10), dtype="float32"), - R.Tensor((3, 10, 10), dtype="float32"), - ) = (lv2, lv4, lv6) - lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] - lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] - lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] + lv: R.Tensor((3, 1, 10, 10), dtype="float32") = R.strided_slice( + data, + (R.prim_value(1),), + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = R.strided_slice( + data, + (R.prim_value(1),), + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv2: R.Tensor((3, 1, 10, 10), dtype="float32") = R.strided_slice( + data, + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(3),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv3: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv, axis=[1]) + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) + lv5: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv2, axis=[1]) gv: R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), - ) = (lv8, lv9, lv10) + ) = (lv3, lv4, lv5) R.output(gv) return gv example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) - verify_model(Chunk(), example_args, {}, Expected) + verify_model(Chunk(), example_args, {}, Expected, run_ep_decomposition=True) example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),) - verify_model(Unbind1(), example_args, {}, expected1) - verify_model(Unbind2(), example_args, {}, expected2) + verify_model(Unbind1(), example_args, {}, expected1, run_ep_decomposition=True) + verify_model(Unbind2(), example_args, {}, expected2, run_ep_decomposition=True) def test_squeeze():