Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3787,6 +3787,42 @@ def _impl_v17(cls, bb, inputs, attr, params):
return relax.Tuple([output, placeholder, placeholder])


class RMSNormalization(OnnxOpConverter):
"""Converts an onnx RMSNormalization node into an equivalent Relax expression."""

@classmethod
def _impl_v23(cls, bb, inputs, attr, params):
data = inputs[0]
scale = inputs[1]
axis = attr.get("axis", -1)
epsilon = attr.get("epsilon", 1e-05)
stash_type = attr.get("stash_type", 1)

# Determine normalization axes: from `axis` to the last dimension
ndim = _get_known_tensor_rank(data)
if ndim is None:
raise ValueError("RMSNormalization requires a statically known input rank.")
axis = _normalize_constant_axes([axis], ndim, "RMSNormalization")[0]
axes = list(range(axis, ndim))

# If stash_type requires float32 computation and input is not float32, cast
input_dtype = data.struct_info.dtype
if stash_type == 1 and input_dtype != "float32":
data_compute = relax.op.astype(data, "float32")
scale_compute = relax.op.astype(scale, "float32")
else:
data_compute = data
scale_compute = scale

output = relax.op.nn.rms_norm(data_compute, scale_compute, axes, epsilon)

# Cast back to original dtype if needed
if stash_type == 1 and input_dtype != "float32":
output = relax.op.astype(output, input_dtype)

return output

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ONNX RMSNormalization operator (opset 23) specifies two outputs: the normalized tensor Y and an optional stashed inverse standard deviation inv_std_dev. To ensure compatibility with ONNX models that might request both outputs, the converter should return a relax.Tuple. Following the pattern used in other normalization converters in this frontend, a placeholder can be provided for the unused second output.

Suggested change
return output
# ONNX RMSNormalization has 2 outputs: Y and inv_std_dev.
# We return a placeholder for the second output.
placeholder = relax.const(0, dtype="float32" if stash_type == 1 else input_dtype)
return relax.Tuple([output, placeholder])

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://onnx.ai/onnx/operators/onnx__RMSNormalization.html .RMSNorm only have one ouput in onnx official spec.(Outputs:
• Y (heterogeneous) - V: Output data tensor. Same shape as X)



class ReduceMax(OnnxOpConverter):
"""Converts an onnx ReduceMax node into an equivalent Relax expression."""

Expand Down Expand Up @@ -5130,6 +5166,7 @@ def _get_convert_map():
# Normalization
"BatchNormalization": BatchNormalization,
"LayerNormalization": LayerNormalization,
"RMSNormalization": RMSNormalization,
"SkipLayerNormalization": SkipLayerNormalization,
"EmbedLayerNormalization": EmbedLayerNormalization,
"InstanceNormalization": InstanceNormalization,
Expand Down
62 changes: 62 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2309,6 +2309,68 @@ def test_layer_norm_with_nd_gamma_beta():
check_correctness(model)


def test_rms_norm():
# Basic test: default axis=-1
rms_norm_node = helper.make_node(
"RMSNormalization", ["input", "scale"], ["Y"], epsilon=1e-05
)

graph = helper.make_graph(
[rms_norm_node],
"rms_norm_test",
inputs=[
helper.make_tensor_value_info("input", TensorProto.FLOAT, [2, 8, 32]),
helper.make_tensor_value_info("scale", TensorProto.FLOAT, [32]),
],
outputs=[
helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 8, 32]),
],
)

model = helper.make_model(graph, producer_name="rms_norm_test")
check_correctness(model, opset=23)

# Test with explicit axis=1 (normalize over last 2 dims)
rms_norm_node = helper.make_node(
"RMSNormalization", ["input", "scale"], ["Y"], axis=1, epsilon=1e-06
)

graph = helper.make_graph(
[rms_norm_node],
"rms_norm_axis_test",
inputs=[
helper.make_tensor_value_info("input", TensorProto.FLOAT, [4, 8, 16]),
helper.make_tensor_value_info("scale", TensorProto.FLOAT, [8, 16]),
],
outputs=[
helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4, 8, 16]),
],
)

model = helper.make_model(graph, producer_name="rms_norm_axis_test")
check_correctness(model, opset=23)

# Test with float16 input (stash_type=1 means compute in float32)
rms_norm_node = helper.make_node(
"RMSNormalization", ["input", "scale"], ["Y"], epsilon=1e-05, stash_type=1
)

graph = helper.make_graph(
[rms_norm_node],
"rms_norm_fp16_test",
inputs=[
helper.make_tensor_value_info("input", TensorProto.FLOAT16, [2, 8, 32]),
helper.make_tensor_value_info("scale", TensorProto.FLOAT16, [32]),
],
outputs=[
helper.make_tensor_value_info("Y", TensorProto.FLOAT16, [2, 8, 32]),
],
)

model = helper.make_model(graph, producer_name="rms_norm_fp16_test")
check_correctness(model, opset=23, rtol=1e-2, atol=1e-2)


# TODO Enable dynamism
@pytest.mark.parametrize("dynamic", [False])
def test_skiplayernormalization(dynamic):
Expand Down