diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index fbd478ee5ab4..aff3853b1350 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -729,6 +729,8 @@ def _impl_v14(cls, bb, inputs, attr, params): if isinstance(axis, relax.Constant): axis = int(axis.data.numpy()) + elif isinstance(axis, relax.Var): + axis = 0 data = relax.op.cumsum(data, axis) if attr.get("reverse", 0) != 0: data = bb.emit_te(topi.flip, data, axis=axis if axis else 0) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 9c9362aaf12b..b2132f3b81e9 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -37,6 +37,7 @@ from tvm.contrib import graph_executor, utils from tvm.relay.frontend.common import infer_type from tvm.relay.build_module import bind_params_by_name +from tvm.relax.frontend.onnx import from_onnx from relay.utils.tag_span import _create_span, _set_span, _verify_structural_equal_with_span import onnx @@ -5386,6 +5387,28 @@ def verify_softplus(indata): verify_softplus(input_data) +def test_load_cumsum(): + """test_load_cumsum""" + + def create_cumsum_model(): + input_shape = [2, 3] + + graph = helper.make_graph( + [ + helper.make_node("CumSum", inputs=["X", "axis"], outputs=["Y"]), + ], + "cumsum_graph", + inputs=[ + helper.make_tensor_value_info("X", onnx.TensorProto.DOUBLE, input_shape), + helper.make_tensor_value_info("axis", onnx.TensorProto.INT32, [1], "axis"), + ], + outputs=[helper.make_tensor_value_info("Y", onnx.TensorProto.DOUBLE, input_shape)], + ) + return helper.make_model(graph) + + from_onnx(create_cumsum_model()) + + @tvm.testing.parametrize_targets def test_cumsum(target, dev): """test_cumsum"""