From 5f75581a0f332b0f078cdb9225524457c66788a6 Mon Sep 17 00:00:00 2001 From: Adam Scott Date: Thu, 30 Apr 2026 00:23:15 -0400 Subject: [PATCH] [Relax][Frontend][TFLite] Add DILATE operator mapping This PR adds TFLite frontend support for the DILATE operator which extends a tensor by inserting a padding value between existing elements per axis according to the dilation strides. Decomposes into existing Relax primitives instead of registering a new op: - relax.op.full pre-fills the output with padding_value - relax.op.scatter_nd places input elements at strided positions Both static and dynamic dilations are supported. Frontend tests use hand-rolled .tflite fixtures since DILATE has no public TF Python emitter through tf.lite.TFLiteConverter, so the standard verify(TestClass, Expected) pattern can't reach it. Extends DENSIFY's fixture builders to handle BuiltinOptions2 and non-FLOAT32 tensors. _finish_tflite_model now writes the TFL3 file identifier so the produced buffer is a valid input for tf.lite.Interpreter in the nightly E2E path. Validation: python -m pytest tests/python/relax/test_frontend_tflite.py -k dilate -v Addresses the DILATE item under #19412. --- .../relax/frontend/tflite/tflite_frontend.py | 62 +++++ tests/python/relax/test_frontend_tflite.py | 261 +++++++++++++++++- 2 files changed, 317 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index b7a7e42c488d..72063cace6c4 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -136,6 +136,7 @@ def __init__(self, model, subgraph, exp_tab, ctx): "DEPTHWISE_CONV_2D": functools.partial(self.convert_conv, conv_type="depthwise"), "DEQUANTIZE": self.convert_dequantize, "DETECTION_POSTPROCESS": self.convert_detection_postprocess, + "DILATE": self.convert_dilate, "DIV": functools.partial(self._convert_elemwise, relax_op=_op.divide), "ELU": self.convert_elu, "EQUAL": functools.partial( @@ -3416,6 +3417,67 @@ def convert_dequantize(self, op): return out + def convert_dilate(self, op): + """Convert TFLite DILATE""" + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + assert len(input_tensors) == 3, "input tensors length should be 3" + assert len(output_tensors) == 1, "output tensors length should be 1" + + in_expr = self.get_tensor_expr(input_tensors[0]) + in_shape = to_int_list(self.get_tensor_shape(input_tensors[0])) + in_dtype = self.get_tensor_type_str(input_tensors[0].tensor.Type()) + n_dims = len(in_shape) + + dilations_tensor = input_tensors[1] + padding_expr = self.get_tensor_expr(input_tensors[2]) + + # Runtime dilations bind tensor values to TIR Vars for symbolic + # per-axis math. + if self.has_expr(dilations_tensor.tensor_idx): + dilations_expr = self.get_expr(dilations_tensor.tensor_idx) + dilations_expr = self.bb.match_cast( + dilations_expr, relax.TensorStructInfo([n_dims], "int32") + ) + dilations_int64 = self.bb.normalize(relax.op.astype(dilations_expr, "int64")) + shape_var = self.bb.emit(relax.op.tensor_to_shape(dilations_int64)) + stride_vars = [tirx.Var(f"dilate_stride_{i}", "int64") for i in range(n_dims)] + self.bb.match_cast(shape_var, relax.ShapeStructInfo(stride_vars)) + strides = stride_vars + else: + strides = to_int_list(self.get_tensor_value(dilations_tensor)) + + # Per axis: reshape to add a size-1 stride-axis, concat (s-1) padding + # values along it, reshape to merge axes (length d*s), trim trailing + # pad to TFLite's output dim formula (d-1)*s + 1. + result = in_expr + current_shape = list(in_shape) + axes = list(range(n_dims)) + ones = [1] * n_dims + for axis in range(n_dims): + d = current_shape[axis] + s = strides[axis] + expanded_shape = current_shape[: axis + 1] + [1] + current_shape[axis + 1 :] + expanded = relax.op.reshape(result, expanded_shape) + pad_shape = list(expanded_shape) + pad_shape[axis + 1] = s - 1 + pad = relax.op.full(pad_shape, padding_expr, dtype=in_dtype) + concatted = relax.op.concat([expanded, pad], axis=axis + 1) + merged_shape = list(current_shape) + merged_shape[axis] = d * s + merged = relax.op.reshape(concatted, merged_shape) + # (d - 1) * s + 1 is the output dim along this axis. + final_dim = (d - 1) * s + 1 + end = list(merged_shape) + end[axis] = final_dim + result = relax.op.strided_slice( + merged, axes=axes, begin=[0] * n_dims, end=end, strides=ones + ) + current_shape = list(merged_shape) + current_shape[axis] = final_dim + + return result + def convert_detection_postprocess(self, op): """Convert TFLite_Detection_PostProcess""" flexbuffer = op.CustomOptionsAsNumpy().tobytes() diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 69aab2d43b93..4fcb2a1a65ac 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -2994,6 +2994,7 @@ def _get_tflite_schema_enum(enum_name): _tfl_add_options = _get_tflite_schema_module("AddOptions") _tfl_buffer = _get_tflite_schema_module("Buffer") _tfl_conv2d_options = _get_tflite_schema_module("Conv2DOptions") +_tfl_dilate_options = _get_tflite_schema_module("DilateOptions") _tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata") _tfl_fully_connected_options = _get_tflite_schema_module("FullyConnectedOptions") _tfl_int32_vector = _get_tflite_schema_module("Int32Vector") @@ -3006,6 +3007,7 @@ def _get_tflite_schema_enum(enum_name): _tfl_builtin_operator = _get_tflite_schema_enum("BuiltinOperator") _tfl_builtin_options = _get_tflite_schema_enum("BuiltinOptions") +_tfl_builtin_options2 = _get_tflite_schema_enum("BuiltinOptions2") _tfl_dimension_type = _get_tflite_schema_enum("DimensionType") _tfl_fc_weights_format = _get_tflite_schema_enum("FullyConnectedOptionsWeightsFormat") _tfl_padding = _get_tflite_schema_enum("Padding") @@ -3061,8 +3063,10 @@ def _tflite_shape(builder, shape): return _tflite_int32_vector(builder, _tfl_tensor.TensorStartShapeVector, shape) -def _build_tensor(builder, buffer_idx, shape, sparsity=None): +def _build_tensor(builder, buffer_idx, shape, sparsity=None, tensor_type=None): """Helper to build a TFLite tensor.""" + if tensor_type is None: + tensor_type = _tfl_tensor_type.FLOAT32 shape_vec = _tflite_shape(builder, shape) _tfl_tensor.TensorStart(builder) _tfl_tensor.TensorAddBuffer(builder, buffer_idx) @@ -3071,7 +3075,7 @@ def _build_tensor(builder, buffer_idx, shape, sparsity=None): _tfl_tensor.TensorAddShape(builder, shape_vec) if sparsity is not None: _tfl_tensor.TensorAddSparsity(builder, sparsity) - _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32) + _tfl_tensor.TensorAddType(builder, tensor_type) return _tfl_tensor.TensorEnd(builder) @@ -3088,7 +3092,14 @@ def _build_buffer(builder, data=None): def _build_operator( - builder, opcode_index, inputs, outputs, builtin_options_type, builtin_options=None + builder, + opcode_index, + inputs, + outputs, + builtin_options_type=None, + builtin_options=None, + builtin_options2_type=None, + builtin_options2=None, ): inputs_vec = _tflite_int32_vector(builder, _tfl_operator.OperatorStartInputsVector, inputs) outputs_vec = _tflite_int32_vector( @@ -3098,15 +3109,23 @@ def _build_operator( _tfl_operator.OperatorAddOpcodeIndex(builder, opcode_index) _tfl_operator.OperatorAddInputs(builder, inputs_vec) _tfl_operator.OperatorAddOutputs(builder, outputs_vec) - _tfl_operator.OperatorAddBuiltinOptionsType(builder, builtin_options_type) + if builtin_options_type is not None: + _tfl_operator.OperatorAddBuiltinOptionsType(builder, builtin_options_type) if builtin_options is not None: _tfl_operator.OperatorAddBuiltinOptions(builder, builtin_options) + if builtin_options2_type is not None: + _tfl_operator.OperatorAddBuiltinOptions2Type(builder, builtin_options2_type) + if builtin_options2 is not None: + _tfl_operator.OperatorAddBuiltinOptions2(builder, builtin_options2) return _tfl_operator.OperatorEnd(builder) def _build_operator_code(builder, builtin_op): + # deprecated_builtin_code is int8 (max 127). Ops past that write 127 as a + # placeholder and use the full builtin_code field. + deprecated_code = builtin_op if builtin_op < 127 else 127 _tfl_operator_code.OperatorCodeStart(builder) - _tfl_operator_code.OperatorCodeAddDeprecatedBuiltinCode(builder, builtin_op) + _tfl_operator_code.OperatorCodeAddDeprecatedBuiltinCode(builder, deprecated_code) _tfl_operator_code.OperatorCodeAddBuiltinCode(builder, builtin_op) _tfl_operator_code.OperatorCodeAddVersion(builder, 1) return _tfl_operator_code.OperatorCodeEnd(builder) @@ -3144,7 +3163,7 @@ def _finish_tflite_model(builder, *, subgraph, operator_codes, buffers): _tfl_model.ModelAddVersion(builder, 3) model = _tfl_model.ModelEnd(builder) - builder.Finish(model) + builder.Finish(model, b"TFL3") return bytes(builder.Output()) @@ -3516,5 +3535,235 @@ def main(x: R.Tensor((1, 4), dtype="float32")) -> R.Tensor((1, 4), dtype="float3 tvm.ir.assert_structural_equal(mod, Expected) +def _build_dilate_only_case( + builder, *, input_shape, dilations, dilation_value, dynamic_dilations=False +): + input_tensor_idx = 0 + dilations_tensor_idx = 1 + padding_value_tensor_idx = 2 + output_tensor_idx = 3 + + output_shape = tuple((input_shape[i] - 1) * dilations[i] + 1 for i in range(len(input_shape))) + + input_tensor = _build_tensor(builder, 1, input_shape) + dilations_tensor = _build_tensor( + builder, 2, [len(dilations)], tensor_type=_tfl_tensor_type.INT32 + ) + padding_value_tensor = _build_tensor(builder, 3, []) + output_tensor = _build_tensor(builder, 4, output_shape) + + _tfl_dilate_options.DilateOptionsStart(builder) + dilate_opts = _tfl_dilate_options.DilateOptionsEnd(builder) + + dilate_op = _build_operator( + builder, + 0, + [input_tensor_idx, dilations_tensor_idx, padding_value_tensor_idx], + [output_tensor_idx], + builtin_options2_type=_tfl_builtin_options2.DilateOptions, + builtin_options2=dilate_opts, + ) + sg_inputs = ( + [input_tensor_idx, dilations_tensor_idx] if dynamic_dilations else [input_tensor_idx] + ) + subgraph = _build_subgraph( + builder, + tensors=[input_tensor, dilations_tensor, padding_value_tensor, output_tensor], + operators=[dilate_op], + inputs=sg_inputs, + outputs=[output_tensor_idx], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.DILATE)] + return subgraph, operator_codes + + +def test_dilate(): + """TFLite DILATE with constant dilations""" + builder = flatbuffers.Builder(1024) + input_shape = (3, 4) + dilations = [2, 2] + dilation_value = 0.5 + + subgraph, operator_codes = _build_dilate_only_case( + builder, + input_shape=input_shape, + dilations=dilations, + dilation_value=dilation_value, + ) + + buffers = [ + _build_buffer(builder), + _build_buffer(builder), + _build_buffer(builder, np.asarray(dilations, dtype=np.int32).tobytes()), + _build_buffer( + builder, np.asarray([dilation_value], dtype=np.float32).tobytes() + ), + _build_buffer(builder), + ] + + buf = _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=operator_codes, buffers=buffers + ) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((3, 4), dtype="float32"), + ) -> R.Tensor((5, 7), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((3, 1, 4), dtype="float32") = R.reshape( + tvmgen_tensor_0, R.shape([3, 1, 4]) + ) + lv1: R.Tensor((3, 1, 4), dtype="float32") = R.full( + R.shape([3, 1, 4]), R.const(0.5, "float32"), dtype="float32" + ) + lv2: R.Tensor((3, 2, 4), dtype="float32") = R.concat((lv, lv1), axis=1) + lv3: R.Tensor((6, 4), dtype="float32") = R.reshape(lv2, R.shape([6, 4])) + lv4: R.Tensor((5, 4), dtype="float32") = R.strided_slice( + lv3, [0, 1], [0, 0], [5, 4], [1, 1], assume_inbound=False + ) + lv5: R.Tensor((5, 4, 1), dtype="float32") = R.reshape( + lv4, R.shape([5, 4, 1]) + ) + lv6: R.Tensor((5, 4, 1), dtype="float32") = R.full( + R.shape([5, 4, 1]), R.const(0.5, "float32"), dtype="float32" + ) + lv7: R.Tensor((5, 4, 2), dtype="float32") = R.concat((lv5, lv6), axis=2) + lv8: R.Tensor((5, 8), dtype="float32") = R.reshape(lv7, R.shape([5, 8])) + gv: R.Tensor((5, 7), dtype="float32") = R.strided_slice( + lv8, [0, 1], [0, 0], [5, 7], [1, 1], assume_inbound=False + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_dilate_dynamic_dilations(): + """DILATE with runtime dilations""" + builder = flatbuffers.Builder(1024) + input_shape = (3, 4) + dilations_for_shape = [2, 2] + dilation_value = 0.5 + + subgraph, operator_codes = _build_dilate_only_case( + builder, + input_shape=input_shape, + dilations=dilations_for_shape, + dilation_value=dilation_value, + dynamic_dilations=True, + ) + + buffers = [ + _build_buffer(builder), + _build_buffer(builder), + _build_buffer(builder), # dilations is a runtime input so empty buffer + _build_buffer( + builder, np.asarray([dilation_value], dtype=np.float32).tobytes() + ), + _build_buffer(builder), + ] + + buf = _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=operator_codes, buffers=buffers + ) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((3, 4), dtype="float32"), + tvmgen_tensor_1: R.Tensor((2,), dtype="int32"), + ) -> R.Tensor(dtype="float32", ndim=2): + R.func_attr({"num_input": 2}) + dilate_stride_0 = T.int64() + dilate_stride_1 = T.int64() + with R.dataflow(): + lv: R.Tensor((2,), dtype="int32") = R.match_cast( + tvmgen_tensor_1, R.Tensor((2,), dtype="int32") + ) + lv1: R.Tensor((2,), dtype="int64") = R.astype(lv, dtype="int64") + lv2: R.Shape(ndim=2) = R.tensor_to_shape(lv1) + _lv3: R.Shape([dilate_stride_0, dilate_stride_1]) = R.match_cast( + lv2, R.Shape([dilate_stride_0, dilate_stride_1]) + ) + lv4: R.Tensor((3, 1, 4), dtype="float32") = R.reshape( + tvmgen_tensor_0, R.shape([3, 1, 4]) + ) + lv5: R.Tensor((3, dilate_stride_0 - 1, 4), dtype="float32") = R.full( + R.shape([3, dilate_stride_0 - 1, 4]), + R.const(0.5, "float32"), + dtype="float32", + ) + lv6: R.Tensor( + (3, 1 + (dilate_stride_0 - 1), 4), dtype="float32" + ) = R.concat((lv4, lv5), axis=1) + lv7: R.Tensor((3 * dilate_stride_0, 4), dtype="float32") = R.reshape( + lv6, R.shape([3 * dilate_stride_0, 4]) + ) + lv8: R.Tensor( + (T.min(dilate_stride_0 * 2 + 1, dilate_stride_0 * 3), 4), + dtype="float32", + ) = R.strided_slice( + lv7, + [0, 1], + [0, 0], + [2 * dilate_stride_0 + 1, 4], + [1, 1], + assume_inbound=False, + ) + lv9: R.Tensor( + (2 * dilate_stride_0 + 1, 4, 1), dtype="float32" + ) = R.reshape(lv8, R.shape([2 * dilate_stride_0 + 1, 4, 1])) + lv10: R.Tensor( + (2 * dilate_stride_0 + 1, 4, dilate_stride_1 - 1), dtype="float32" + ) = R.full( + R.shape([2 * dilate_stride_0 + 1, 4, dilate_stride_1 - 1]), + R.const(0.5, "float32"), + dtype="float32", + ) + lv11: R.Tensor( + (2 * dilate_stride_0 + 1, 4, 1 + (dilate_stride_1 - 1)), + dtype="float32", + ) = R.concat((lv9, lv10), axis=2) + lv12: R.Tensor( + (2 * dilate_stride_0 + 1, 4 * dilate_stride_1), dtype="float32" + ) = R.reshape( + lv11, R.shape([2 * dilate_stride_0 + 1, 4 * dilate_stride_1]) + ) + gv: R.Tensor( + ( + dilate_stride_0 * 2 + 1, + T.min(dilate_stride_1 * 3 + 1, dilate_stride_1 * 4), + ), + dtype="float32", + ) = R.strided_slice( + lv12, + [0, 1], + [0, 0], + [2 * dilate_stride_0 + 1, 3 * dilate_stride_1 + 1], + [1, 1], + assume_inbound=False, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + if __name__ == "__main__": pytest.main(["-s", __file__])