fix: Support 5D volumetric inputs in ONNX GridSample frontend converter#19816
Conversation
There was a problem hiding this comment.
Code Review
This pull request extends the ONNX frontend's GridSample operator to support 5D inputs (NCDHW layout) in addition to the existing 4D inputs (NCHW layout), and adds corresponding unit tests. The feedback suggests using the existing helper function _get_known_tensor_rank(data) to determine the input's dimensionality more robustly and idiomatically, rather than manually checking attributes on struct_info.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if hasattr(data.struct_info, "ndim"): | ||
| ndim = data.struct_info.ndim | ||
| else: | ||
| ndim = len(data.struct_info.shape) |
There was a problem hiding this comment.
Instead of manually checking hasattr(data.struct_info, "ndim") and falling back to len(data.struct_info.shape), it is more robust and idiomatic to use the existing helper function _get_known_tensor_rank(data). This helper is used throughout the codebase and safely handles various expression types (e.g., relax.Constant, relax.ShapeExpr, relax.PrimValue, and relax.TensorStructInfo) while avoiding potential AttributeError or TypeError if struct_info or shape is not fully defined.
ndim = _get_known_tensor_rank(data)There was a problem hiding this comment.
Code Review
This pull request adds support for 5D inputs in the ONNX GridSample operator, mapping it to the NCDHW layout and permuting the grid dimensions accordingly, along with corresponding unit tests. Feedback suggests using the helper function _get_known_tensor_rank(data) instead of directly accessing data.struct_info to safely retrieve the input tensor's rank and avoid potential AttributeErrors.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if hasattr(data.struct_info, "ndim"): | ||
| ndim = data.struct_info.ndim | ||
| else: | ||
| ndim = len(data.struct_info.shape) |
There was a problem hiding this comment.
Directly accessing data.struct_info can lead to an AttributeError if struct_info is None. It is safer and more consistent with the rest of the codebase to use the helper function _get_known_tensor_rank(data) to retrieve the rank of the input tensor.
| if hasattr(data.struct_info, "ndim"): | |
| ndim = data.struct_info.ndim | |
| else: | |
| ndim = len(data.struct_info.shape) | |
| ndim = _get_known_tensor_rank(data) | |
| if ndim is None: | |
| raise ValueError("GridSample requires a statically known input rank.") |
The ONNX frontend dispatches 5D GridSample to the relax grid_sample op with layout=NCDHW, and TOPI already implements the 3D compute, but InferStructInfoGridSample hardcoded NCHW so 5D inputs hit a fatal layout error during StructInfo inference. Branch on the NCDHW layout and derive the output spatial extents from grid->values[2:], mirroring the existing Resize3D inference. The 2D NCHW path is unchanged.
tlopex
left a comment
There was a problem hiding this comment.
Thanks for the fix! @mvanhorn
I found two issues worth addressing before merge:
- 5D
mode="cubic"is accepted by the ONNX frontend but is not supported by TOPI.
GridSample._impl_v16 translates ONNX mode="cubic" to method="bicubic" and then sends 5D inputs to relax.op.image.grid_sample(..., layout="NCDHW"). However, the TOPI 3D implementation only supports ("bilinear", "nearest"):
assert method in ("bilinear", "nearest"), f"{method} is not supported"So a valid ONNX 5D cubic GridSample model will import successfully but fail later during legalization/compile. Could we either implement 3D cubic, or explicitly reject ndim == 5 and method == "bicubic" in the frontend with a clear NotImplementedError? A regression test for this case would be good too.
- While touching
InferStructInfoGridSample, there is an existing 4D shape inference mismatch that this PR could fix.
The ONNX frontend permutes the 4D grid from [N, H_out, W_out, 2] to TOPI layout [N, 2, H_out, W_out], but the 4D inference path still reads:
out_tgt_shape.Set(2, grid_shape->values[1]);
out_tgt_shape.Set(3, grid_shape->values[2]);For a non-square output grid like [N, 2, 3, 5], this infers [N, C, 2, 3] instead of [N, C, 3, 5]. The current 4D test uses [1, 2, 2, 2], so it does not catch this. Since the 5D branch correctly reads the permuted spatial dims from grid_shape->values[2:], I think the 4D branch should do the same (values[2], values[3]) and add a non-square 4D test.
…rence - onnx frontend: raise NotImplementedError for 5D mode='cubic' (TOPI 3D grid_sample supports only bilinear/nearest), instead of importing a model that fails later at legalization - InferStructInfoGridSample: 4D branch now reads the permuted grid spatial dims (values[2]/values[3]) to match the frontend's NCHW permutation; fixes non-square output shape inference - tests: add 5D cubic rejection test and non-square 4D output shape test
|
Thanks for the careful review! Both addressed in 9518915:
I couldn't run pytest locally (no TVM build here), so I'm relying on CI for the compiled checks. |
|
LGTM! Thanks for the fix! |
The 5D GridSample change (apache#19816) landed with a clang-format violation on the structured binding for CheckTensorLayout, which fails the repo-wide pre-commit lint (clang-format v20.1.8). Reformat to satisfy the hook.
Summary
The Relax ONNX frontend's GridSample._impl_v16 converter unconditionally permutes the grid from ONNX [N,H,W,2] to TVM [N,2,H,W] and calls image.grid_sample with layout="NCHW". For 5D volumetric inputs ([N,C,D,H,W] with grid [N,D,H,W,3]) this crashes at permute_dims with an InternalError ('PermuteDims expects the number of input axes to equal the ndim of the input tensor.
Changes
In GridSample._impl_v16, read data.struct_info.ndim and dispatch on rank. For ndim==4, keep the existing permute_dims(grid,[0,3,1,2]) + grid_sample(layout="NCHW").
Fixes #19688