From e07e7d69dfeea05b48f29770fc521d6088c3f55a Mon Sep 17 00:00:00 2001 From: Aharrypotter Date: Sun, 31 May 2026 00:37:34 +0800 Subject: [PATCH] [Relax][Frontend][TFLite] Support STABLEHLO_WHILE Add Relax TFLite frontend support for STABLEHLO_WHILE by parsing StablehloWhileOptions from BuiltinOptions2 and reusing the existing TFLite while lowering path. Refactor the native WHILE converter through a shared _convert_while_like helper so both native WHILE and STABLEHLO_WHILE validate cond/body subgraph boundaries, loop-carried tensor metadata, scalar bool conditions, and output arity consistently. Include the function prefix in the cached recursive loop key so native and StableHLO while functions cannot collide when they reference the same subgraph indices. Add manually-built StableHLO while flatbuffer tests covering the recursive Relax private function lowering plus invalid cond/body index, non-bool condition, subgraph output count mismatch, input metadata mismatch, and body output metadata mismatch guards. --- .../relax/frontend/tflite/tflite_frontend.py | 64 +++-- tests/python/relax/test_frontend_tflite.py | 219 ++++++++++++++++++ 2 files changed, 264 insertions(+), 19 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 7046e43bbe68..45cd41ce5b14 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -387,6 +387,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): self._convert_stablehlo_binary, relax_op=_op.subtract ), "STABLEHLO_TANH": functools.partial(self._convert_stablehlo_unary, relax_op=_op.tanh), + "STABLEHLO_WHILE": self._convert_stablehlo_while, "SQUEEZE": self.convert_squeeze, "STRIDED_SLICE": self.convert_strided_slice, "SUB": functools.partial(self._convert_elemwise, relax_op=_op.subtract), @@ -2161,6 +2162,19 @@ def _convert_stablehlo_sort(self, op): relax.op.sort(data, axis=int(opts.Dimension()), descending=descending) ) + def _convert_stablehlo_while(self, op): + """Convert STABLEHLO_WHILE to a recursive Relax private function.""" + from tflite.StablehloWhileOptions import StablehloWhileOptions + + opts = self._get_stablehlo_options(op, StablehloWhileOptions) + return self._convert_while_like( + op, + "STABLEHLO_WHILE", + int(opts.CondSubgraphIndex()), + int(opts.BodySubgraphIndex()), + "tflite_stablehlo_while", + ) + def _get_builtin_options(self, op, options_cls): """Parse BuiltinOptions for a TFLite builtin operator.""" from tflite.BuiltinOptions import BuiltinOptions @@ -2402,14 +2416,15 @@ def _lower_while_to_function( cond_func, body_func, body_subgraph, + function_prefix="tflite_while", ): """Lower a TFLite WHILE op into a recursive private Relax function.""" - cache_key = (cond_subgraph_index, body_subgraph_index, loop_var_count) + cache_key = (function_prefix, cond_subgraph_index, body_subgraph_index, loop_var_count) lowered_while_functions = self.conversion_state["lowered_while_functions"] if cache_key in lowered_while_functions: return lowered_while_functions[cache_key] - loop_name = f"tflite_while_subgraph_{cond_subgraph_index}_{body_subgraph_index}" + loop_name = f"{function_prefix}_subgraph_{cond_subgraph_index}_{body_subgraph_index}" params, _ = self._get_subgraph_params(body_subgraph) dummy_body = self._make_tuple_or_single(params) module_builder = self.conversion_state["module_builder"] @@ -2489,47 +2504,44 @@ def convert_if(self, op): args = [self.get_tensor_expr(tensor) for tensor in input_tensors] return relax.Call(if_func, args) - def convert_while(self, op): - """Convert TFLite WHILE to a recursive Relax private function.""" - from tflite.WhileOptions import WhileOptions - - opts = self._get_builtin_options(op, WhileOptions) - cond_subgraph_index = int(opts.CondSubgraphIndex()) - body_subgraph_index = int(opts.BodySubgraphIndex()) + def _convert_while_like( + self, op, op_name, cond_subgraph_index, body_subgraph_index, function_prefix + ): + """Convert a TFLite while-like operator with referenced cond/body subgraphs.""" input_tensors = self.get_input_tensors(op) output_tensors = self.get_output_tensors(op) loop_var_count = len(input_tensors) if loop_var_count == 0: - raise tvm.error.OpNotImplemented("WHILE requires loop-carried inputs") + raise tvm.error.OpNotImplemented(f"{op_name} requires loop-carried inputs") if len(output_tensors) != loop_var_count: - raise tvm.error.OpNotImplemented("WHILE output count must match input count") + raise tvm.error.OpNotImplemented(f"{op_name} output count must match input count") cond_subgraph = self._check_subgraph_interface( cond_subgraph_index, - "WHILE", + op_name, input_tensors=input_tensors, output_count=1, ) body_subgraph = self._check_subgraph_interface( body_subgraph_index, - "WHILE", + op_name, input_tensors=input_tensors, output_tensors=input_tensors, ) for input_tensor, output_tensor in zip(input_tensors, output_tensors): - self._check_tensor_metadata_match(input_tensor, output_tensor, "WHILE", "loop state") + self._check_tensor_metadata_match(input_tensor, output_tensor, op_name, "loop state") cond_output = cond_subgraph.Tensors(int(cond_subgraph.Outputs(0))) - self._require_scalar_bool_tensor(cond_output, "WHILE") + self._require_scalar_bool_tensor(cond_output, op_name) cond_func = self._lower_subgraph_to_function( cond_subgraph_index, - f"tflite_while_cond_subgraph_{cond_subgraph_index}", - op_name="WHILE", + f"{function_prefix}_cond_subgraph_{cond_subgraph_index}", + op_name=op_name, ) body_func = self._lower_subgraph_to_function( body_subgraph_index, - f"tflite_while_body_subgraph_{body_subgraph_index}", - op_name="WHILE", + f"{function_prefix}_body_subgraph_{body_subgraph_index}", + op_name=op_name, ) loop_gv = self._lower_while_to_function( @@ -2539,11 +2551,25 @@ def convert_while(self, op): cond_func, body_func, body_subgraph, + function_prefix=function_prefix, ) args = [self.get_tensor_expr(tensor) for tensor in input_tensors] return relax.Call(loop_gv, args) + def convert_while(self, op): + """Convert TFLite WHILE to a recursive Relax private function.""" + from tflite.WhileOptions import WhileOptions + + opts = self._get_builtin_options(op, WhileOptions) + return self._convert_while_like( + op, + "WHILE", + int(opts.CondSubgraphIndex()), + int(opts.BodySubgraphIndex()), + "tflite_while", + ) + def convert_call_once(self, op): """Convert TFLite CALL_ONCE for no-op and resource-variable initialization subsets.""" from tflite.CallOnceOptions import CallOnceOptions diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 05a6c1e5e5fa..cc3a84e2fd91 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3695,6 +3695,7 @@ def _get_tflite_schema_enum(enum_name): _tfl_stablehlo_reduce_window_opts = _get_tflite_schema_module("StablehloReduceWindowOptions") _tfl_stablehlo_scatter_opts = _get_tflite_schema_module("StablehloScatterOptions") _tfl_stablehlo_sort_opts = _get_tflite_schema_module("StablehloSortOptions") +_tfl_stablehlo_while_opts = _get_tflite_schema_module("StablehloWhileOptions") _tfl_call_options = _get_tflite_schema_module("CallOptions") _tfl_call_once_options = _get_tflite_schema_module("CallOnceOptions") _tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata") @@ -3946,6 +3947,17 @@ def _build_while_options(builder, cond_subgraph_index, body_subgraph_index): return _tfl_while_options.WhileOptionsEnd(builder) +def _build_stablehlo_while_options(builder, cond_subgraph_index, body_subgraph_index): + _tfl_stablehlo_while_opts.StablehloWhileOptionsStart(builder) + _tfl_stablehlo_while_opts.StablehloWhileOptionsAddCondSubgraphIndex( + builder, cond_subgraph_index + ) + _tfl_stablehlo_while_opts.StablehloWhileOptionsAddBodySubgraphIndex( + builder, body_subgraph_index + ) + return _tfl_stablehlo_while_opts.StablehloWhileOptionsEnd(builder) + + def _build_call_once_options(builder, init_subgraph_index): _tfl_call_once_options.CallOnceOptionsStart(builder) _tfl_call_once_options.CallOnceOptionsAddInitSubgraphIndex(builder, init_subgraph_index) @@ -6296,6 +6308,107 @@ def _build_stablehlo_scatter_model(reducer_name="STABLEHLO_ADD", update_window_d ) +def _build_stablehlo_while_model( + cond_subgraph_index=1, + body_subgraph_index=2, + cond_output_type=_tfl_tensor_type.BOOL, + cond_input_type=_tfl_tensor_type.INT32, + body_outputs=None, + body_input_type=_tfl_tensor_type.INT32, + body_output_type=_tfl_tensor_type.INT32, + main_output_type=_tfl_tensor_type.INT32, +): + """Build a STABLEHLO_WHILE model incrementing an int32 scalar until i < 3 is false.""" + builder = flatbuffers.Builder(1024) + + body_outputs = [2] if body_outputs is None else body_outputs + while_options = _build_stablehlo_while_options( + builder, cond_subgraph_index, body_subgraph_index + ) + _tfl_stablehlo_compare_opts.StablehloCompareOptionsStart(builder) + _tfl_stablehlo_compare_opts.StablehloCompareOptionsAddComparisonDirection( + builder, + _tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_LT, + ) + compare_opts = _tfl_stablehlo_compare_opts.StablehloCompareOptionsEnd(builder) + one = np.array(1, dtype=np.int32) + three = np.array(3, dtype=np.int32) + + main_tensors = [ + _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 3, [], tensor_type=main_output_type), + ] + main_while = _build_operator( + builder, + 0, + [0], + [1], + builtin_options2_type=_tfl_builtin_options2.StablehloWhileOptions, + builtin_options2=while_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=main_tensors, + operators=[main_while], + inputs=[0], + outputs=[1], + ) + + cond_tensors = [ + _build_tensor(builder, 0, [], tensor_type=cond_input_type), + _build_tensor(builder, 1, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 3, [], tensor_type=cond_output_type), + ] + cond_compare = _build_operator( + builder, + 1, + [0, 1], + [2], + builtin_options2_type=_tfl_builtin_options2.StablehloCompareOptions, + builtin_options2=compare_opts, + ) + cond_subgraph = _build_subgraph( + builder, + tensors=cond_tensors, + operators=[cond_compare], + inputs=[0], + outputs=[2], + ) + + body_tensors = [ + _build_tensor(builder, 0, [], tensor_type=body_input_type), + _build_tensor(builder, 2, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 3, [], tensor_type=body_output_type), + ] + body_add = _build_operator(builder, 2, [0, 1], [2]) + body_subgraph = _build_subgraph( + builder, + tensors=body_tensors, + operators=[body_add], + inputs=[0], + outputs=body_outputs, + ) + + operator_codes = [ + _build_operator_code(builder, _get_stablehlo_builtin_operator("STABLEHLO_WHILE")), + _build_operator_code(builder, _get_stablehlo_builtin_operator("STABLEHLO_COMPARE")), + _build_operator_code(builder, _get_stablehlo_builtin_operator("STABLEHLO_ADD")), + ] + buffers = [ + _build_buffer(builder), + _build_buffer(builder, three.tobytes()), + _build_buffer(builder, one.tobytes()), + _build_buffer(builder), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[cond_subgraph, body_subgraph], + operator_codes=operator_codes, + buffers=buffers, + ) + + def _build_stablehlo_composite_model(with_attributes=False, use_main_input_after_composite=False): """Build a STABLEHLO_COMPOSITE model that decomposes to STABLEHLO_NEGATE.""" builder = flatbuffers.Builder(1024) @@ -6699,6 +6812,112 @@ def test_stablehlo_scatter_update_window_unsupported(): from_tflite(tflite_model) +def test_stablehlo_while(): + """TFLite STABLEHLO_WHILE lowers to a recursive Relax private function.""" + mod = _load_model_from_buffer(_build_stablehlo_while_model()) + + @I.ir_module + class Expected: + @R.function(private=True) + def tflite_stablehlo_while_cond_subgraph_1( + tvmgen_tensor_0: R.Tensor((), dtype="int32"), + ) -> R.Tensor((), dtype="bool"): + with R.dataflow(): + gv: R.Tensor((), dtype="bool") = R.less(tvmgen_tensor_0, R.const(3, "int32")) + R.output(gv) + return gv + + @R.function(private=True) + def tflite_stablehlo_while_body_subgraph_2( + tvmgen_tensor_0: R.Tensor((), dtype="int32"), + ) -> R.Tensor((), dtype="int32"): + with R.dataflow(): + gv: R.Tensor((), dtype="int32") = R.add(tvmgen_tensor_0, R.const(1, "int32")) + R.output(gv) + return gv + + @R.function(private=True) + def tflite_stablehlo_while_subgraph_1_2( + tvmgen_tensor_0: R.Tensor((), dtype="int32"), + ) -> R.Tensor((), dtype="int32"): + cls = Expected + while_cond: R.Tensor((), dtype="bool") = cls.tflite_stablehlo_while_cond_subgraph_1( + tvmgen_tensor_0 + ) + if while_cond: + gv: R.Tensor((), dtype="int32") = cls.tflite_stablehlo_while_body_subgraph_2( + tvmgen_tensor_0 + ) + gv1: R.Tensor((), dtype="int32") = cls.tflite_stablehlo_while_subgraph_1_2(gv) + cond_result: R.Tensor((), dtype="int32") = gv1 + else: + cond_result: R.Tensor((), dtype="int32") = tvmgen_tensor_0 + return cond_result + + @R.function + def main( + tvmgen_tensor_0: R.Tensor((), dtype="int32"), + ) -> R.Tensor((), dtype="int32"): + R.func_attr({"num_input": 1}) + cls = Expected + with R.dataflow(): + gv: R.Tensor((), dtype="int32") = cls.tflite_stablehlo_while_subgraph_1_2( + tvmgen_tensor_0 + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_stablehlo_while_non_bool_condition_unsupported(): + """STABLEHLO_WHILE rejects cond subgraphs that do not return scalar bool.""" + with pytest.raises( + tvm.error.OpNotImplemented, match="STABLEHLO_WHILE requires a scalar bool condition" + ): + _load_model_from_buffer( + _build_stablehlo_while_model(cond_output_type=_tfl_tensor_type.INT32) + ) + + +def test_stablehlo_while_invalid_index_unsupported(): + """STABLEHLO_WHILE rejects invalid cond/body subgraph indices before lowering.""" + with pytest.raises( + tvm.error.OpNotImplemented, match="STABLEHLO_WHILE requires a valid subgraph index" + ): + _load_model_from_buffer(_build_stablehlo_while_model(cond_subgraph_index=3)) + + +def test_stablehlo_while_output_count_mismatch_unsupported(): + """STABLEHLO_WHILE rejects body subgraphs whose output arity does not match loop vars.""" + with pytest.raises( + tvm.error.OpNotImplemented, match="STABLEHLO_WHILE subgraph output count mismatch" + ): + _load_model_from_buffer(_build_stablehlo_while_model(body_outputs=[])) + + +def test_stablehlo_while_input_metadata_mismatch_unsupported(): + """STABLEHLO_WHILE rejects cond subgraph inputs whose metadata does not match loop vars.""" + with pytest.raises( + tvm.error.OpNotImplemented, + match="STABLEHLO_WHILE subgraph input tensor metadata mismatch", + ): + _load_model_from_buffer( + _build_stablehlo_while_model(cond_input_type=_tfl_tensor_type.FLOAT32) + ) + + +def test_stablehlo_while_output_metadata_mismatch_unsupported(): + """STABLEHLO_WHILE rejects body outputs whose metadata does not match loop vars.""" + with pytest.raises( + tvm.error.OpNotImplemented, + match="STABLEHLO_WHILE subgraph output tensor metadata mismatch", + ): + _load_model_from_buffer( + _build_stablehlo_while_model(body_output_type=_tfl_tensor_type.FLOAT32) + ) + + def test_stablehlo_composite(): """TFLite StableHLO COMPOSITE inlines a simple decomposition subgraph.""" mod = _load_model_from_buffer(_build_stablehlo_composite_model())