Skip to content

[Bug] Relax ONNX Clip with NaN max bound returns all NaNs #19533

Description

@ALinrunrun

Expected behavior

TVM Relax should execute ONNX Clip consistently with ONNX Runtime when the max bound is NaN.

For the following inputs:

X  = [0.5, -3.0, 4.5, 11.0, NaN]
MN = 0.0
MX = NaN

ONNX Runtime returns:

[0.5, 0.0, 4.5, 11.0, nan]

Actual behavior

TVM Relax returns NaN for every element:

ORT: [0.5, 0.0, 4.5, 11.0, nan]
TVM: [nan, nan, nan, nan, nan]

The discrepancy appears when importing an ONNX Clip model through the Relax ONNX frontend and compiling it for the llvm target.

Environment

TVM: 0.14 environment / Relax ONNX frontend
ONNX Runtime: 1.23
Python: 3.11
Target: llvm
OS: Linux

Steps to reproduce

import numpy as np
import onnx
import onnxruntime as ort
from onnx import TensorProto, helper
import tvm
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx


node = helper.make_node("Clip", ["X", "MN", "MX"], ["Y"])

graph = helper.make_graph(
    [node],
    "g",
    [
        helper.make_tensor_value_info("X", TensorProto.FLOAT, [5]),
        helper.make_tensor_value_info("MN", TensorProto.FLOAT, []),
        helper.make_tensor_value_info("MX", TensorProto.FLOAT, []),
    ],
    [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [5])],
)

model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)])
model.ir_version = 9

x = np.array([0.5, -3.0, 4.5, 11.0, np.nan], dtype=np.float32)
mn = np.array(0.0, dtype=np.float32)
mx = np.array(np.nan, dtype=np.float32)

ort_out = ort.InferenceSession(
    model.SerializeToString(),
    providers=["CPUExecutionProvider"],
).run(None, {"X": x, "MN": mn, "MX": mx})[0]

mod = from_onnx(model, keep_params_in_input=False)

with tvm.transform.PassContext(opt_level=3):
    ex = tvm.compile(mod, target=tvm.target.Target("llvm"))

vm = relax.VirtualMachine(ex, tvm.cpu())

out = vm["main"](
    tvm.runtime.tensor(x, tvm.cpu()),
    tvm.runtime.tensor(mn, tvm.cpu()),
    tvm.runtime.tensor(mx, tvm.cpu()),
)

tvm_out = (out[0] if isinstance(out, (list, tuple)) else out).numpy()

print("ORT:", ort_out.tolist())
print("TVM:", tvm_out.tolist())

Triage

  • needs-triage

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions