diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index b7a7e42c488d..72063cace6c4 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -136,6 +136,7 @@ def __init__(self, model, subgraph, exp_tab, ctx): "DEPTHWISE_CONV_2D": functools.partial(self.convert_conv, conv_type="depthwise"), "DEQUANTIZE": self.convert_dequantize, "DETECTION_POSTPROCESS": self.convert_detection_postprocess, + "DILATE": self.convert_dilate, "DIV": functools.partial(self._convert_elemwise, relax_op=_op.divide), "ELU": self.convert_elu, "EQUAL": functools.partial( @@ -3416,6 +3417,67 @@ def convert_dequantize(self, op): return out + def convert_dilate(self, op): + """Convert TFLite DILATE""" + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + assert len(input_tensors) == 3, "input tensors length should be 3" + assert len(output_tensors) == 1, "output tensors length should be 1" + + in_expr = self.get_tensor_expr(input_tensors[0]) + in_shape = to_int_list(self.get_tensor_shape(input_tensors[0])) + in_dtype = self.get_tensor_type_str(input_tensors[0].tensor.Type()) + n_dims = len(in_shape) + + dilations_tensor = input_tensors[1] + padding_expr = self.get_tensor_expr(input_tensors[2]) + + # Runtime dilations bind tensor values to TIR Vars for symbolic + # per-axis math. + if self.has_expr(dilations_tensor.tensor_idx): + dilations_expr = self.get_expr(dilations_tensor.tensor_idx) + dilations_expr = self.bb.match_cast( + dilations_expr, relax.TensorStructInfo([n_dims], "int32") + ) + dilations_int64 = self.bb.normalize(relax.op.astype(dilations_expr, "int64")) + shape_var = self.bb.emit(relax.op.tensor_to_shape(dilations_int64)) + stride_vars = [tirx.Var(f"dilate_stride_{i}", "int64") for i in range(n_dims)] + self.bb.match_cast(shape_var, relax.ShapeStructInfo(stride_vars)) + strides = stride_vars + else: + strides = to_int_list(self.get_tensor_value(dilations_tensor)) + + # Per axis: reshape to add a size-1 stride-axis, concat (s-1) padding + # values along it, reshape to merge axes (length d*s), trim trailing + # pad to TFLite's output dim formula (d-1)*s + 1. + result = in_expr + current_shape = list(in_shape) + axes = list(range(n_dims)) + ones = [1] * n_dims + for axis in range(n_dims): + d = current_shape[axis] + s = strides[axis] + expanded_shape = current_shape[: axis + 1] + [1] + current_shape[axis + 1 :] + expanded = relax.op.reshape(result, expanded_shape) + pad_shape = list(expanded_shape) + pad_shape[axis + 1] = s - 1 + pad = relax.op.full(pad_shape, padding_expr, dtype=in_dtype) + concatted = relax.op.concat([expanded, pad], axis=axis + 1) + merged_shape = list(current_shape) + merged_shape[axis] = d * s + merged = relax.op.reshape(concatted, merged_shape) + # (d - 1) * s + 1 is the output dim along this axis. + final_dim = (d - 1) * s + 1 + end = list(merged_shape) + end[axis] = final_dim + result = relax.op.strided_slice( + merged, axes=axes, begin=[0] * n_dims, end=end, strides=ones + ) + current_shape = list(merged_shape) + current_shape[axis] = final_dim + + return result + def convert_detection_postprocess(self, op): """Convert TFLite_Detection_PostProcess""" flexbuffer = op.CustomOptionsAsNumpy().tobytes() diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 69aab2d43b93..4fcb2a1a65ac 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -2994,6 +2994,7 @@ def _get_tflite_schema_enum(enum_name): _tfl_add_options = _get_tflite_schema_module("AddOptions") _tfl_buffer = _get_tflite_schema_module("Buffer") _tfl_conv2d_options = _get_tflite_schema_module("Conv2DOptions") +_tfl_dilate_options = _get_tflite_schema_module("DilateOptions") _tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata") _tfl_fully_connected_options = _get_tflite_schema_module("FullyConnectedOptions") _tfl_int32_vector = _get_tflite_schema_module("Int32Vector") @@ -3006,6 +3007,7 @@ def _get_tflite_schema_enum(enum_name): _tfl_builtin_operator = _get_tflite_schema_enum("BuiltinOperator") _tfl_builtin_options = _get_tflite_schema_enum("BuiltinOptions") +_tfl_builtin_options2 = _get_tflite_schema_enum("BuiltinOptions2") _tfl_dimension_type = _get_tflite_schema_enum("DimensionType") _tfl_fc_weights_format = _get_tflite_schema_enum("FullyConnectedOptionsWeightsFormat") _tfl_padding = _get_tflite_schema_enum("Padding") @@ -3061,8 +3063,10 @@ def _tflite_shape(builder, shape): return _tflite_int32_vector(builder, _tfl_tensor.TensorStartShapeVector, shape) -def _build_tensor(builder, buffer_idx, shape, sparsity=None): +def _build_tensor(builder, buffer_idx, shape, sparsity=None, tensor_type=None): """Helper to build a TFLite tensor.""" + if tensor_type is None: + tensor_type = _tfl_tensor_type.FLOAT32 shape_vec = _tflite_shape(builder, shape) _tfl_tensor.TensorStart(builder) _tfl_tensor.TensorAddBuffer(builder, buffer_idx) @@ -3071,7 +3075,7 @@ def _build_tensor(builder, buffer_idx, shape, sparsity=None): _tfl_tensor.TensorAddShape(builder, shape_vec) if sparsity is not None: _tfl_tensor.TensorAddSparsity(builder, sparsity) - _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32) + _tfl_tensor.TensorAddType(builder, tensor_type) return _tfl_tensor.TensorEnd(builder) @@ -3088,7 +3092,14 @@ def _build_buffer(builder, data=None): def _build_operator( - builder, opcode_index, inputs, outputs, builtin_options_type, builtin_options=None + builder, + opcode_index, + inputs, + outputs, + builtin_options_type=None, + builtin_options=None, + builtin_options2_type=None, + builtin_options2=None, ): inputs_vec = _tflite_int32_vector(builder, _tfl_operator.OperatorStartInputsVector, inputs) outputs_vec = _tflite_int32_vector( @@ -3098,15 +3109,23 @@ def _build_operator( _tfl_operator.OperatorAddOpcodeIndex(builder, opcode_index) _tfl_operator.OperatorAddInputs(builder, inputs_vec) _tfl_operator.OperatorAddOutputs(builder, outputs_vec) - _tfl_operator.OperatorAddBuiltinOptionsType(builder, builtin_options_type) + if builtin_options_type is not None: + _tfl_operator.OperatorAddBuiltinOptionsType(builder, builtin_options_type) if builtin_options is not None: _tfl_operator.OperatorAddBuiltinOptions(builder, builtin_options) + if builtin_options2_type is not None: + _tfl_operator.OperatorAddBuiltinOptions2Type(builder, builtin_options2_type) + if builtin_options2 is not None: + _tfl_operator.OperatorAddBuiltinOptions2(builder, builtin_options2) return _tfl_operator.OperatorEnd(builder) def _build_operator_code(builder, builtin_op): + # deprecated_builtin_code is int8 (max 127). Ops past that write 127 as a + # placeholder and use the full builtin_code field. + deprecated_code = builtin_op if builtin_op < 127 else 127 _tfl_operator_code.OperatorCodeStart(builder) - _tfl_operator_code.OperatorCodeAddDeprecatedBuiltinCode(builder, builtin_op) + _tfl_operator_code.OperatorCodeAddDeprecatedBuiltinCode(builder, deprecated_code) _tfl_operator_code.OperatorCodeAddBuiltinCode(builder, builtin_op) _tfl_operator_code.OperatorCodeAddVersion(builder, 1) return _tfl_operator_code.OperatorCodeEnd(builder) @@ -3144,7 +3163,7 @@ def _finish_tflite_model(builder, *, subgraph, operator_codes, buffers): _tfl_model.ModelAddVersion(builder, 3) model = _tfl_model.ModelEnd(builder) - builder.Finish(model) + builder.Finish(model, b"TFL3") return bytes(builder.Output()) @@ -3516,5 +3535,235 @@ def main(x: R.Tensor((1, 4), dtype="float32")) -> R.Tensor((1, 4), dtype="float3 tvm.ir.assert_structural_equal(mod, Expected) +def _build_dilate_only_case( + builder, *, input_shape, dilations, dilation_value, dynamic_dilations=False +): + input_tensor_idx = 0 + dilations_tensor_idx = 1 + padding_value_tensor_idx = 2 + output_tensor_idx = 3 + + output_shape = tuple((input_shape[i] - 1) * dilations[i] + 1 for i in range(len(input_shape))) + + input_tensor = _build_tensor(builder, 1, input_shape) + dilations_tensor = _build_tensor( + builder, 2, [len(dilations)], tensor_type=_tfl_tensor_type.INT32 + ) + padding_value_tensor = _build_tensor(builder, 3, []) + output_tensor = _build_tensor(builder, 4, output_shape) + + _tfl_dilate_options.DilateOptionsStart(builder) + dilate_opts = _tfl_dilate_options.DilateOptionsEnd(builder) + + dilate_op = _build_operator( + builder, + 0, + [input_tensor_idx, dilations_tensor_idx, padding_value_tensor_idx], + [output_tensor_idx], + builtin_options2_type=_tfl_builtin_options2.DilateOptions, + builtin_options2=dilate_opts, + ) + sg_inputs = ( + [input_tensor_idx, dilations_tensor_idx] if dynamic_dilations else [input_tensor_idx] + ) + subgraph = _build_subgraph( + builder, + tensors=[input_tensor, dilations_tensor, padding_value_tensor, output_tensor], + operators=[dilate_op], + inputs=sg_inputs, + outputs=[output_tensor_idx], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.DILATE)] + return subgraph, operator_codes + + +def test_dilate(): + """TFLite DILATE with constant dilations""" + builder = flatbuffers.Builder(1024) + input_shape = (3, 4) + dilations = [2, 2] + dilation_value = 0.5 + + subgraph, operator_codes = _build_dilate_only_case( + builder, + input_shape=input_shape, + dilations=dilations, + dilation_value=dilation_value, + ) + + buffers = [ + _build_buffer(builder), + _build_buffer(builder), + _build_buffer(builder, np.asarray(dilations, dtype=np.int32).tobytes()), + _build_buffer( + builder, np.asarray([dilation_value], dtype=np.float32).tobytes() + ), + _build_buffer(builder), + ] + + buf = _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=operator_codes, buffers=buffers + ) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((3, 4), dtype="float32"), + ) -> R.Tensor((5, 7), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((3, 1, 4), dtype="float32") = R.reshape( + tvmgen_tensor_0, R.shape([3, 1, 4]) + ) + lv1: R.Tensor((3, 1, 4), dtype="float32") = R.full( + R.shape([3, 1, 4]), R.const(0.5, "float32"), dtype="float32" + ) + lv2: R.Tensor((3, 2, 4), dtype="float32") = R.concat((lv, lv1), axis=1) + lv3: R.Tensor((6, 4), dtype="float32") = R.reshape(lv2, R.shape([6, 4])) + lv4: R.Tensor((5, 4), dtype="float32") = R.strided_slice( + lv3, [0, 1], [0, 0], [5, 4], [1, 1], assume_inbound=False + ) + lv5: R.Tensor((5, 4, 1), dtype="float32") = R.reshape( + lv4, R.shape([5, 4, 1]) + ) + lv6: R.Tensor((5, 4, 1), dtype="float32") = R.full( + R.shape([5, 4, 1]), R.const(0.5, "float32"), dtype="float32" + ) + lv7: R.Tensor((5, 4, 2), dtype="float32") = R.concat((lv5, lv6), axis=2) + lv8: R.Tensor((5, 8), dtype="float32") = R.reshape(lv7, R.shape([5, 8])) + gv: R.Tensor((5, 7), dtype="float32") = R.strided_slice( + lv8, [0, 1], [0, 0], [5, 7], [1, 1], assume_inbound=False + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_dilate_dynamic_dilations(): + """DILATE with runtime dilations""" + builder = flatbuffers.Builder(1024) + input_shape = (3, 4) + dilations_for_shape = [2, 2] + dilation_value = 0.5 + + subgraph, operator_codes = _build_dilate_only_case( + builder, + input_shape=input_shape, + dilations=dilations_for_shape, + dilation_value=dilation_value, + dynamic_dilations=True, + ) + + buffers = [ + _build_buffer(builder), + _build_buffer(builder), + _build_buffer(builder), # dilations is a runtime input so empty buffer + _build_buffer( + builder, np.asarray([dilation_value], dtype=np.float32).tobytes() + ), + _build_buffer(builder), + ] + + buf = _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=operator_codes, buffers=buffers + ) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((3, 4), dtype="float32"), + tvmgen_tensor_1: R.Tensor((2,), dtype="int32"), + ) -> R.Tensor(dtype="float32", ndim=2): + R.func_attr({"num_input": 2}) + dilate_stride_0 = T.int64() + dilate_stride_1 = T.int64() + with R.dataflow(): + lv: R.Tensor((2,), dtype="int32") = R.match_cast( + tvmgen_tensor_1, R.Tensor((2,), dtype="int32") + ) + lv1: R.Tensor((2,), dtype="int64") = R.astype(lv, dtype="int64") + lv2: R.Shape(ndim=2) = R.tensor_to_shape(lv1) + _lv3: R.Shape([dilate_stride_0, dilate_stride_1]) = R.match_cast( + lv2, R.Shape([dilate_stride_0, dilate_stride_1]) + ) + lv4: R.Tensor((3, 1, 4), dtype="float32") = R.reshape( + tvmgen_tensor_0, R.shape([3, 1, 4]) + ) + lv5: R.Tensor((3, dilate_stride_0 - 1, 4), dtype="float32") = R.full( + R.shape([3, dilate_stride_0 - 1, 4]), + R.const(0.5, "float32"), + dtype="float32", + ) + lv6: R.Tensor( + (3, 1 + (dilate_stride_0 - 1), 4), dtype="float32" + ) = R.concat((lv4, lv5), axis=1) + lv7: R.Tensor((3 * dilate_stride_0, 4), dtype="float32") = R.reshape( + lv6, R.shape([3 * dilate_stride_0, 4]) + ) + lv8: R.Tensor( + (T.min(dilate_stride_0 * 2 + 1, dilate_stride_0 * 3), 4), + dtype="float32", + ) = R.strided_slice( + lv7, + [0, 1], + [0, 0], + [2 * dilate_stride_0 + 1, 4], + [1, 1], + assume_inbound=False, + ) + lv9: R.Tensor( + (2 * dilate_stride_0 + 1, 4, 1), dtype="float32" + ) = R.reshape(lv8, R.shape([2 * dilate_stride_0 + 1, 4, 1])) + lv10: R.Tensor( + (2 * dilate_stride_0 + 1, 4, dilate_stride_1 - 1), dtype="float32" + ) = R.full( + R.shape([2 * dilate_stride_0 + 1, 4, dilate_stride_1 - 1]), + R.const(0.5, "float32"), + dtype="float32", + ) + lv11: R.Tensor( + (2 * dilate_stride_0 + 1, 4, 1 + (dilate_stride_1 - 1)), + dtype="float32", + ) = R.concat((lv9, lv10), axis=2) + lv12: R.Tensor( + (2 * dilate_stride_0 + 1, 4 * dilate_stride_1), dtype="float32" + ) = R.reshape( + lv11, R.shape([2 * dilate_stride_0 + 1, 4 * dilate_stride_1]) + ) + gv: R.Tensor( + ( + dilate_stride_0 * 2 + 1, + T.min(dilate_stride_1 * 3 + 1, dilate_stride_1 * 4), + ), + dtype="float32", + ) = R.strided_slice( + lv12, + [0, 1], + [0, 0], + [2 * dilate_stride_0 + 1, 3 * dilate_stride_1 + 1], + [1, 1], + assume_inbound=False, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + if __name__ == "__main__": pytest.main(["-s", __file__])