From 73551f6076a5a3c56ef025d673280397993133ad Mon Sep 17 00:00:00 2001 From: nirdesh17 Date: Thu, 26 Mar 2026 15:38:01 +0530 Subject: [PATCH] [Relax][PyTorch] Add interpolate 3D support Signed-off-by: nirdesh17 --- .../tvm/relax/frontend/torch/fx_translator.py | 40 +++- tests/python/relax/test_frontend_from_fx.py | 222 ++++++++++++++++++ 2 files changed, 251 insertions(+), 11 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index e7fcb0c202b9..c81768f6d946 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -527,7 +527,7 @@ def _interpolate(self, node: fx.Node) -> relax.Var: # Determine spatial dimension indices based on layout # NCHW: spatial dims are [2, 3, ...] (skip batch and channel) # NHWC: spatial dims are [1, 2, ...] (skip batch, before channel) - if self.default_image_layout == "NHWC": + if self.default_image_layout in ("NHWC", "NDHWC"): spatial_start = 1 spatial_end = len(shape) - 1 else: # NCHW or other layouts @@ -547,25 +547,43 @@ def _interpolate(self, node: fx.Node) -> relax.Var: if method.startswith("nearest"): method = "nearest_neighbor" - elif method[0:2] == "bi": + elif method.startswith("bi"): method = method[2:] + elif method.startswith("tri"): + method = method[3:] if method == "nearest_neighbor": coord_trans = "asymmetric" - elif align_corners: + elif align_corners is True: coord_trans = "align_corners" else: coord_trans = "half_pixel" - return self.block_builder.emit( - relax.op.image.resize2d( - data, - size, - layout=self.default_image_layout, - method=method, - coordinate_transformation_mode=coord_trans, + if data.struct_info.ndim == 5: + if self.default_image_layout == "NDHWC": + layout_3d = "NDHWC" + else: + layout_3d = "NCDHW" + + return self.block_builder.emit( + relax.op.image.resize3d( + data, + size, + layout=layout_3d, + method=method, + coordinate_transformation_mode=coord_trans, + ) + ) + else: + return self.block_builder.emit( + relax.op.image.resize2d( + data, + size, + layout=self.default_image_layout, + method=method, + coordinate_transformation_mode=coord_trans, + ) ) - ) def _linear_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 102b8a0f420f..4d9060bf720e 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3671,6 +3671,149 @@ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tensor( verify_model(Interpolate4(), input_info, {}, expected4) + input_info_5d = [([1, 3, 4, 10, 10], "float32")] + class Interpolate5(Module): + def forward(self, input): + return torch.nn.functional.interpolate( + input, + size=None, + scale_factor=(2.0, 2.0, 2.0), + mode="trilinear", + align_corners=False, + ) + @tvm.script.ir_module + class expected5: + @R.function + def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> R.Tensor( + (1, 3, 8, 20, 20), dtype="float32" + ): + + with R.dataflow(): + lv: R.Tensor((1, 3, 8, 20, 20), dtype="float32") = R.image.resize3d( + input_5, + (8, 20, 20), + roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000], + layout="NCDHW", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="", + cubic_alpha=-0.75, + cubic_exclude=0, + extrapolation_value=0, + out_dtype="", + ) + gv: R.Tensor((1, 3, 8, 20, 20), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Interpolate5(), input_info_5d, {}, expected5) + + class Interpolate6(Module): + def forward(self, input): + return torch.nn.functional.interpolate( + input, + size=None, + scale_factor=(2.0,4.0,4.0), + mode="trilinear", + align_corners=False, + ) + @tvm.script.ir_module + class expected6: + @R.function + def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> R.Tensor( + (1, 3, 8, 40, 40), dtype="float32" + ): + + with R.dataflow(): + lv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = R.image.resize3d( + input_5, + (8, 40, 40), + roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000], + layout="NCDHW", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="", + cubic_alpha=-0.75, + cubic_exclude=0, + extrapolation_value=0, + out_dtype="", + ) + gv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Interpolate6(), input_info_5d, {}, expected6) + + class Interpolate7(Module): + def forward(self, input): + return torch.nn.functional.interpolate( + input, + size=(8,40,40), + mode="trilinear", + align_corners=False, + ) + @tvm.script.ir_module + class expected7: + @R.function + def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> R.Tensor( + (1, 3, 8, 40, 40), dtype="float32" + ): + + with R.dataflow(): + lv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = R.image.resize3d( + input_5, + (8, 40, 40), + roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000], + layout="NCDHW", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="", + cubic_alpha=-0.75, + cubic_exclude=0, + extrapolation_value=0, + out_dtype="", + ) + gv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Interpolate7(), input_info_5d, {}, expected7) + + class Interpolate8(Module): + def forward(self, input): + return torch.nn.functional.interpolate( + input, + size=(8,40,40), + mode="trilinear", + align_corners=True, + ) + @tvm.script.ir_module + class expected8: + @R.function + def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> R.Tensor( + (1, 3, 8, 40, 40), dtype="float32" + ): + + with R.dataflow(): + lv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = R.image.resize3d( + input_5, + (8, 40, 40), + roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000], + layout="NCDHW", + method="linear", + coordinate_transformation_mode="align_corners", + rounding_method="", + cubic_alpha=-0.75, + cubic_exclude=0, + extrapolation_value=0, + out_dtype="", + ) + gv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Interpolate8(), input_info_5d, {}, expected8) + def test_interpolate_nhwc_layout(): # First verify backward compatibility - default should still be NCHW @@ -3786,6 +3929,85 @@ def main(input_1: R.Tensor((1, 10, 10, 3), dtype="float32")) -> R.Tensor( mod2 = from_fx(graph_model2, input_info, default_image_layout="NHWC") tvm.ir.assert_structural_equal(mod2, expected_nhwc2) + input_info_5d = [([1, 4, 10, 10, 3], "float32")] + + class InterpolateNHWC3(Module): + def forward(self, input): + return torch.nn.functional.interpolate( + input, + size=None, + scale_factor=(2.0,4.0,4.0), + mode="trilinear", + align_corners=False, + ) + @tvm.script.ir_module + class expected_nhwc3: + @R.function + def main(input_5: R.Tensor((1, 4, 10, 10, 3), dtype="float32")) -> R.Tensor( + (1, 8, 40, 40, 3), dtype="float32" + ): + + with R.dataflow(): + lv: R.Tensor((1, 8, 40, 40, 3), dtype="float32") = R.image.resize3d( + input_5, + (8, 40, 40), + roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000], + layout="NDHWC", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="", + cubic_alpha=-0.75, + cubic_exclude=0, + extrapolation_value=0, + out_dtype="", + ) + gv: R.Tensor((1, 8, 40, 40, 3), dtype="float32") = lv + R.output(gv) + return gv + + graph_model3 = fx.symbolic_trace(InterpolateNHWC3()) + with torch.no_grad(): + mod3 = from_fx(graph_model3, input_info_5d, default_image_layout="NDHWC") + tvm.ir.assert_structural_equal(mod3, expected_nhwc3) + + class InterpolateNHWC4(Module): + def forward(self, input): + return torch.nn.functional.interpolate( + input, + size=None, + scale_factor=(2.0,4.0,4.0), + mode="trilinear", + align_corners=True, + ) + @tvm.script.ir_module + class expected_nhwc4: + @R.function + def main(input_5: R.Tensor((1, 4, 10, 10, 3), dtype="float32")) -> R.Tensor( + (1, 8, 40, 40, 3), dtype="float32" + ): + + with R.dataflow(): + lv: R.Tensor((1, 8, 40, 40, 3), dtype="float32") = R.image.resize3d( + input_5, + (8, 40, 40), + roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000], + layout="NDHWC", + method="linear", + coordinate_transformation_mode="align_corners", + rounding_method="", + cubic_alpha=-0.75, + cubic_exclude=0, + extrapolation_value=0, + out_dtype="", + ) + gv: R.Tensor((1, 8, 40, 40, 3), dtype="float32") = lv + R.output(gv) + return gv + + graph_model4 = fx.symbolic_trace(InterpolateNHWC4()) + with torch.no_grad(): + mod4 = from_fx(graph_model4, input_info_5d, default_image_layout="NDHWC") + tvm.ir.assert_structural_equal(mod4, expected_nhwc4) def test_addmm(): input_info = [