[Relax][PyTorch] Add 3D interpolate support using resize3d#18937
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Relax frontend's capability by integrating support for 3D interpolation from PyTorch's torch.nn.functional.interpolate. It specifically addresses 5D tensor inputs (NCDHW layout), translating these operations into relax.op.image.resize3d calls. This expansion allows for more robust and accurate conversion of deep learning models involving 3D data processing, ensuring proper handling of spatial dimensions and interpolation methods within the Relax framework. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for trilinear interpolation for 5D tensors in the PyTorch frontend by introducing relax.op.image.resize3d calls and updating the interpolation method parsing. New test cases have been added to validate this functionality. The review comments highlight a potential inconsistency in layout handling for 5D tensors when self.default_image_layout is "NHWC", which could lead to incorrect behavior, and suggest adding a test case to expose this. Additionally, it is recommended to refactor the newly added test cases to reduce code duplication and improve clarity, possibly by using parameterized tests and more generic input parameter names.
| if len(shape) == 5: | ||
| return self.block_builder.emit( | ||
| relax.op.image.resize3d( | ||
| data, | ||
| size, | ||
| layout="NCDHW", | ||
| method=method, | ||
| coordinate_transformation_mode=coord_trans, | ||
| ) | ||
| ) |
There was a problem hiding this comment.
There's a potential inconsistency in layout handling for 5D tensors. The size of the output tensor is calculated based on self.default_image_layout (lines 530-535). If self.default_image_layout is "NHWC", the size calculation will assume an NDHWC-like layout for 5D tensors.
However, this block correctly hardcodes the layout to "NCDHW" for relax.op.image.resize3d, as PyTorch's interpolate for 5D inputs only supports NCDHW. This creates a mismatch between how size is calculated and how resize3d is called, which will lead to incorrect behavior when self.default_image_layout is "NHWC".
A robust solution would involve adjusting the size calculation logic to handle 5D tensors specifically, assuming an NCDHW layout regardless of self.default_image_layout. Since that logic is outside the current diff, I'm pointing out the issue here. Adding a test case for a 5D input with default_image_layout="NHWC" would help expose this issue.
| class Interpolate5(Module): | ||
| def forward(self, input): | ||
| return torch.nn.functional.interpolate( | ||
| input, | ||
| size=None, | ||
| scale_factor=(2.0), | ||
| mode="trilinear", | ||
| align_corners=False, | ||
| ) | ||
| @tvm.script.ir_module | ||
| class expected5: | ||
| @R.function | ||
| def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> R.Tensor( | ||
| (1, 3, 8, 20, 20), dtype="float32" | ||
| ): | ||
|
|
||
| with R.dataflow(): | ||
| lv: R.Tensor((1, 3, 8, 20, 20), dtype="float32") = R.image.resize3d( | ||
| input_5, | ||
| (8, 20, 20), | ||
| roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000], | ||
| layout="NCDHW", | ||
| method="linear", | ||
| coordinate_transformation_mode="half_pixel", | ||
| rounding_method="", | ||
| cubic_alpha=-0.75, | ||
| cubic_exclude=0, | ||
| extrapolation_value=0, | ||
| out_dtype="", | ||
| ) | ||
| gv: R.Tensor((1, 3, 8, 20, 20), dtype="float32") = lv | ||
| R.output(gv) | ||
| return gv | ||
|
|
||
| verify_model(Interpolate5(), [([1, 3, 4, 10, 10], "float32")], {}, expected5) | ||
|
|
||
| class Interpolate6(Module): | ||
| def forward(self, input): | ||
| return torch.nn.functional.interpolate( | ||
| input, | ||
| size=None, | ||
| scale_factor=(2.0,4.0,4.0), | ||
| mode="trilinear", | ||
| align_corners=False, | ||
| ) | ||
| @tvm.script.ir_module | ||
| class expected6: | ||
| @R.function | ||
| def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> R.Tensor( | ||
| (1, 3, 8, 40, 40), dtype="float32" | ||
| ): | ||
|
|
||
| with R.dataflow(): | ||
| lv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = R.image.resize3d( | ||
| input_5, | ||
| (8, 40, 40), | ||
| roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000], | ||
| layout="NCDHW", | ||
| method="linear", | ||
| coordinate_transformation_mode="half_pixel", | ||
| rounding_method="", | ||
| cubic_alpha=-0.75, | ||
| cubic_exclude=0, | ||
| extrapolation_value=0, | ||
| out_dtype="", | ||
| ) | ||
| gv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = lv | ||
| R.output(gv) | ||
| return gv | ||
|
|
||
| verify_model(Interpolate6(), [([1, 3, 4, 10, 10], "float32")], {}, expected6) | ||
|
|
||
| class Interpolate7(Module): | ||
| def forward(self, input): | ||
| return torch.nn.functional.interpolate( | ||
| input, | ||
| size=(8,40,40), | ||
| mode="trilinear", | ||
| align_corners=False, | ||
| ) | ||
| @tvm.script.ir_module | ||
| class expected7: | ||
| @R.function | ||
| def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> R.Tensor( | ||
| (1, 3, 8, 40, 40), dtype="float32" | ||
| ): | ||
|
|
||
| with R.dataflow(): | ||
| lv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = R.image.resize3d( | ||
| input_5, | ||
| (8, 40, 40), | ||
| roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000], | ||
| layout="NCDHW", | ||
| method="linear", | ||
| coordinate_transformation_mode="half_pixel", | ||
| rounding_method="", | ||
| cubic_alpha=-0.75, | ||
| cubic_exclude=0, | ||
| extrapolation_value=0, | ||
| out_dtype="", | ||
| ) | ||
| gv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = lv | ||
| R.output(gv) | ||
| return gv | ||
|
|
||
| verify_model(Interpolate7(), [([1, 3, 4, 10, 10], "float32")], {}, expected7) |
There was a problem hiding this comment.
These new tests for 3D interpolation (Interpolate5, Interpolate6, Interpolate7) are great for coverage. However, they contain a lot of duplicated code.
Some suggestions to improve maintainability:
- Consider refactoring them into a single parameterized test using
pytest.mark.parametrize. You could parameterize thetorch.nn.functional.interpolatearguments, the input shape, and the expected IR module for each case. - The input parameter is named
input_5inexpected5,expected6, andexpected7. While this doesn't affect test correctness, it could be confusing. Consider using a more generic name likeinporinput_tensorfor clarity.
Applying these changes would make the test more concise and easier to maintain and extend in the future.
tlopex
left a comment
There was a problem hiding this comment.
A few things to fix here:
layout="NCDHW"is hardcoded in the 5D path, so it ignoresself.default_image_layout. The 2D path already respects the configurable layout; the 3D path should do the same, e.g. mapNHWCtoNDHWC.shape = self.shape_of(data)was moved outside theif size is Noneguard, butlen(shape)is now used unconditionally. Ifstruct_info.shapeisNone, this can crash. Usingdata.struct_info.ndimwould be safer.method[0:2] == "bi"/method[0:3] == "tri"is a bit inconsistent with thenearestbranch, which usesstartswith. Better to usestartswithconsistently.- Test coverage also looks incomplete: there should be a case for
align_corners=Truewith trilinear, and a 5D test with a non-default layout to catch the hardcoded layout issue.
Signed-off-by: nirdesh17 <nirdeshdevadiya17@gmail.com>
e1ace83 to
73551f6
Compare
|
@tlopex I’ve addressed all the requested changes:
All tests are passing locally. Please take another look. |
Adds support for torch.nn.functional.interpolate 3D mode in Relax frontend.
All tests pass locally.
part of #18928