diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 5ff0444e0bfa..0abd700562e2 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -18,7 +18,10 @@ # pylint: disable=import-outside-toplevel, use-list-literal # pylint: disable=no-value-for-parameter, unused-variable # pylint: disable=unexpected-keyword-arg, unused-import, too-many-function-args -# ruff: noqa: RUF005, F821, F841 +# ruff: noqa: RUF005 +# F821: _qnn and _expr references are in unreachable code paths (guarded by NotImplementedError) +# and will be resolved when quantization and vision op support are added. +# ruff: noqa: F821 """Tensorflow lite frontend.""" import functools @@ -468,7 +471,9 @@ def get_tensors(self, tensors_idx_list): qnn_params = dict() qnn_params["scale"] = relax.const(scale, "float32") qnn_params["zero_point"] = relax.const(zero_point, "int32") - raise NotImplementedError("Quantized operators not supported now") + raise NotImplementedError( + "Quantized TFLite models are not yet supported in the Relax frontend" + ) return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params)) return return_list @@ -530,20 +535,14 @@ def get_tensor_type_str(self, tensor_type): return "bool" raise NotImplementedError(f"Tensor type {tensor_type!s} is currently not supported") - def flatten_to_nd(self, x, x_shape, nd=3): + def flatten_to_nd(self, x, nd=3): """Flatten input tensor to nd rank""" - ndims = self._infer_shape(x_shape)[0] + shape = x.struct_info.shape + ndims = len(shape) if ndims == nd: return x - newshape = relax.op.concat( - [ - relax.const([-1], dtype=self._infer_type(x_shape).checked_type.dtype), - relax.op.strided_slice(x_shape, [ndims - nd + 1], [ndims]), - ], - 0, - ) - out = relax.op.reshape(x, self._fold_constant(newshape)) - return out + new_shape = [-1] + [int(shape[i]) for i in range(ndims - nd + 1, ndims)] + return relax.op.reshape(x, new_shape) def has_same_qnn_params(self, lhs_tensor, rhs_tensor): lhs_scale = lhs_tensor.qnn_params["scale"] @@ -709,7 +708,7 @@ def _convert_resize(self, method, op): # ResizeNearestNeighborOptions was added in tflite v1.13 tflite_ver = 1120 - if "ResizeNearestNeighborOptions" in dir(tflite.BuiltinOptions): + if hasattr(BuiltinOptions, "ResizeNearestNeighborOptions"): tflite_ver = 1130 input_tensors = self.get_input_tensors(op) @@ -947,8 +946,7 @@ def convert_shape(self, op): shape_options = ShapeOptions() shape_options.Init(op_options.Bytes, op_options.Pos) - out_type = self.get_tensor_type_str(shape_options.OutType()) - out = shape_of(self.get_tensor_expr(input_tensors[0]), dtype=out_type) + out = relax.op.shape_of(self.get_tensor_expr(input_tensors[0])) return out @@ -1428,6 +1426,7 @@ def convert_gather(self, op): from tflite.BuiltinOptions import BuiltinOptions from tflite.GatherOptions import GatherOptions + from tflite.TensorType import TensorType input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be 2" @@ -2804,6 +2803,11 @@ def convert_batch_matmul(self, op): assert len(input_tensors) == 2, "two input tensor arguments expected" + if self.is_quantized(op): + raise NotImplementedError( + "Quantized BATCH_MATMUL is not yet supported in the Relax frontend" + ) + batch_matmul_options = BatchMatMulOptions() op_options = op.BuiltinOptions() batch_matmul_options.Init(op_options.Bytes, op_options.Pos) @@ -2811,108 +2815,54 @@ def convert_batch_matmul(self, op): input_a = self.get_expr(input_tensors[0].tensor_idx) input_b = self.get_expr(input_tensors[1].tensor_idx) - shape_a = shape_of(input_a) - shape_b = shape_of(input_b) - rank_a = self._infer_shape(shape_a)[0] - rank_b = self._infer_shape(shape_b)[0] + shape_a = list(input_a.struct_info.shape) + shape_b = list(input_b.struct_info.shape) + rank_a = len(shape_a) + rank_b = len(shape_b) if rank_a > 2 or rank_b > 2: - # Determine the output batch dimension - new_a_shape = shape_a - new_b_shape = shape_b - if rank_a > rank_b: - rank_diff = rank_a - rank_b - new_b_shape = relax.op.concat( - [ - relax.const( - [1] * rank_diff, dtype=self._infer_type(new_b_shape).checked_type.dtype - ), - shape_b, - ], - 0, - ) - elif rank_a < rank_b: - rank_diff = rank_b - rank_a - new_a_shape = relax.op.concat( - [ - relax.const( - [1] * rank_diff, dtype=self._infer_type(new_a_shape).checked_type.dtype - ), - shape_a, - ], - 0, - ) - else: - pass + # Broadcast batch dimensions + new_a_shape = [1] * max(0, rank_b - rank_a) + [int(s) for s in shape_a] + new_b_shape = [1] * max(0, rank_a - rank_b) + [int(s) for s in shape_b] + max_rank = max(rank_a, rank_b) - out_batch = relax.op.concat( - [ - relax.op.maximum( - relax.op.strided_slice(new_b_shape, [i], [i + 1]), - relax.op.strided_slice(new_a_shape, [i], [i + 1]), - ) - for i in range(max(rank_a, rank_b) - 2) - ], - 0, - ) + batch_shape = [ + max(new_a_shape[i], new_b_shape[i]) for i in range(max_rank - 2) + ] - a_broadcasted_shape = _fold_constant( - _op.concat([out_batch, _op.strided_slice(shape_a, [rank_a - 2], [rank_a])], 0) - ) - b_broadcasted_shape = _fold_constant( - _op.concat([out_batch, _op.strided_slice(shape_b, [rank_b - 2], [rank_b])], 0) - ) - if not tvm.ir.structural_equal(shape_a, a_broadcasted_shape): - input_a = relax.op.transform.broadcast_to(input_a, a_broadcasted_shape) - if not tvm.ir.structural_equal(shape_b, b_broadcasted_shape): - input_b = relax.op.transform.broadcast_to(input_b, b_broadcasted_shape) + a_broadcast = batch_shape + [int(shape_a[-2]), int(shape_a[-1])] + b_broadcast = batch_shape + [int(shape_b[-2]), int(shape_b[-1])] - input_a = self.flatten_to_nd(input_a, shape_a, 3) - input_b = self.flatten_to_nd(input_b, shape_b, 3) + if [int(s) for s in shape_a] != a_broadcast: + input_a = relax.op.broadcast_to(input_a, a_broadcast) + if [int(s) for s in shape_b] != b_broadcast: + input_b = relax.op.broadcast_to(input_b, b_broadcast) - if batch_matmul_options.AdjX(): + input_a = self.flatten_to_nd(input_a, 3) + input_b = self.flatten_to_nd(input_b, 3) + + adj_x = batch_matmul_options.AdjX() + adj_y = batch_matmul_options.AdjY() + + if adj_x: input_a = relax.op.permute_dims(input_a, [0, 2, 1]) - if not batch_matmul_options.AdjY(): + if adj_y: input_b = relax.op.permute_dims(input_b, [0, 2, 1]) - if self.is_quantized(op): - output = _qnn.op.batch_matmul( - input_a, - input_b, - relax.const(0, "int32"), - relax.const(0, "int32"), - relax.const(1.0, "float32"), - relax.const(1.0, "float32"), - ) - else: - output = relax.op.nn.batch_matmul(input_a, input_b) + output = relax.op.matmul(input_a, input_b) - # Reshape output to original dimensions. - output_shape = shape_of(output) + # Compute output matmul dims from original shapes + m_dim = int(shape_a[-1]) if adj_x else int(shape_a[-2]) + n_dim = int(shape_b[-2]) if adj_y else int(shape_b[-1]) + final_shape = [int(s) for s in shape_a[: rank_a - 2]] + [m_dim, n_dim] + return relax.op.reshape(output, final_shape) - rank_out = self._infer_shape(output_shape)[0] - - final_shape = relax.op.concat( - [ - relax.op.strided_slice(shape_a, [0], [rank_a - 2]), - relax.op.strided_slice(output_shape, [rank_out - 2], [rank_out]), - ], - 0, - ) - - reshape = relax.op.reshape(output, self._fold_constant(final_shape)) - # qnn batch matmul returns a int32 tensor so we need to requantize - if self.is_quantized(op): - return _qnn.op.requantize( - reshape, - relax.const(1.0, "float32"), - relax.const(0, "int32"), - relax.const(1.0, "float32"), - relax.const(0, "int32"), - out_dtype="int8", - ) - else: - return reshape + # rank <= 2: use matmul directly + if batch_matmul_options.AdjX(): + input_a = relax.op.permute_dims(input_a) + if batch_matmul_options.AdjY(): + input_b = relax.op.permute_dims(input_b) + return relax.op.matmul(input_a, input_b) def convert_space_to_batch_nd(self, op): """space_to_batch_nd implementation.""" @@ -2974,6 +2924,7 @@ def convert_space_to_depth(self, op): def convert_sparse_to_dense(self, op): """Convert TFLite SPARSE_TO_DENSE""" + from tflite.TensorType import TensorType input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 4, "input tensors length should be 4" @@ -3029,6 +2980,7 @@ def convert_transpose_conv(self, op): from tflite.BuiltinOptions import BuiltinOptions from tflite.Padding import Padding + from tflite.TensorType import TensorType from tflite.TransposeConvOptions import TransposeConvOptions input_tensors = self.get_input_tensors(op) @@ -3226,6 +3178,7 @@ def convert_quantize(self, op): def convert_dequantize(self, op): """Convert TFLite Dequantize""" + from tflite.TensorType import TensorType input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" @@ -3251,6 +3204,11 @@ def convert_dequantize(self, op): def convert_detection_postprocess(self, op): """Convert TFLite_Detection_PostProcess""" + raise NotImplementedError( + "DETECTION_POSTPROCESS requires vision ops (multibox_transform_loc, " + "non_max_suppression, get_valid_counts) not yet available in Relax. " + "See https://github.com/apache/tvm/issues/XXXX" + ) flexbuffer = op.CustomOptionsAsNumpy().tobytes() custom_options = FlexBufferDecoder(flexbuffer).decode() @@ -3381,6 +3339,11 @@ def convert_detection_postprocess(self, op): def convert_nms_v5(self, op): """Convert TFLite NonMaxSuppressionV5""" # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/non-max-suppression-v5 + raise NotImplementedError( + "NON_MAX_SUPPRESSION_V5 requires vision ops (get_valid_counts, " + "non_max_suppression) not yet available in Relax. " + "See https://github.com/apache/tvm/issues/XXXX" + ) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 6, "input tensor length should be 6" @@ -3843,7 +3806,7 @@ def _def_prepare_dense_matrix_from_sparse(indices, level, prev_idx): def get_scalar_from_constant(expr): """Returns scalar value from Relax constant scalar.""" - assert isinstance(expr, _expr.Constant) and not expr.data.shape, ( + assert isinstance(expr, relax.Constant) and not expr.data.shape, ( "Expr is not a constant scalar." ) value = expr.data.numpy() @@ -4091,7 +4054,7 @@ def func(self, data): with bb.function("main"): input_list = [] - with bb.dataflow() as df: # pylint: disable=invalid-name, unused-variable + with bb.dataflow() as df: # noqa: F841 # pylint: disable=invalid-name, unused-variable exp_tab = ExprTable() for model_input in model_inputs: model_input_name = get_tensor_name(subgraph, model_input) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 9d3d6a9aeb6a..e7d81cf5fea0 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -825,5 +825,33 @@ def func(self, data): verify(concrete_func) +def test_batch_matmul(): + class BatchMatMul(tf.Module): + @tf.function( + input_signature=[ + tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32), + tf.TensorSpec(shape=(2, 4, 5), dtype=tf.float32), + ] + ) + def func(self, x, y): + return tf.matmul(x, y) + + verify(BatchMatMul) + + +def test_batch_matmul_adj(): + class BatchMatMulAdj(tf.Module): + @tf.function( + input_signature=[ + tf.TensorSpec(shape=(2, 4, 3), dtype=tf.float32), + tf.TensorSpec(shape=(2, 5, 4), dtype=tf.float32), + ] + ) + def func(self, x, y): + return tf.matmul(x, y, transpose_a=True, transpose_b=True) + + verify(BatchMatMulAdj) + + if __name__ == "__main__": pytest.main(["-s", __file__])