Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 63 additions & 18 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import numpy as np

import tvm
from tvm import relax
from tvm import relax, tirx
from tvm.relax import op as _op

from .tflite_flexbuffer import FlexBufferDecoder
Expand Down Expand Up @@ -1770,14 +1770,24 @@ def convert_fill(self, op):
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"

if self.has_expr(input_tensors[0].tensor_idx):
raise tvm.error.OpNotImplemented(
"For dims parameter of Fill operator, only constant values are supported."
)

in_dims = list(self.get_tensor_value(input_tensors[0]))
dims_tensor = input_tensors[0]
in_value_expr = self.get_expr(input_tensors[1].tensor_idx)
out = relax.op.full(in_dims, in_value_expr)

if self.has_expr(dims_tensor.tensor_idx):
dims_expr = self.get_expr(dims_tensor.tensor_idx)
dims_ndim = int(self.get_tensor_shape(dims_tensor)[0])

# Bind runtime dims to fresh symbolic shape vars so the imported
# module remains well formed before LegalizeOps runs.
dims_expr = self.bb.match_cast(dims_expr, relax.TensorStructInfo([dims_ndim], "int32"))
dims_expr = self.bb.normalize(relax.op.astype(dims_expr, "int64"))
shape_dataflow_var = self.bb.emit(relax.op.tensor_to_shape(dims_expr))
shape_vars = [tirx.Var(f"fill_dim_{i}", "int64") for i in range(dims_ndim)]
self.bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars))
out = relax.op.full(relax.ShapeExpr(shape_vars), in_value_expr)
else:
in_dims = list(self.get_tensor_value(dims_tensor))
out = relax.op.full(in_dims, in_value_expr)

return out

Expand Down Expand Up @@ -2331,29 +2341,64 @@ def convert_split(self, op):
def convert_split_v(self, op):
"""SPLIT_V implementation."""
input_tensors = self.get_input_tensors(op)
output_tensors = self.get_output_tensors(op)

assert len(input_tensors) == 3, "input tensors length should be 3"

input_tensor = input_tensors[0]
input_tensor_idx = input_tensor.tensor_idx
in_expr = self.get_expr(input_tensor_idx)

if self.has_expr(input_tensors[1].tensor_idx):
raise tvm.error.OpNotImplemented(
"For size_splits parameter of SPLIT_V operator, only constant values are supported."
)
size_splits = list(self.get_tensor_value(input_tensors[1]))
size_splits = tuple(np.cumsum(size_splits)[:-1])

axis_tensor = input_tensors[2]
split_axis = self.get_tensor_value(axis_tensor)
split_axis = int(self.get_tensor_value(axis_tensor))

size_splits_tensor = input_tensors[1]

if self.has_expr(size_splits_tensor.tensor_idx):
# Dynamic size_splits case: decompose into dynamic strided slices.
size_splits_expr = self.get_expr(size_splits_tensor.tensor_idx)
cumsum = relax.op.cumsum(size_splits_expr, axis=0, dtype="int64")
# Pad a leading zero so that cumsum[i-1] can be read uniformly
# via strided_slice even for i == 0.
zero = relax.const(np.array([0], dtype="int64"), "int64")
padded_cumsum = relax.op.concat([zero, cumsum], axis=0)
# TFLite fixes the tuple arity in the graph, even when the split
# sizes themselves are supplied at runtime.
num_splits = len(output_tensors)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable num_splits is used but it is better to use len(output_tensors) directly if it is only used once, or rename it to num_outputs to be more descriptive of what it represents in the context of TFLite output tensors.

rank = len(in_expr.struct_info.shape)

# end_base is the full input shape; only split_axis changes per slice.
end_base = relax.op.shape_to_tensor(relax.op.shape_of(in_expr))
begin_base = relax.const(np.zeros((rank,), dtype="int64"), "int64")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable rank is used to create begin_base and strides. It is clearer to define this as input_rank to avoid ambiguity with other potential rank definitions.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, these are naming-only suggestions. Since the variables are local and the current intent is already documented by comments, I’d prefer to keep the diff minimal.

strides = relax.const(np.ones((rank,), dtype="int64"), "int64")
scatter_idx = relax.const([split_axis], "int64")

outputs = []
for i in range(num_splits):
start_val = relax.op.strided_slice(
padded_cumsum, axes=[0], begin=[i], end=[i + 1]
)
end_val = relax.op.strided_slice(
padded_cumsum, axes=[0], begin=[i + 1], end=[i + 2]
)

begin = relax.op.scatter_elements(begin_base, scatter_idx, start_val)
end = relax.op.scatter_elements(end_base, scatter_idx, end_val)
slice_i = relax.op.dynamic_strided_slice(in_expr, begin, end, strides)
outputs.append(slice_i)

out = relax.Tuple(outputs)
else:
# Static size_splits case
size_splits = list(self.get_tensor_value(size_splits_tensor))
size_splits = tuple(np.cumsum(size_splits)[:-1])
out = relax.op.split(in_expr, size_splits, axis=split_axis)

out = relax.op.split(in_expr, size_splits, axis=int(split_axis))
# Relay does not like a TupleWrapper of 1 element, further this
# only shows up with tf1.13 if we use a split with num_splits==1.
# In tf 1.14 this doesn't appear as it is automatically a reshape
# operation.
if isinstance(out, relax.Tuple) and out.size == 1:
if isinstance(out, relax.Tuple) and len(out.fields) == 1:
out = out[0]

return out
Expand Down
16 changes: 14 additions & 2 deletions python/tvm/relax/transform/legalize_ops/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from tvm import tirx, topi

from ...block_builder import BlockBuilder
from ...expr import Call, Expr, PrimValue, const
from ...expr import Call, Expr, PrimValue, ShapeExpr, const
from ...struct_info import ShapeStructInfo
from .common import LegalizeFunc, _try_convert_to_scalar_const, register_legalize


Expand All @@ -34,10 +35,21 @@ def full_call_te(bb: BlockBuilder, call: Call) -> Expr:
if fill_value is None
else fill_value
)
shape = call.args[0].struct_info.shape if is_like else call.args[0]

if isinstance(shape, ShapeExpr):
output_shape = shape.values
else:
assert isinstance(shape.struct_info, ShapeStructInfo)
assert shape.struct_info.ndim >= 0

shape = bb.emit(shape)
output_shape = [tirx.Var(f"s{i}", "int64") for i in range(shape.struct_info.ndim)]
bb.match_cast(shape, ShapeStructInfo(output_shape))

return bb.call_te(
topi.full,
call.args[0].struct_info.shape if is_like else call.args[0],
output_shape,
call.struct_info.dtype,
_fill_value,
primfunc_name_hint=primfunc_name,
Expand Down
2 changes: 2 additions & 0 deletions src/relax/op/tensor/create.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ TVM_REGISTER_OP("relax.full")
.add_argument("shape", "Shape", "The shape of the created tensor.")
.add_argument("fill_value", "Tensor", "The scalar tensor, denoting the value to fill.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFull)
.set_attr<Bool>("RequiresArgumentShapes", Bool(false))
.set_attr<Bool>("FDataDependent", Bool(true))
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));

Expand Down
42 changes: 42 additions & 0 deletions tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,26 @@ def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 3, 10), dtype="f
verify(Split, Expected)


def test_split_v_dynamic():
"""SPLIT_V with runtime split sizes imports shape-aware Relax IR."""

class TfSplitVDynamic(tf.Module):
@tf.function(
input_signature=[
tf.TensorSpec(shape=(10,), dtype=tf.float32),
tf.TensorSpec(shape=(3,), dtype=tf.int32),
]
)
def func(self, x, size_splits):
return tf.split(x, size_splits, axis=0)

cf = TfSplitVDynamic().func.get_concrete_function()
mod = _get_mod_from_cfunc(cf)
ir = mod.script()
assert "R.dynamic_strided_slice" in ir
assert "R.scatter_elements" in ir


def test_pack():
class Pack(tf.Module):
@tf.function(
Expand Down Expand Up @@ -592,6 +612,28 @@ def main(
verify(TfInput, Expected)


def test_fill_dynamic_dims():
"""FILL with runtime dims legalizes and compiles."""

class TfFillDynamic(tf.Module):
@tf.function(
input_signature=[
tf.TensorSpec(shape=(2,), dtype=tf.int32),
tf.TensorSpec(shape=(), dtype=tf.float32),
]
)
def func(self, dims, value):
return tf.fill(dims, value)

cf = TfFillDynamic().func.get_concrete_function()
mod = _get_mod_from_cfunc(cf)
ir = mod.script()
assert "R.tensor_to_shape" in ir
assert "R.full" in ir
tvm.compile(mod, tvm.target.Target("llvm"))
verify(cf)


@pytest.mark.parametrize(
"tf_op, relax_op",
[
Expand Down
Loading