From 5af7f1846dee1fed1f2c8d442db51ea212ebe183 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 22 Jul 2024 17:42:56 +0900 Subject: [PATCH 1/3] add testcase --- tests/python/relax/test_frontend_from_fx.py | 44 +++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index dd2719f8ce91..9e8976b0962d 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -650,6 +650,50 @@ def main( ) +def test_einsum(): + class Einsum1(Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.einsum("ii", x) + + class Einsum2(Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.einsum('i,j->ij', x, y) + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((4, 4), dtype="float32") + ) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.einsum((inp_0,), subscripts="ii") + gv: R.Tensor((), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((5,), dtype="float32"), + inp_1: R.Tensor((4,), dtype="float32") + ) -> R.Tensor((5, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((5, 4), dtype="float32") = R.einsum((inp_0, inp_1), subscripts="i,j->ij") + gv: R.Tensor((5, 4), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Einsum1(), [([4, 4], "float32")], {}, Expected1) + verify_model(Einsum2(), [([5], "float32"), ([4], "float32")], {}, Expected2) + + def test_relu(): class ReLU0(Module): def __init__(self): From 93f2ad7681002cadfa1ac72d71ffbcd36a58e8e6 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 22 Jul 2024 17:43:33 +0900 Subject: [PATCH 2/3] add support for torch.einsum --- python/tvm/relax/frontend/torch/fx_translator.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 5ed0f18deb9e..ab3198f1d0a7 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -518,6 +518,14 @@ def _baddbmm(self, node: fx.node.Node) -> relax.Var: res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) return res + def _einsum(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.einsum(tuple(args[1]), args[0])) + return self.block_builder.emit(relax.op.einsum(args[1:], args[0])) + ########## Manipulation ########## def _cat(self, node: fx.node.Node) -> relax.Var: @@ -1478,6 +1486,7 @@ def create_convert_map(self): "max": self._max, "cross_entropy": self._cross_entropy, "scaled_dot_product_attention": self._scaled_dot_product_attention, + "einsum": self._einsum, } def update_convert_map(self, custom_convert_map: dict): From 56778a9737030674580861b250d17cfca9a269e2 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 22 Jul 2024 19:05:59 +0900 Subject: [PATCH 3/3] format --- tests/python/relax/test_frontend_from_fx.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 9e8976b0962d..be7edc913cc2 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -663,14 +663,12 @@ def __init__(self): super().__init__() def forward(self, x, y): - return torch.einsum('i,j->ij', x, y) + return torch.einsum("i,j->ij", x, y) @tvm.script.ir_module class Expected1: @R.function - def main( - inp_0: R.Tensor((4, 4), dtype="float32") - ) -> R.Tensor((), dtype="float32"): + def main(inp_0: R.Tensor((4, 4), dtype="float32")) -> R.Tensor((), dtype="float32"): with R.dataflow(): lv: R.Tensor((), dtype="float32") = R.einsum((inp_0,), subscripts="ii") gv: R.Tensor((), dtype="float32") = lv @@ -681,11 +679,12 @@ def main( class Expected2: @R.function def main( - inp_0: R.Tensor((5,), dtype="float32"), - inp_1: R.Tensor((4,), dtype="float32") - ) -> R.Tensor((5, 4), dtype="float32"): + inp_0: R.Tensor((5,), dtype="float32"), inp_1: R.Tensor((4,), dtype="float32") + ) -> R.Tensor((5, 4), dtype="float32"): with R.dataflow(): - lv: R.Tensor((5, 4), dtype="float32") = R.einsum((inp_0, inp_1), subscripts="i,j->ij") + lv: R.Tensor((5, 4), dtype="float32") = R.einsum( + (inp_0, inp_1), subscripts="i,j->ij" + ) gv: R.Tensor((5, 4), dtype="float32") = lv R.output(gv) return gv