Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(self, model, subgraph, exp_tab, ctx):
"DEPTHWISE_CONV_2D": functools.partial(self.convert_conv, conv_type="depthwise"),
"DEQUANTIZE": self.convert_dequantize,
"DETECTION_POSTPROCESS": self.convert_detection_postprocess,
"DILATE": self.convert_dilate,
"DIV": functools.partial(self._convert_elemwise, relax_op=_op.divide),
"ELU": self.convert_elu,
"EQUAL": functools.partial(
Expand Down Expand Up @@ -3416,6 +3417,67 @@ def convert_dequantize(self, op):

return out

def convert_dilate(self, op):
"""Convert TFLite DILATE"""
input_tensors = self.get_input_tensors(op)
output_tensors = self.get_output_tensors(op)
assert len(input_tensors) == 3, "input tensors length should be 3"
assert len(output_tensors) == 1, "output tensors length should be 1"

in_expr = self.get_tensor_expr(input_tensors[0])
in_shape = to_int_list(self.get_tensor_shape(input_tensors[0]))
in_dtype = self.get_tensor_type_str(input_tensors[0].tensor.Type())
n_dims = len(in_shape)

dilations_tensor = input_tensors[1]
padding_expr = self.get_tensor_expr(input_tensors[2])

# Runtime dilations bind tensor values to TIR Vars for symbolic
# per-axis math.
if self.has_expr(dilations_tensor.tensor_idx):
dilations_expr = self.get_expr(dilations_tensor.tensor_idx)
dilations_expr = self.bb.match_cast(
dilations_expr, relax.TensorStructInfo([n_dims], "int32")
)
dilations_int64 = self.bb.normalize(relax.op.astype(dilations_expr, "int64"))
shape_var = self.bb.emit(relax.op.tensor_to_shape(dilations_int64))
stride_vars = [tirx.Var(f"dilate_stride_{i}", "int64") for i in range(n_dims)]
self.bb.match_cast(shape_var, relax.ShapeStructInfo(stride_vars))
strides = stride_vars
else:
strides = to_int_list(self.get_tensor_value(dilations_tensor))

# Per axis: reshape to add a size-1 stride-axis, concat (s-1) padding
# values along it, reshape to merge axes (length d*s), trim trailing
# pad to TFLite's output dim formula (d-1)*s + 1.
result = in_expr
current_shape = list(in_shape)
axes = list(range(n_dims))
ones = [1] * n_dims
for axis in range(n_dims):
d = current_shape[axis]
s = strides[axis]
expanded_shape = current_shape[: axis + 1] + [1] + current_shape[axis + 1 :]
expanded = relax.op.reshape(result, expanded_shape)
pad_shape = list(expanded_shape)
pad_shape[axis + 1] = s - 1
pad = relax.op.full(pad_shape, padding_expr, dtype=in_dtype)
concatted = relax.op.concat([expanded, pad], axis=axis + 1)
merged_shape = list(current_shape)
merged_shape[axis] = d * s
merged = relax.op.reshape(concatted, merged_shape)
# (d - 1) * s + 1 is the output dim along this axis.
final_dim = (d - 1) * s + 1
end = list(merged_shape)
end[axis] = final_dim
result = relax.op.strided_slice(
merged, axes=axes, begin=[0] * n_dims, end=end, strides=ones
)
current_shape = list(merged_shape)
current_shape[axis] = final_dim

return result

def convert_detection_postprocess(self, op):
"""Convert TFLite_Detection_PostProcess"""
flexbuffer = op.CustomOptionsAsNumpy().tobytes()
Expand Down
261 changes: 255 additions & 6 deletions tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2994,6 +2994,7 @@ def _get_tflite_schema_enum(enum_name):
_tfl_add_options = _get_tflite_schema_module("AddOptions")
_tfl_buffer = _get_tflite_schema_module("Buffer")
_tfl_conv2d_options = _get_tflite_schema_module("Conv2DOptions")
_tfl_dilate_options = _get_tflite_schema_module("DilateOptions")
_tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata")
_tfl_fully_connected_options = _get_tflite_schema_module("FullyConnectedOptions")
_tfl_int32_vector = _get_tflite_schema_module("Int32Vector")
Expand All @@ -3006,6 +3007,7 @@ def _get_tflite_schema_enum(enum_name):

_tfl_builtin_operator = _get_tflite_schema_enum("BuiltinOperator")
_tfl_builtin_options = _get_tflite_schema_enum("BuiltinOptions")
_tfl_builtin_options2 = _get_tflite_schema_enum("BuiltinOptions2")
_tfl_dimension_type = _get_tflite_schema_enum("DimensionType")
_tfl_fc_weights_format = _get_tflite_schema_enum("FullyConnectedOptionsWeightsFormat")
_tfl_padding = _get_tflite_schema_enum("Padding")
Expand Down Expand Up @@ -3061,8 +3063,10 @@ def _tflite_shape(builder, shape):
return _tflite_int32_vector(builder, _tfl_tensor.TensorStartShapeVector, shape)


def _build_tensor(builder, buffer_idx, shape, sparsity=None):
def _build_tensor(builder, buffer_idx, shape, sparsity=None, tensor_type=None):
"""Helper to build a TFLite tensor."""
if tensor_type is None:
tensor_type = _tfl_tensor_type.FLOAT32
shape_vec = _tflite_shape(builder, shape)
_tfl_tensor.TensorStart(builder)
_tfl_tensor.TensorAddBuffer(builder, buffer_idx)
Expand All @@ -3071,7 +3075,7 @@ def _build_tensor(builder, buffer_idx, shape, sparsity=None):
_tfl_tensor.TensorAddShape(builder, shape_vec)
if sparsity is not None:
_tfl_tensor.TensorAddSparsity(builder, sparsity)
_tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32)
_tfl_tensor.TensorAddType(builder, tensor_type)
return _tfl_tensor.TensorEnd(builder)


Expand All @@ -3088,7 +3092,14 @@ def _build_buffer(builder, data=None):


def _build_operator(
builder, opcode_index, inputs, outputs, builtin_options_type, builtin_options=None
builder,
opcode_index,
inputs,
outputs,
builtin_options_type=None,
builtin_options=None,
builtin_options2_type=None,
builtin_options2=None,
):
inputs_vec = _tflite_int32_vector(builder, _tfl_operator.OperatorStartInputsVector, inputs)
outputs_vec = _tflite_int32_vector(
Expand All @@ -3098,15 +3109,23 @@ def _build_operator(
_tfl_operator.OperatorAddOpcodeIndex(builder, opcode_index)
_tfl_operator.OperatorAddInputs(builder, inputs_vec)
_tfl_operator.OperatorAddOutputs(builder, outputs_vec)
_tfl_operator.OperatorAddBuiltinOptionsType(builder, builtin_options_type)
if builtin_options_type is not None:
_tfl_operator.OperatorAddBuiltinOptionsType(builder, builtin_options_type)
if builtin_options is not None:
_tfl_operator.OperatorAddBuiltinOptions(builder, builtin_options)
if builtin_options2_type is not None:
_tfl_operator.OperatorAddBuiltinOptions2Type(builder, builtin_options2_type)
if builtin_options2 is not None:
_tfl_operator.OperatorAddBuiltinOptions2(builder, builtin_options2)
return _tfl_operator.OperatorEnd(builder)


def _build_operator_code(builder, builtin_op):
# deprecated_builtin_code is int8 (max 127). Ops past that write 127 as a
# placeholder and use the full builtin_code field.
deprecated_code = builtin_op if builtin_op < 127 else 127
_tfl_operator_code.OperatorCodeStart(builder)
_tfl_operator_code.OperatorCodeAddDeprecatedBuiltinCode(builder, builtin_op)
_tfl_operator_code.OperatorCodeAddDeprecatedBuiltinCode(builder, deprecated_code)
_tfl_operator_code.OperatorCodeAddBuiltinCode(builder, builtin_op)
_tfl_operator_code.OperatorCodeAddVersion(builder, 1)
return _tfl_operator_code.OperatorCodeEnd(builder)
Expand Down Expand Up @@ -3144,7 +3163,7 @@ def _finish_tflite_model(builder, *, subgraph, operator_codes, buffers):
_tfl_model.ModelAddVersion(builder, 3)
model = _tfl_model.ModelEnd(builder)

builder.Finish(model)
builder.Finish(model, b"TFL3")
return bytes(builder.Output())


Expand Down Expand Up @@ -3516,5 +3535,235 @@ def main(x: R.Tensor((1, 4), dtype="float32")) -> R.Tensor((1, 4), dtype="float3
tvm.ir.assert_structural_equal(mod, Expected)


def _build_dilate_only_case(
builder, *, input_shape, dilations, dilation_value, dynamic_dilations=False
):
input_tensor_idx = 0
dilations_tensor_idx = 1
padding_value_tensor_idx = 2
output_tensor_idx = 3

output_shape = tuple((input_shape[i] - 1) * dilations[i] + 1 for i in range(len(input_shape)))

input_tensor = _build_tensor(builder, 1, input_shape)
dilations_tensor = _build_tensor(
builder, 2, [len(dilations)], tensor_type=_tfl_tensor_type.INT32
)
padding_value_tensor = _build_tensor(builder, 3, [])
output_tensor = _build_tensor(builder, 4, output_shape)

_tfl_dilate_options.DilateOptionsStart(builder)
dilate_opts = _tfl_dilate_options.DilateOptionsEnd(builder)

dilate_op = _build_operator(
builder,
0,
[input_tensor_idx, dilations_tensor_idx, padding_value_tensor_idx],
[output_tensor_idx],
builtin_options2_type=_tfl_builtin_options2.DilateOptions,
builtin_options2=dilate_opts,
)
sg_inputs = (
[input_tensor_idx, dilations_tensor_idx] if dynamic_dilations else [input_tensor_idx]
)
subgraph = _build_subgraph(
builder,
tensors=[input_tensor, dilations_tensor, padding_value_tensor, output_tensor],
operators=[dilate_op],
inputs=sg_inputs,
outputs=[output_tensor_idx],
)
operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.DILATE)]
return subgraph, operator_codes


def test_dilate():
"""TFLite DILATE with constant dilations"""
builder = flatbuffers.Builder(1024)
input_shape = (3, 4)
dilations = [2, 2]
dilation_value = 0.5

subgraph, operator_codes = _build_dilate_only_case(
builder,
input_shape=input_shape,
dilations=dilations,
dilation_value=dilation_value,
)

buffers = [
_build_buffer(builder),
_build_buffer(builder),
_build_buffer(builder, np.asarray(dilations, dtype=np.int32).tobytes()),
_build_buffer(
builder, np.asarray([dilation_value], dtype=np.float32).tobytes()
),
_build_buffer(builder),
]

buf = _finish_tflite_model(
builder, subgraph=subgraph, operator_codes=operator_codes, buffers=buffers
)
if hasattr(tflite.Model, "Model"):
tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
else:
tflite_model = tflite.Model.GetRootAsModel(buf, 0)
mod = from_tflite(tflite_model)
mod["main"] = mod["main"].without_attr("params")

@I.ir_module
class Expected:
@R.function
def main(
tvmgen_tensor_0: R.Tensor((3, 4), dtype="float32"),
) -> R.Tensor((5, 7), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor((3, 1, 4), dtype="float32") = R.reshape(
tvmgen_tensor_0, R.shape([3, 1, 4])
)
lv1: R.Tensor((3, 1, 4), dtype="float32") = R.full(
R.shape([3, 1, 4]), R.const(0.5, "float32"), dtype="float32"
)
lv2: R.Tensor((3, 2, 4), dtype="float32") = R.concat((lv, lv1), axis=1)
lv3: R.Tensor((6, 4), dtype="float32") = R.reshape(lv2, R.shape([6, 4]))
lv4: R.Tensor((5, 4), dtype="float32") = R.strided_slice(
lv3, [0, 1], [0, 0], [5, 4], [1, 1], assume_inbound=False
)
lv5: R.Tensor((5, 4, 1), dtype="float32") = R.reshape(
lv4, R.shape([5, 4, 1])
)
lv6: R.Tensor((5, 4, 1), dtype="float32") = R.full(
R.shape([5, 4, 1]), R.const(0.5, "float32"), dtype="float32"
)
lv7: R.Tensor((5, 4, 2), dtype="float32") = R.concat((lv5, lv6), axis=2)
lv8: R.Tensor((5, 8), dtype="float32") = R.reshape(lv7, R.shape([5, 8]))
gv: R.Tensor((5, 7), dtype="float32") = R.strided_slice(
lv8, [0, 1], [0, 0], [5, 7], [1, 1], assume_inbound=False
)
R.output(gv)
return gv

tvm.ir.assert_structural_equal(mod, Expected)


def test_dilate_dynamic_dilations():
"""DILATE with runtime dilations"""
builder = flatbuffers.Builder(1024)
input_shape = (3, 4)
dilations_for_shape = [2, 2]
dilation_value = 0.5

subgraph, operator_codes = _build_dilate_only_case(
builder,
input_shape=input_shape,
dilations=dilations_for_shape,
dilation_value=dilation_value,
dynamic_dilations=True,
)

buffers = [
_build_buffer(builder),
_build_buffer(builder),
_build_buffer(builder), # dilations is a runtime input so empty buffer
_build_buffer(
builder, np.asarray([dilation_value], dtype=np.float32).tobytes()
),
_build_buffer(builder),
]

buf = _finish_tflite_model(
builder, subgraph=subgraph, operator_codes=operator_codes, buffers=buffers
)
if hasattr(tflite.Model, "Model"):
tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
else:
tflite_model = tflite.Model.GetRootAsModel(buf, 0)
mod = from_tflite(tflite_model)
mod["main"] = mod["main"].without_attr("params")

@I.ir_module
class Expected:
@R.function
def main(
tvmgen_tensor_0: R.Tensor((3, 4), dtype="float32"),
tvmgen_tensor_1: R.Tensor((2,), dtype="int32"),
) -> R.Tensor(dtype="float32", ndim=2):
R.func_attr({"num_input": 2})
dilate_stride_0 = T.int64()
dilate_stride_1 = T.int64()
with R.dataflow():
lv: R.Tensor((2,), dtype="int32") = R.match_cast(
tvmgen_tensor_1, R.Tensor((2,), dtype="int32")
)
lv1: R.Tensor((2,), dtype="int64") = R.astype(lv, dtype="int64")
lv2: R.Shape(ndim=2) = R.tensor_to_shape(lv1)
_lv3: R.Shape([dilate_stride_0, dilate_stride_1]) = R.match_cast(
lv2, R.Shape([dilate_stride_0, dilate_stride_1])
)
lv4: R.Tensor((3, 1, 4), dtype="float32") = R.reshape(
tvmgen_tensor_0, R.shape([3, 1, 4])
)
lv5: R.Tensor((3, dilate_stride_0 - 1, 4), dtype="float32") = R.full(
R.shape([3, dilate_stride_0 - 1, 4]),
R.const(0.5, "float32"),
dtype="float32",
)
lv6: R.Tensor(
(3, 1 + (dilate_stride_0 - 1), 4), dtype="float32"
) = R.concat((lv4, lv5), axis=1)
lv7: R.Tensor((3 * dilate_stride_0, 4), dtype="float32") = R.reshape(
lv6, R.shape([3 * dilate_stride_0, 4])
)
lv8: R.Tensor(
(T.min(dilate_stride_0 * 2 + 1, dilate_stride_0 * 3), 4),
dtype="float32",
) = R.strided_slice(
lv7,
[0, 1],
[0, 0],
[2 * dilate_stride_0 + 1, 4],
[1, 1],
assume_inbound=False,
)
lv9: R.Tensor(
(2 * dilate_stride_0 + 1, 4, 1), dtype="float32"
) = R.reshape(lv8, R.shape([2 * dilate_stride_0 + 1, 4, 1]))
lv10: R.Tensor(
(2 * dilate_stride_0 + 1, 4, dilate_stride_1 - 1), dtype="float32"
) = R.full(
R.shape([2 * dilate_stride_0 + 1, 4, dilate_stride_1 - 1]),
R.const(0.5, "float32"),
dtype="float32",
)
lv11: R.Tensor(
(2 * dilate_stride_0 + 1, 4, 1 + (dilate_stride_1 - 1)),
dtype="float32",
) = R.concat((lv9, lv10), axis=2)
lv12: R.Tensor(
(2 * dilate_stride_0 + 1, 4 * dilate_stride_1), dtype="float32"
) = R.reshape(
lv11, R.shape([2 * dilate_stride_0 + 1, 4 * dilate_stride_1])
)
gv: R.Tensor(
(
dilate_stride_0 * 2 + 1,
T.min(dilate_stride_1 * 3 + 1, dilate_stride_1 * 4),
),
dtype="float32",
) = R.strided_slice(
lv12,
[0, 1],
[0, 0],
[2 * dilate_stride_0 + 1, 3 * dilate_stride_1 + 1],
[1, 1],
assume_inbound=False,
)
R.output(gv)
return gv

tvm.ir.assert_structural_equal(mod, Expected)


if __name__ == "__main__":
pytest.main(["-s", __file__])
Loading