[BugFix][ONNX] Fix Round op to use ties-to-even#19367
Conversation
The ONNX Round operator specification requires ties-to-even rounding: "For cases where number is exactly halfway between two integers, it rounds to the nearest even integer." Previously, `topi.round()` lowered to `te.round` -> `tir.round` -> `llvm::round`, which uses ties-away-from-zero. This caused TVM to return wrong results for midpoint values (e.g., round(0.5) = 1 instead of 0, round(2.5) = 3 instead of 2). Fix by switching `topi.round()` to `te.nearbyint`, which lowers to `tir.nearbyint` -> `llvm::nearbyint`. The `nearbyint` intrinsic respects the IEEE 754 default rounding mode (ties-to-even), matching the ONNX spec and onnxruntime behavior. Also register `tir.nearbyint` for the WebGPU backend, mapping to WGSL `round()` which is already ties-to-even per the WGSL spec. Add a targeted test with midpoint inputs to prevent regression. Fixes apache#18590
There was a problem hiding this comment.
Code Review
This pull request updates the round operation in TVM TOPI to use ties-to-even (banker's rounding) by switching the underlying implementation to te.nearbyint, ensuring alignment with ONNX and IEEE 754 standards. It also adds the necessary intrinsic registration for WebGPU and a new test case to verify rounding behavior. Feedback indicates that the existing tirx.round registration on WebGPU is now inconsistent because it also maps to the ties-to-even round function, whereas it typically implies ties-away-from-zero; a follow-up is suggested to address this discrepancy.
|
As noted in the comment above, the same semantics is intentional. The actual pre-existing inconsistency runs in the opposite direction: LLVM/ROCm/Hexagon backends lower Worth a follow-up issue to align all backends — |
|
@swjng Thanks for the contribution! Could you send a follow-up PR to fix the problem you said? |
|
@tlopex Yes, will make a follow-up |
tir.round constant-folds using std::nearbyint (ties-to-even), but backends lowered it to platform round() (ties-away-from-zero). This inconsistency meant compiled code could produce different results from constant-folded code for midpoint values like 0.5, 2.5, etc. Fix all backends to use ties-to-even intrinsics: - LLVM/ROCm/Hexagon: llvm::Intrinsic::round -> nearbyint - NVPTX: __nv_round -> __nv_nearbyint - CUDA: round/roundf -> nearbyint/nearbyintf (f16/bf16 already used hrint) - Metal/OpenCL: round -> rint - Vulkan/SPIR-V: GLSLstd450Round -> GLSLstd450RoundEven - OpenCL codegen: fix nearbyint mapping from round() to rint() Also update op.h documentation to explicitly state ties-to-even semantics and add test_round_ties_to_even regression test. Follow-up to apache#19367. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
tir.round constant-folds using std::nearbyint (ties-to-even), but backends lowered it to platform round() (ties-away-from-zero). This inconsistency meant compiled code could produce different results from constant-folded code for midpoint values like 0.5, 2.5, etc. Fix all backends to use ties-to-even intrinsics: - LLVM/ROCm/Hexagon: llvm::Intrinsic::round -> nearbyint - NVPTX: __nv_round -> __nv_nearbyint - CUDA: round/roundf -> nearbyint/nearbyintf (f16/bf16 already used hrint) - Metal/OpenCL: round -> rint - Vulkan/SPIR-V: GLSLstd450Round -> GLSLstd450RoundEven - OpenCL codegen: fix nearbyint mapping from round() to rint() Also update op.h documentation to explicitly state ties-to-even semantics and add test_round_ties_to_even regression test. Follow-up to apache#19367. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
## Problem `tir.round` constant-folds using `std::nearbyint` (IEEE 754 ties-to-even), but all backends lower it to platform `round()` which uses ties-away-from-zero. This means compiled code can produce different results from constant-folded code for midpoint values: | Input | Constant-fold (ties-to-even) | Compiled (ties-away) | |-------|-----|------| | 0.5 | 0.0 | 1.0 | | 2.5 | 2.0 | 3.0 | | -0.5 | 0.0 | -1.0 | This was identified as a follow-up to #19367 — see [this comment](#19367 (comment)). ## Fix Align all backends to use ties-to-even intrinsics, matching the constant-folding behavior: | Backend | Before | After | |---------|--------|-------| | LLVM/ROCm/Hexagon | `llvm::Intrinsic::round` | `llvm::Intrinsic::nearbyint` | | NVPTX | `__nv_round[f]` | `__nv_nearbyint[f]` | | CUDA | `round`/`roundf` | `nearbyint`/`nearbyintf` (f16/bf16 already used `hrint`) | | Metal/OpenCL | `round` | `rint` | | Vulkan/SPIR-V | `GLSLstd450Round` | `GLSLstd450RoundEven` | Also fixes OpenCL codegen where `tir.nearbyint` was incorrectly mapped to OpenCL `round()` instead of `rint()`. Updates `op.h` documentation to explicitly state ties-to-even semantics for both `round()` and `nearbyint()`. ## Testing ``` python -m pytest tests/python/tirx-base/test_tir_intrin.py -xvs ``` New `test_round_ties_to_even` verifies midpoint inputs `[0.5, 1.5, 2.5, 3.5, -0.5, -1.5, -2.5, -3.5]` produce ties-to-even results on the LLVM backend. All 12 tests pass (10 passed, 2 skipped for CUDA). --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Problem
The ONNX
Roundoperator specification requires ties-to-even (banker's) rounding:However, the current TVM implementation produces ties-away-from-zero results on midpoint values:
This was reported in issue #18590.
Root Cause
The lowering chain for
relax.op.round:llvm::roundis defined as ties-away-from-zero (C99round()), whilellvm::nearbyintuses the IEEE 754 default rounding mode (ties-to-even).Fix
python/tvm/topi/math.py: Switchtopi.round()fromte.roundtote.nearbyint. This lowers totir.nearbyint->llvm::nearbyint, which respects IEEE 754 ties-to-even.src/target/source/intrin_rule_webgpu.cc: Registertir.nearbyintfor the WebGPU backend. WGSLround()is already ties-to-even per the WGSL spec, sotir.nearbyint->roundis the correct mapping.tests/python/relax/test_frontend_onnx.py: Addtest_round_ties_to_even()with explicit midpoint inputs to prevent regression.Testing
Both pass. The new test compares TVM output against onnxruntime (which correctly implements ties-to-even) for inputs
[0.5, 1.5, 2.5, -0.5, -1.5, -2.5].Fixes #18590