[ONNX] Add RMSNormalization converter for ONNX opset 23#19590
Conversation
Add support for the ONNX RMSNormalization operator (opset 23) in the Relax ONNX frontend. This operator is essential for importing LLM models (LLaMA, Gemma, etc.) that use RMS normalization. The implementation: - Maps ONNX RMSNormalization to relax.op.nn.rms_norm - Supports the axis, epsilon, and stash_type attributes - Handles float16 inputs with stash_type=1 (compute in float32) - Includes unit tests comparing against ONNX Runtime
There was a problem hiding this comment.
Code Review
This pull request implements the RMSNormalization operator (opset 23) in the TVM Relax ONNX frontend, including support for axis attributes, epsilon, and stash_type for high-precision computation. Unit tests were added to verify correctness across different configurations. Feedback suggests enhancing the implementation by using standard utility functions for safer rank and axis handling, and modifying the return type to a tuple to align with the ONNX specification's multi-output requirement.
| ndim = len(data.struct_info.shape) | ||
| if axis < 0: | ||
| axis = ndim + axis | ||
| axes = list(range(axis, ndim)) |
There was a problem hiding this comment.
Accessing data.struct_info.shape directly can be unsafe if the input rank is unknown (i.e., ndim == -1), as shape would be None. It is recommended to use _get_known_tensor_rank to safely retrieve the rank and provide a descriptive error message if it's unavailable. Additionally, using _normalize_constant_axes is a more robust way to handle negative axis values and perform bounds checking.
| ndim = len(data.struct_info.shape) | |
| if axis < 0: | |
| axis = ndim + axis | |
| axes = list(range(axis, ndim)) | |
| ndim = _get_known_tensor_rank(data) | |
| if ndim is None: | |
| raise ValueError("RMSNormalization requires a statically known input rank.") | |
| axis = _normalize_constant_axes([axis], ndim, "RMSNormalization")[0] | |
| axes = list(range(axis, ndim)) |
| if stash_type == 1 and input_dtype != "float32": | ||
| output = relax.op.astype(output, input_dtype) | ||
|
|
||
| return output |
There was a problem hiding this comment.
The ONNX RMSNormalization operator (opset 23) specifies two outputs: the normalized tensor Y and an optional stashed inverse standard deviation inv_std_dev. To ensure compatibility with ONNX models that might request both outputs, the converter should return a relax.Tuple. Following the pattern used in other normalization converters in this frontend, a placeholder can be provided for the unused second output.
| return output | |
| # ONNX RMSNormalization has 2 outputs: Y and inv_std_dev. | |
| # We return a placeholder for the second output. | |
| placeholder = relax.const(0, dtype="float32" if stash_type == 1 else input_dtype) | |
| return relax.Tuple([output, placeholder]) |
There was a problem hiding this comment.
https://onnx.ai/onnx/operators/onnx__RMSNormalization.html .RMSNorm only have one ouput in onnx official spec.(Outputs:
• Y (heterogeneous) - V: Output data tensor. Same shape as X)
…es for safer rank/axis handling
Add support for the ONNX RMSNormalization operator (opset 23) in the Relax ONNX frontend. This operator is essential for importing LLM models (LLaMA, Gemma, etc.) that use RMS normalization.
The implementation: