diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index b3c2d06eab07..3dc575ae778c 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1974,14 +1974,57 @@ def _impl_v11(cls, bb, inputs, attr, params): class Tile(OnnxOpConverter): """Converts an onnx Tile node into an equivalent Relax expression.""" + @staticmethod + def _tensor_length(expr): + shape = expr.struct_info.shape + if not isinstance(shape, relax.ShapeExpr): + return None + + length = shape.values[0] + if not isinstance(length, tir.IntImm): + return None + return length.value + @classmethod def _impl_v13(cls, bb, inputs, attr, params): reps = get_constant(inputs[1], params) if isinstance(reps, relax.Constant): reps = reps.data.numpy().tolist() - else: - raise ValueError("Dynamic reps for Tile are supported yet.") - return bb.emit_te(topi.tile, inputs[0], reps) + return bb.emit_te(topi.tile, inputs[0], reps) + + data = inputs[0] + data_ndim = data.struct_info.ndim + reps_len = cls._tensor_length(reps) + if data_ndim == -1 or reps_len is None: + raise ValueError("Dynamic Tile requires known input rank and repeats length.") + + if reps.struct_info.dtype != "int64": + reps = bb.normalize(relax.op.astype(reps, "int64")) + + data_shape = bb.normalize(relax.op.shape_of(data)) + data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape)) + output_shape_tensor = reps + + if data_ndim > reps_len: + reps_prefix = relax.const(_np.ones((data_ndim - reps_len,), dtype="int64"), "int64") + output_shape_tensor = bb.normalize( + relax.op.concat([reps_prefix, output_shape_tensor], axis=0) + ) + elif reps_len > data_ndim: + data_prefix = relax.const(_np.ones((reps_len - data_ndim,), dtype="int64"), "int64") + data_shape_tensor = bb.normalize( + relax.op.concat([data_prefix, data_shape_tensor], axis=0) + ) + + output_shape_tensor = bb.normalize( + relax.op.multiply(output_shape_tensor, data_shape_tensor) + ) + output_shape = bb.normalize(relax.op.tensor_to_shape(output_shape_tensor)) + output_shape_vars = [ + tir.Var(f"tile_dim_{i}", "int64") for i in range(max(data_ndim, reps_len)) + ] + bb.match_cast(output_shape, relax.ShapeStructInfo(output_shape_vars)) + return bb.emit_te(topi.dyn_tile, data, output_shape_vars, reps_len) class Expand(OnnxOpConverter): diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index bc187e3f269b..a3e736644607 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -657,6 +657,28 @@ def tile(a, reps): return cpp.tile(a, reps) +def dyn_tile(a, new_shape, rdim): + """Repeats the whole array multiple times with dynamic output shape. + + Parameters + ---------- + a : tvm.te.Tensor + The tensor to be tiled. + + new_shape : tuple of PrimExpr + The output shape after tiling. + + rdim : int + The rank of the repeats input. + + Returns + ------- + ret : tvm.te.Tensor + """ + + return cpp.dyn_tile(a, new_shape, rdim) + + def layout_transform(array, src_layout, dst_layout, schedule_rule="None"): """Transform the layout according to src_layout and dst_layout diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 5e2ffd4cbd9a..09f9a9be5ea7 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -156,6 +156,11 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](ffi::PackedArgs args, ffi::Any* rv) { *rv = tile(args[0].cast(), args[1].cast>()); }) + .def_packed("topi.dyn_tile", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = dyn_tile(args[0].cast(), args[1].cast>(), + args[2].cast()); + }) .def_packed("topi.gather", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = gather(args[0].cast(), args[1].cast(), diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 12f9f0f35368..ecbc6c9e8a5e 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -2700,6 +2700,37 @@ def verify_tile(in_shape, repeats, out_shape): verify_tile(x.shape, repeats, z_array.shape) +@pytest.mark.parametrize("dynamic_input", [True, False]) +@pytest.mark.parametrize( + "in_shape,repeats", + [ + ((2, 3), np.array([2, 2], dtype=np.int64)), + ((2, 3, 4), np.array([2, 2, 1], dtype=np.int64)), + ((2, 3, 4, 5), np.array([1, 2, 1, 2], dtype=np.int64)), + ], +) +def test_tile_dynamic_repeats(dynamic_input, in_shape, repeats): + x = np.random.rand(*in_shape).astype(np.float32) + out_shape = np.tile(x, repeats).shape + + input_shape = ["?" for _ in in_shape] if dynamic_input else list(x.shape) + output_shape = ["?" for _ in out_shape] if dynamic_input else list(out_shape) + + node = helper.make_node("Tile", inputs=["input", "repeats"], outputs=["out"]) + graph = helper.make_graph( + [node], + "tile_dynamic_repeats_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, input_shape), + helper.make_tensor_value_info("repeats", TensorProto.INT64, [len(repeats)]), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, output_shape)], + ) + model = helper.make_model(graph, producer_name="tile_dynamic_repeats_test") + + check_correctness(model, inputs={"input": x, "repeats": repeats}, opset=13) + + def _generate_roi_cases(): # Base case when with_roi is False roi_list = [