diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 784be639dd72..61ab45d308c8 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -2435,8 +2435,18 @@ def _impl_v15(cls, bb, inputs, attr, params): mean = inputs[3] var = inputs[4] epsilon = attr.get("epsilon", 1e-05) + momentum = attr.get("momentum", 0.9) + training_mode = attr.get("training_mode", 0) return relax.op.nn.batch_norm( - data, gamma=scale, beta=bias, moving_mean=mean, moving_var=var, epsilon=epsilon, axis=1 + data, + gamma=scale, + beta=bias, + moving_mean=mean, + moving_var=var, + axis=1, + epsilon=epsilon, + momentum=momentum, + training=training_mode, )