Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions include/tvm/tirx/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -654,19 +654,25 @@ TVM_DLL PrimExpr floor(PrimExpr x, Span span = Span());
TVM_DLL PrimExpr ceil(PrimExpr x, Span span = Span());

/*!
* \brief Calculate round(x)
* \brief Round x to the nearest integer, ties to even.
*
* Uses IEEE 754 default rounding mode (ties-to-even / banker's rounding).
* Constant-folding and all backends consistently use std::nearbyint semantics.
*
* \param x The input expression.
* \param span The location of this operation in the source.
* \return The result expression.
*/
TVM_DLL PrimExpr round(PrimExpr x, Span span = Span());

/*!
* \brief Calculates std::nearbyint(x)
* \brief Round x to the nearest integer, ties to even.
*
* Equivalent to round(). Both use IEEE 754 default rounding mode (ties-to-even).
*
* \param x The input expression.
* \param span The location of this operation in the source.
* \return The result expression.
* This is a faster alternate to round.
*/
TVM_DLL PrimExpr nearbyint(PrimExpr x, Span span = Span());

Expand Down
10 changes: 6 additions & 4 deletions python/tvm/topi/testing/roi_pool_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@ def roi_pool_nchw_python(a_np, rois_np, pooled_size, spatial_scale):
for i in range(num_roi):
roi = rois_np[i]
batch_index = int(roi[0])
roi_start_w = round(roi[1] * spatial_scale)
roi_start_h = round(roi[2] * spatial_scale)
roi_end_w = round(roi[3] * spatial_scale)
roi_end_h = round(roi[4] * spatial_scale)
# Use ties-away-from-zero rounding to match ONNX runtime (std::round semantics).
# Python's built-in round() uses ties-to-even, so use floor(x + 0.5) explicitly.
roi_start_w = math.floor(roi[1] * spatial_scale + 0.5)
roi_start_h = math.floor(roi[2] * spatial_scale + 0.5)
roi_end_w = math.floor(roi[3] * spatial_scale + 0.5)
roi_end_h = math.floor(roi[4] * spatial_scale + 0.5)
roi_h = max(roi_end_h - roi_start_h + 1, 1)
roi_w = max(roi_end_w - roi_start_w + 1, 1)

Expand Down
15 changes: 11 additions & 4 deletions python/tvm/topi/vision/roi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,19 @@ def roi_pool_nchw(data, rois, pooled_size, spatial_scale):

neg_inf = tvm.tirx.const(float("-inf"), data.dtype)

def _round_away(x):
# ONNX MaxRoiPool spec uses ties-away-from-zero rounding for coordinate
# mapping (matching std::round semantics in the reference implementation).
# Use floor(x + 0.5) to be explicit and independent of tir.round semantics.
half = tvm.tirx.const(0.5, roi_dtype)
return te.floor(x + half)

def _bin_bounds(i, ph, pw):
roi = rois[i]
roi_start_w = te.round(roi[1] * spatial_scale).astype("int32")
roi_start_h = te.round(roi[2] * spatial_scale).astype("int32")
roi_end_w = te.round(roi[3] * spatial_scale).astype("int32")
roi_end_h = te.round(roi[4] * spatial_scale).astype("int32")
roi_start_w = _round_away(roi[1] * spatial_scale).astype("int32")
roi_start_h = _round_away(roi[2] * spatial_scale).astype("int32")
roi_end_w = _round_away(roi[3] * spatial_scale).astype("int32")
roi_end_h = _round_away(roi[4] * spatial_scale).astype("int32")

roi_h = te.max(roi_end_h - roi_start_h + 1, tvm.tirx.const(1, "int32"))
roi_w = te.max(roi_end_w - roi_start_w + 1, tvm.tirx.const(1, "int32"))
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/intrin_rule_hexagon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ TVM_REGISTER_OP("tirx.fabs")

TVM_REGISTER_OP("tirx.round")
.set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);

TVM_REGISTER_OP("tirx.ctpop")
.set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic",
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/intrin_rule_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ TVM_REGISTER_OP("tirx.fabs")

TVM_REGISTER_OP("tirx.round")
.set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);

TVM_REGISTER_OP("tirx.nearbyint")
.set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
Expand Down
10 changes: 9 additions & 1 deletion src/target/llvm/intrin_rule_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,15 @@ TVM_REGISTER_OP("tirx.ceil")
.set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice);

TVM_REGISTER_OP("tirx.round")
.set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice);
.set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
// Redirect to nearbyint (ties-to-even) to match constant-folding semantics.
using namespace tirx;
const CallNode* call = e.as<CallNode>();
TVM_FFI_ICHECK(call != nullptr);
auto nearbyint_op = Op::Get("tirx.nearbyint");
auto new_call = Call(call->dtype, nearbyint_op, call->args);
return DispatchPureExternLibDevice(new_call);
});

TVM_REGISTER_OP("tirx.nearbyint")
.set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice);
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/intrin_rule_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ TVM_REGISTER_OP("tirx.ceil")

TVM_REGISTER_OP("tirx.round")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);

TVM_REGISTER_OP("tirx.nearbyint")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) {
this->PrintCallExtern(GetType(ffi::GetRef<PrimExpr>(op)), "atomic_add_float_emu", op->args,
true, os);
} else if (func->value == "nearbyint") {
this->PrintCallExtern(GetType(ffi::GetRef<PrimExpr>(op)), "round", op->args, true, os);
this->PrintCallExtern(GetType(ffi::GetRef<PrimExpr>(op)), "rint", op->args, true, os);
} else {
if (func->value == "atomic_add") {
enable_atomics_ = true;
Expand Down
3 changes: 3 additions & 0 deletions src/target/source/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@ struct CUDAMath {
if (t.is_float()) {
switch (t.bits()) {
case 64:
// Use nearbyint (ties-to-even) for round to match constant-folding semantics.
if (name == "round") return "nearbyint";
return name;
case 32:
if (name == "round") return "nearbyintf";
return name + 'f';
case 16: {
if (name == "fabs") {
Expand Down
11 changes: 10 additions & 1 deletion src/target/source/intrin_rule_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,16 @@ TVM_REGISTER_OP("tirx.fabs")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);

TVM_REGISTER_OP("tirx.round")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
// Metal's rint() uses ties-to-even, matching constant-folding semantics.
const tirx::CallNode* call = e.as<tirx::CallNode>();
TVM_FFI_ICHECK(call != nullptr);
ffi::Array<PrimExpr> new_args = {tirx::StringImm("rint")};
for (auto arg : call->args) {
new_args.push_back(arg);
}
return tirx::Call(call->dtype, tirx::builtin::call_pure_extern(), new_args);
});

TVM_REGISTER_OP("tirx.nearbyint")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
Expand Down
11 changes: 10 additions & 1 deletion src/target/source/intrin_rule_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,16 @@ TVM_REGISTER_OP("tirx.fabs")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);

TVM_REGISTER_OP("tirx.round")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
// OpenCL's rint() uses ties-to-even, matching constant-folding semantics.
const tirx::CallNode* call = e.as<tirx::CallNode>();
TVM_FFI_ICHECK(call != nullptr);
ffi::Array<PrimExpr> new_args = {tirx::StringImm("rint")};
for (auto arg : call->args) {
new_args.push_back(arg);
}
return tirx::Call(call->dtype, tirx::builtin::call_pure_extern(), new_args);
});

TVM_REGISTER_OP("tirx.nearbyint")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
Expand Down
6 changes: 4 additions & 2 deletions src/target/spirv/intrin_rule_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ TVM_REGISTER_OP("tirx.ceil")
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Ceil>);

TVM_REGISTER_OP("tirx.round")
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Round>);
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
DispatchGLSLPureIntrin<GLSLstd450RoundEven>);

TVM_REGISTER_OP("tirx.nearbyint")
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Round>);
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
DispatchGLSLPureIntrin<GLSLstd450RoundEven>);

TVM_REGISTER_OP("tirx.trunc")
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Trunc>);
Expand Down
21 changes: 21 additions & 0 deletions tests/python/tirx-base/test_tir_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,27 @@ def test_nearbyint():
tvm.testing.assert_allclose(a_rounded.numpy(), np.rint(a.numpy()))


def test_round_ties_to_even():
"""Test that tir.round uses ties-to-even (banker's rounding) semantics."""
m = te.var("m")
A = te.placeholder((m,), name="A")
A_rounded = te.compute((m,), lambda *i: tvm.tirx.round(A(*i)), name="A")

mod = te.create_prim_func([A, A_rounded])
sch = tvm.s_tir.Schedule(mod)
func = tvm.compile(sch.mod, target="llvm")

dev = tvm.cpu(0)
# Midpoint values where ties-to-even and ties-away differ
test_values = np.array([0.5, 1.5, 2.5, 3.5, -0.5, -1.5, -2.5, -3.5], dtype="float32")
expected = np.array([0.0, 2.0, 2.0, 4.0, 0.0, -2.0, -2.0, -4.0], dtype="float32")

a = tvm.runtime.tensor(test_values, dev)
a_rounded = tvm.runtime.tensor(np.zeros(len(test_values), dtype="float32"), dev)
func(a, a_rounded)
tvm.testing.assert_allclose(a_rounded.numpy(), expected)


def test_round_intrinsics_on_int():
i = tvm.tirx.Var("i", "int32")
for op in [tvm.tirx.round, tvm.tirx.trunc, tvm.tirx.ceil, tvm.tirx.floor, tvm.tirx.nearbyint]:
Expand Down
Loading