diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 3f25d2ff3bb5..b42a3a4d9c86 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -526,6 +526,20 @@ class Div(BinaryBase): @classmethod def _impl_v7(cls, bb, inputs, attr, params): + try: + lhs_code = DataType(inputs[0].struct_info.dtype).type_code + rhs_code = DataType(inputs[1].struct_info.dtype).type_code + except (AttributeError, ValueError, TypeError, TVMError): + return cls.base_impl(bb, inputs, attr, params) + + lhs_is_integer = lhs_code == DataTypeCode.INT or lhs_code == DataTypeCode.UINT + rhs_is_integer = rhs_code == DataTypeCode.INT or rhs_code == DataTypeCode.UINT + if not (lhs_is_integer and rhs_is_integer): + return cls.base_impl(bb, inputs, attr, params) + + if isinstance(inputs[1], relax.Constant) and bool(_np.any(inputs[1].data.numpy() == 0)): + raise ValueError("ONNX Div with integer inputs encountered divisor value 0.") + return cls.base_impl(bb, inputs, attr, params) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 0d1d9f2d7c4b..a0c95fb703c8 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -591,6 +591,25 @@ def test_binary(op_name: str): verify_binary_scalar(op_name) +def test_div_integer_constant_zero_divisor_raises_valueerror(): + b_init = numpy_helper.from_array(np.array([3, 0, -2, 1], dtype=np.int32), name="b") + node = helper.make_node("Div", ["a", "b"], ["y"]) + graph = helper.make_graph( + [node], + "div_const_zero", + [helper.make_tensor_value_info("a", TensorProto.INT32, [4])], + [helper.make_tensor_value_info("y", TensorProto.INT32, [4])], + initializer=[b_init], + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)]) + model.ir_version = 9 + + with pytest.raises( + ValueError, match="ONNX Div with integer inputs encountered divisor value 0" + ): + from_onnx(model, opset=18, keep_params_in_input=False) + + @pytest.mark.parametrize("int_mode", [True, False]) def test_mod(int_mode: bool): if int_mode: diff --git a/tests/python/relax/test_frontend_onnx_backend.py b/tests/python/relax/test_frontend_onnx_backend.py index 3eb63f153598..301b95f640c4 100644 --- a/tests/python/relax/test_frontend_onnx_backend.py +++ b/tests/python/relax/test_frontend_onnx_backend.py @@ -77,12 +77,10 @@ def run(self, inputs, **kwargs): self._vm.invoke_stateful("main") output = self._vm.get_outputs("main") - if isinstance(output, (tvm.runtime.Tensor, np.ndarray)): + if isinstance(output, tvm.runtime.Tensor | np.ndarray): return (output.numpy() if hasattr(output, "numpy") else output,) - if isinstance(output, (tuple, list)): - return tuple( - o.numpy() if hasattr(o, "numpy") else np.array(o) for o in output - ) + if isinstance(output, tuple | list): + return tuple(o.numpy() if hasattr(o, "numpy") else np.array(o) for o in output) return (np.array(output),) @@ -110,9 +108,7 @@ def prepare(cls, model, device="CPU", **kwargs): func_param_names = [p.name_hint for p in func.params] graph_input_names = [inp.name for inp in model.graph.input] - return TVMRelaxBackendRep( - tvm_model, params, func_param_names, graph_input_names - ) + return TVMRelaxBackendRep(tvm_model, params, func_param_names, graph_input_names) @classmethod def supports_device(cls, device: str) -> bool: @@ -133,32 +129,77 @@ def supports_device(cls, device: str) -> bool: # validated against the ONNX Backend Test Suite. They can be added # incrementally as the importer improves. _INCLUDE_OPS = [ - "abs", "acos", "acosh", "add", "and", "argmax", "argmin", - "averagepool", "bitshift", - "bitwise_and", "bitwise_not", "bitwise_or", "bitwise_xor", - "ceil", "clip", "compress", "concat", - "conv", "cos", "cosh", - "depthtospace", "div", - "einsum", "erf", "exp", - "flatten", "floor", - "gathernd", "gemm", - "globalaveragepool", "globalmaxpool", "greater", "greater_equal", - "hardmax", "hardswish", + "abs", + "acos", + "acosh", + "add", + "and", + "argmax", + "argmin", + "averagepool", + "bitshift", + "bitwise_and", + "bitwise_not", + "bitwise_or", + "bitwise_xor", + "ceil", + "clip", + "compress", + "concat", + "conv", + "cos", + "cosh", + "depthtospace", + "div", + "einsum", + "erf", + "exp", + "flatten", + "floor", + "gathernd", + "gemm", + "globalaveragepool", + "globalmaxpool", + "greater", + "greater_equal", + "hardmax", + "hardswish", "isnan", - "less", "less_equal", "lrn", - "matmul", "matmulinteger", "mean", "min", "mod", "mul", "neg", - "nonzero", "not", + "less", + "less_equal", + "lrn", + "matmul", + "matmulinteger", + "mean", + "min", + "mod", + "mul", + "neg", + "nonzero", + "not", "or", "reciprocal", "round", "scatternd", - "sigmoid", "sign", - "sin", "sinh", "size", "slice", + "sigmoid", + "sign", + "sin", + "sinh", + "size", + "slice", "spacetodepth", - "sqrt", "squeeze", "sub", "sum", - "tan", "tanh", "tile", "transpose", - "unique", "unsqueeze", - "where", "xor", + "sqrt", + "squeeze", + "sub", + "sum", + "tan", + "tanh", + "tile", + "transpose", + "unique", + "unsqueeze", + "where", + "xor", ] for _op in _INCLUDE_OPS: