From f1f9651bf421d3a268c17c25177c1434719d47f9 Mon Sep 17 00:00:00 2001 From: q55180514 Date: Wed, 20 May 2026 10:07:14 +0800 Subject: [PATCH 1/2] [Relay/ONNX] Add RMSNormalization converter for ONNX opset 23 Add support for the ONNX RMSNormalization operator (opset 23) in the Relax ONNX frontend. This operator is essential for importing LLM models (LLaMA, Gemma, etc.) that use RMS normalization. The implementation: - Maps ONNX RMSNormalization to relax.op.nn.rms_norm - Supports the axis, epsilon, and stash_type attributes - Handles float16 inputs with stash_type=1 (compute in float32) - Includes unit tests comparing against ONNX Runtime --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 36 +++++++++++ tests/python/relax/test_frontend_onnx.py | 62 +++++++++++++++++++ 2 files changed, 98 insertions(+) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 5f41644149db..4ce358bb52ad 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -3787,6 +3787,41 @@ def _impl_v17(cls, bb, inputs, attr, params): return relax.Tuple([output, placeholder, placeholder]) +class RMSNormalization(OnnxOpConverter): + """Converts an onnx RMSNormalization node into an equivalent Relax expression.""" + + @classmethod + def _impl_v23(cls, bb, inputs, attr, params): + data = inputs[0] + scale = inputs[1] + axis = attr.get("axis", -1) + epsilon = attr.get("epsilon", 1e-05) + stash_type = attr.get("stash_type", 1) + + # Determine normalization axes: from `axis` to the last dimension + ndim = len(data.struct_info.shape) + if axis < 0: + axis = ndim + axis + axes = list(range(axis, ndim)) + + # If stash_type requires float32 computation and input is not float32, cast + input_dtype = data.struct_info.dtype + if stash_type == 1 and input_dtype != "float32": + data_compute = relax.op.astype(data, "float32") + scale_compute = relax.op.astype(scale, "float32") + else: + data_compute = data + scale_compute = scale + + output = relax.op.nn.rms_norm(data_compute, scale_compute, axes, epsilon) + + # Cast back to original dtype if needed + if stash_type == 1 and input_dtype != "float32": + output = relax.op.astype(output, input_dtype) + + return output + + class ReduceMax(OnnxOpConverter): """Converts an onnx ReduceMax node into an equivalent Relax expression.""" @@ -5130,6 +5165,7 @@ def _get_convert_map(): # Normalization "BatchNormalization": BatchNormalization, "LayerNormalization": LayerNormalization, + "RMSNormalization": RMSNormalization, "SkipLayerNormalization": SkipLayerNormalization, "EmbedLayerNormalization": EmbedLayerNormalization, "InstanceNormalization": InstanceNormalization, diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index b658a2aabaea..2b0194f08578 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -2309,6 +2309,68 @@ def test_layer_norm_with_nd_gamma_beta(): check_correctness(model) +def test_rms_norm(): + # Basic test: default axis=-1 + rms_norm_node = helper.make_node( + "RMSNormalization", ["input", "scale"], ["Y"], epsilon=1e-05 + ) + + graph = helper.make_graph( + [rms_norm_node], + "rms_norm_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, [2, 8, 32]), + helper.make_tensor_value_info("scale", TensorProto.FLOAT, [32]), + ], + outputs=[ + helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 8, 32]), + ], + ) + + model = helper.make_model(graph, producer_name="rms_norm_test") + check_correctness(model, opset=23) + + # Test with explicit axis=1 (normalize over last 2 dims) + rms_norm_node = helper.make_node( + "RMSNormalization", ["input", "scale"], ["Y"], axis=1, epsilon=1e-06 + ) + + graph = helper.make_graph( + [rms_norm_node], + "rms_norm_axis_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, [4, 8, 16]), + helper.make_tensor_value_info("scale", TensorProto.FLOAT, [8, 16]), + ], + outputs=[ + helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4, 8, 16]), + ], + ) + + model = helper.make_model(graph, producer_name="rms_norm_axis_test") + check_correctness(model, opset=23) + + # Test with float16 input (stash_type=1 means compute in float32) + rms_norm_node = helper.make_node( + "RMSNormalization", ["input", "scale"], ["Y"], epsilon=1e-05, stash_type=1 + ) + + graph = helper.make_graph( + [rms_norm_node], + "rms_norm_fp16_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT16, [2, 8, 32]), + helper.make_tensor_value_info("scale", TensorProto.FLOAT16, [32]), + ], + outputs=[ + helper.make_tensor_value_info("Y", TensorProto.FLOAT16, [2, 8, 32]), + ], + ) + + model = helper.make_model(graph, producer_name="rms_norm_fp16_test") + check_correctness(model, opset=23, rtol=1e-2, atol=1e-2) + + # TODO Enable dynamism @pytest.mark.parametrize("dynamic", [False]) def test_skiplayernormalization(dynamic): From 6ce32c801a8caa78102a64f91878a2bf9c515572 Mon Sep 17 00:00:00 2001 From: q55180514 Date: Wed, 20 May 2026 10:40:37 +0800 Subject: [PATCH 2/2] address review: use _get_known_tensor_rank and _normalize_constant_axes for safer rank/axis handling --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 4ce358bb52ad..a1e82afcf332 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -3799,9 +3799,10 @@ def _impl_v23(cls, bb, inputs, attr, params): stash_type = attr.get("stash_type", 1) # Determine normalization axes: from `axis` to the last dimension - ndim = len(data.struct_info.shape) - if axis < 0: - axis = ndim + axis + ndim = _get_known_tensor_rank(data) + if ndim is None: + raise ValueError("RMSNormalization requires a statically known input rank.") + axis = _normalize_constant_axes([axis], ndim, "RMSNormalization")[0] axes = list(range(axis, ndim)) # If stash_type requires float32 computation and input is not float32, cast