Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions python/tvm/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,10 @@ def isinf(x):

@tvm.te.tag_scope(tag=tag.ELEMWISE)
def round(x):
"""Round elements of x to nearest integer.
"""Round elements of x to nearest integer using ties-to-even (banker's rounding).

Ties are broken by rounding to the nearest even integer, matching the ONNX Round
specification and IEEE 754 default rounding mode.

Parameters
----------
Expand All @@ -459,7 +462,7 @@ def round(x):
y : tvm.te.Tensor
The result.
"""
return te.compute(x.shape, lambda *i: te.round(x(*i)))
return te.compute(x.shape, lambda *i: te.nearbyint(x(*i)))


def log(x):
Expand Down
8 changes: 8 additions & 0 deletions src/target/source/intrin_rule_webgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ TVM_REGISTER_OP("tirx.log2")
TVM_REGISTER_OP("tirx.pow")
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchPureExtern<Direct>);

struct ReturnRound {
std::string operator()(DataType t, std::string name) const { return "round"; }
};

// WGSL round() uses ties-to-even (banker's rounding), matching IEEE 754 and ONNX Round spec.
TVM_REGISTER_OP("tirx.nearbyint")
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchPureExtern<ReturnRound>);
Comment thread
swjng marked this conversation as resolved.

TVM_REGISTER_OP("tirx.round")
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchPureExtern<Direct>);

Expand Down
21 changes: 21 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,27 @@ def test_unary(op_name: str):
verify_unary(op_name, [8, 8, 8], input_dtype=input_dtype, output_dtype=output_dtype)


def test_round_ties_to_even():
"""ONNX Round must use ties-to-even (banker's rounding), not ties-away-from-zero.

Per the ONNX spec: "For cases where number is exactly halfway between two
integers, it rounds to the nearest even integer."
https://onnx.ai/onnx/operators/onnx__Round.html
"""
round_node = helper.make_node("Round", ["x"], ["y"])
graph = helper.make_graph(
[round_node],
"round_ties_to_even_test",
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [6])],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [6])],
)
model = helper.make_model(graph, producer_name="round_ties_to_even_test")
# Midpoint values: 0.5->0, 1.5->2, 2.5->2, -0.5->0, -1.5->-2, -2.5->-2 (ties-to-even)
# Ties-away would give: 0.5->1, 1.5->2, 2.5->3, -0.5->-1, -1.5->-2, -2.5->-3
inputs = {"x": np.array([0.5, 1.5, 2.5, -0.5, -1.5, -2.5], dtype="float32")}
check_correctness(model, inputs=inputs, opset=11)


@pytest.mark.parametrize("from_type", [TensorProto.INT32, TensorProto.FLOAT, TensorProto.FLOAT16])
@pytest.mark.parametrize("to_type", [TensorProto.INT32, TensorProto.FLOAT, TensorProto.FLOAT16])
def test_cast(from_type, to_type):
Expand Down
Loading