diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 00fa9f597d06..aa0217db046c 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2325,6 +2325,11 @@ def nll_loss(self, inputs, input_types): weights = _op.full(_expr.const(1), (num_class,), dtype=input_types[0]) return _op.nn.nll_loss(predictions, targets, weights, reduction, ignore_index) + def flip(self, inputs, input_types): + data = inputs[0] + axis = inputs[1] + return _op.transform.reverse(data, axis=axis[0]) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2539,6 +2544,7 @@ def create_convert_map(self): "aten::_unique2": self.unique, "aten::nll_loss": self.nll_loss, "aten::nll_loss2d": self.nll_loss, + "aten::flip": self.flip, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 2ec281094080..f76ea9a5d324 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3893,6 +3893,25 @@ def test_forward_nll_loss(): verify_model(torch.nn.NLLLoss(reduction="none").eval(), input_data=[predictions, targets]) +@tvm.testing.uses_gpu +def test_forward_flip(): + torch.set_grad_enabled(False) + + class Flip(Module): + def __init__(self, axis=0): + super().__init__() + self.axis = axis + + def forward(self, x): + return x.flip([self.axis]) + + input = torch.randn(2, 3, 4) + verify_model(Flip(axis=0), input_data=input) + verify_model(Flip(axis=1), input_data=input) + verify_model(Flip(axis=2), input_data=input) + verify_model(Flip(axis=-1), input_data=input) + + if __name__ == "__main__": # some structural tests test_forward_traced_function() @@ -4035,6 +4054,7 @@ def test_forward_nll_loss(): test_hard_swish() test_hard_sigmoid() test_forward_nll_loss() + test_forward_flip() # Model tests test_resnet18()