diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 5f41644149db..a1e82afcf332 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -3787,6 +3787,42 @@ def _impl_v17(cls, bb, inputs, attr, params): return relax.Tuple([output, placeholder, placeholder]) +class RMSNormalization(OnnxOpConverter): + """Converts an onnx RMSNormalization node into an equivalent Relax expression.""" + + @classmethod + def _impl_v23(cls, bb, inputs, attr, params): + data = inputs[0] + scale = inputs[1] + axis = attr.get("axis", -1) + epsilon = attr.get("epsilon", 1e-05) + stash_type = attr.get("stash_type", 1) + + # Determine normalization axes: from `axis` to the last dimension + ndim = _get_known_tensor_rank(data) + if ndim is None: + raise ValueError("RMSNormalization requires a statically known input rank.") + axis = _normalize_constant_axes([axis], ndim, "RMSNormalization")[0] + axes = list(range(axis, ndim)) + + # If stash_type requires float32 computation and input is not float32, cast + input_dtype = data.struct_info.dtype + if stash_type == 1 and input_dtype != "float32": + data_compute = relax.op.astype(data, "float32") + scale_compute = relax.op.astype(scale, "float32") + else: + data_compute = data + scale_compute = scale + + output = relax.op.nn.rms_norm(data_compute, scale_compute, axes, epsilon) + + # Cast back to original dtype if needed + if stash_type == 1 and input_dtype != "float32": + output = relax.op.astype(output, input_dtype) + + return output + + class ReduceMax(OnnxOpConverter): """Converts an onnx ReduceMax node into an equivalent Relax expression.""" @@ -5130,6 +5166,7 @@ def _get_convert_map(): # Normalization "BatchNormalization": BatchNormalization, "LayerNormalization": LayerNormalization, + "RMSNormalization": RMSNormalization, "SkipLayerNormalization": SkipLayerNormalization, "EmbedLayerNormalization": EmbedLayerNormalization, "InstanceNormalization": InstanceNormalization, diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index b658a2aabaea..2b0194f08578 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -2309,6 +2309,68 @@ def test_layer_norm_with_nd_gamma_beta(): check_correctness(model) +def test_rms_norm(): + # Basic test: default axis=-1 + rms_norm_node = helper.make_node( + "RMSNormalization", ["input", "scale"], ["Y"], epsilon=1e-05 + ) + + graph = helper.make_graph( + [rms_norm_node], + "rms_norm_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, [2, 8, 32]), + helper.make_tensor_value_info("scale", TensorProto.FLOAT, [32]), + ], + outputs=[ + helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 8, 32]), + ], + ) + + model = helper.make_model(graph, producer_name="rms_norm_test") + check_correctness(model, opset=23) + + # Test with explicit axis=1 (normalize over last 2 dims) + rms_norm_node = helper.make_node( + "RMSNormalization", ["input", "scale"], ["Y"], axis=1, epsilon=1e-06 + ) + + graph = helper.make_graph( + [rms_norm_node], + "rms_norm_axis_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, [4, 8, 16]), + helper.make_tensor_value_info("scale", TensorProto.FLOAT, [8, 16]), + ], + outputs=[ + helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4, 8, 16]), + ], + ) + + model = helper.make_model(graph, producer_name="rms_norm_axis_test") + check_correctness(model, opset=23) + + # Test with float16 input (stash_type=1 means compute in float32) + rms_norm_node = helper.make_node( + "RMSNormalization", ["input", "scale"], ["Y"], epsilon=1e-05, stash_type=1 + ) + + graph = helper.make_graph( + [rms_norm_node], + "rms_norm_fp16_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT16, [2, 8, 32]), + helper.make_tensor_value_info("scale", TensorProto.FLOAT16, [32]), + ], + outputs=[ + helper.make_tensor_value_info("Y", TensorProto.FLOAT16, [2, 8, 32]), + ], + ) + + model = helper.make_model(graph, producer_name="rms_norm_fp16_test") + check_correctness(model, opset=23, rtol=1e-2, atol=1e-2) + + # TODO Enable dynamism @pytest.mark.parametrize("dynamic", [False]) def test_skiplayernormalization(dynamic):