Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1974,14 +1974,57 @@ def _impl_v11(cls, bb, inputs, attr, params):
class Tile(OnnxOpConverter):
"""Converts an onnx Tile node into an equivalent Relax expression."""

@staticmethod
def _tensor_length(expr):
shape = expr.struct_info.shape
if not isinstance(shape, relax.ShapeExpr):
return None

length = shape.values[0]
if not isinstance(length, tir.IntImm):
return None
return length.value

@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
reps = get_constant(inputs[1], params)
if isinstance(reps, relax.Constant):
reps = reps.data.numpy().tolist()
else:
raise ValueError("Dynamic reps for Tile are supported yet.")
return bb.emit_te(topi.tile, inputs[0], reps)
return bb.emit_te(topi.tile, inputs[0], reps)

data = inputs[0]
data_ndim = data.struct_info.ndim
reps_len = cls._tensor_length(reps)
if data_ndim == -1 or reps_len is None:
raise ValueError("Dynamic Tile requires known input rank and repeats length.")

if reps.struct_info.dtype != "int64":
reps = bb.normalize(relax.op.astype(reps, "int64"))

data_shape = bb.normalize(relax.op.shape_of(data))
data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape))
output_shape_tensor = reps

if data_ndim > reps_len:
reps_prefix = relax.const(_np.ones((data_ndim - reps_len,), dtype="int64"), "int64")
output_shape_tensor = bb.normalize(
relax.op.concat([reps_prefix, output_shape_tensor], axis=0)
)
elif reps_len > data_ndim:
data_prefix = relax.const(_np.ones((reps_len - data_ndim,), dtype="int64"), "int64")
data_shape_tensor = bb.normalize(
relax.op.concat([data_prefix, data_shape_tensor], axis=0)
)

output_shape_tensor = bb.normalize(
relax.op.multiply(output_shape_tensor, data_shape_tensor)
)
output_shape = bb.normalize(relax.op.tensor_to_shape(output_shape_tensor))
output_shape_vars = [
tir.Var(f"tile_dim_{i}", "int64") for i in range(max(data_ndim, reps_len))
]
bb.match_cast(output_shape, relax.ShapeStructInfo(output_shape_vars))
return bb.emit_te(topi.dyn_tile, data, output_shape_vars, reps_len)


class Expand(OnnxOpConverter):
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,28 @@ def tile(a, reps):
return cpp.tile(a, reps)


def dyn_tile(a, new_shape, rdim):
"""Repeats the whole array multiple times with dynamic output shape.

Parameters
----------
a : tvm.te.Tensor
The tensor to be tiled.

new_shape : tuple of PrimExpr
The output shape after tiling.

rdim : int
The rank of the repeats input.

Returns
-------
ret : tvm.te.Tensor
"""

return cpp.dyn_tile(a, new_shape, rdim)


def layout_transform(array, src_layout, dst_layout, schedule_rule="None"):
"""Transform the layout according to src_layout and dst_layout

Expand Down
5 changes: 5 additions & 0 deletions src/topi/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ TVM_FFI_STATIC_INIT_BLOCK() {
[](ffi::PackedArgs args, ffi::Any* rv) {
*rv = tile(args[0].cast<te::Tensor>(), args[1].cast<ffi::Array<Integer>>());
})
.def_packed("topi.dyn_tile",
[](ffi::PackedArgs args, ffi::Any* rv) {
*rv = dyn_tile(args[0].cast<te::Tensor>(), args[1].cast<ffi::Array<PrimExpr>>(),
args[2].cast<int>());
})
.def_packed("topi.gather",
[](ffi::PackedArgs args, ffi::Any* rv) {
*rv = gather(args[0].cast<te::Tensor>(), args[1].cast<int>(),
Expand Down
31 changes: 31 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2700,6 +2700,37 @@ def verify_tile(in_shape, repeats, out_shape):
verify_tile(x.shape, repeats, z_array.shape)


@pytest.mark.parametrize("dynamic_input", [True, False])
@pytest.mark.parametrize(
"in_shape,repeats",
[
((2, 3), np.array([2, 2], dtype=np.int64)),
((2, 3, 4), np.array([2, 2, 1], dtype=np.int64)),
((2, 3, 4, 5), np.array([1, 2, 1, 2], dtype=np.int64)),
],
)
def test_tile_dynamic_repeats(dynamic_input, in_shape, repeats):
x = np.random.rand(*in_shape).astype(np.float32)
out_shape = np.tile(x, repeats).shape

input_shape = ["?" for _ in in_shape] if dynamic_input else list(x.shape)
output_shape = ["?" for _ in out_shape] if dynamic_input else list(out_shape)

node = helper.make_node("Tile", inputs=["input", "repeats"], outputs=["out"])
graph = helper.make_graph(
[node],
"tile_dynamic_repeats_test",
inputs=[
helper.make_tensor_value_info("input", TensorProto.FLOAT, input_shape),
helper.make_tensor_value_info("repeats", TensorProto.INT64, [len(repeats)]),
],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, output_shape)],
)
model = helper.make_model(graph, producer_name="tile_dynamic_repeats_test")

check_correctness(model, inputs={"input": x, "repeats": repeats}, opset=13)
Comment thread
LudovicoYIN marked this conversation as resolved.


def _generate_roi_cases():
# Base case when with_roi is False
roi_list = [
Expand Down
Loading