Skip to content
Merged
14 changes: 14 additions & 0 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,20 @@ class Div(BinaryBase):

@classmethod
def _impl_v7(cls, bb, inputs, attr, params):
try:
lhs_code = DataType(inputs[0].struct_info.dtype).type_code
rhs_code = DataType(inputs[1].struct_info.dtype).type_code
except (AttributeError, ValueError, TypeError, TVMError):
return cls.base_impl(bb, inputs, attr, params)

lhs_is_integer = lhs_code == DataTypeCode.INT or lhs_code == DataTypeCode.UINT
rhs_is_integer = rhs_code == DataTypeCode.INT or rhs_code == DataTypeCode.UINT
if not (lhs_is_integer and rhs_is_integer):
return cls.base_impl(bb, inputs, attr, params)

if isinstance(inputs[1], relax.Constant) and bool(_np.any(inputs[1].data.numpy() == 0)):
raise ValueError("ONNX Div with integer inputs encountered divisor value 0.")

return cls.base_impl(bb, inputs, attr, params)


Expand Down
19 changes: 19 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,25 @@ def test_binary(op_name: str):
verify_binary_scalar(op_name)


def test_div_integer_constant_zero_divisor_raises_valueerror():
b_init = numpy_helper.from_array(np.array([3, 0, -2, 1], dtype=np.int32), name="b")
node = helper.make_node("Div", ["a", "b"], ["y"])
graph = helper.make_graph(
[node],
"div_const_zero",
[helper.make_tensor_value_info("a", TensorProto.INT32, [4])],
[helper.make_tensor_value_info("y", TensorProto.INT32, [4])],
initializer=[b_init],
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)])
model.ir_version = 9

with pytest.raises(
ValueError, match="ONNX Div with integer inputs encountered divisor value 0"
):
from_onnx(model, opset=18, keep_params_in_input=False)


@pytest.mark.parametrize("int_mode", [True, False])
def test_mod(int_mode: bool):
if int_mode:
Expand Down
97 changes: 69 additions & 28 deletions tests/python/relax/test_frontend_onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,10 @@ def run(self, inputs, **kwargs):
self._vm.invoke_stateful("main")
output = self._vm.get_outputs("main")

if isinstance(output, (tvm.runtime.Tensor, np.ndarray)):
if isinstance(output, tvm.runtime.Tensor | np.ndarray):
return (output.numpy() if hasattr(output, "numpy") else output,)
if isinstance(output, (tuple, list)):
return tuple(
o.numpy() if hasattr(o, "numpy") else np.array(o) for o in output
)
if isinstance(output, tuple | list):
return tuple(o.numpy() if hasattr(o, "numpy") else np.array(o) for o in output)
return (np.array(output),)


Expand Down Expand Up @@ -110,9 +108,7 @@ def prepare(cls, model, device="CPU", **kwargs):
func_param_names = [p.name_hint for p in func.params]
graph_input_names = [inp.name for inp in model.graph.input]

return TVMRelaxBackendRep(
tvm_model, params, func_param_names, graph_input_names
)
return TVMRelaxBackendRep(tvm_model, params, func_param_names, graph_input_names)

@classmethod
def supports_device(cls, device: str) -> bool:
Expand All @@ -133,32 +129,77 @@ def supports_device(cls, device: str) -> bool:
# validated against the ONNX Backend Test Suite. They can be added
# incrementally as the importer improves.
_INCLUDE_OPS = [
"abs", "acos", "acosh", "add", "and", "argmax", "argmin",
"averagepool", "bitshift",
"bitwise_and", "bitwise_not", "bitwise_or", "bitwise_xor",
"ceil", "clip", "compress", "concat",
"conv", "cos", "cosh",
"depthtospace", "div",
"einsum", "erf", "exp",
"flatten", "floor",
"gathernd", "gemm",
"globalaveragepool", "globalmaxpool", "greater", "greater_equal",
"hardmax", "hardswish",
"abs",
"acos",
"acosh",
"add",
"and",
"argmax",
"argmin",
"averagepool",
"bitshift",
"bitwise_and",
"bitwise_not",
"bitwise_or",
"bitwise_xor",
"ceil",
"clip",
"compress",
"concat",
"conv",
"cos",
"cosh",
"depthtospace",
"div",
"einsum",
"erf",
"exp",
"flatten",
"floor",
"gathernd",
"gemm",
"globalaveragepool",
"globalmaxpool",
"greater",
"greater_equal",
"hardmax",
"hardswish",
"isnan",
"less", "less_equal", "lrn",
"matmul", "matmulinteger", "mean", "min", "mod", "mul", "neg",
"nonzero", "not",
"less",
"less_equal",
"lrn",
"matmul",
"matmulinteger",
"mean",
"min",
"mod",
"mul",
"neg",
"nonzero",
"not",
"or",
"reciprocal",
"round",
"scatternd",
"sigmoid", "sign",
"sin", "sinh", "size", "slice",
"sigmoid",
"sign",
"sin",
"sinh",
"size",
"slice",
"spacetodepth",
"sqrt", "squeeze", "sub", "sum",
"tan", "tanh", "tile", "transpose",
"unique", "unsqueeze",
"where", "xor",
"sqrt",
"squeeze",
"sub",
"sum",
"tan",
"tanh",
"tile",
"transpose",
"unique",
"unsqueeze",
"where",
"xor",
]

for _op in _INCLUDE_OPS:
Expand Down
Loading