Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 45 additions & 19 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None):
self._convert_stablehlo_binary, relax_op=_op.subtract
),
"STABLEHLO_TANH": functools.partial(self._convert_stablehlo_unary, relax_op=_op.tanh),
"STABLEHLO_WHILE": self._convert_stablehlo_while,
"SQUEEZE": self.convert_squeeze,
"STRIDED_SLICE": self.convert_strided_slice,
"SUB": functools.partial(self._convert_elemwise, relax_op=_op.subtract),
Expand Down Expand Up @@ -2161,6 +2162,19 @@ def _convert_stablehlo_sort(self, op):
relax.op.sort(data, axis=int(opts.Dimension()), descending=descending)
)

def _convert_stablehlo_while(self, op):
"""Convert STABLEHLO_WHILE to a recursive Relax private function."""
from tflite.StablehloWhileOptions import StablehloWhileOptions

opts = self._get_stablehlo_options(op, StablehloWhileOptions)
return self._convert_while_like(
op,
"STABLEHLO_WHILE",
int(opts.CondSubgraphIndex()),
int(opts.BodySubgraphIndex()),
"tflite_stablehlo_while",
)
Comment on lines +2165 to +2176

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In _convert_stablehlo_while, if opts is None (e.g., due to a malformed model or parsing failure), calling opts.CondSubgraphIndex() will raise an AttributeError. Adding a check to ensure opts is not None before extracting the subgraph indices would make the parser more robust and provide a clearer error message.

Suggested change
def _convert_stablehlo_while(self, op):
"""Convert STABLEHLO_WHILE to a recursive Relax private function."""
from tflite.StablehloWhileOptions import StablehloWhileOptions
opts = self._get_stablehlo_options(op, StablehloWhileOptions)
return self._convert_while_like(
op,
"STABLEHLO_WHILE",
int(opts.CondSubgraphIndex()),
int(opts.BodySubgraphIndex()),
"tflite_stablehlo_while",
)
def _convert_stablehlo_while(self, op):
"""Convert STABLEHLO_WHILE to a recursive Relax private function."""
from tflite.StablehloWhileOptions import StablehloWhileOptions
opts = self._get_stablehlo_options(op, StablehloWhileOptions)
if opts is None:
raise tvm.error.OpNotImplemented("STABLEHLO_WHILE requires valid StablehloWhileOptions")
return self._convert_while_like(
op,
"STABLEHLO_WHILE",
int(opts.CondSubgraphIndex()),
int(opts.BodySubgraphIndex()),
"tflite_stablehlo_while",
)


def _get_builtin_options(self, op, options_cls):
"""Parse BuiltinOptions for a TFLite builtin operator."""
from tflite.BuiltinOptions import BuiltinOptions
Expand Down Expand Up @@ -2402,14 +2416,15 @@ def _lower_while_to_function(
cond_func,
body_func,
body_subgraph,
function_prefix="tflite_while",
):
"""Lower a TFLite WHILE op into a recursive private Relax function."""
cache_key = (cond_subgraph_index, body_subgraph_index, loop_var_count)
cache_key = (function_prefix, cond_subgraph_index, body_subgraph_index, loop_var_count)
lowered_while_functions = self.conversion_state["lowered_while_functions"]
if cache_key in lowered_while_functions:
return lowered_while_functions[cache_key]

loop_name = f"tflite_while_subgraph_{cond_subgraph_index}_{body_subgraph_index}"
loop_name = f"{function_prefix}_subgraph_{cond_subgraph_index}_{body_subgraph_index}"
params, _ = self._get_subgraph_params(body_subgraph)
dummy_body = self._make_tuple_or_single(params)
module_builder = self.conversion_state["module_builder"]
Expand Down Expand Up @@ -2489,47 +2504,44 @@ def convert_if(self, op):
args = [self.get_tensor_expr(tensor) for tensor in input_tensors]
return relax.Call(if_func, args)

def convert_while(self, op):
"""Convert TFLite WHILE to a recursive Relax private function."""
from tflite.WhileOptions import WhileOptions

opts = self._get_builtin_options(op, WhileOptions)
cond_subgraph_index = int(opts.CondSubgraphIndex())
body_subgraph_index = int(opts.BodySubgraphIndex())
def _convert_while_like(
self, op, op_name, cond_subgraph_index, body_subgraph_index, function_prefix
):
"""Convert a TFLite while-like operator with referenced cond/body subgraphs."""
input_tensors = self.get_input_tensors(op)
output_tensors = self.get_output_tensors(op)
loop_var_count = len(input_tensors)
if loop_var_count == 0:
raise tvm.error.OpNotImplemented("WHILE requires loop-carried inputs")
raise tvm.error.OpNotImplemented(f"{op_name} requires loop-carried inputs")
if len(output_tensors) != loop_var_count:
raise tvm.error.OpNotImplemented("WHILE output count must match input count")
raise tvm.error.OpNotImplemented(f"{op_name} output count must match input count")

cond_subgraph = self._check_subgraph_interface(
cond_subgraph_index,
"WHILE",
op_name,
input_tensors=input_tensors,
output_count=1,
)
body_subgraph = self._check_subgraph_interface(
body_subgraph_index,
"WHILE",
op_name,
input_tensors=input_tensors,
output_tensors=input_tensors,
)
for input_tensor, output_tensor in zip(input_tensors, output_tensors):
self._check_tensor_metadata_match(input_tensor, output_tensor, "WHILE", "loop state")
self._check_tensor_metadata_match(input_tensor, output_tensor, op_name, "loop state")
cond_output = cond_subgraph.Tensors(int(cond_subgraph.Outputs(0)))
self._require_scalar_bool_tensor(cond_output, "WHILE")
self._require_scalar_bool_tensor(cond_output, op_name)

cond_func = self._lower_subgraph_to_function(
cond_subgraph_index,
f"tflite_while_cond_subgraph_{cond_subgraph_index}",
op_name="WHILE",
f"{function_prefix}_cond_subgraph_{cond_subgraph_index}",
op_name=op_name,
)
body_func = self._lower_subgraph_to_function(
body_subgraph_index,
f"tflite_while_body_subgraph_{body_subgraph_index}",
op_name="WHILE",
f"{function_prefix}_body_subgraph_{body_subgraph_index}",
op_name=op_name,
)

loop_gv = self._lower_while_to_function(
Expand All @@ -2539,11 +2551,25 @@ def convert_while(self, op):
cond_func,
body_func,
body_subgraph,
function_prefix=function_prefix,
)

args = [self.get_tensor_expr(tensor) for tensor in input_tensors]
return relax.Call(loop_gv, args)

def convert_while(self, op):
"""Convert TFLite WHILE to a recursive Relax private function."""
from tflite.WhileOptions import WhileOptions

opts = self._get_builtin_options(op, WhileOptions)
return self._convert_while_like(
op,
"WHILE",
int(opts.CondSubgraphIndex()),
int(opts.BodySubgraphIndex()),
"tflite_while",
)

def convert_call_once(self, op):
"""Convert TFLite CALL_ONCE for no-op and resource-variable initialization subsets."""
from tflite.CallOnceOptions import CallOnceOptions
Expand Down
219 changes: 219 additions & 0 deletions tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3695,6 +3695,7 @@ def _get_tflite_schema_enum(enum_name):
_tfl_stablehlo_reduce_window_opts = _get_tflite_schema_module("StablehloReduceWindowOptions")
_tfl_stablehlo_scatter_opts = _get_tflite_schema_module("StablehloScatterOptions")
_tfl_stablehlo_sort_opts = _get_tflite_schema_module("StablehloSortOptions")
_tfl_stablehlo_while_opts = _get_tflite_schema_module("StablehloWhileOptions")
_tfl_call_options = _get_tflite_schema_module("CallOptions")
_tfl_call_once_options = _get_tflite_schema_module("CallOnceOptions")
_tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata")
Expand Down Expand Up @@ -3946,6 +3947,17 @@ def _build_while_options(builder, cond_subgraph_index, body_subgraph_index):
return _tfl_while_options.WhileOptionsEnd(builder)


def _build_stablehlo_while_options(builder, cond_subgraph_index, body_subgraph_index):
_tfl_stablehlo_while_opts.StablehloWhileOptionsStart(builder)
_tfl_stablehlo_while_opts.StablehloWhileOptionsAddCondSubgraphIndex(
builder, cond_subgraph_index
)
_tfl_stablehlo_while_opts.StablehloWhileOptionsAddBodySubgraphIndex(
builder, body_subgraph_index
)
return _tfl_stablehlo_while_opts.StablehloWhileOptionsEnd(builder)


def _build_call_once_options(builder, init_subgraph_index):
_tfl_call_once_options.CallOnceOptionsStart(builder)
_tfl_call_once_options.CallOnceOptionsAddInitSubgraphIndex(builder, init_subgraph_index)
Expand Down Expand Up @@ -6296,6 +6308,107 @@ def _build_stablehlo_scatter_model(reducer_name="STABLEHLO_ADD", update_window_d
)


def _build_stablehlo_while_model(
cond_subgraph_index=1,
body_subgraph_index=2,
cond_output_type=_tfl_tensor_type.BOOL,
cond_input_type=_tfl_tensor_type.INT32,
body_outputs=None,
body_input_type=_tfl_tensor_type.INT32,
body_output_type=_tfl_tensor_type.INT32,
main_output_type=_tfl_tensor_type.INT32,
):
"""Build a STABLEHLO_WHILE model incrementing an int32 scalar until i < 3 is false."""
builder = flatbuffers.Builder(1024)

body_outputs = [2] if body_outputs is None else body_outputs
while_options = _build_stablehlo_while_options(
builder, cond_subgraph_index, body_subgraph_index
)
_tfl_stablehlo_compare_opts.StablehloCompareOptionsStart(builder)
_tfl_stablehlo_compare_opts.StablehloCompareOptionsAddComparisonDirection(
builder,
_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_LT,
)
compare_opts = _tfl_stablehlo_compare_opts.StablehloCompareOptionsEnd(builder)
one = np.array(1, dtype=np.int32)
three = np.array(3, dtype=np.int32)

main_tensors = [
_build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32),
_build_tensor(builder, 3, [], tensor_type=main_output_type),
]
main_while = _build_operator(
builder,
0,
[0],
[1],
builtin_options2_type=_tfl_builtin_options2.StablehloWhileOptions,
builtin_options2=while_options,
)
main_subgraph = _build_subgraph(
builder,
tensors=main_tensors,
operators=[main_while],
inputs=[0],
outputs=[1],
)

cond_tensors = [
_build_tensor(builder, 0, [], tensor_type=cond_input_type),
_build_tensor(builder, 1, [], tensor_type=_tfl_tensor_type.INT32),
_build_tensor(builder, 3, [], tensor_type=cond_output_type),
]
cond_compare = _build_operator(
builder,
1,
[0, 1],
[2],
builtin_options2_type=_tfl_builtin_options2.StablehloCompareOptions,
builtin_options2=compare_opts,
)
cond_subgraph = _build_subgraph(
builder,
tensors=cond_tensors,
operators=[cond_compare],
inputs=[0],
outputs=[2],
)

body_tensors = [
_build_tensor(builder, 0, [], tensor_type=body_input_type),
_build_tensor(builder, 2, [], tensor_type=_tfl_tensor_type.INT32),
_build_tensor(builder, 3, [], tensor_type=body_output_type),
]
body_add = _build_operator(builder, 2, [0, 1], [2])
body_subgraph = _build_subgraph(
builder,
tensors=body_tensors,
operators=[body_add],
inputs=[0],
outputs=body_outputs,
)

operator_codes = [
_build_operator_code(builder, _get_stablehlo_builtin_operator("STABLEHLO_WHILE")),
_build_operator_code(builder, _get_stablehlo_builtin_operator("STABLEHLO_COMPARE")),
_build_operator_code(builder, _get_stablehlo_builtin_operator("STABLEHLO_ADD")),
]
buffers = [
_build_buffer(builder),
_build_buffer(builder, three.tobytes()),
_build_buffer(builder, one.tobytes()),
_build_buffer(builder),
]
return _finish_tflite_model(
builder,
subgraph=main_subgraph,
extra_subgraphs=[cond_subgraph, body_subgraph],
operator_codes=operator_codes,
buffers=buffers,
)


def _build_stablehlo_composite_model(with_attributes=False, use_main_input_after_composite=False):
"""Build a STABLEHLO_COMPOSITE model that decomposes to STABLEHLO_NEGATE."""
builder = flatbuffers.Builder(1024)
Expand Down Expand Up @@ -6699,6 +6812,112 @@ def test_stablehlo_scatter_update_window_unsupported():
from_tflite(tflite_model)


def test_stablehlo_while():
"""TFLite STABLEHLO_WHILE lowers to a recursive Relax private function."""
mod = _load_model_from_buffer(_build_stablehlo_while_model())

@I.ir_module
class Expected:
@R.function(private=True)
def tflite_stablehlo_while_cond_subgraph_1(
tvmgen_tensor_0: R.Tensor((), dtype="int32"),
) -> R.Tensor((), dtype="bool"):
with R.dataflow():
gv: R.Tensor((), dtype="bool") = R.less(tvmgen_tensor_0, R.const(3, "int32"))
R.output(gv)
return gv

@R.function(private=True)
def tflite_stablehlo_while_body_subgraph_2(
tvmgen_tensor_0: R.Tensor((), dtype="int32"),
) -> R.Tensor((), dtype="int32"):
with R.dataflow():
gv: R.Tensor((), dtype="int32") = R.add(tvmgen_tensor_0, R.const(1, "int32"))
R.output(gv)
return gv

@R.function(private=True)
def tflite_stablehlo_while_subgraph_1_2(
tvmgen_tensor_0: R.Tensor((), dtype="int32"),
) -> R.Tensor((), dtype="int32"):
cls = Expected
while_cond: R.Tensor((), dtype="bool") = cls.tflite_stablehlo_while_cond_subgraph_1(
tvmgen_tensor_0
)
if while_cond:
gv: R.Tensor((), dtype="int32") = cls.tflite_stablehlo_while_body_subgraph_2(
tvmgen_tensor_0
)
gv1: R.Tensor((), dtype="int32") = cls.tflite_stablehlo_while_subgraph_1_2(gv)
cond_result: R.Tensor((), dtype="int32") = gv1
else:
cond_result: R.Tensor((), dtype="int32") = tvmgen_tensor_0
return cond_result

@R.function
def main(
tvmgen_tensor_0: R.Tensor((), dtype="int32"),
) -> R.Tensor((), dtype="int32"):
R.func_attr({"num_input": 1})
cls = Expected
with R.dataflow():
gv: R.Tensor((), dtype="int32") = cls.tflite_stablehlo_while_subgraph_1_2(
tvmgen_tensor_0
)
R.output(gv)
return gv

tvm.ir.assert_structural_equal(mod, Expected)


def test_stablehlo_while_non_bool_condition_unsupported():
"""STABLEHLO_WHILE rejects cond subgraphs that do not return scalar bool."""
with pytest.raises(
tvm.error.OpNotImplemented, match="STABLEHLO_WHILE requires a scalar bool condition"
):
_load_model_from_buffer(
_build_stablehlo_while_model(cond_output_type=_tfl_tensor_type.INT32)
)


def test_stablehlo_while_invalid_index_unsupported():
"""STABLEHLO_WHILE rejects invalid cond/body subgraph indices before lowering."""
with pytest.raises(
tvm.error.OpNotImplemented, match="STABLEHLO_WHILE requires a valid subgraph index"
):
_load_model_from_buffer(_build_stablehlo_while_model(cond_subgraph_index=3))


def test_stablehlo_while_output_count_mismatch_unsupported():
"""STABLEHLO_WHILE rejects body subgraphs whose output arity does not match loop vars."""
with pytest.raises(
tvm.error.OpNotImplemented, match="STABLEHLO_WHILE subgraph output count mismatch"
):
_load_model_from_buffer(_build_stablehlo_while_model(body_outputs=[]))


def test_stablehlo_while_input_metadata_mismatch_unsupported():
"""STABLEHLO_WHILE rejects cond subgraph inputs whose metadata does not match loop vars."""
with pytest.raises(
tvm.error.OpNotImplemented,
match="STABLEHLO_WHILE subgraph input tensor metadata mismatch",
):
_load_model_from_buffer(
_build_stablehlo_while_model(cond_input_type=_tfl_tensor_type.FLOAT32)
)


def test_stablehlo_while_output_metadata_mismatch_unsupported():
"""STABLEHLO_WHILE rejects body outputs whose metadata does not match loop vars."""
with pytest.raises(
tvm.error.OpNotImplemented,
match="STABLEHLO_WHILE subgraph output tensor metadata mismatch",
):
_load_model_from_buffer(
_build_stablehlo_while_model(body_output_type=_tfl_tensor_type.FLOAT32)
)


def test_stablehlo_composite():
"""TFLite StableHLO COMPOSITE inlines a simple decomposition subgraph."""
mod = _load_model_from_buffer(_build_stablehlo_composite_model())
Expand Down
Loading