From 21229ea7d3406a054f242e48ba09bb70d9b519fc Mon Sep 17 00:00:00 2001 From: Soowon Jeong Date: Wed, 8 Apr 2026 14:52:17 +0900 Subject: [PATCH 1/3] [BugFix] Align tir.round to ties-to-even across all backends tir.round constant-folds using std::nearbyint (ties-to-even), but backends lowered it to platform round() (ties-away-from-zero). This inconsistency meant compiled code could produce different results from constant-folded code for midpoint values like 0.5, 2.5, etc. Fix all backends to use ties-to-even intrinsics: - LLVM/ROCm/Hexagon: llvm::Intrinsic::round -> nearbyint - NVPTX: __nv_round -> __nv_nearbyint - CUDA: round/roundf -> nearbyint/nearbyintf (f16/bf16 already used hrint) - Metal/OpenCL: round -> rint - Vulkan/SPIR-V: GLSLstd450Round -> GLSLstd450RoundEven - OpenCL codegen: fix nearbyint mapping from round() to rint() Also update op.h documentation to explicitly state ties-to-even semantics and add test_round_ties_to_even regression test. Follow-up to #19367. Co-Authored-By: Claude Opus 4.6 (1M context) --- include/tvm/tirx/op.h | 12 +++++++++--- src/target/llvm/intrin_rule_hexagon.cc | 2 +- src/target/llvm/intrin_rule_llvm.cc | 2 +- src/target/llvm/intrin_rule_nvptx.cc | 16 +++++++++++++++- src/target/llvm/intrin_rule_rocm.cc | 2 +- src/target/source/codegen_opencl.cc | 2 +- src/target/source/intrin_rule_cuda.cc | 3 +++ src/target/source/intrin_rule_metal.cc | 11 ++++++++++- src/target/source/intrin_rule_opencl.cc | 11 ++++++++++- src/target/spirv/intrin_rule_spirv.cc | 6 ++++-- tests/python/tirx-base/test_tir_intrin.py | 21 +++++++++++++++++++++ 11 files changed, 76 insertions(+), 12 deletions(-) diff --git a/include/tvm/tirx/op.h b/include/tvm/tirx/op.h index 66d9d932b3fa..c953f12e3870 100644 --- a/include/tvm/tirx/op.h +++ b/include/tvm/tirx/op.h @@ -654,7 +654,11 @@ TVM_DLL PrimExpr floor(PrimExpr x, Span span = Span()); TVM_DLL PrimExpr ceil(PrimExpr x, Span span = Span()); /*! - * \brief Calculate round(x) + * \brief Round x to the nearest integer, ties to even. + * + * Uses IEEE 754 default rounding mode (ties-to-even / banker's rounding). + * Constant-folding and all backends consistently use std::nearbyint semantics. + * * \param x The input expression. * \param span The location of this operation in the source. * \return The result expression. @@ -662,11 +666,13 @@ TVM_DLL PrimExpr ceil(PrimExpr x, Span span = Span()); TVM_DLL PrimExpr round(PrimExpr x, Span span = Span()); /*! - * \brief Calculates std::nearbyint(x) + * \brief Round x to the nearest integer, ties to even. + * + * Equivalent to round(). Both use IEEE 754 default rounding mode (ties-to-even). + * * \param x The input expression. * \param span The location of this operation in the source. * \return The result expression. - * This is a faster alternate to round. */ TVM_DLL PrimExpr nearbyint(PrimExpr x, Span span = Span()); diff --git a/src/target/llvm/intrin_rule_hexagon.cc b/src/target/llvm/intrin_rule_hexagon.cc index 79e91c20a3b8..e330dba4e1c7 100644 --- a/src/target/llvm/intrin_rule_hexagon.cc +++ b/src/target/llvm/intrin_rule_hexagon.cc @@ -93,7 +93,7 @@ TVM_REGISTER_OP("tirx.fabs") TVM_REGISTER_OP("tirx.round") .set_attr("hexagon.FLowerIntrinsic", - DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); + DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); TVM_REGISTER_OP("tirx.ctpop") .set_attr("hexagon.FLowerIntrinsic", diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 468f0fb7b59f..3244deab875b 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -90,7 +90,7 @@ TVM_REGISTER_OP("tirx.fabs") TVM_REGISTER_OP("tirx.round") .set_attr("llvm.FLowerIntrinsic", - DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); + DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); TVM_REGISTER_OP("tirx.nearbyint") .set_attr("llvm.FLowerIntrinsic", diff --git a/src/target/llvm/intrin_rule_nvptx.cc b/src/target/llvm/intrin_rule_nvptx.cc index 4560205a6094..08196cffeb74 100644 --- a/src/target/llvm/intrin_rule_nvptx.cc +++ b/src/target/llvm/intrin_rule_nvptx.cc @@ -66,7 +66,21 @@ TVM_REGISTER_OP("tirx.ceil") .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); TVM_REGISTER_OP("tirx.round") - .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); + .set_attr("nvptx.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { + // Use nearbyint (ties-to-even) instead of round (ties-away-from-zero) + // to match constant-folding semantics. + using namespace tirx; + const CallNode* call = e.as(); + TVM_FFI_ICHECK(call != nullptr); + TVM_FFI_ICHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) + << "Only support float32 or float64."; + std::string name = call->dtype.bits() == 32 ? "__nv_nearbyintf" : "__nv_nearbyint"; + ffi::Array new_args = {StringImm(name)}; + for (auto arg : call->args) { + new_args.push_back(arg); + } + return Call(call->dtype, builtin::call_pure_extern(), new_args); + }); TVM_REGISTER_OP("tirx.nearbyint") .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index 6d72c777834c..4d542c1299ec 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -132,7 +132,7 @@ TVM_REGISTER_OP("tirx.ceil") TVM_REGISTER_OP("tirx.round") .set_attr("rocm.FLowerIntrinsic", - DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); + DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); TVM_REGISTER_OP("tirx.nearbyint") .set_attr("rocm.FLowerIntrinsic", diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 5d9135ef223b..b2f78c2dbd37 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -526,7 +526,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintCallExtern(GetType(ffi::GetRef(op)), "atomic_add_float_emu", op->args, true, os); } else if (func->value == "nearbyint") { - this->PrintCallExtern(GetType(ffi::GetRef(op)), "round", op->args, true, os); + this->PrintCallExtern(GetType(ffi::GetRef(op)), "rint", op->args, true, os); } else { if (func->value == "atomic_add") { enable_atomics_ = true; diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index bcd158432bcd..d38db9fe8372 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -37,8 +37,11 @@ struct CUDAMath { if (t.is_float()) { switch (t.bits()) { case 64: + // Use nearbyint (ties-to-even) for round to match constant-folding semantics. + if (name == "round") return "nearbyint"; return name; case 32: + if (name == "round") return "nearbyintf"; return name + 'f'; case 16: { if (name == "fabs") { diff --git a/src/target/source/intrin_rule_metal.cc b/src/target/source/intrin_rule_metal.cc index d61bf1256f64..cea19519ca7f 100644 --- a/src/target/source/intrin_rule_metal.cc +++ b/src/target/source/intrin_rule_metal.cc @@ -68,7 +68,16 @@ TVM_REGISTER_OP("tirx.fabs") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tirx.round") - .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); + .set_attr("metal.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { + // Metal's rint() uses ties-to-even, matching constant-folding semantics. + const tirx::CallNode* call = e.as(); + TVM_FFI_ICHECK(call != nullptr); + ffi::Array new_args = {tirx::StringImm("rint")}; + for (auto arg : call->args) { + new_args.push_back(arg); + } + return tirx::Call(call->dtype, tirx::builtin::call_pure_extern(), new_args); + }); TVM_REGISTER_OP("tirx.nearbyint") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index 85084b1a1649..ba1873bde694 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -47,7 +47,16 @@ TVM_REGISTER_OP("tirx.fabs") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tirx.round") - .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); + .set_attr("opencl.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { + // OpenCL's rint() uses ties-to-even, matching constant-folding semantics. + const tirx::CallNode* call = e.as(); + TVM_FFI_ICHECK(call != nullptr); + ffi::Array new_args = {tirx::StringImm("rint")}; + for (auto arg : call->args) { + new_args.push_back(arg); + } + return tirx::Call(call->dtype, tirx::builtin::call_pure_extern(), new_args); + }); TVM_REGISTER_OP("tirx.nearbyint") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index cde1e0165f82..4b1ffc4b6d7f 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -68,10 +68,12 @@ TVM_REGISTER_OP("tirx.ceil") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tirx.round") - .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); + .set_attr("vulkan.FLowerIntrinsic", + DispatchGLSLPureIntrin); TVM_REGISTER_OP("tirx.nearbyint") - .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); + .set_attr("vulkan.FLowerIntrinsic", + DispatchGLSLPureIntrin); TVM_REGISTER_OP("tirx.trunc") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); diff --git a/tests/python/tirx-base/test_tir_intrin.py b/tests/python/tirx-base/test_tir_intrin.py index 0dd06dee934a..30676715b899 100644 --- a/tests/python/tirx-base/test_tir_intrin.py +++ b/tests/python/tirx-base/test_tir_intrin.py @@ -56,6 +56,27 @@ def test_nearbyint(): tvm.testing.assert_allclose(a_rounded.numpy(), np.rint(a.numpy())) +def test_round_ties_to_even(): + """Test that tir.round uses ties-to-even (banker's rounding) semantics.""" + m = te.var("m") + A = te.placeholder((m,), name="A") + A_rounded = te.compute((m,), lambda *i: tvm.tirx.round(A(*i)), name="A") + + mod = te.create_prim_func([A, A_rounded]) + sch = tvm.s_tir.Schedule(mod) + func = tvm.compile(sch.mod, target="llvm") + + dev = tvm.cpu(0) + # Midpoint values where ties-to-even and ties-away differ + test_values = np.array([0.5, 1.5, 2.5, 3.5, -0.5, -1.5, -2.5, -3.5], dtype="float32") + expected = np.array([0.0, 2.0, 2.0, 4.0, 0.0, -2.0, -2.0, -4.0], dtype="float32") + + a = tvm.runtime.tensor(test_values, dev) + a_rounded = tvm.runtime.tensor(np.zeros(len(test_values), dtype="float32"), dev) + func(a, a_rounded) + tvm.testing.assert_allclose(a_rounded.numpy(), expected) + + def test_round_intrinsics_on_int(): i = tvm.tirx.Var("i", "int32") for op in [tvm.tirx.round, tvm.tirx.trunc, tvm.tirx.ceil, tvm.tirx.floor, tvm.tirx.nearbyint]: From 3fa6d6abff6d015323f2d4666f90c58298314f03 Mon Sep 17 00:00:00 2001 From: Soowon Jeong Date: Wed, 8 Apr 2026 15:16:32 +0900 Subject: [PATCH 2/3] refactor(nvptx): simplify round lowering by reusing nearbyint dispatch Redirect tirx.round to tirx.nearbyint op and let DispatchPureExternLibDevice generate __nv_nearbyint[f], instead of duplicating the name construction and arg copying in a lambda. --- src/target/llvm/intrin_rule_nvptx.cc | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/target/llvm/intrin_rule_nvptx.cc b/src/target/llvm/intrin_rule_nvptx.cc index 08196cffeb74..0707a9a78771 100644 --- a/src/target/llvm/intrin_rule_nvptx.cc +++ b/src/target/llvm/intrin_rule_nvptx.cc @@ -67,19 +67,13 @@ TVM_REGISTER_OP("tirx.ceil") TVM_REGISTER_OP("tirx.round") .set_attr("nvptx.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { - // Use nearbyint (ties-to-even) instead of round (ties-away-from-zero) - // to match constant-folding semantics. + // Redirect to nearbyint (ties-to-even) to match constant-folding semantics. using namespace tirx; const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); - TVM_FFI_ICHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) - << "Only support float32 or float64."; - std::string name = call->dtype.bits() == 32 ? "__nv_nearbyintf" : "__nv_nearbyint"; - ffi::Array new_args = {StringImm(name)}; - for (auto arg : call->args) { - new_args.push_back(arg); - } - return Call(call->dtype, builtin::call_pure_extern(), new_args); + auto nearbyint_op = Op::Get("tirx.nearbyint"); + auto new_call = Call(call->dtype, nearbyint_op, call->args); + return DispatchPureExternLibDevice(new_call); }); TVM_REGISTER_OP("tirx.nearbyint") From e1e04f963694bacc729bd8b478d3abc53d615aad Mon Sep 17 00:00:00 2001 From: Soowon Jeong Date: Wed, 8 Apr 2026 16:50:02 +0900 Subject: [PATCH 3/3] fix(roi_pool): use ties-away rounding for ROI coordinate mapping tir.round semantics changed to ties-to-even (nearbyint) in this branch, but ONNX runtime uses std::round (ties-away-from-zero) for MaxRoiPool coordinate mapping. Use floor(x + 0.5) explicitly to match ONNX runtime behavior and be independent of tir.round semantics. Also fix roi_pool_python.py reference which had the same issue: Python's built-in round() is ties-to-even and was already inconsistent with ONNX runtime. --- python/tvm/topi/testing/roi_pool_python.py | 10 ++++++---- python/tvm/topi/vision/roi_pool.py | 15 +++++++++++---- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/python/tvm/topi/testing/roi_pool_python.py b/python/tvm/topi/testing/roi_pool_python.py index 0f7120b46650..583800e9828b 100644 --- a/python/tvm/topi/testing/roi_pool_python.py +++ b/python/tvm/topi/testing/roi_pool_python.py @@ -36,10 +36,12 @@ def roi_pool_nchw_python(a_np, rois_np, pooled_size, spatial_scale): for i in range(num_roi): roi = rois_np[i] batch_index = int(roi[0]) - roi_start_w = round(roi[1] * spatial_scale) - roi_start_h = round(roi[2] * spatial_scale) - roi_end_w = round(roi[3] * spatial_scale) - roi_end_h = round(roi[4] * spatial_scale) + # Use ties-away-from-zero rounding to match ONNX runtime (std::round semantics). + # Python's built-in round() uses ties-to-even, so use floor(x + 0.5) explicitly. + roi_start_w = math.floor(roi[1] * spatial_scale + 0.5) + roi_start_h = math.floor(roi[2] * spatial_scale + 0.5) + roi_end_w = math.floor(roi[3] * spatial_scale + 0.5) + roi_end_h = math.floor(roi[4] * spatial_scale + 0.5) roi_h = max(roi_end_h - roi_start_h + 1, 1) roi_w = max(roi_end_w - roi_start_w + 1, 1) diff --git a/python/tvm/topi/vision/roi_pool.py b/python/tvm/topi/vision/roi_pool.py index 54a4aeba50be..2e86066c5b23 100644 --- a/python/tvm/topi/vision/roi_pool.py +++ b/python/tvm/topi/vision/roi_pool.py @@ -36,12 +36,19 @@ def roi_pool_nchw(data, rois, pooled_size, spatial_scale): neg_inf = tvm.tirx.const(float("-inf"), data.dtype) + def _round_away(x): + # ONNX MaxRoiPool spec uses ties-away-from-zero rounding for coordinate + # mapping (matching std::round semantics in the reference implementation). + # Use floor(x + 0.5) to be explicit and independent of tir.round semantics. + half = tvm.tirx.const(0.5, roi_dtype) + return te.floor(x + half) + def _bin_bounds(i, ph, pw): roi = rois[i] - roi_start_w = te.round(roi[1] * spatial_scale).astype("int32") - roi_start_h = te.round(roi[2] * spatial_scale).astype("int32") - roi_end_w = te.round(roi[3] * spatial_scale).astype("int32") - roi_end_h = te.round(roi[4] * spatial_scale).astype("int32") + roi_start_w = _round_away(roi[1] * spatial_scale).astype("int32") + roi_start_h = _round_away(roi[2] * spatial_scale).astype("int32") + roi_end_w = _round_away(roi[3] * spatial_scale).astype("int32") + roi_end_h = _round_away(roi[4] * spatial_scale).astype("int32") roi_h = te.max(roi_end_h - roi_start_h + 1, tvm.tirx.const(1, "int32")) roi_w = te.max(roi_end_w - roi_start_w + 1, tvm.tirx.const(1, "int32"))