From 221d7091685794a88fa9b29c5fe70aee77355633 Mon Sep 17 00:00:00 2001 From: LudovicoYIN Date: Fri, 6 Mar 2026 02:10:03 +0000 Subject: [PATCH 1/2] [Relax][ONNX] Support dynamic repeats for Tile --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 41 +++++++++++++++++-- python/tvm/topi/transform.py | 22 ++++++++++ src/topi/transform.cc | 5 +++ tests/python/relax/test_frontend_onnx.py | 24 +++++++++++ 4 files changed, 89 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index b3c2d06eab07..9a2ef07d492f 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1974,14 +1974,49 @@ def _impl_v11(cls, bb, inputs, attr, params): class Tile(OnnxOpConverter): """Converts an onnx Tile node into an equivalent Relax expression.""" + @staticmethod + def _tensor_length(expr): + shape = expr.struct_info.shape + if not isinstance(shape, relax.ShapeExpr): + return None + + length = shape.values[0] + if not isinstance(length, tir.IntImm): + return None + return length.value + @classmethod def _impl_v13(cls, bb, inputs, attr, params): reps = get_constant(inputs[1], params) if isinstance(reps, relax.Constant): reps = reps.data.numpy().tolist() - else: - raise ValueError("Dynamic reps for Tile are supported yet.") - return bb.emit_te(topi.tile, inputs[0], reps) + return bb.emit_te(topi.tile, inputs[0], reps) + + data = inputs[0] + data_ndim = data.struct_info.ndim + reps_len = cls._tensor_length(reps) + if data_ndim == -1 or reps_len is None: + raise ValueError("Dynamic Tile requires known input rank and repeats length.") + + if reps.struct_info.dtype != "int64": + reps = bb.normalize(relax.op.astype(reps, "int64")) + + data_shape = list(data.struct_info.shape.values) + data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(relax.ShapeExpr(data_shape))) + output_shape_tensor = reps + + if data_ndim > reps_len: + reps_prefix = relax.const(_np.ones((data_ndim - reps_len,), dtype="int64"), "int64") + output_shape_tensor = bb.normalize(relax.op.concat([reps_prefix, output_shape_tensor], axis=0)) + elif reps_len > data_ndim: + data_prefix = relax.const(_np.ones((reps_len - data_ndim,), dtype="int64"), "int64") + data_shape_tensor = bb.normalize(relax.op.concat([data_prefix, data_shape_tensor], axis=0)) + + output_shape_tensor = bb.normalize(relax.op.multiply(output_shape_tensor, data_shape_tensor)) + output_shape = bb.normalize(relax.op.tensor_to_shape(output_shape_tensor)) + output_shape_vars = [tir.Var(f"tile_dim_{i}", "int64") for i in range(max(data_ndim, reps_len))] + bb.match_cast(output_shape, relax.ShapeStructInfo(output_shape_vars)) + return bb.emit_te(topi.dyn_tile, data, output_shape_vars, reps_len) class Expand(OnnxOpConverter): diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index bc187e3f269b..a3e736644607 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -657,6 +657,28 @@ def tile(a, reps): return cpp.tile(a, reps) +def dyn_tile(a, new_shape, rdim): + """Repeats the whole array multiple times with dynamic output shape. + + Parameters + ---------- + a : tvm.te.Tensor + The tensor to be tiled. + + new_shape : tuple of PrimExpr + The output shape after tiling. + + rdim : int + The rank of the repeats input. + + Returns + ------- + ret : tvm.te.Tensor + """ + + return cpp.dyn_tile(a, new_shape, rdim) + + def layout_transform(array, src_layout, dst_layout, schedule_rule="None"): """Transform the layout according to src_layout and dst_layout diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 5e2ffd4cbd9a..09f9a9be5ea7 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -156,6 +156,11 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](ffi::PackedArgs args, ffi::Any* rv) { *rv = tile(args[0].cast(), args[1].cast>()); }) + .def_packed("topi.dyn_tile", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = dyn_tile(args[0].cast(), args[1].cast>(), + args[2].cast()); + }) .def_packed("topi.gather", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = gather(args[0].cast(), args[1].cast(), diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 12f9f0f35368..7892b5cb2017 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -2700,6 +2700,30 @@ def verify_tile(in_shape, repeats, out_shape): verify_tile(x.shape, repeats, z_array.shape) +@pytest.mark.parametrize("dynamic_input", [True, False]) +def test_tile_dynamic_repeats(dynamic_input): + x = np.random.rand(2, 3).astype(np.float32) + repeats = np.array([2, 2], dtype=np.int64) + out_shape = np.tile(x, repeats).shape + + input_shape = ["?", "?"] if dynamic_input else list(x.shape) + output_shape = ["?" for _ in out_shape] if dynamic_input else list(out_shape) + + node = helper.make_node("Tile", inputs=["input", "repeats"], outputs=["out"]) + graph = helper.make_graph( + [node], + "tile_dynamic_repeats_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, input_shape), + helper.make_tensor_value_info("repeats", TensorProto.INT64, [len(repeats)]), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, output_shape)], + ) + model = helper.make_model(graph, producer_name="tile_dynamic_repeats_test") + + check_correctness(model, inputs={"input": x, "repeats": repeats}, opset=13) + + def _generate_roi_cases(): # Base case when with_roi is False roi_list = [ From 8f551c5fd3b3260aefc7461e4b6439f1d1c94cd8 Mon Sep 17 00:00:00 2001 From: LudovicoYIN Date: Fri, 6 Mar 2026 02:30:39 +0000 Subject: [PATCH 2/2] Address Tile dynamic repeats review comments --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 20 +++++++++++++------ tests/python/relax/test_frontend_onnx.py | 15 ++++++++++---- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 9a2ef07d492f..3dc575ae778c 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -2001,20 +2001,28 @@ def _impl_v13(cls, bb, inputs, attr, params): if reps.struct_info.dtype != "int64": reps = bb.normalize(relax.op.astype(reps, "int64")) - data_shape = list(data.struct_info.shape.values) - data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(relax.ShapeExpr(data_shape))) + data_shape = bb.normalize(relax.op.shape_of(data)) + data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape)) output_shape_tensor = reps if data_ndim > reps_len: reps_prefix = relax.const(_np.ones((data_ndim - reps_len,), dtype="int64"), "int64") - output_shape_tensor = bb.normalize(relax.op.concat([reps_prefix, output_shape_tensor], axis=0)) + output_shape_tensor = bb.normalize( + relax.op.concat([reps_prefix, output_shape_tensor], axis=0) + ) elif reps_len > data_ndim: data_prefix = relax.const(_np.ones((reps_len - data_ndim,), dtype="int64"), "int64") - data_shape_tensor = bb.normalize(relax.op.concat([data_prefix, data_shape_tensor], axis=0)) + data_shape_tensor = bb.normalize( + relax.op.concat([data_prefix, data_shape_tensor], axis=0) + ) - output_shape_tensor = bb.normalize(relax.op.multiply(output_shape_tensor, data_shape_tensor)) + output_shape_tensor = bb.normalize( + relax.op.multiply(output_shape_tensor, data_shape_tensor) + ) output_shape = bb.normalize(relax.op.tensor_to_shape(output_shape_tensor)) - output_shape_vars = [tir.Var(f"tile_dim_{i}", "int64") for i in range(max(data_ndim, reps_len))] + output_shape_vars = [ + tir.Var(f"tile_dim_{i}", "int64") for i in range(max(data_ndim, reps_len)) + ] bb.match_cast(output_shape, relax.ShapeStructInfo(output_shape_vars)) return bb.emit_te(topi.dyn_tile, data, output_shape_vars, reps_len) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 7892b5cb2017..ecbc6c9e8a5e 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -2701,12 +2701,19 @@ def verify_tile(in_shape, repeats, out_shape): @pytest.mark.parametrize("dynamic_input", [True, False]) -def test_tile_dynamic_repeats(dynamic_input): - x = np.random.rand(2, 3).astype(np.float32) - repeats = np.array([2, 2], dtype=np.int64) +@pytest.mark.parametrize( + "in_shape,repeats", + [ + ((2, 3), np.array([2, 2], dtype=np.int64)), + ((2, 3, 4), np.array([2, 2, 1], dtype=np.int64)), + ((2, 3, 4, 5), np.array([1, 2, 1, 2], dtype=np.int64)), + ], +) +def test_tile_dynamic_repeats(dynamic_input, in_shape, repeats): + x = np.random.rand(*in_shape).astype(np.float32) out_shape = np.tile(x, repeats).shape - input_shape = ["?", "?"] if dynamic_input else list(x.shape) + input_shape = ["?" for _ in in_shape] if dynamic_input else list(x.shape) output_shape = ["?" for _ in out_shape] if dynamic_input else list(out_shape) node = helper.make_node("Tile", inputs=["input", "repeats"], outputs=["out"])