Skip to content

fix: Support 5D volumetric inputs in ONNX GridSample frontend converter#19816

Merged
tlopex merged 3 commits into
apache:mainfrom
mvanhorn:fix/19688-onnx-gridsample-5d
Jun 18, 2026
Merged

fix: Support 5D volumetric inputs in ONNX GridSample frontend converter#19816
tlopex merged 3 commits into
apache:mainfrom
mvanhorn:fix/19688-onnx-gridsample-5d

Conversation

@mvanhorn

Copy link
Copy Markdown
Contributor

Summary

The Relax ONNX frontend's GridSample._impl_v16 converter unconditionally permutes the grid from ONNX [N,H,W,2] to TVM [N,2,H,W] and calls image.grid_sample with layout="NCHW". For 5D volumetric inputs ([N,C,D,H,W] with grid [N,D,H,W,3]) this crashes at permute_dims with an InternalError ('PermuteDims expects the number of input axes to equal the ndim of the input tensor.

Changes

In GridSample._impl_v16, read data.struct_info.ndim and dispatch on rank. For ndim==4, keep the existing permute_dims(grid,[0,3,1,2]) + grid_sample(layout="NCHW").

Fixes #19688

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request extends the ONNX frontend's GridSample operator to support 5D inputs (NCDHW layout) in addition to the existing 4D inputs (NCHW layout), and adds corresponding unit tests. The feedback suggests using the existing helper function _get_known_tensor_rank(data) to determine the input's dimensionality more robustly and idiomatically, rather than manually checking attributes on struct_info.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +4997 to +5000
if hasattr(data.struct_info, "ndim"):
ndim = data.struct_info.ndim
else:
ndim = len(data.struct_info.shape)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of manually checking hasattr(data.struct_info, "ndim") and falling back to len(data.struct_info.shape), it is more robust and idiomatic to use the existing helper function _get_known_tensor_rank(data). This helper is used throughout the codebase and safely handles various expression types (e.g., relax.Constant, relax.ShapeExpr, relax.PrimValue, and relax.TensorStructInfo) while avoiding potential AttributeError or TypeError if struct_info or shape is not fully defined.

        ndim = _get_known_tensor_rank(data)

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for 5D inputs in the ONNX GridSample operator, mapping it to the NCDHW layout and permuting the grid dimensions accordingly, along with corresponding unit tests. Feedback suggests using the helper function _get_known_tensor_rank(data) instead of directly accessing data.struct_info to safely retrieve the input tensor's rank and avoid potential AttributeErrors.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +4997 to +5000
if hasattr(data.struct_info, "ndim"):
ndim = data.struct_info.ndim
else:
ndim = len(data.struct_info.shape)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Directly accessing data.struct_info can lead to an AttributeError if struct_info is None. It is safer and more consistent with the rest of the codebase to use the helper function _get_known_tensor_rank(data) to retrieve the rank of the input tensor.

Suggested change
if hasattr(data.struct_info, "ndim"):
ndim = data.struct_info.ndim
else:
ndim = len(data.struct_info.shape)
ndim = _get_known_tensor_rank(data)
if ndim is None:
raise ValueError("GridSample requires a statically known input rank.")

The ONNX frontend dispatches 5D GridSample to the relax grid_sample op with
layout=NCDHW, and TOPI already implements the 3D compute, but
InferStructInfoGridSample hardcoded NCHW so 5D inputs hit a fatal layout
error during StructInfo inference. Branch on the NCDHW layout and derive the
output spatial extents from grid->values[2:], mirroring the existing
Resize3D inference. The 2D NCHW path is unchanged.

@tlopex tlopex left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! @mvanhorn
I found two issues worth addressing before merge:

  1. 5D mode="cubic" is accepted by the ONNX frontend but is not supported by TOPI.

GridSample._impl_v16 translates ONNX mode="cubic" to method="bicubic" and then sends 5D inputs to relax.op.image.grid_sample(..., layout="NCDHW"). However, the TOPI 3D implementation only supports ("bilinear", "nearest"):

assert method in ("bilinear", "nearest"), f"{method} is not supported"

So a valid ONNX 5D cubic GridSample model will import successfully but fail later during legalization/compile. Could we either implement 3D cubic, or explicitly reject ndim == 5 and method == "bicubic" in the frontend with a clear NotImplementedError? A regression test for this case would be good too.

  1. While touching InferStructInfoGridSample, there is an existing 4D shape inference mismatch that this PR could fix.

The ONNX frontend permutes the 4D grid from [N, H_out, W_out, 2] to TOPI layout [N, 2, H_out, W_out], but the 4D inference path still reads:

out_tgt_shape.Set(2, grid_shape->values[1]);
out_tgt_shape.Set(3, grid_shape->values[2]);

For a non-square output grid like [N, 2, 3, 5], this infers [N, C, 2, 3] instead of [N, C, 3, 5]. The current 4D test uses [1, 2, 2, 2], so it does not catch this. Since the 5D branch correctly reads the permuted spatial dims from grid_shape->values[2:], I think the 4D branch should do the same (values[2], values[3]) and add a non-square 4D test.

…rence

- onnx frontend: raise NotImplementedError for 5D mode='cubic' (TOPI 3D
  grid_sample supports only bilinear/nearest), instead of importing a model
  that fails later at legalization
- InferStructInfoGridSample: 4D branch now reads the permuted grid spatial
  dims (values[2]/values[3]) to match the frontend's NCHW permutation; fixes
  non-square output shape inference
- tests: add 5D cubic rejection test and non-square 4D output shape test
@mvanhorn

Copy link
Copy Markdown
Contributor Author

Thanks for the careful review! Both addressed in 9518915:

  1. The frontend now raises NotImplementedError for 5D mode="cubic" (4D bicubic is still fine via TOPI's 2D path), so an unsupported volumetric-cubic model fails fast at import with a clear message instead of dying later at legalization. Added a regression test.

  2. Fixed the 4D shape-inference mismatch in InferStructInfoGridSample: the 4D branch now reads the permuted spatial dims (values[2]/values[3]) to match the frontend's [N, 2, H_out, W_out] permutation, mirroring the 5D branch. Added a non-square 4D test ([1, 3, 5, 2] -> [1, 3, 3, 5]) that catches the old values[1]/values[2] bug.

I couldn't run pytest locally (no TVM build here), so I'm relying on CI for the compiled checks.

@tlopex

tlopex commented Jun 18, 2026

Copy link
Copy Markdown
Member

LGTM! Thanks for the fix!

@tlopex tlopex merged commit da52d7d into apache:main Jun 18, 2026
6 checks passed
tlopex added a commit to tlopex/tvm that referenced this pull request Jun 18, 2026
The 5D GridSample change (apache#19816) landed with a clang-format violation on
the structured binding for CheckTensorLayout, which fails the repo-wide
pre-commit lint (clang-format v20.1.8). Reformat to satisfy the hook.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug][Relax][ONNX] GridSample 5D (volumetric) input crashes the frontend

2 participants