Skip to content

[Bug] Relax ONNX NonMaxSuppression ignores max_output_boxes_per_class=0 #19544

Description

@ALinrunrun

Expected behavior

TVM Relax should execute ONNX NonMaxSuppression consistently with ONNX Runtime.

According to the ONNX NonMaxSuppression specification, max_output_boxes_per_class defaults to 0, which means no output. Therefore, both of the following cases should return an empty tensor:

  1. max_output_boxes_per_class is omitted.
  2. max_output_boxes_per_class is explicitly provided as 0.

Actual behavior

TVM Relax returns selected indices even when max_output_boxes_per_class is omitted or explicitly set to 0:

no max input:
  ORT: []
  TVM: [[0, 0, 0], [0, 0, 2], [0, 0, 3]]

max=0:
  ORT: []
  TVM: [[0, 0, 0], [0, 0, 2], [0, 0, 3]]

The discrepancy appears when importing an ONNX NonMaxSuppression model through the Relax ONNX frontend and compiling it for the llvm target.

Environment

TVM: 0.14 environment / Relax ONNX frontend
ONNX Runtime: 1.23
Python: 3.11
Target: llvm
OS: Linux

Steps to reproduce

import warnings

warnings.filterwarnings("ignore")

import numpy as np
import onnx
import onnxruntime as ort
import tvm
from onnx import TensorProto, helper
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx


def make_nms_model(with_max_input):
    inputs = ["b", "s"] + (["m"] if with_max_input else [])

    node = helper.make_node("NonMaxSuppression", inputs, ["y"])

    model_inputs = [
        helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 4, 4]),
        helper.make_tensor_value_info("s", TensorProto.FLOAT, [1, 1, 4]),
    ]

    if with_max_input:
        model_inputs.append(
            helper.make_tensor_value_info("m", TensorProto.INT64, [1])
        )

    graph = helper.make_graph(
        [node],
        "g",
        model_inputs,
        [helper.make_tensor_value_info("y", TensorProto.INT64, None)],
    )

    model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)])
    model.ir_version = 9
    return model


def run_ort(model, feed):
    sess = ort.InferenceSession(
        model.SerializeToString(),
        providers=["CPUExecutionProvider"],
    )
    return sess.run(None, feed)[0]


def run_tvm(model, feed):
    mod = from_onnx(model)

    with tvm.transform.PassContext(opt_level=3):
        ex = tvm.compile(mod, target=tvm.target.Target("llvm"))

    vm = relax.VirtualMachine(ex, tvm.cpu())

    args = [tvm.runtime.tensor(v, tvm.cpu()) for v in feed.values()]
    out = vm["main"](*args)

    out = out[0] if isinstance(out, (list, tuple)) else out
    return out.numpy()


base = {
    "b": np.array(
        [
            [
                [5.0, 5.0, 6.0, 6.0],
                [5.05, 5.05, 6.05, 6.05],
                [10.0, 10.0, 11.0, 11.0],
                [20.0, 20.0, 21.0, 21.0],
            ]
        ],
        dtype=np.float32,
    ),
    "s": np.array([[[0.95, 0.85, 0.75, 0.65]]], dtype=np.float32),
}

model_no_max = make_nms_model(False)
ort_no_max = run_ort(model_no_max, base)
tvm_no_max = run_tvm(model_no_max, base)

print("no max input:")
print("  ORT:", ort_no_max.tolist())
print("  TVM:", tvm_no_max.tolist())

model_max_zero = make_nms_model(True)
feed_max_zero = {**base, "m": np.array([0], dtype=np.int64)}

ort_max_zero = run_ort(model_max_zero, feed_max_zero)
tvm_max_zero = run_tvm(model_max_zero, feed_max_zero)

print("max=0:")
print("  ORT:", ort_max_zero.tolist())
print("  TVM:", tvm_max_zero.tolist())

Triage

  • needs-triage

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions