diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index b17f62738f0a..aedef8acf84c 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1722,6 +1722,9 @@ def _split(self, node: fx.Node) -> relax.Var: def _squeeze(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + # Support both "dim" and "dims" parameters + if dim is None: + dim = node.kwargs.get("dims", None) return self.block_builder.emit(relax.op.squeeze(x, dim)) def _stack(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 5bb7a9ea8bc5..48ae002c05c0 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1018,6 +1018,7 @@ def create_convert_map( "split_with_sizes.default": self._split, "squeeze.default": self._squeeze, "squeeze.dim": self._squeeze, + "squeeze.dims": self._squeeze, "stack.default": self._stack, "take.default": self._take, "tile.default": self._tile, @@ -1075,6 +1076,7 @@ def create_convert_map( # other "getitem": self._getitem, "item.default": self._item, + "_local_scalar_dense.default": self._item, } def create_input_vars( diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index ac36c3fe8fb3..019d64955857 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -5823,7 +5823,7 @@ def main( return gv example_input = torch.randn(5, 3, dtype=torch.float32) - verify_model(Cumprod(), (example_input,), {}, Expected) + verify_model(Cumprod(), (example_input,), {}, Expected, run_ep_decomposition=True) def test_where(): @@ -5849,7 +5849,7 @@ def main( x = torch.randn(5, 3, dtype=torch.float32) y = torch.randn(5, 3, dtype=torch.float32) - verify_model(Where(), (condition, x, y), {}, Expected) + verify_model(Where(), (condition, x, y), {}, Expected, run_ep_decomposition=True) def test_bucketize(): @@ -5874,7 +5874,7 @@ def main( input_tensor = torch.arange(0, 20) boundaries = torch.arange(0, 20, 2) - verify_model(Bucketize(), (input_tensor, boundaries), {}, Expected) + verify_model(Bucketize(), (input_tensor, boundaries), {}, Expected, run_ep_decomposition=True) def test_argsort(): @@ -5890,12 +5890,18 @@ def main(x: R.Tensor((5, 3), dtype="float32")) -> R.Tuple(R.Tensor((5, 3), dtype lv: R.Tensor((5, 3), dtype="int32") = R.argsort( x, axis=1, descending=True, dtype="int32" ) - gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv,) + lv1: R.Tensor((5, 3), dtype="float32") = R.gather_elements(x, lv, axis=1) + lv2: R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3), dtype="int32")) = ( + lv1, + lv, + ) + lv3: R.Tensor((5, 3), dtype="int32") = lv2[1] + gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv3,) R.output(gv) return gv example_args = (torch.randn(5, 3, dtype=torch.float32),) - verify_model(Argsort(), example_args, {}, Expected) + verify_model(Argsort(), example_args, {}, Expected, run_ep_decomposition=True) def test_topk(): @@ -5923,7 +5929,7 @@ def main( return gv example_args = (torch.randn(5, 3, dtype=torch.float32),) - verify_model(Topk(), example_args, {}, Expected) + verify_model(Topk(), example_args, {}, Expected, run_ep_decomposition=True) def test_dynamic_shape(): @@ -5972,7 +5978,7 @@ def main( return gv example_args = (torch.randn(5, 1, dtype=torch.float32),) - verify_model(BroadcastTo(), example_args, {}, Expected) + verify_model(BroadcastTo(), example_args, {}, Expected, run_ep_decomposition=True) def test_narrow(): @@ -5992,6 +5998,7 @@ def main( (R.prim_value(1),), (R.prim_value(0),), (R.prim_value(2),), + (R.prim_value(1),), assume_inbound=False, ) gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,) @@ -6000,7 +6007,7 @@ def main( return gv example_args = (torch.randn(5, 3, dtype=torch.float32),) - verify_model(Narrow(), example_args, {}, Expected) + verify_model(Narrow(), example_args, {}, Expected, run_ep_decomposition=True) def test_item(): @@ -6019,7 +6026,7 @@ def main(input: R.Tensor((1,), dtype="float32")) -> R.Tuple(R.Tensor((), dtype=" return gv example_args = (torch.randn(1, dtype=torch.float32),) - verify_model(Item(), example_args, {}, Expected) + verify_model(Item(), example_args, {}, Expected, run_ep_decomposition=True) def test_norm(): @@ -6131,7 +6138,9 @@ def main( example_args = (torch.randn(1, 3, 5, 3, dtype=torch.float32),) for (p, dim, keepdim), expected in norms: - verify_model(Norm(p, dim=dim, keepdim=keepdim), example_args, {}, expected) + verify_model( + Norm(p, dim=dim, keepdim=keepdim), example_args, {}, expected, run_ep_decomposition=True + ) def test_eye(): @@ -6146,8 +6155,20 @@ def main( input: R.Tensor((3, 5), dtype="float32") ) -> R.Tuple(R.Tensor((3, 5), dtype="float32")): with R.dataflow(): - lv: R.Tensor((3, 5), dtype="float32") = R.eye(3, 5, dtype="float32") - gv: R.Tuple(R.Tensor((3, 5), dtype="float32")) = (lv,) + lv: R.Tensor((3,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(3), R.prim_value(1), dtype="int64" + ) + lv1: R.Tensor((5,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(5), R.prim_value(1), dtype="int64" + ) + lv2: R.Tensor((3, 1), dtype="int64") = R.expand_dims(lv, axis=[-1]) + lv3: R.Tensor((3, 5), dtype="bool") = R.equal(lv2, lv1) + lv4: R.Tensor((1,), dtype="float32") = R.full( + R.shape([1]), R.const(1.0, "float32"), dtype="float32" + ) + lv5: R.Tensor((), dtype="float32") = R.const(0.0, "float32") + lv6: R.Tensor((3, 5), dtype="float32") = R.where(lv3, lv4, lv5) + gv: R.Tuple(R.Tensor((3, 5), dtype="float32")) = (lv6,) R.output(gv) return gv @@ -6162,16 +6183,28 @@ def main( input: R.Tensor((5,), dtype="float32") ) -> R.Tuple(R.Tensor((5, 5), dtype="float32")): with R.dataflow(): - lv: R.Tensor((5, 5), dtype="float32") = R.eye(5, dtype="float32") - gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv,) + lv: R.Tensor((5,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(5), R.prim_value(1), dtype="int64" + ) + lv1: R.Tensor((5,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(5), R.prim_value(1), dtype="int64" + ) + lv2: R.Tensor((5, 1), dtype="int64") = R.expand_dims(lv, axis=[-1]) + lv3: R.Tensor((5, 5), dtype="bool") = R.equal(lv2, lv1) + lv4: R.Tensor((1,), dtype="float32") = R.full( + R.shape([1]), R.const(1.0, "float32"), dtype="float32" + ) + lv5: R.Tensor((), dtype="float32") = R.const(0.0, "float32") + lv6: R.Tensor((5, 5), dtype="float32") = R.where(lv3, lv4, lv5) + gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv6,) R.output(gv) return gv example_args1 = (torch.randn(3, 5, dtype=torch.float32),) - verify_model(Eye1(), example_args1, {}, Expected1) + verify_model(Eye1(), example_args1, {}, Expected1, run_ep_decomposition=True) example_args2 = (torch.randn(5, dtype=torch.float32),) - verify_model(Eye2(), example_args2, {}, Expected2) + verify_model(Eye2(), example_args2, {}, Expected2, run_ep_decomposition=True) def test_cross_entropy(): @@ -6187,21 +6220,39 @@ def forward(self, x): @tvm.script.ir_module class Expected1: @R.function - def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32")): + def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((4,), dtype="float32")): with R.dataflow(): - lv: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(x, axis=-1) - lv1: R.Tensor((), dtype="float32") = R.nn.nll_loss( - lv, - targets=R.const([0, 1, 2, 1], dtype="int64"), - reduction="mean", - ignore_index=-100, + lv: R.Tensor((4, 3), dtype="float32") = R.astype(x, dtype="float32") + lv1: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(lv, axis=1) + lv2: R.Tensor((4,), dtype="bool") = R.not_equal( + R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64") + ) + lv3: R.Tensor((), dtype="int64") = R.const(0, "int64") + lv4: R.Tensor((4,), dtype="int64") = R.where( + lv2, R.const([0, 1, 2, 1], dtype="int64"), lv3 ) - gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,) + lv5: R.Tensor((4, 1), dtype="int64") = R.expand_dims(lv4, axis=[1]) + lv6: R.Tensor((4, 1), dtype="float32") = R.gather_elements(lv1, lv5, axis=1) + lv7: R.Tensor((4,), dtype="float32") = R.squeeze(lv6, axis=[1]) + lv8: R.Tensor((4,), dtype="float32") = R.negative(lv7) + lv9: R.Tensor((4,), dtype="bool") = R.not_equal( + R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64") + ) + lv10: R.Tensor((), dtype="float32") = R.const(0.0, "float32") + lv11: R.Tensor((4,), dtype="float32") = R.where(lv9, lv8, lv10) + lv12: R.Tensor((4,), dtype="bool") = R.not_equal( + R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64") + ) + lv13: R.Tensor((4,), dtype="bool") = R.sum(lv12, axis=[], keepdims=False) + lv14: R.Tensor((4,), dtype="float32") = R.astype(lv13, dtype="float32") + lv15: R.Tensor((4,), dtype="float32") = R.sum(lv11, axis=[], keepdims=False) + lv16: R.Tensor((4,), dtype="float32") = R.divide(lv15, lv14) + gv: R.Tuple(R.Tensor((4,), dtype="float32")) = (lv16,) R.output(gv) return gv example_args1 = (torch.randn(4, 3, dtype=torch.float32),) - verify_model(CrossEntropyModule(), example_args1, {}, Expected1) + verify_model(CrossEntropyModule(), example_args1, {}, Expected1, run_ep_decomposition=True) def test_linspace(): @@ -6216,13 +6267,24 @@ def main( input: R.Tensor((9, 9), dtype="float32") ) -> R.Tuple(R.Tensor((9,), dtype="float32")): with R.dataflow(): - lv: R.Tensor((9,), dtype="float32") = R.arange(0, 1.0625, 0.125, dtype="float32") - gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv,) + lv: R.Tensor((9,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(9), R.prim_value(1), dtype="int64" + ) + lv1: R.Tensor((9,), dtype="bool") = R.less(lv, R.const(4, "int64")) + lv2: R.Tensor((9,), dtype="float32") = R.astype(lv, dtype="float32") + lv3: R.Tensor((9,), dtype="float32") = R.multiply(lv2, R.const(0.125, "float32")) + lv4: R.Tensor((9,), dtype="float32") = R.add(lv3, R.const(0.0, "float32")) + lv5: R.Tensor((9,), dtype="int64") = R.subtract(R.const(8, "int64"), lv) + lv6: R.Tensor((9,), dtype="float32") = R.astype(lv5, dtype="float32") + lv7: R.Tensor((9,), dtype="float32") = R.multiply(lv6, R.const(0.125, "float32")) + lv8: R.Tensor((9,), dtype="float32") = R.subtract(R.const(1.0, "float32"), lv7) + lv9: R.Tensor((9,), dtype="float32") = R.where(lv1, lv4, lv8) + gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv9,) R.output(gv) return gv example_args = (torch.randn(9, 9, dtype=torch.float32),) - verify_model(Linspace(), example_args, {}, Expected) + verify_model(Linspace(), example_args, {}, Expected, run_ep_decomposition=True) @pytest.mark.parametrize( @@ -6259,7 +6321,7 @@ def main( R.output(gv) return gv - verify_model(Model(), example_args, {}, Expected) + verify_model(Model(), example_args, {}, Expected, run_ep_decomposition=True) def test_mm(): @@ -6285,7 +6347,7 @@ def main( R.output(gv) return gv - verify_model(MatrixMultiply(), example_args, {}, Expected) + verify_model(MatrixMultiply(), example_args, {}, Expected, run_ep_decomposition=True) def test_lstm():