diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 67d93b066972..cbf9e33a126f 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -809,9 +809,16 @@ def create_convert_map( "cosh.default": self._unary_op(relax.op.cosh), "dropout.default": lambda node: self.env[node.args[0]], "dropout_.default": lambda node: self.env[node.args[0]], + "native_dropout.default": lambda node: self.env[node.args[0]], "elu.default": self._elu, "erf.default": self._unary_op(relax.op.erf), "exp.default": self._unary_op(relax.op.exp), + "expm1.default": lambda node: self.block_builder.emit( + relax.op.subtract( + relax.op.exp(self.env[node.args[0]]), + relax.const(1.0, self.env[node.args[0]].struct_info.dtype), + ) + ), "floor.default": self._unary_op(relax.op.floor), "gelu.default": self._gelu, "hardsigmoid.default": self._hardsigmoid, @@ -869,6 +876,7 @@ def create_convert_map( "bitwise_or.Scalar": self._binary_op(relax.op.bitwise_or, operator.or_), "bitwise_or_.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_), "bitwise_or.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_), + "div.Scalar": self._binary_op(relax.op.divide, operator.truediv), "div.Tensor": self._binary_op(relax.op.divide, operator.truediv), "div.Tensor_mode": self._div, "eq.Scalar": self._binary_op(relax.op.equal, operator.eq), @@ -1019,7 +1027,9 @@ def create_convert_map( "detach_.default": self._detach, "contiguous.default": lambda node: self.env[node.args[0]], # no-op "clone.default": lambda node: self.env[node.args[0]], + "bernoulli.p": lambda node: self.env[node.args[0]], # Dropout: just return input "empty.memory_format": self._empty, + "empty_permuted.default": self._empty, # Similar to empty with permuted layout "empty_like.default": self._empty_like, "eye.default": self._eye, "eye.m": self._eye, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 657ade455bd7..338214156708 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -31,9 +31,11 @@ from tvm.relax.frontend.torch import from_exported_program -def verify_model(torch_model, example_args, binding, expected, dynamic_shapes=None): +def verify_model( + torch_model, example_args, binding, expected, dynamic_shapes=None, run_ep_decomposition=False +): exported_program = export(torch_model, args=example_args, dynamic_shapes=dynamic_shapes) - mod = from_exported_program(exported_program) + mod = from_exported_program(exported_program, run_ep_decomposition=run_ep_decomposition) binding = {k: tvm.runtime.tensor(v) for k, v in binding.items()} expected = relax.transform.BindParams("main", binding)(expected) @@ -155,26 +157,19 @@ def main( ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) - lv_div: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( lv, R.const(1.0, "float32") ) - lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( - lv_div, R.const(1.0, "float32") - ) - lv_min: R.Tensor((1, 3, 10, 10), dtype="float32") = R.minimum( - R.const(0.0, "float32"), lv_sub - ) - lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( - R.const(1.0, "float32"), lv_min + lv2: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater( + input_1, R.const(0.0, "float32") ) - lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) - lv_celu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_scaled, lv_relu_x) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_celu,) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv2, input_1, lv1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,) R.output(gv) return gv - verify_model(Celu1(), example_args, {}, expected_celu) - verify_model(Celu2(), example_args, {}, expected_celu) + verify_model(Celu1(), example_args, {}, expected_celu, run_ep_decomposition=True) + verify_model(Celu2(), example_args, {}, expected_celu, run_ep_decomposition=True) # clamp class Clamp(Module): @@ -197,7 +192,7 @@ def main( R.output(gv) return gv - verify_model(Clamp(), example_args, {}, expected_clamp) + verify_model(Clamp(), example_args, {}, expected_clamp, run_ep_decomposition=True) class ClampMinOnly(Module): def forward(self, input): @@ -217,7 +212,9 @@ def main( R.output(gv) return gv - verify_model(ClampMinOnly(), example_args, {}, expected_clamp_min_only) + verify_model( + ClampMinOnly(), example_args, {}, expected_clamp_min_only, run_ep_decomposition=True + ) class ClampTensors(Module): def forward(self, input): @@ -245,7 +242,9 @@ def main( R.output(gv) return gv - verify_model(ClampTensors(), example_args, {}, expected_clamp_tensors) + verify_model( + ClampTensors(), example_args, {}, expected_clamp_tensors, run_ep_decomposition=True + ) # dropout @@ -266,20 +265,44 @@ def forward(self, input): return torch.ops.aten.dropout_(input, 0.5, train=True) @tvm.script.ir_module - class expected_dropout: + class expected_dropout_for_1_2: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (input_1,) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (input,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_dropout_for_3: + @R.function + def main( + input: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32") + ): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.zeros( + R.shape([1, 3, 10, 10]), dtype="float32" + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv, R.const(0.5, "float32") + ) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(input, lv1) + gv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + ) = (lv2, lv2) R.output(gv) return gv - verify_model(Dropout1(), example_args, {}, expected_dropout) - verify_model(Dropout2(), example_args, {}, expected_dropout) - verify_model(Dropout3(), example_args, {}, expected_dropout) + verify_model(Dropout1(), example_args, {}, expected_dropout_for_1_2, run_ep_decomposition=True) + verify_model(Dropout2(), example_args, {}, expected_dropout_for_1_2, run_ep_decomposition=True) + verify_model(Dropout3(), example_args, {}, expected_dropout_for_3, run_ep_decomposition=True) # elu class Elu(Module): @@ -298,28 +321,32 @@ def forward(self, input): class expected_elu: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 with R.dataflow(): - lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) - lv_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( - R.const(1.0, dtype="float32"), lv_exp + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater( + input, R.const(0.0, "float32") ) - lv_relu_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu( - lv_one_minus_exp + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + input, R.const(1.0, "float32") ) - lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( - R.const(-1.0, dtype="float32"), lv_relu_one_minus_exp + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + input, R.const(1.0, "float32") ) - lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) - lv_elu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_scaled, lv_relu_x) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_elu,) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(lv2) + lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( + lv3, R.const(1.0, "float32") + ) + lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + lv4, R.const(1.0, "float32") + ) + lv6: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv, lv1, lv5) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv6,) R.output(gv) return gv - verify_model(Elu(), example_args, {}, expected_elu) - verify_model(Elu2(), example_args, {}, expected_elu) + verify_model(Elu(), example_args, {}, expected_elu, run_ep_decomposition=True) + verify_model(Elu2(), example_args, {}, expected_elu, run_ep_decomposition=True) # hardsigmoid class Hardsigmoid(torch.nn.Module): @@ -341,17 +368,24 @@ def main( inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) - lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6) - lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( - lv1, R.const(6, "float32") + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add( + inp_0, R.const(3.0, "float32") + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + lv, R.prim_value(0), R.prim_value(T.float64("inf")) + ) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + lv1, R.prim_value(T.float64("-inf")), R.prim_value(6) + ) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv2, R.const(6.0, "float32") ) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv2,) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,) R.output(gv) return gv - verify_model(Hardsigmoid(), example_args, {}, expected_hardsigmoid) - verify_model(Hardsigmoid2(), example_args, {}, expected_hardsigmoid) + verify_model(Hardsigmoid(), example_args, {}, expected_hardsigmoid, run_ep_decomposition=True) + verify_model(Hardsigmoid2(), example_args, {}, expected_hardsigmoid, run_ep_decomposition=True) # hardwish class Hardswish(torch.nn.Module): @@ -371,25 +405,67 @@ def forward(self, input): return torch.ops.aten.hardswish_(input) @tvm.script.ir_module - class expected1: + class expected_hardswish_for_1_2: @R.function def main( inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) - lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6) - lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( - lv1, R.const(6, "float32") + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add( + inp_0, R.const(3.0, "float32") + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + lv, R.prim_value(0), R.prim_value(T.float64("inf")) + ) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + lv1, R.prim_value(T.float64("-inf")), R.prim_value(6) ) lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(inp_0, lv2) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,) + lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv3, R.const(6.0, "float32") + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv4,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_hardswish_for_3: + @R.function + def main( + input: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32") + ): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add( + input, R.const(3.0, "float32") + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + lv, R.prim_value(0), R.prim_value(T.float64("inf")) + ) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + lv1, R.prim_value(T.float64("-inf")), R.prim_value(6) + ) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(input, lv2) + lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv3, R.const(6.0, "float32") + ) + gv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + ) = (lv4, lv4) R.output(gv) return gv - verify_model(Hardswish(), example_args, {}, expected1) - verify_model(Hardswish2(), example_args, {}, expected1) - verify_model(Hardswish3(), example_args, {}, expected1) + verify_model( + Hardswish(), example_args, {}, expected_hardswish_for_1_2, run_ep_decomposition=True + ) + verify_model( + Hardswish2(), example_args, {}, expected_hardswish_for_1_2, run_ep_decomposition=True + ) + verify_model( + Hardswish3(), example_args, {}, expected_hardswish_for_3, run_ep_decomposition=True + ) # log2 class Log2(Module): @@ -411,7 +487,7 @@ def main( R.output(gv) return gv - verify_model(Log2(), example_args, {}, Expected_log2) + verify_model(Log2(), example_args, {}, Expected_log2, run_ep_decomposition=True) # log10 class Log10(Module): @@ -433,7 +509,7 @@ def main( R.output(gv) return gv - verify_model(Log10(), example_args, {}, Expected_log10) + verify_model(Log10(), example_args, {}, Expected_log10, run_ep_decomposition=True) # log1p class Log1p(Module): @@ -454,7 +530,7 @@ def main( R.output(gv) return gv - verify_model(Log1p(), example_args, {}, Expected_log1p) + verify_model(Log1p(), example_args, {}, Expected_log1p, run_ep_decomposition=True) # reciprocal class Reciprocal(Module): @@ -475,7 +551,7 @@ def main( R.output(gv) return gv - verify_model(Reciprocal(), example_args, {}, expected_reciprocal) + verify_model(Reciprocal(), example_args, {}, expected_reciprocal, run_ep_decomposition=True) # Returns the maximum value of all elements in the input tensor. class MaxModel(Module): @@ -494,7 +570,7 @@ def main( R.output(gv) return gv - verify_model(MaxModel(), example_args, {}, expected_max) + verify_model(MaxModel(), example_args, {}, expected_max, run_ep_decomposition=True) # Returns the minimum value of all elements in the input tensor. class MinModel(Module): @@ -513,7 +589,7 @@ def main( R.output(gv) return gv - verify_model(MinModel(), example_args, {}, expected_min) + verify_model(MinModel(), example_args, {}, expected_min, run_ep_decomposition=True) # relu6 class ReLU6_1(torch.nn.Module): @@ -558,9 +634,28 @@ def main( R.output(gv) return gv - verify_model(ReLU6_1(), example_args, {}, expected_relu6_1) - verify_model(ReLU6_2(), example_args, {}, expected_relu6_2) - verify_model(ReLU6_3(), example_args, {}, expected_relu6_2) + @tvm.script.ir_module + class expected_relu6_3: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32") + ): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + x, R.prim_value(0), R.prim_value(6) + ) + gv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + ) = (lv, lv) + R.output(gv) + return gv + + verify_model(ReLU6_1(), example_args, {}, expected_relu6_1, run_ep_decomposition=True) + verify_model(ReLU6_2(), example_args, {}, expected_relu6_2, run_ep_decomposition=True) + verify_model(ReLU6_3(), example_args, {}, expected_relu6_3, run_ep_decomposition=True) def test_hardtanh():