diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index b82fceff1d6c..8bacd1b329ec 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -4531,7 +4531,12 @@ class Sign(OnnxOpConverter): @classmethod def _impl_v9(cls, bb, inputs, attr, params): - return relax.op.sign(inputs[0]) + x = inputs[0] + x_dtype = x.struct_info.dtype if isinstance(x.struct_info, relax.TensorStructInfo) else None + y = relax.op.sign(x) + if x_dtype is not None and _relax_dtype_is_floating_point(x_dtype): + return relax.op.where(relax.op.isnan(x), x, y) + return y class Not(OnnxOpConverter): diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 7ee10993a4e9..f8e8ed150bd9 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -771,6 +771,35 @@ def test_unary(op_name: str): verify_unary(op_name, [8, 8, 8], input_dtype=input_dtype, output_dtype=output_dtype) +def test_sign_nan_preserve(): + sign_node = helper.make_node("Sign", ["x"], ["y"]) + graph = helper.make_graph( + [sign_node], + "sign_nan_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [4])], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [4])], + ) + model = helper.make_model(graph, producer_name="sign_nan_test") + model.ir_version = 8 + for opset_import in model.opset_import: + if opset_import.domain in ["", "ai.onnx"]: + opset_import.version = 18 + break + x = np.array([np.nan, 9.0, -9.0, np.nan], dtype=np.float32) + + ort_out = onnxruntime.InferenceSession( + model.SerializeToString(), providers=["CPUExecutionProvider"] + ).run([], {"x": x})[0] + + tvm_out = run_in_tvm(model, inputs={"x": x}, opset=18) + out_np = (tvm_out[0] if isinstance(tvm_out, list | tuple) else tvm_out).numpy() + + np.testing.assert_array_equal(np.isnan(out_np), np.isnan(ort_out)) + np.testing.assert_allclose( + out_np[~np.isnan(ort_out)], ort_out[~np.isnan(ort_out)], rtol=1e-7, atol=1e-5 + ) + + @pytest.mark.parametrize("op_name", ["Softmax", "LogSoftmax", "Hardmax"]) def test_softmax_family_opset11_default_axis_semantics(op_name: str): verify_unary(op_name, [2, 3, 4], opset=11)