Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 72 additions & 109 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
# pylint: disable=import-outside-toplevel, use-list-literal
# pylint: disable=no-value-for-parameter, unused-variable
# pylint: disable=unexpected-keyword-arg, unused-import, too-many-function-args
# ruff: noqa: RUF005, F821, F841
# ruff: noqa: RUF005
# F821: _qnn and _expr references are in unreachable code paths (guarded by NotImplementedError)
# and will be resolved when quantization and vision op support are added.
# ruff: noqa: F821
"""Tensorflow lite frontend."""

import functools
Expand Down Expand Up @@ -468,7 +471,9 @@ def get_tensors(self, tensors_idx_list):
qnn_params = dict()
qnn_params["scale"] = relax.const(scale, "float32")
qnn_params["zero_point"] = relax.const(zero_point, "int32")
raise NotImplementedError("Quantized operators not supported now")
raise NotImplementedError(
"Quantized TFLite models are not yet supported in the Relax frontend"
)
return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params))
return return_list

Expand Down Expand Up @@ -530,20 +535,14 @@ def get_tensor_type_str(self, tensor_type):
return "bool"
raise NotImplementedError(f"Tensor type {tensor_type!s} is currently not supported")

def flatten_to_nd(self, x, x_shape, nd=3):
def flatten_to_nd(self, x, nd=3):
"""Flatten input tensor to nd rank"""
ndims = self._infer_shape(x_shape)[0]
shape = x.struct_info.shape
ndims = len(shape)
if ndims == nd:
return x
newshape = relax.op.concat(
[
relax.const([-1], dtype=self._infer_type(x_shape).checked_type.dtype),
relax.op.strided_slice(x_shape, [ndims - nd + 1], [ndims]),
],
0,
)
out = relax.op.reshape(x, self._fold_constant(newshape))
return out
new_shape = [-1] + [int(shape[i]) for i in range(ndims - nd + 1, ndims)]
return relax.op.reshape(x, new_shape)

def has_same_qnn_params(self, lhs_tensor, rhs_tensor):
lhs_scale = lhs_tensor.qnn_params["scale"]
Expand Down Expand Up @@ -709,7 +708,7 @@ def _convert_resize(self, method, op):

# ResizeNearestNeighborOptions was added in tflite v1.13
tflite_ver = 1120
if "ResizeNearestNeighborOptions" in dir(tflite.BuiltinOptions):
if hasattr(BuiltinOptions, "ResizeNearestNeighborOptions"):
tflite_ver = 1130

input_tensors = self.get_input_tensors(op)
Expand Down Expand Up @@ -947,8 +946,7 @@ def convert_shape(self, op):
shape_options = ShapeOptions()
shape_options.Init(op_options.Bytes, op_options.Pos)

out_type = self.get_tensor_type_str(shape_options.OutType())
out = shape_of(self.get_tensor_expr(input_tensors[0]), dtype=out_type)
out = relax.op.shape_of(self.get_tensor_expr(input_tensors[0]))

return out

Expand Down Expand Up @@ -1428,6 +1426,7 @@ def convert_gather(self, op):

from tflite.BuiltinOptions import BuiltinOptions
from tflite.GatherOptions import GatherOptions
from tflite.TensorType import TensorType

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"
Expand Down Expand Up @@ -2804,115 +2803,66 @@ def convert_batch_matmul(self, op):

assert len(input_tensors) == 2, "two input tensor arguments expected"

if self.is_quantized(op):
raise NotImplementedError(
"Quantized BATCH_MATMUL is not yet supported in the Relax frontend"
)

batch_matmul_options = BatchMatMulOptions()
op_options = op.BuiltinOptions()
batch_matmul_options.Init(op_options.Bytes, op_options.Pos)

input_a = self.get_expr(input_tensors[0].tensor_idx)
input_b = self.get_expr(input_tensors[1].tensor_idx)

shape_a = shape_of(input_a)
shape_b = shape_of(input_b)
rank_a = self._infer_shape(shape_a)[0]
rank_b = self._infer_shape(shape_b)[0]
shape_a = list(input_a.struct_info.shape)
shape_b = list(input_b.struct_info.shape)
rank_a = len(shape_a)
rank_b = len(shape_b)

if rank_a > 2 or rank_b > 2:
# Determine the output batch dimension
new_a_shape = shape_a
new_b_shape = shape_b
if rank_a > rank_b:
rank_diff = rank_a - rank_b
new_b_shape = relax.op.concat(
[
relax.const(
[1] * rank_diff, dtype=self._infer_type(new_b_shape).checked_type.dtype
),
shape_b,
],
0,
)
elif rank_a < rank_b:
rank_diff = rank_b - rank_a
new_a_shape = relax.op.concat(
[
relax.const(
[1] * rank_diff, dtype=self._infer_type(new_a_shape).checked_type.dtype
),
shape_a,
],
0,
)
else:
pass
# Broadcast batch dimensions
new_a_shape = [1] * max(0, rank_b - rank_a) + [int(s) for s in shape_a]
new_b_shape = [1] * max(0, rank_a - rank_b) + [int(s) for s in shape_b]
max_rank = max(rank_a, rank_b)

out_batch = relax.op.concat(
[
relax.op.maximum(
relax.op.strided_slice(new_b_shape, [i], [i + 1]),
relax.op.strided_slice(new_a_shape, [i], [i + 1]),
)
for i in range(max(rank_a, rank_b) - 2)
],
0,
)
batch_shape = [
max(new_a_shape[i], new_b_shape[i]) for i in range(max_rank - 2)
]

a_broadcasted_shape = _fold_constant(
_op.concat([out_batch, _op.strided_slice(shape_a, [rank_a - 2], [rank_a])], 0)
)
b_broadcasted_shape = _fold_constant(
_op.concat([out_batch, _op.strided_slice(shape_b, [rank_b - 2], [rank_b])], 0)
)
if not tvm.ir.structural_equal(shape_a, a_broadcasted_shape):
input_a = relax.op.transform.broadcast_to(input_a, a_broadcasted_shape)
if not tvm.ir.structural_equal(shape_b, b_broadcasted_shape):
input_b = relax.op.transform.broadcast_to(input_b, b_broadcasted_shape)
a_broadcast = batch_shape + [int(shape_a[-2]), int(shape_a[-1])]
b_broadcast = batch_shape + [int(shape_b[-2]), int(shape_b[-1])]

input_a = self.flatten_to_nd(input_a, shape_a, 3)
input_b = self.flatten_to_nd(input_b, shape_b, 3)
if [int(s) for s in shape_a] != a_broadcast:
input_a = relax.op.broadcast_to(input_a, a_broadcast)
if [int(s) for s in shape_b] != b_broadcast:
input_b = relax.op.broadcast_to(input_b, b_broadcast)

if batch_matmul_options.AdjX():
input_a = self.flatten_to_nd(input_a, 3)
input_b = self.flatten_to_nd(input_b, 3)

adj_x = batch_matmul_options.AdjX()
adj_y = batch_matmul_options.AdjY()

if adj_x:
input_a = relax.op.permute_dims(input_a, [0, 2, 1])
if not batch_matmul_options.AdjY():
if adj_y:
input_b = relax.op.permute_dims(input_b, [0, 2, 1])

if self.is_quantized(op):
output = _qnn.op.batch_matmul(
input_a,
input_b,
relax.const(0, "int32"),
relax.const(0, "int32"),
relax.const(1.0, "float32"),
relax.const(1.0, "float32"),
)
else:
output = relax.op.nn.batch_matmul(input_a, input_b)
output = relax.op.matmul(input_a, input_b)

# Reshape output to original dimensions.
output_shape = shape_of(output)
# Compute output matmul dims from original shapes
m_dim = int(shape_a[-1]) if adj_x else int(shape_a[-2])
n_dim = int(shape_b[-2]) if adj_y else int(shape_b[-1])
final_shape = [int(s) for s in shape_a[: rank_a - 2]] + [m_dim, n_dim]
return relax.op.reshape(output, final_shape)

rank_out = self._infer_shape(output_shape)[0]

final_shape = relax.op.concat(
[
relax.op.strided_slice(shape_a, [0], [rank_a - 2]),
relax.op.strided_slice(output_shape, [rank_out - 2], [rank_out]),
],
0,
)

reshape = relax.op.reshape(output, self._fold_constant(final_shape))
# qnn batch matmul returns a int32 tensor so we need to requantize
if self.is_quantized(op):
return _qnn.op.requantize(
reshape,
relax.const(1.0, "float32"),
relax.const(0, "int32"),
relax.const(1.0, "float32"),
relax.const(0, "int32"),
out_dtype="int8",
)
else:
return reshape
# rank <= 2: use matmul directly
if batch_matmul_options.AdjX():
input_a = relax.op.permute_dims(input_a)
if batch_matmul_options.AdjY():
input_b = relax.op.permute_dims(input_b)
return relax.op.matmul(input_a, input_b)

def convert_space_to_batch_nd(self, op):
"""space_to_batch_nd implementation."""
Expand Down Expand Up @@ -2974,6 +2924,7 @@ def convert_space_to_depth(self, op):

def convert_sparse_to_dense(self, op):
"""Convert TFLite SPARSE_TO_DENSE"""
from tflite.TensorType import TensorType

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 4, "input tensors length should be 4"
Expand Down Expand Up @@ -3029,6 +2980,7 @@ def convert_transpose_conv(self, op):

from tflite.BuiltinOptions import BuiltinOptions
from tflite.Padding import Padding
from tflite.TensorType import TensorType
from tflite.TransposeConvOptions import TransposeConvOptions

input_tensors = self.get_input_tensors(op)
Expand Down Expand Up @@ -3226,6 +3178,7 @@ def convert_quantize(self, op):

def convert_dequantize(self, op):
"""Convert TFLite Dequantize"""
from tflite.TensorType import TensorType

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
Expand All @@ -3251,6 +3204,11 @@ def convert_dequantize(self, op):

def convert_detection_postprocess(self, op):
"""Convert TFLite_Detection_PostProcess"""
raise NotImplementedError(
"DETECTION_POSTPROCESS requires vision ops (multibox_transform_loc, "
"non_max_suppression, get_valid_counts) not yet available in Relax. "
"See https://github.com/apache/tvm/issues/XXXX"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's good practice to link to a tracking issue for NotImplementedError. Please replace XXXX with the actual GitHub issue number that tracks the implementation of the missing vision ops. This will help developers track progress and contribute more effectively.

)
flexbuffer = op.CustomOptionsAsNumpy().tobytes()
custom_options = FlexBufferDecoder(flexbuffer).decode()

Expand Down Expand Up @@ -3381,6 +3339,11 @@ def convert_detection_postprocess(self, op):
def convert_nms_v5(self, op):
"""Convert TFLite NonMaxSuppressionV5"""
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/non-max-suppression-v5
raise NotImplementedError(
"NON_MAX_SUPPRESSION_V5 requires vision ops (get_valid_counts, "
"non_max_suppression) not yet available in Relax. "
"See https://github.com/apache/tvm/issues/XXXX"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the comment on convert_detection_postprocess, please replace the XXXX placeholder with a specific GitHub issue number for tracking the implementation of the required vision ops. This provides a clear reference for tracking the required feature.

)

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 6, "input tensor length should be 6"
Expand Down Expand Up @@ -3843,7 +3806,7 @@ def _def_prepare_dense_matrix_from_sparse(indices, level, prev_idx):

def get_scalar_from_constant(expr):
"""Returns scalar value from Relax constant scalar."""
assert isinstance(expr, _expr.Constant) and not expr.data.shape, (
assert isinstance(expr, relax.Constant) and not expr.data.shape, (
"Expr is not a constant scalar."
)
value = expr.data.numpy()
Expand Down Expand Up @@ -4091,7 +4054,7 @@ def func(self, data):

with bb.function("main"):
input_list = []
with bb.dataflow() as df: # pylint: disable=invalid-name, unused-variable
with bb.dataflow() as df: # noqa: F841 # pylint: disable=invalid-name, unused-variable
exp_tab = ExprTable()
for model_input in model_inputs:
model_input_name = get_tensor_name(subgraph, model_input)
Expand Down
28 changes: 28 additions & 0 deletions tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,5 +825,33 @@ def func(self, data):
verify(concrete_func)


def test_batch_matmul():
class BatchMatMul(tf.Module):
@tf.function(
input_signature=[
tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32),
tf.TensorSpec(shape=(2, 4, 5), dtype=tf.float32),
]
)
def func(self, x, y):
return tf.matmul(x, y)

verify(BatchMatMul)


def test_batch_matmul_adj():
class BatchMatMulAdj(tf.Module):
@tf.function(
input_signature=[
tf.TensorSpec(shape=(2, 4, 3), dtype=tf.float32),
tf.TensorSpec(shape=(2, 5, 4), dtype=tf.float32),
]
)
def func(self, x, y):
return tf.matmul(x, y, transpose_a=True, transpose_b=True)

verify(BatchMatMulAdj)


if __name__ == "__main__":
pytest.main(["-s", __file__])
Loading