Expected behavior
When importing an ONNX model with BatchNormalization operator that has training_mode=0 (inference mode), TVM Relax should generate R.nn.batch_norm(..., training=False) and use the provided running_mean and running_var parameters for normalization.
According to the ONNX BatchNormalization specification:
training_mode=0 (default): Use running statistics (inference mode)
training_mode=1: Compute batch statistics from input (training mode)
Actual behavior
TVM Relax ONNX frontend ignores the training_mode attribute and always generates R.nn.batch_norm(..., training=True), causing TVM to compute batch statistics from the input tensor instead of using the provided running_mean and running_var.
Generated IR (incorrect):
lv: R.Tuple(...) = R.nn.batch_norm(X, scale, bias, mean, var, axis=1, epsilon=1e-05, training=True)
^^^^^^^^^^^^
Should be training=False!
Environment
- TVM version: 0.23.dev0 (commit hash if available)
- OS: Ubuntu Linux
- Target: llvm (CPU)
- Python: 3.11
- ONNX opset: 15
Steps to reproduce
Minimal reproduction script
Save as reproduce_bn_bug.py and run with python reproduce_bn_bug.py:
#!/usr/bin/env python3
"""
TVM BatchNormalization training_mode Bug - Minimal Reproduction
"""
import numpy as np
import onnx
from onnx import helper, TensorProto, numpy_helper
import onnxruntime as ort
import tvm
from tvm.relax.frontend.onnx import from_onnx
from tvm import relax
def create_minimal_bn_model():
"""Create a minimal ONNX model with only BatchNormalization (training_mode=0)."""
batch, channels, height, width = 2, 3, 4, 4
epsilon = 1e-5
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [batch, channels, height, width])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [batch, channels, height, width])
scale = numpy_helper.from_array(np.array([1.0, 2.0, 0.5], dtype=np.float32), name='scale')
bias = numpy_helper.from_array(np.array([0.0, 1.0, -1.0], dtype=np.float32), name='bias')
mean = numpy_helper.from_array(np.array([0.5, 1.0, 2.0], dtype=np.float32), name='mean')
var = numpy_helper.from_array(np.array([0.25, 1.0, 4.0], dtype=np.float32), name='var')
bn_node = helper.make_node(
'BatchNormalization',
inputs=['X', 'scale', 'bias', 'mean', 'var'],
outputs=['Y'],
epsilon=epsilon,
momentum=0.9,
training_mode=0 # KEY: inference mode!
)
graph = helper.make_graph([bn_node], 'bn_test', [X], [Y], [scale, bias, mean, var])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid('', 15)])
model.ir_version = 8
onnx.checker.check_model(model)
return model
def main():
print("Creating ONNX model with BatchNormalization (training_mode=0)...")
model = create_minimal_bn_model()
# Verify ONNX attribute
for node in model.graph.node:
if node.op_type == 'BatchNormalization':
training_mode = next((a.i for a in node.attribute if a.name == 'training_mode'), 0)
print(f"ONNX training_mode = {training_mode}")
# Test input
np.random.seed(42)
input_data = np.random.randn(2, 3, 4, 4).astype(np.float32)
# ONNX Runtime (reference)
model_bytes = model.SerializeToString()
sess = ort.InferenceSession(model_bytes, providers=['CPUExecutionProvider'])
ort_output = sess.run(None, {'X': input_data})[0]
print(f"ORT output sample: {ort_output[0, 0, 0, :3]}")
# TVM Relax
shape_dict = {'X': list(input_data.shape)}
mod = from_onnx(model, shape_dict=shape_dict)
# Check IR for the bug
ir_text = mod.script()
if 'training=True' in ir_text:
print("\n[BUG] TVM IR contains training=True (should be False)!")
for line in ir_text.split('\n'):
if 'batch_norm' in line:
print(f" {line.strip()}")
# Compile and run
target = tvm.target.Target("llvm")
ex = tvm.compile(mod, target)
device = tvm.cpu()
vm = relax.VirtualMachine(ex, device)
tvm_input = tvm.runtime.tensor(input_data, device=device)
tvm_output = vm['main'](tvm_input).numpy()
print(f"TVM output sample: {tvm_output[0, 0, 0, :3]}")
# Compare
max_diff = np.max(np.abs(ort_output - tvm_output))
print(f"\nMax difference (ORT vs TVM): {max_diff:.6f}")
if max_diff > 0.001:
print("\n[BUG CONFIRMED] TVM produces incorrect results!")
return 1
return 0
if __name__ == '__main__':
exit(main())
Expected output
Creating ONNX model with BatchNormalization (training_mode=0)...
ONNX training_mode = 0
ORT output sample: [-0.00657159 -1.2765031 0.29537117]
[BUG] TVM IR contains training=True (should be False)!
lv: R.Tuple(...) = R.nn.batch_norm(X, ..., training=True)
TVM output sample: [ 0.66183543 -0.05753053 0.8328741 ]
Max difference (ORT vs TVM): 2.758021
[BUG CONFIRMED] TVM produces incorrect results!
Triage
cc @KJlaccHoeUM9l @junrushao
Expected behavior
When importing an ONNX model with
BatchNormalizationoperator that hastraining_mode=0(inference mode), TVM Relax should generateR.nn.batch_norm(..., training=False)and use the providedrunning_meanandrunning_varparameters for normalization.According to the ONNX BatchNormalization specification:
training_mode=0(default): Use running statistics (inference mode)training_mode=1: Compute batch statistics from input (training mode)Actual behavior
TVM Relax ONNX frontend ignores the
training_modeattribute and always generatesR.nn.batch_norm(..., training=True), causing TVM to compute batch statistics from the input tensor instead of using the providedrunning_meanandrunning_var.Generated IR (incorrect):
Environment
Steps to reproduce
Minimal reproduction script
Save as
reproduce_bn_bug.pyand run withpython reproduce_bn_bug.py:Expected output
Triage
cc @KJlaccHoeUM9l @junrushao