Skip to content

[Relax][PyTorch] Add 3D interpolate support using resize3d#18937

Merged
tlopex merged 1 commit into
apache:mainfrom
nirdesh17:pytorch-interpolate
Mar 26, 2026
Merged

[Relax][PyTorch] Add 3D interpolate support using resize3d#18937
tlopex merged 1 commit into
apache:mainfrom
nirdesh17:pytorch-interpolate

Conversation

@nirdesh17

Copy link
Copy Markdown

Adds support for torch.nn.functional.interpolate 3D mode in Relax frontend.

  • Handles 5D inputs (NCDHW)
  • Maps to relax.op.image.resize3d
  • Ensures correct layout handling
  • Adds tests for scale_factor and size cases

All tests pass locally.

part of #18928

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Relax frontend's capability by integrating support for 3D interpolation from PyTorch's torch.nn.functional.interpolate. It specifically addresses 5D tensor inputs (NCDHW layout), translating these operations into relax.op.image.resize3d calls. This expansion allows for more robust and accurate conversion of deep learning models involving 3D data processing, ensuring proper handling of spatial dimensions and interpolation methods within the Relax framework.

Highlights

  • 3D Interpolation Support: Added support for torch.nn.functional.interpolate 3D mode in the Relax frontend.
  • 5D Input Handling: Implemented handling for 5D inputs (NCDHW format) for 3D interpolation.
  • Mapping to resize3d: Mapped 3D interpolation operations to relax.op.image.resize3d.
  • Layout Management: Ensured correct layout handling, specifically for "NCDHW" in 3D cases.
  • Comprehensive Testing: Included new tests covering scale_factor (single and tuple) and size parameters for 3D interpolation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for trilinear interpolation for 5D tensors in the PyTorch frontend by introducing relax.op.image.resize3d calls and updating the interpolation method parsing. New test cases have been added to validate this functionality. The review comments highlight a potential inconsistency in layout handling for 5D tensors when self.default_image_layout is "NHWC", which could lead to incorrect behavior, and suggest adding a test case to expose this. Additionally, it is recommended to refactor the newly added test cases to reduce code duplication and improve clarity, possibly by using parameterized tests and more generic input parameter names.

Comment on lines +562 to +571
if len(shape) == 5:
return self.block_builder.emit(
relax.op.image.resize3d(
data,
size,
layout="NCDHW",
method=method,
coordinate_transformation_mode=coord_trans,
)
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a potential inconsistency in layout handling for 5D tensors. The size of the output tensor is calculated based on self.default_image_layout (lines 530-535). If self.default_image_layout is "NHWC", the size calculation will assume an NDHWC-like layout for 5D tensors.

However, this block correctly hardcodes the layout to "NCDHW" for relax.op.image.resize3d, as PyTorch's interpolate for 5D inputs only supports NCDHW. This creates a mismatch between how size is calculated and how resize3d is called, which will lead to incorrect behavior when self.default_image_layout is "NHWC".

A robust solution would involve adjusting the size calculation logic to handle 5D tensors specifically, assuming an NCDHW layout regardless of self.default_image_layout. Since that logic is outside the current diff, I'm pointing out the issue here. Adding a test case for a 5D input with default_image_layout="NHWC" would help expose this issue.

Comment on lines +3674 to +3779
class Interpolate5(Module):
def forward(self, input):
return torch.nn.functional.interpolate(
input,
size=None,
scale_factor=(2.0),
mode="trilinear",
align_corners=False,
)
@tvm.script.ir_module
class expected5:
@R.function
def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> R.Tensor(
(1, 3, 8, 20, 20), dtype="float32"
):

with R.dataflow():
lv: R.Tensor((1, 3, 8, 20, 20), dtype="float32") = R.image.resize3d(
input_5,
(8, 20, 20),
roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
layout="NCDHW",
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="",
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="",
)
gv: R.Tensor((1, 3, 8, 20, 20), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Interpolate5(), [([1, 3, 4, 10, 10], "float32")], {}, expected5)

class Interpolate6(Module):
def forward(self, input):
return torch.nn.functional.interpolate(
input,
size=None,
scale_factor=(2.0,4.0,4.0),
mode="trilinear",
align_corners=False,
)
@tvm.script.ir_module
class expected6:
@R.function
def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> R.Tensor(
(1, 3, 8, 40, 40), dtype="float32"
):

with R.dataflow():
lv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = R.image.resize3d(
input_5,
(8, 40, 40),
roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
layout="NCDHW",
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="",
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="",
)
gv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Interpolate6(), [([1, 3, 4, 10, 10], "float32")], {}, expected6)

class Interpolate7(Module):
def forward(self, input):
return torch.nn.functional.interpolate(
input,
size=(8,40,40),
mode="trilinear",
align_corners=False,
)
@tvm.script.ir_module
class expected7:
@R.function
def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> R.Tensor(
(1, 3, 8, 40, 40), dtype="float32"
):

with R.dataflow():
lv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = R.image.resize3d(
input_5,
(8, 40, 40),
roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
layout="NCDHW",
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="",
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="",
)
gv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Interpolate7(), [([1, 3, 4, 10, 10], "float32")], {}, expected7)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These new tests for 3D interpolation (Interpolate5, Interpolate6, Interpolate7) are great for coverage. However, they contain a lot of duplicated code.

Some suggestions to improve maintainability:

  • Consider refactoring them into a single parameterized test using pytest.mark.parametrize. You could parameterize the torch.nn.functional.interpolate arguments, the input shape, and the expected IR module for each case.
  • The input parameter is named input_5 in expected5, expected6, and expected7. While this doesn't affect test correctness, it could be confusing. Consider using a more generic name like inp or input_tensor for clarity.

Applying these changes would make the test more concise and easier to maintain and extend in the future.

@tlopex tlopex left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few things to fix here:

  • layout="NCDHW" is hardcoded in the 5D path, so it ignores self.default_image_layout. The 2D path already respects the configurable layout; the 3D path should do the same, e.g. map NHWC to NDHWC.
  • shape = self.shape_of(data) was moved outside the if size is None guard, but len(shape) is now used unconditionally. If struct_info.shape is None, this can crash. Using data.struct_info.ndim would be safer.
  • method[0:2] == "bi" / method[0:3] == "tri" is a bit inconsistent with the nearest branch, which uses startswith. Better to use startswith consistently.
  • Test coverage also looks incomplete: there should be a case for align_corners=True with trilinear, and a 5D test with a non-default layout to catch the hardcoded layout issue.

Signed-off-by: nirdesh17 <nirdeshdevadiya17@gmail.com>
@nirdesh17 nirdesh17 force-pushed the pytorch-interpolate branch from e1ace83 to 73551f6 Compare March 26, 2026 10:08
@nirdesh17

Copy link
Copy Markdown
Author

@tlopex I’ve addressed all the requested changes:

  • Fixed layout handling for 5D (respecting default_image_layout, including NDHWC)
  • Used struct_info.ndim instead of relying on shape
  • Cleaned up method handling with startswith
  • Added tests for align_corners=True and 5D NHWC layout

All tests are passing locally. Please take another look.

@tlopex tlopex left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you

@tlopex tlopex merged commit bf6ed31 into apache:main Mar 26, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants