Skip to content

[ONNX] Add RMSNormalization converter for ONNX opset 23#19590

Merged
tlopex merged 2 commits into
apache:mainfrom
q55180514:feat/onnx-rmsnormalization
May 21, 2026
Merged

[ONNX] Add RMSNormalization converter for ONNX opset 23#19590
tlopex merged 2 commits into
apache:mainfrom
q55180514:feat/onnx-rmsnormalization

Conversation

@q55180514

Copy link
Copy Markdown
Contributor

Add support for the ONNX RMSNormalization operator (opset 23) in the Relax ONNX frontend. This operator is essential for importing LLM models (LLaMA, Gemma, etc.) that use RMS normalization.

The implementation:

  • Maps ONNX RMSNormalization to relax.op.nn.rms_norm
  • Supports the axis, epsilon, and stash_type attributes
  • Handles float16 inputs with stash_type=1 (compute in float32)
  • Includes unit tests comparing against ONNX Runtime

Add support for the ONNX RMSNormalization operator (opset 23) in the
Relax ONNX frontend. This operator is essential for importing LLM models
(LLaMA, Gemma, etc.) that use RMS normalization.

The implementation:
- Maps ONNX RMSNormalization to relax.op.nn.rms_norm
- Supports the axis, epsilon, and stash_type attributes
- Handles float16 inputs with stash_type=1 (compute in float32)
- Includes unit tests comparing against ONNX Runtime

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the RMSNormalization operator (opset 23) in the TVM Relax ONNX frontend, including support for axis attributes, epsilon, and stash_type for high-precision computation. Unit tests were added to verify correctness across different configurations. Feedback suggests enhancing the implementation by using standard utility functions for safer rank and axis handling, and modifying the return type to a tuple to align with the ONNX specification's multi-output requirement.

Comment on lines +3802 to +3805
ndim = len(data.struct_info.shape)
if axis < 0:
axis = ndim + axis
axes = list(range(axis, ndim))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Accessing data.struct_info.shape directly can be unsafe if the input rank is unknown (i.e., ndim == -1), as shape would be None. It is recommended to use _get_known_tensor_rank to safely retrieve the rank and provide a descriptive error message if it's unavailable. Additionally, using _normalize_constant_axes is a more robust way to handle negative axis values and perform bounds checking.

Suggested change
ndim = len(data.struct_info.shape)
if axis < 0:
axis = ndim + axis
axes = list(range(axis, ndim))
ndim = _get_known_tensor_rank(data)
if ndim is None:
raise ValueError("RMSNormalization requires a statically known input rank.")
axis = _normalize_constant_axes([axis], ndim, "RMSNormalization")[0]
axes = list(range(axis, ndim))

if stash_type == 1 and input_dtype != "float32":
output = relax.op.astype(output, input_dtype)

return output

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ONNX RMSNormalization operator (opset 23) specifies two outputs: the normalized tensor Y and an optional stashed inverse standard deviation inv_std_dev. To ensure compatibility with ONNX models that might request both outputs, the converter should return a relax.Tuple. Following the pattern used in other normalization converters in this frontend, a placeholder can be provided for the unused second output.

Suggested change
return output
# ONNX RMSNormalization has 2 outputs: Y and inv_std_dev.
# We return a placeholder for the second output.
placeholder = relax.const(0, dtype="float32" if stash_type == 1 else input_dtype)
return relax.Tuple([output, placeholder])

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://onnx.ai/onnx/operators/onnx__RMSNormalization.html .RMSNorm only have one ouput in onnx official spec.(Outputs:
• Y (heterogeneous) - V: Output data tensor. Same shape as X)

@tlopex tlopex merged commit a1e4cd8 into apache:main May 21, 2026
6 checks passed
@tlopex tlopex changed the title [Relay/ONNX] Add RMSNormalization converter for ONNX opset 23 [ONNX] Add RMSNormalization converter for ONNX opset 23 May 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants