Skip to content
2 changes: 2 additions & 0 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,8 @@ def _impl_v14(cls, bb, inputs, attr, params):

if isinstance(axis, relax.Constant):
axis = int(axis.data.numpy())
elif isinstance(axis, relax.Var):
axis = 0
data = relax.op.cumsum(data, axis)
if attr.get("reverse", 0) != 0:
data = bb.emit_te(topi.flip, data, axis=axis if axis else 0)
Expand Down
23 changes: 23 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from tvm.contrib import graph_executor, utils
from tvm.relay.frontend.common import infer_type
from tvm.relay.build_module import bind_params_by_name
from tvm.relax.frontend.onnx import from_onnx
from relay.utils.tag_span import _create_span, _set_span, _verify_structural_equal_with_span

import onnx
Expand Down Expand Up @@ -5386,6 +5387,28 @@ def verify_softplus(indata):
verify_softplus(input_data)


def test_load_cumsum():
"""test_load_cumsum"""

def create_cumsum_model():
input_shape = [2, 3]

graph = helper.make_graph(
[
helper.make_node("CumSum", inputs=["X", "axis"], outputs=["Y"]),
],
"cumsum_graph",
inputs=[
helper.make_tensor_value_info("X", onnx.TensorProto.DOUBLE, input_shape),
helper.make_tensor_value_info("axis", onnx.TensorProto.INT32, [1], "axis"),
],
outputs=[helper.make_tensor_value_info("Y", onnx.TensorProto.DOUBLE, input_shape)],
)
return helper.make_model(graph)

from_onnx(create_cumsum_model())


@tvm.testing.parametrize_targets
def test_cumsum(target, dev):
"""test_cumsum"""
Expand Down