diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 979bbbb867ba..f395c95b6d99 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -154,7 +154,7 @@ class OperatorConverter: } ) - def __init__(self, model, subgraph, exp_tab, ctx): + def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): from tflite.ActivationFunctionType import ActivationFunctionType from tflite.BuiltinOperator import BuiltinOperator from tflite.BuiltinOptions import BuiltinOptions @@ -168,6 +168,17 @@ def __init__(self, model, subgraph, exp_tab, ctx): self.prefetched_nodes = {} self.allow_custom_ops = False self.bb = ctx + if conversion_state is None: + conversion_state = { + "lowered_subgraphs": {}, + "lowered_if_functions": {}, + "lowered_while_functions": {}, + "lowering_stack": [], + "module_builder": ctx, + } + else: + conversion_state.setdefault("module_builder", ctx) + self.conversion_state = conversion_state # Add more operators self.convert_map = { @@ -183,6 +194,8 @@ def __init__(self, model, subgraph, exp_tab, ctx): "BITCAST": self.convert_bitcast, "BROADCAST_TO": self.convert_broadcast_to, "BROADCAST_ARGS": self.convert_broadcast_args, + "CALL": self.convert_call, + "CALL_ONCE": self.convert_call_once, "CAST": self.convert_cast, "CEIL": functools.partial(self._convert_unary_elemwise, relax_op=_op.ceil), "CONCATENATION": self.convert_concatenation, @@ -221,6 +234,7 @@ def __init__(self, model, subgraph, exp_tab, ctx): ), "GELU": self.convert_gelu, "HARD_SWISH": self.convert_hard_swish, + "IF": self.convert_if, "L2_NORMALIZATION": self.convert_l2_normalization, "L2_POOL_2D": functools.partial(self.convert_pool2d, pool_type="l2"), "LEAKY_RELU": self.convert_leaky_relu, @@ -375,6 +389,7 @@ def __init__(self, model, subgraph, exp_tab, ctx): ), # "UNIDIRECTIONAL_SEQUENCE_LSTM": self.convert_unidirectional_sequence_lstm, "WHERE": self.convert_select, + "WHILE": self.convert_while, "ZEROS_LIKE": self.convert_zeros_like, "NON_MAX_SUPPRESSION_V4": self.convert_nms_v4, "NON_MAX_SUPPRESSION_V5": self.convert_nms_v5, @@ -562,7 +577,7 @@ def get_output_tensors(self, op): def get_tensors(self, tensors_idx_list): """Get tensor wrapper list from given TFLite tensor index list""" return_list = list() - for tensor_idx in tensors_idx_list: + for tensor_idx in self._indices_or_empty(tensors_idx_list): if tensor_idx < 0: return_list.append(TensorWrapper(tensor_idx, 0, 0)) continue @@ -1888,6 +1903,417 @@ def _convert_stablehlo_sort(self, op): relax.op.sort(data, axis=int(opts.Dimension()), descending=descending) ) + def _get_builtin_options(self, op, options_cls): + """Parse BuiltinOptions for a TFLite builtin operator.""" + from tflite.BuiltinOptions import BuiltinOptions + + op_options = op.BuiltinOptions() + if op_options is None: + raise tvm.error.OpNotImplemented(f"{options_cls.__name__} is required") + + options_type = getattr(BuiltinOptions, options_cls.__name__, None) + if options_type is not None and op.BuiltinOptionsType() != options_type: + raise tvm.error.OpNotImplemented( + f"Unexpected BuiltinOptions type: expected " + f"{options_cls.__name__}, got {op.BuiltinOptionsType()}" + ) + result = options_cls() + result.Init(op_options.Bytes, op_options.Pos) + return result + + def _get_subgraph(self, subgraph_index, op_name, allow_main=False): + """Return a validated TFLite subgraph by index.""" + if subgraph_index < 0 or subgraph_index >= self.model.SubgraphsLength(): + raise tvm.error.OpNotImplemented(f"{op_name} requires a valid subgraph index") + if not allow_main and subgraph_index == 0: + raise tvm.error.OpNotImplemented(f"{op_name} cannot target the main subgraph") + return self.model.Subgraphs(subgraph_index) + + def _make_tuple_or_single(self, exprs): + """Return a single expression or Relax tuple for a list of expressions.""" + if len(exprs) == 1: + return exprs[0] + return relax.Tuple(exprs) + + def _indices_or_empty(self, indices): + """Return a TFLite index vector, using an empty list for absent vectors.""" + return indices if indices is not None else [] + + def _check_subgraph_io(self, subgraph_index, op_name, input_count=None, output_count=None): + """Validate a referenced subgraph's input and output counts.""" + subgraph = self._get_subgraph(subgraph_index, op_name) + if input_count is not None and subgraph.InputsLength() != input_count: + raise tvm.error.OpNotImplemented(f"{op_name} subgraph input count mismatch") + if output_count is not None and subgraph.OutputsLength() != output_count: + raise tvm.error.OpNotImplemented(f"{op_name} subgraph output count mismatch") + return subgraph + + def _check_subgraph_interface( + self, + subgraph_index, + op_name, + input_tensors=None, + output_tensors=None, + input_count=None, + output_count=None, + ): + """Validate a referenced subgraph's arity and tensor metadata.""" + if input_tensors is not None: + input_count = len(input_tensors) + if output_tensors is not None: + output_count = len(output_tensors) + + subgraph = self._check_subgraph_io( + subgraph_index, op_name, input_count=input_count, output_count=output_count + ) + if input_tensors is not None: + self._check_subgraph_tensor_metadata( + subgraph, + op_name, + "subgraph input", + subgraph.InputsAsNumpy(), + input_tensors, + ) + if output_tensors is not None: + self._check_subgraph_tensor_metadata( + subgraph, + op_name, + "subgraph output", + subgraph.OutputsAsNumpy(), + output_tensors, + ) + return subgraph + + def _get_tensor_metadata(self, tensor): + """Return static shape and dtype metadata for a TFLite tensor.""" + if isinstance(tensor, TensorWrapper): + tensor = tensor.tensor + shape = tuple(tensor.ShapeAsNumpy()) if tensor.ShapeLength() > 0 else () + dtype = self.get_tensor_type_str(tensor.Type()) + return shape, dtype + + def _check_tensor_metadata_match(self, actual, expected, op_name, tensor_role): + """Validate that two TFLite tensors have matching static metadata.""" + if self._get_tensor_metadata(actual) != self._get_tensor_metadata(expected): + raise tvm.error.OpNotImplemented(f"{op_name} {tensor_role} tensor metadata mismatch") + + def _check_subgraph_tensor_metadata( + self, subgraph, op_name, tensor_role, subgraph_indices, expected_tensors + ): + """Validate referenced subgraph tensor metadata against caller tensors.""" + for subgraph_index, expected_tensor in zip( + self._indices_or_empty(subgraph_indices), expected_tensors + ): + self._check_tensor_metadata_match( + subgraph.Tensors(int(subgraph_index)), + expected_tensor, + op_name, + tensor_role, + ) + + def _require_scalar_bool_tensor(self, tensor, op_name): + """Validate that a TFLite tensor is a scalar bool tensor.""" + if isinstance(tensor, TensorWrapper): + tensor = tensor.tensor + dtype = self.get_tensor_type_str(tensor.Type()) + if dtype != "bool" or tensor.ShapeLength() != 0: + raise tvm.error.OpNotImplemented(f"{op_name} requires a scalar bool condition") + + def _get_subgraph_params(self, subgraph): + """Create Relax parameters for a TFLite subgraph.""" + params = [] + exp_tab = ExprTable() + for input_index in self._indices_or_empty(subgraph.InputsAsNumpy()): + tensor = subgraph.Tensors(int(input_index)) + input_name = get_tensor_name(subgraph, int(input_index)) + shape = tuple(tensor.ShapeAsNumpy()) if tensor.ShapeLength() > 0 else [] + dtype = self.get_tensor_type_str(tensor.Type()) + param = relax.Var(input_name, relax.TensorStructInfo(shape=shape, dtype=dtype)) + exp_tab.set_expr(input_name, param) + params.append(param) + return params, exp_tab + + def _get_tensor_param(self, tensor_wrapper): + """Create a Relax parameter from TFLite tensor metadata.""" + name = get_tensor_name(self.subgraph, tensor_wrapper.tensor_idx) + shape = ( + tuple(tensor_wrapper.tensor.ShapeAsNumpy()) + if tensor_wrapper.tensor.ShapeLength() > 0 + else [] + ) + dtype = self.get_tensor_type_str(tensor_wrapper.tensor.Type()) + return relax.Var(name, relax.TensorStructInfo(shape=shape, dtype=dtype)) + + def _lower_subgraph_to_function(self, subgraph_index, function_name_hint, op_name="CALL"): + """Lower a TFLite subgraph into a private Relax function.""" + lowered_subgraphs = self.conversion_state["lowered_subgraphs"] + if subgraph_index in lowered_subgraphs: + return lowered_subgraphs[subgraph_index] + + lowering_stack = self.conversion_state["lowering_stack"] + if subgraph_index in lowering_stack: + raise tvm.error.OpNotImplemented( + f"Recursive TFLite {op_name} subgraphs are not supported" + ) + + subgraph = self._get_subgraph(subgraph_index, op_name) + lowering_stack.append(subgraph_index) + try: + params, subgraph_exp_tab = self._get_subgraph_params(subgraph) + subgraph_bb = relax.BlockBuilder() + with subgraph_bb.function(function_name_hint, params=params, private=True): + with subgraph_bb.dataflow(): + subgraph_converter = type(self)( + self.model, + subgraph, + subgraph_exp_tab, + subgraph_bb, + self.conversion_state, + ) + subgraph_converter.check_unsupported_ops() + subgraph_converter.convert_op_to_relax() + output_tensors = subgraph_converter.get_tensors(subgraph.OutputsAsNumpy()) + outputs = [ + subgraph_converter.get_tensor_expr(tensor) for tensor in output_tensors + ] + output = subgraph_bb.emit_output(self._make_tuple_or_single(outputs)) + subgraph_bb.emit_func_output(output) + + subgraph_mod = subgraph_bb.get() + module_builder = self.conversion_state["module_builder"] + gv = module_builder.add_func(subgraph_mod[function_name_hint], function_name_hint) + lowered_subgraphs[subgraph_index] = gv + return gv + finally: + lowering_stack.pop() + + def _bind_call_outputs(self, call, output_count): + """Return per-output expressions from a single or tuple-valued call.""" + if output_count == 1: + return [call] + return [call[index] for index in range(output_count)] + + def _lower_if_to_function( + self, + then_subgraph_index, + else_subgraph_index, + input_tensors, + branch_input_count, + output_count, + ): + """Lower a TFLite IF op into a private Relax function.""" + cache_key = (then_subgraph_index, else_subgraph_index, branch_input_count, output_count) + lowered_if_functions = self.conversion_state["lowered_if_functions"] + if cache_key in lowered_if_functions: + return lowered_if_functions[cache_key] + + then_func = self._lower_subgraph_to_function( + then_subgraph_index, + f"tflite_if_then_subgraph_{then_subgraph_index}", + op_name="IF", + ) + else_func = self._lower_subgraph_to_function( + else_subgraph_index, + f"tflite_if_else_subgraph_{else_subgraph_index}", + op_name="IF", + ) + if_name = f"tflite_if_subgraph_{then_subgraph_index}_{else_subgraph_index}" + params = [self._get_tensor_param(tensor) for tensor in input_tensors] + cond = params[0] + branch_args = params[1:] + + if_bb = relax.BlockBuilder() + with if_bb.function(if_name, params=params, private=True): + result = relax.If( + cond, + relax.Call(then_func, branch_args), + relax.Call(else_func, branch_args), + ) + if_bb.emit_func_output(result) + if_func = if_bb.get()[if_name] + module_builder = self.conversion_state["module_builder"] + gv = module_builder.add_func(if_func, if_name) + lowered_if_functions[cache_key] = gv + return gv + + def _lower_while_to_function( + self, + cond_subgraph_index, + body_subgraph_index, + loop_var_count, + cond_func, + body_func, + body_subgraph, + ): + """Lower a TFLite WHILE op into a recursive private Relax function.""" + cache_key = (cond_subgraph_index, body_subgraph_index, loop_var_count) + lowered_while_functions = self.conversion_state["lowered_while_functions"] + if cache_key in lowered_while_functions: + return lowered_while_functions[cache_key] + + loop_name = f"tflite_while_subgraph_{cond_subgraph_index}_{body_subgraph_index}" + params, _ = self._get_subgraph_params(body_subgraph) + dummy_body = self._make_tuple_or_single(params) + module_builder = self.conversion_state["module_builder"] + loop_gv = module_builder.add_func(relax.Function(params, dummy_body), loop_name) + lowered_while_functions[cache_key] = loop_gv + + loop_bb = relax.BlockBuilder() + with loop_bb.function(loop_name, params=params, private=True): + cond = loop_bb.emit(relax.Call(cond_func, params), "while_cond") + next_state = relax.Call(body_func, params) + next_args = self._bind_call_outputs(next_state, loop_var_count) + true_branch = relax.Call(loop_gv, next_args) + false_branch = self._make_tuple_or_single(params) + result = relax.If(cond, true_branch, false_branch) + loop_bb.emit_func_output(result) + loop_func = loop_bb.get()[loop_name] + module_builder.update_func(loop_gv, loop_func) + return loop_gv + + def convert_call(self, op): + """Convert TFLite CALL to a Relax private function call.""" + from tflite.CallOptions import CallOptions + + opts = self._get_builtin_options(op, CallOptions) + subgraph_index = int(opts.Subgraph()) + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + self._check_subgraph_interface( + subgraph_index, + "CALL", + input_tensors=input_tensors, + output_tensors=output_tensors, + ) + + callee = self._lower_subgraph_to_function( + subgraph_index, f"tflite_call_subgraph_{subgraph_index}", op_name="CALL" + ) + args = [self.get_tensor_expr(tensor) for tensor in input_tensors] + return relax.Call(callee, args) + + def convert_if(self, op): + """Convert TFLite IF to Relax If with private branch functions.""" + from tflite.IfOptions import IfOptions + + opts = self._get_builtin_options(op, IfOptions) + then_subgraph_index = int(opts.ThenSubgraphIndex()) + else_subgraph_index = int(opts.ElseSubgraphIndex()) + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + if len(input_tensors) < 1: + raise tvm.error.OpNotImplemented("IF requires a condition input") + + self._require_scalar_bool_tensor(input_tensors[0], "IF") + branch_input_count = len(input_tensors) - 1 + output_count = len(output_tensors) + branch_input_tensors = input_tensors[1:] + self._check_subgraph_interface( + then_subgraph_index, + "IF", + input_tensors=branch_input_tensors, + output_tensors=output_tensors, + ) + self._check_subgraph_interface( + else_subgraph_index, + "IF", + input_tensors=branch_input_tensors, + output_tensors=output_tensors, + ) + + if_func = self._lower_if_to_function( + then_subgraph_index, + else_subgraph_index, + input_tensors, + branch_input_count, + output_count, + ) + args = [self.get_tensor_expr(tensor) for tensor in input_tensors] + return relax.Call(if_func, args) + + def convert_while(self, op): + """Convert TFLite WHILE to a recursive Relax private function.""" + from tflite.WhileOptions import WhileOptions + + opts = self._get_builtin_options(op, WhileOptions) + cond_subgraph_index = int(opts.CondSubgraphIndex()) + body_subgraph_index = int(opts.BodySubgraphIndex()) + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + loop_var_count = len(input_tensors) + if loop_var_count == 0: + raise tvm.error.OpNotImplemented("WHILE requires loop-carried inputs") + if len(output_tensors) != loop_var_count: + raise tvm.error.OpNotImplemented("WHILE output count must match input count") + + cond_subgraph = self._check_subgraph_interface( + cond_subgraph_index, + "WHILE", + input_tensors=input_tensors, + output_count=1, + ) + body_subgraph = self._check_subgraph_interface( + body_subgraph_index, + "WHILE", + input_tensors=input_tensors, + output_tensors=input_tensors, + ) + for input_tensor, output_tensor in zip(input_tensors, output_tensors): + self._check_tensor_metadata_match(input_tensor, output_tensor, "WHILE", "loop state") + cond_output = cond_subgraph.Tensors(int(cond_subgraph.Outputs(0))) + self._require_scalar_bool_tensor(cond_output, "WHILE") + + cond_func = self._lower_subgraph_to_function( + cond_subgraph_index, + f"tflite_while_cond_subgraph_{cond_subgraph_index}", + op_name="WHILE", + ) + body_func = self._lower_subgraph_to_function( + body_subgraph_index, + f"tflite_while_body_subgraph_{body_subgraph_index}", + op_name="WHILE", + ) + + loop_gv = self._lower_while_to_function( + cond_subgraph_index, + body_subgraph_index, + loop_var_count, + cond_func, + body_func, + body_subgraph, + ) + + args = [self.get_tensor_expr(tensor) for tensor in input_tensors] + return relax.Call(loop_gv, args) + + def convert_call_once(self, op): + """Convert the no-op subset of TFLite CALL_ONCE. + + Non-empty CALL_ONCE init subgraphs are used for resource initialization + side effects in TFLite. The Relax TFLite frontend does not yet support + TFLite resource variable operators, so only the empty no-op form is safe + to import. + """ + from tflite.CallOnceOptions import CallOnceOptions + + opts = self._get_builtin_options(op, CallOnceOptions) + init_subgraph_index = int(opts.InitSubgraphIndex()) + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + if len(input_tensors) != 0 or len(output_tensors) != 0: + raise tvm.error.OpNotImplemented("CALL_ONCE with inputs or outputs is not supported") + + init_subgraph = self._get_subgraph(init_subgraph_index, "CALL_ONCE") + if init_subgraph.InputsLength() != 0 or init_subgraph.OutputsLength() != 0: + raise tvm.error.OpNotImplemented( + "CALL_ONCE with non-empty init subgraph I/O is not supported" + ) + if init_subgraph.OperatorsLength() != 0: + raise tvm.error.OpNotImplemented( + "CALL_ONCE with non-empty init subgraphs is not supported" + ) + return None + def _convert_stablehlo_convert(self, op): """Convert STABLEHLO_CONVERT to Relax (astype). @@ -6201,8 +6627,8 @@ def func(self, data): _dtype_dict.update(dtype_dict) # Only Subgraphs(0) is converted into Relax main. Additional subgraphs are - # region bodies referenced by specific TFLite ops and are consumed by those - # op converters as needed. + # region/control-flow bodies referenced by specific TFLite ops and are + # consumed by those op converters as needed. assert model.SubgraphsLength() >= 1, "TFLite model must contain at least one subgraph" subgraph = model.Subgraphs(0) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index d03de3b6a9c4..be762d5cb4f8 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3695,8 +3695,11 @@ def _get_tflite_schema_enum(enum_name): _tfl_stablehlo_reduce_window_opts = _get_tflite_schema_module("StablehloReduceWindowOptions") _tfl_stablehlo_scatter_opts = _get_tflite_schema_module("StablehloScatterOptions") _tfl_stablehlo_sort_opts = _get_tflite_schema_module("StablehloSortOptions") +_tfl_call_options = _get_tflite_schema_module("CallOptions") +_tfl_call_once_options = _get_tflite_schema_module("CallOnceOptions") _tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata") _tfl_fully_connected_options = _get_tflite_schema_module("FullyConnectedOptions") +_tfl_if_options = _get_tflite_schema_module("IfOptions") _tfl_int32_vector = _get_tflite_schema_module("Int32Vector") _tfl_model = _get_tflite_schema_module("Model") _tfl_operator = _get_tflite_schema_module("Operator") @@ -3705,6 +3708,7 @@ def _get_tflite_schema_enum(enum_name): _tfl_sparsity_parameters = _get_tflite_schema_module("SparsityParameters") _tfl_subgraph = _get_tflite_schema_module("SubGraph") _tfl_tensor = _get_tflite_schema_module("Tensor") +_tfl_while_options = _get_tflite_schema_module("WhileOptions") _tfl_builtin_operator = _get_tflite_schema_enum("BuiltinOperator") _tfl_builtin_options = _get_tflite_schema_enum("BuiltinOptions") @@ -3909,6 +3913,32 @@ def _finish_tflite_model(builder, *, subgraph, operator_codes, buffers, extra_su return bytes(builder.Output()) +def _build_call_options(builder, subgraph_index): + _tfl_call_options.CallOptionsStart(builder) + _tfl_call_options.CallOptionsAddSubgraph(builder, subgraph_index) + return _tfl_call_options.CallOptionsEnd(builder) + + +def _build_if_options(builder, then_subgraph_index, else_subgraph_index): + _tfl_if_options.IfOptionsStart(builder) + _tfl_if_options.IfOptionsAddThenSubgraphIndex(builder, then_subgraph_index) + _tfl_if_options.IfOptionsAddElseSubgraphIndex(builder, else_subgraph_index) + return _tfl_if_options.IfOptionsEnd(builder) + + +def _build_while_options(builder, cond_subgraph_index, body_subgraph_index): + _tfl_while_options.WhileOptionsStart(builder) + _tfl_while_options.WhileOptionsAddCondSubgraphIndex(builder, cond_subgraph_index) + _tfl_while_options.WhileOptionsAddBodySubgraphIndex(builder, body_subgraph_index) + return _tfl_while_options.WhileOptionsEnd(builder) + + +def _build_call_once_options(builder, init_subgraph_index): + _tfl_call_once_options.CallOnceOptionsStart(builder) + _tfl_call_once_options.CallOnceOptionsAddInitSubgraphIndex(builder, init_subgraph_index) + return _tfl_call_once_options.CallOnceOptionsEnd(builder) + + def _load_model_from_buffer(model_bytes): if hasattr(tflite.Model, "Model"): tflite_model = tflite.Model.Model.GetRootAsModel(model_bytes, 0) @@ -3919,6 +3949,1328 @@ def _load_model_from_buffer(model_bytes): return mod +def _get_builtin_operator(builtin_name): + if not hasattr(_tfl_builtin_operator, builtin_name): + pytest.skip(f"TFLite schema does not provide BuiltinOperator.{builtin_name}") + return getattr(_tfl_builtin_operator, builtin_name) + + +def _build_tflite_call_model( + call_subgraph_index=1, + callee_inputs=None, + callee_outputs=None, + callee_output_shape=None, + callee_output_type=None, +): + """Build a TFLite model where main CALLs a subgraph computing x + 1.""" + builder = flatbuffers.Builder(1024) + + callee_inputs = [0] if callee_inputs is None else callee_inputs + callee_outputs = [2] if callee_outputs is None else callee_outputs + callee_output_shape = [2, 2] if callee_output_shape is None else callee_output_shape + callee_output_type = ( + _tfl_tensor_type.FLOAT32 if callee_output_type is None else callee_output_type + ) + call_options = _build_call_options(builder, call_subgraph_index) + one = np.array(1.0, dtype=np.float32) + + main_tensors = [ + _build_tensor(builder, 0, [2, 2]), + _build_tensor(builder, 2, [2, 2]), + ] + main_call = _build_operator( + builder, + 0, + [0], + [1], + builtin_options_type=_tfl_builtin_options.CallOptions, + builtin_options=call_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=main_tensors, + operators=[main_call], + inputs=[0], + outputs=[1], + ) + + callee_tensors = [ + _build_tensor(builder, 0, [2, 2]), + _build_tensor(builder, 1, []), + _build_tensor(builder, 2, callee_output_shape, tensor_type=callee_output_type), + ] + callee_add = _build_operator(builder, 1, [0, 1], [2]) + callee_subgraph = _build_subgraph( + builder, + tensors=callee_tensors, + operators=[callee_add], + inputs=callee_inputs, + outputs=callee_outputs, + ) + + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("CALL")), + _build_operator_code(builder, _get_builtin_operator("ADD")), + ] + buffers = [ + _build_buffer(builder), + _build_buffer(builder, one.tobytes()), + _build_buffer(builder), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[callee_subgraph], + operator_codes=operator_codes, + buffers=buffers, + ) + + +def test_call_subgraph(): + """Test TFLite CALL conversion to a private Relax function.""" + mod = _load_model_from_buffer(_build_tflite_call_model()) + + @I.ir_module + class Expected: + @R.function(private=True) + def tflite_call_subgraph_1( + tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="float32") = R.add( + tvmgen_tensor_0, R.const(1.0, "float32") + ) + R.output(gv) + return gv + + @R.function + def main( + tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = Expected + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="float32") = cls.tflite_call_subgraph_1(tvmgen_tensor_0) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def _build_tflite_multi_output_call_model(): + """Build a TFLite model where CALL returns x + 1 and x - 1.""" + builder = flatbuffers.Builder(1024) + + call_options = _build_call_options(builder, 1) + one = np.array(1.0, dtype=np.float32) + + main_tensors = [ + _build_tensor(builder, 0, [2, 2]), + _build_tensor(builder, 2, [2, 2]), + _build_tensor(builder, 3, [2, 2]), + ] + main_call = _build_operator( + builder, + 0, + [0], + [1, 2], + builtin_options_type=_tfl_builtin_options.CallOptions, + builtin_options=call_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=main_tensors, + operators=[main_call], + inputs=[0], + outputs=[1, 2], + ) + + callee_tensors = [ + _build_tensor(builder, 0, [2, 2]), + _build_tensor(builder, 1, []), + _build_tensor(builder, 2, [2, 2]), + _build_tensor(builder, 3, [2, 2]), + ] + callee_add = _build_operator(builder, 1, [0, 1], [2]) + callee_sub = _build_operator(builder, 2, [0, 1], [3]) + callee_subgraph = _build_subgraph( + builder, + tensors=callee_tensors, + operators=[callee_add, callee_sub], + inputs=[0], + outputs=[2, 3], + ) + + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("CALL")), + _build_operator_code(builder, _get_builtin_operator("ADD")), + _build_operator_code(builder, _get_builtin_operator("SUB")), + ] + buffers = [ + _build_buffer(builder), + _build_buffer(builder, one.tobytes()), + _build_buffer(builder), + _build_buffer(builder), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[callee_subgraph], + operator_codes=operator_codes, + buffers=buffers, + ) + + +def test_call_subgraph_multi_output(): + """Test CALL tuple returns are split and rebound to TFLite output tensors.""" + mod = _load_model_from_buffer(_build_tflite_multi_output_call_model()) + + @I.ir_module + class Expected: + @R.function(private=True) + def tflite_call_subgraph_1( + tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32")): + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="float32") = R.add( + tvmgen_tensor_0, R.const(1.0, "float32") + ) + gv1: R.Tensor((2, 2), dtype="float32") = R.subtract( + tvmgen_tensor_0, R.const(1.0, "float32") + ) + gv2: R.Tuple( + R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") + ) = (gv, gv1) + R.output(gv2) + return gv2 + + @R.function + def main( + tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32")): + R.func_attr({"num_input": 1}) + cls = Expected + with R.dataflow(): + lv: R.Tuple( + R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") + ) = cls.tflite_call_subgraph_1(tvmgen_tensor_0) + lv1: R.Tensor((2, 2), dtype="float32") = lv[0] + lv2: R.Tensor((2, 2), dtype="float32") = lv[1] + gv: R.Tuple( + R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") + ) = (lv1, lv2) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def _build_tflite_nested_call_model(): + """Build a TFLite model where main CALLs subgraph A, which CALLs subgraph B.""" + builder = flatbuffers.Builder(1024) + + main_call_options = _build_call_options(builder, 1) + nested_call_options = _build_call_options(builder, 2) + one = np.array(1.0, dtype=np.float32) + + main_tensors = [ + _build_tensor(builder, 0, [2, 2]), + _build_tensor(builder, 3, [2, 2]), + ] + main_call = _build_operator( + builder, + 0, + [0], + [1], + builtin_options_type=_tfl_builtin_options.CallOptions, + builtin_options=main_call_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=main_tensors, + operators=[main_call], + inputs=[0], + outputs=[1], + ) + + caller_tensors = [ + _build_tensor(builder, 0, [2, 2]), + _build_tensor(builder, 3, [2, 2]), + ] + nested_call = _build_operator( + builder, + 0, + [0], + [1], + builtin_options_type=_tfl_builtin_options.CallOptions, + builtin_options=nested_call_options, + ) + caller_subgraph = _build_subgraph( + builder, + tensors=caller_tensors, + operators=[nested_call], + inputs=[0], + outputs=[1], + ) + + callee_tensors = [ + _build_tensor(builder, 0, [2, 2]), + _build_tensor(builder, 1, []), + _build_tensor(builder, 3, [2, 2]), + ] + callee_add = _build_operator(builder, 1, [0, 1], [2]) + callee_subgraph = _build_subgraph( + builder, + tensors=callee_tensors, + operators=[callee_add], + inputs=[0], + outputs=[2], + ) + + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("CALL")), + _build_operator_code(builder, _get_builtin_operator("ADD")), + ] + buffers = [ + _build_buffer(builder), + _build_buffer(builder, one.tobytes()), + _build_buffer(builder), + _build_buffer(builder), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[caller_subgraph, callee_subgraph], + operator_codes=operator_codes, + buffers=buffers, + ) + + +def test_call_subgraph_nested_call(): + """Test nested CALL subgraphs register all generated private functions.""" + mod = _load_model_from_buffer(_build_tflite_nested_call_model()) + + @I.ir_module + class Expected: + @R.function(private=True) + def tflite_call_subgraph_2( + tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="float32") = R.add( + tvmgen_tensor_0, R.const(1.0, "float32") + ) + R.output(gv) + return gv + + @R.function(private=True) + def tflite_call_subgraph_1( + tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + cls = Expected + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="float32") = cls.tflite_call_subgraph_2(tvmgen_tensor_0) + R.output(gv) + return gv + + @R.function + def main( + tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = Expected + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="float32") = cls.tflite_call_subgraph_1(tvmgen_tensor_0) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_call_subgraph_invalid_index_unsupported(): + """Test CALL rejects invalid subgraph indices before lowering.""" + with pytest.raises(tvm.error.OpNotImplemented, match="CALL requires a valid subgraph index"): + _load_model_from_buffer(_build_tflite_call_model(call_subgraph_index=2)) + + +def test_call_subgraph_io_mismatch_unsupported(): + """Test CALL rejects callees whose input arity does not match the call site.""" + with pytest.raises(tvm.error.OpNotImplemented, match="CALL subgraph input count mismatch"): + _load_model_from_buffer(_build_tflite_call_model(callee_inputs=[])) + + +def test_call_subgraph_output_metadata_mismatch_unsupported(): + """Test CALL rejects callees whose output metadata does not match the call site.""" + with pytest.raises( + tvm.error.OpNotImplemented, match="CALL subgraph output tensor metadata mismatch" + ): + _load_model_from_buffer(_build_tflite_call_model(callee_output_shape=[2])) + + +def _build_tflite_if_model( + condition_type=_tfl_tensor_type.BOOL, + then_subgraph_index=1, + else_subgraph_index=2, + then_outputs=None, + else_outputs=None, + else_input_shape=None, + else_input_type=None, + else_output_shape=None, + else_output_type=None, +): + """Build a TFLite model where IF selects x + 1 or x - 1.""" + builder = flatbuffers.Builder(1024) + + then_outputs = [2] if then_outputs is None else then_outputs + else_outputs = [2] if else_outputs is None else else_outputs + else_input_shape = [2, 2] if else_input_shape is None else else_input_shape + else_input_type = _tfl_tensor_type.FLOAT32 if else_input_type is None else else_input_type + else_output_shape = [2, 2] if else_output_shape is None else else_output_shape + else_output_type = _tfl_tensor_type.FLOAT32 if else_output_type is None else else_output_type + if_options = _build_if_options(builder, then_subgraph_index, else_subgraph_index) + one = np.array(1.0, dtype=np.float32) + + main_tensors = [ + _build_tensor(builder, 0, [], tensor_type=condition_type), + _build_tensor(builder, 1, [2, 2]), + _build_tensor(builder, 3, [2, 2]), + ] + main_if = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options_type=_tfl_builtin_options.IfOptions, + builtin_options=if_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=main_tensors, + operators=[main_if], + inputs=[0, 1], + outputs=[2], + ) + + then_tensors = [ + _build_tensor(builder, 1, [2, 2]), + _build_tensor(builder, 2, []), + _build_tensor(builder, 3, [2, 2]), + ] + then_add = _build_operator(builder, 1, [0, 1], [2]) + then_subgraph = _build_subgraph( + builder, + tensors=then_tensors, + operators=[then_add], + inputs=[0], + outputs=then_outputs, + ) + + else_tensors = [ + _build_tensor(builder, 1, else_input_shape, tensor_type=else_input_type), + _build_tensor(builder, 2, []), + _build_tensor(builder, 3, else_output_shape, tensor_type=else_output_type), + ] + else_sub = _build_operator(builder, 2, [0, 1], [2]) + else_subgraph = _build_subgraph( + builder, + tensors=else_tensors, + operators=[else_sub], + inputs=[0], + outputs=else_outputs, + ) + + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("IF")), + _build_operator_code(builder, _get_builtin_operator("ADD")), + _build_operator_code(builder, _get_builtin_operator("SUB")), + ] + buffers = [ + _build_buffer(builder), + _build_buffer(builder), + _build_buffer(builder, one.tobytes()), + _build_buffer(builder), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[then_subgraph, else_subgraph], + operator_codes=operator_codes, + buffers=buffers, + ) + + +def test_if_subgraphs(): + """Test TFLite IF conversion to Relax If.""" + mod = _load_model_from_buffer(_build_tflite_if_model()) + + @I.ir_module + class Expected: + @R.function(private=True) + def tflite_if_then_subgraph_1( + tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="float32") = R.add( + tvmgen_tensor_0, R.const(1.0, "float32") + ) + R.output(gv) + return gv + + @R.function(private=True) + def tflite_if_else_subgraph_2( + tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="float32") = R.subtract( + tvmgen_tensor_0, R.const(1.0, "float32") + ) + R.output(gv) + return gv + + @R.function(private=True) + def tflite_if_subgraph_1_2( + tvmgen_tensor_0: R.Tensor((), dtype="bool"), + tvmgen_tensor_1: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + cls = Expected + if tvmgen_tensor_0: + gv: R.Tensor((2, 2), dtype="float32") = cls.tflite_if_then_subgraph_1( + tvmgen_tensor_1 + ) + cond_result: R.Tensor((2, 2), dtype="float32") = gv + else: + gv1: R.Tensor((2, 2), dtype="float32") = cls.tflite_if_else_subgraph_2( + tvmgen_tensor_1 + ) + cond_result: R.Tensor((2, 2), dtype="float32") = gv1 + return cond_result + + @R.function + def main( + tvmgen_tensor_0: R.Tensor((), dtype="bool"), + tvmgen_tensor_1: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 2}) + cls = Expected + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="float32") = cls.tflite_if_subgraph_1_2( + tvmgen_tensor_0, tvmgen_tensor_1 + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def _build_tflite_multi_output_if_model(): + """Build a TFLite model where IF returns two tensor outputs.""" + builder = flatbuffers.Builder(1024) + + if_options = _build_if_options(builder, 1, 2) + one = np.array(1.0, dtype=np.float32) + + main_tensors = [ + _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.BOOL), + _build_tensor(builder, 1, [2, 2]), + _build_tensor(builder, 4, [2, 2]), + _build_tensor(builder, 5, [2, 2]), + ] + main_if = _build_operator( + builder, + 0, + [0, 1], + [2, 3], + builtin_options_type=_tfl_builtin_options.IfOptions, + builtin_options=if_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=main_tensors, + operators=[main_if], + inputs=[0, 1], + outputs=[2, 3], + ) + + then_tensors = [ + _build_tensor(builder, 1, [2, 2]), + _build_tensor(builder, 2, []), + _build_tensor(builder, 3, [2, 2]), + _build_tensor(builder, 4, [2, 2]), + ] + then_add = _build_operator(builder, 1, [0, 1], [2]) + then_sub = _build_operator(builder, 2, [0, 1], [3]) + then_subgraph = _build_subgraph( + builder, + tensors=then_tensors, + operators=[then_add, then_sub], + inputs=[0], + outputs=[2, 3], + ) + + else_tensors = [ + _build_tensor(builder, 1, [2, 2]), + _build_tensor(builder, 2, []), + _build_tensor(builder, 3, [2, 2]), + _build_tensor(builder, 4, [2, 2]), + ] + else_sub = _build_operator(builder, 2, [0, 1], [2]) + else_add = _build_operator(builder, 1, [0, 1], [3]) + else_subgraph = _build_subgraph( + builder, + tensors=else_tensors, + operators=[else_sub, else_add], + inputs=[0], + outputs=[2, 3], + ) + + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("IF")), + _build_operator_code(builder, _get_builtin_operator("ADD")), + _build_operator_code(builder, _get_builtin_operator("SUB")), + ] + buffers = [ + _build_buffer(builder), + _build_buffer(builder), + _build_buffer(builder, one.tobytes()), + _build_buffer(builder), + _build_buffer(builder), + _build_buffer(builder), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[then_subgraph, else_subgraph], + operator_codes=operator_codes, + buffers=buffers, + ) + + +def test_if_subgraphs_multi_output(): + """Test IF tuple returns are preserved through the private wrapper function.""" + mod = _load_model_from_buffer(_build_tflite_multi_output_if_model()) + + @I.ir_module + class Expected: + @R.function(private=True) + def tflite_if_then_subgraph_1( + tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32")): + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="float32") = R.add( + tvmgen_tensor_0, R.const(1.0, "float32") + ) + gv1: R.Tensor((2, 2), dtype="float32") = R.subtract( + tvmgen_tensor_0, R.const(1.0, "float32") + ) + gv2: R.Tuple( + R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") + ) = (gv, gv1) + R.output(gv2) + return gv2 + + @R.function(private=True) + def tflite_if_else_subgraph_2( + tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32")): + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="float32") = R.subtract( + tvmgen_tensor_0, R.const(1.0, "float32") + ) + gv1: R.Tensor((2, 2), dtype="float32") = R.add( + tvmgen_tensor_0, R.const(1.0, "float32") + ) + gv2: R.Tuple( + R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") + ) = (gv, gv1) + R.output(gv2) + return gv2 + + @R.function(private=True) + def tflite_if_subgraph_1_2( + tvmgen_tensor_0: R.Tensor((), dtype="bool"), + tvmgen_tensor_1: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32")): + cls = Expected + if tvmgen_tensor_0: + gv: R.Tuple( + R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") + ) = cls.tflite_if_then_subgraph_1(tvmgen_tensor_1) + cond_result: R.Tuple( + R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") + ) = gv + else: + gv1: R.Tuple( + R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") + ) = cls.tflite_if_else_subgraph_2(tvmgen_tensor_1) + cond_result: R.Tuple( + R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") + ) = gv1 + return cond_result + + @R.function + def main( + tvmgen_tensor_0: R.Tensor((), dtype="bool"), + tvmgen_tensor_1: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32")): + R.func_attr({"num_input": 2}) + cls = Expected + with R.dataflow(): + lv: R.Tuple( + R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") + ) = cls.tflite_if_subgraph_1_2(tvmgen_tensor_0, tvmgen_tensor_1) + lv1: R.Tensor((2, 2), dtype="float32") = lv[0] + lv2: R.Tensor((2, 2), dtype="float32") = lv[1] + gv: R.Tuple( + R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") + ) = (lv1, lv2) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_if_subgraphs_non_bool_condition_unsupported(): + """Test IF rejects non-bool condition tensors.""" + with pytest.raises(tvm.error.OpNotImplemented, match="IF requires a scalar bool condition"): + _load_model_from_buffer(_build_tflite_if_model(condition_type=_tfl_tensor_type.INT32)) + + +def test_if_subgraphs_invalid_index_unsupported(): + """Test IF rejects invalid branch subgraph indices before lowering.""" + with pytest.raises(tvm.error.OpNotImplemented, match="IF requires a valid subgraph index"): + _load_model_from_buffer(_build_tflite_if_model(then_subgraph_index=3)) + + +def test_if_subgraphs_output_count_mismatch_unsupported(): + """Test IF rejects branches whose output arity does not match the call site.""" + with pytest.raises(tvm.error.OpNotImplemented, match="IF subgraph output count mismatch"): + _load_model_from_buffer(_build_tflite_if_model(else_outputs=[])) + + +def test_if_subgraphs_input_metadata_mismatch_unsupported(): + """Test IF rejects branches whose input metadata does not match the call site.""" + with pytest.raises( + tvm.error.OpNotImplemented, match="IF subgraph input tensor metadata mismatch" + ): + _load_model_from_buffer(_build_tflite_if_model(else_input_shape=[2])) + + +def test_if_subgraphs_output_metadata_mismatch_unsupported(): + """Test IF rejects branches whose output metadata does not match the call site.""" + with pytest.raises( + tvm.error.OpNotImplemented, match="IF subgraph output tensor metadata mismatch" + ): + _load_model_from_buffer(_build_tflite_if_model(else_output_shape=[2])) + + +def _build_tflite_while_model( + cond_subgraph_index=1, + body_subgraph_index=2, + cond_output_type=_tfl_tensor_type.BOOL, + cond_input_type=_tfl_tensor_type.INT32, + body_outputs=None, + body_input_type=_tfl_tensor_type.INT32, + body_output_type=_tfl_tensor_type.INT32, + main_output_type=_tfl_tensor_type.INT32, +): + """Build a TFLite WHILE model incrementing an int32 scalar until i < 3 is false.""" + builder = flatbuffers.Builder(1024) + + body_outputs = [2] if body_outputs is None else body_outputs + while_options = _build_while_options(builder, cond_subgraph_index, body_subgraph_index) + one = np.array(1, dtype=np.int32) + three = np.array(3, dtype=np.int32) + + main_tensors = [ + _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 3, [], tensor_type=main_output_type), + ] + main_while = _build_operator( + builder, + 0, + [0], + [1], + builtin_options_type=_tfl_builtin_options.WhileOptions, + builtin_options=while_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=main_tensors, + operators=[main_while], + inputs=[0], + outputs=[1], + ) + + cond_tensors = [ + _build_tensor(builder, 0, [], tensor_type=cond_input_type), + _build_tensor(builder, 1, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 3, [], tensor_type=cond_output_type), + ] + cond_less = _build_operator(builder, 1, [0, 1], [2]) + cond_subgraph = _build_subgraph( + builder, + tensors=cond_tensors, + operators=[cond_less], + inputs=[0], + outputs=[2], + ) + + body_tensors = [ + _build_tensor(builder, 0, [], tensor_type=body_input_type), + _build_tensor(builder, 2, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 3, [], tensor_type=body_output_type), + ] + body_add = _build_operator(builder, 2, [0, 1], [2]) + body_subgraph = _build_subgraph( + builder, + tensors=body_tensors, + operators=[body_add], + inputs=[0], + outputs=body_outputs, + ) + + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("WHILE")), + _build_operator_code(builder, _get_builtin_operator("LESS")), + _build_operator_code(builder, _get_builtin_operator("ADD")), + ] + buffers = [ + _build_buffer(builder), + _build_buffer(builder, three.tobytes()), + _build_buffer(builder, one.tobytes()), + _build_buffer(builder), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[cond_subgraph, body_subgraph], + operator_codes=operator_codes, + buffers=buffers, + ) + + +def _build_tflite_repeated_while_model(): + """Build a TFLite model where two WHILE ops share the same cond/body subgraphs.""" + builder = flatbuffers.Builder(1024) + + while_options = _build_while_options(builder, 1, 2) + one = np.array(1, dtype=np.int32) + three = np.array(3, dtype=np.int32) + + main_tensors = [ + _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 3, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 4, [], tensor_type=_tfl_tensor_type.INT32), + ] + main_while_0 = _build_operator( + builder, + 0, + [0], + [1], + builtin_options_type=_tfl_builtin_options.WhileOptions, + builtin_options=while_options, + ) + main_while_1 = _build_operator( + builder, + 0, + [1], + [2], + builtin_options_type=_tfl_builtin_options.WhileOptions, + builtin_options=while_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=main_tensors, + operators=[main_while_0, main_while_1], + inputs=[0], + outputs=[2], + ) + + cond_tensors = [ + _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 1, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 3, [], tensor_type=_tfl_tensor_type.BOOL), + ] + cond_less = _build_operator(builder, 1, [0, 1], [2]) + cond_subgraph = _build_subgraph( + builder, + tensors=cond_tensors, + operators=[cond_less], + inputs=[0], + outputs=[2], + ) + + body_tensors = [ + _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 2, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 3, [], tensor_type=_tfl_tensor_type.INT32), + ] + body_add = _build_operator(builder, 2, [0, 1], [2]) + body_subgraph = _build_subgraph( + builder, + tensors=body_tensors, + operators=[body_add], + inputs=[0], + outputs=[2], + ) + + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("WHILE")), + _build_operator_code(builder, _get_builtin_operator("LESS")), + _build_operator_code(builder, _get_builtin_operator("ADD")), + ] + buffers = [ + _build_buffer(builder), + _build_buffer(builder, three.tobytes()), + _build_buffer(builder, one.tobytes()), + _build_buffer(builder), + _build_buffer(builder), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[cond_subgraph, body_subgraph], + operator_codes=operator_codes, + buffers=buffers, + ) + + +def _build_tflite_zero_var_while_model(): + """Build a TFLite WHILE model with no loop-carried tensors.""" + builder = flatbuffers.Builder(1024) + + while_options = _build_while_options(builder, 1, 2) + main_while = _build_operator( + builder, + 0, + [], + [], + builtin_options_type=_tfl_builtin_options.WhileOptions, + builtin_options=while_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=[], + operators=[main_while], + inputs=[], + outputs=[], + ) + cond_subgraph = _build_subgraph(builder, tensors=[], operators=[], inputs=[], outputs=[]) + body_subgraph = _build_subgraph(builder, tensors=[], operators=[], inputs=[], outputs=[]) + + operator_codes = [_build_operator_code(builder, _get_builtin_operator("WHILE"))] + buffers = [_build_buffer(builder)] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[cond_subgraph, body_subgraph], + operator_codes=operator_codes, + buffers=buffers, + ) + + +def test_while_subgraphs(): + """Test TFLite WHILE conversion to a recursive Relax private function.""" + mod = _load_model_from_buffer(_build_tflite_while_model()) + + @I.ir_module + class Expected: + @R.function(private=True) + def tflite_while_cond_subgraph_1( + tvmgen_tensor_0: R.Tensor((), dtype="int32"), + ) -> R.Tensor((), dtype="bool"): + with R.dataflow(): + gv: R.Tensor((), dtype="bool") = R.less(tvmgen_tensor_0, R.const(3, "int32")) + R.output(gv) + return gv + + @R.function(private=True) + def tflite_while_body_subgraph_2( + tvmgen_tensor_0: R.Tensor((), dtype="int32"), + ) -> R.Tensor((), dtype="int32"): + with R.dataflow(): + gv: R.Tensor((), dtype="int32") = R.add(tvmgen_tensor_0, R.const(1, "int32")) + R.output(gv) + return gv + + @R.function(private=True) + def tflite_while_subgraph_1_2( + tvmgen_tensor_0: R.Tensor((), dtype="int32"), + ) -> R.Tensor((), dtype="int32"): + cls = Expected + while_cond: R.Tensor((), dtype="bool") = cls.tflite_while_cond_subgraph_1( + tvmgen_tensor_0 + ) + if while_cond: + gv: R.Tensor((), dtype="int32") = cls.tflite_while_body_subgraph_2(tvmgen_tensor_0) + gv1: R.Tensor((), dtype="int32") = cls.tflite_while_subgraph_1_2(gv) + cond_result: R.Tensor((), dtype="int32") = gv1 + else: + cond_result: R.Tensor((), dtype="int32") = tvmgen_tensor_0 + return cond_result + + @R.function + def main( + tvmgen_tensor_0: R.Tensor((), dtype="int32"), + ) -> R.Tensor((), dtype="int32"): + R.func_attr({"num_input": 1}) + cls = Expected + with R.dataflow(): + gv: R.Tensor((), dtype="int32") = cls.tflite_while_subgraph_1_2(tvmgen_tensor_0) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_while_subgraphs_repeated_cond_body_pair(): + """Test repeated WHILE ops reuse the same recursive private function.""" + mod = _load_model_from_buffer(_build_tflite_repeated_while_model()) + names = [gv.name_hint for gv in mod.get_global_vars()] + assert names.count("tflite_while_subgraph_1_2") == 1 + + +def _build_tflite_two_var_while_model(): + """Build a TFLite WHILE model with two int32 loop-carried scalar tensors.""" + builder = flatbuffers.Builder(1024) + + while_options = _build_while_options(builder, 1, 2) + one = np.array(1, dtype=np.int32) + three = np.array(3, dtype=np.int32) + + main_tensors = [ + _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 1, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 4, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 5, [], tensor_type=_tfl_tensor_type.INT32), + ] + main_while = _build_operator( + builder, + 0, + [0, 1], + [2, 3], + builtin_options_type=_tfl_builtin_options.WhileOptions, + builtin_options=while_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=main_tensors, + operators=[main_while], + inputs=[0, 1], + outputs=[2, 3], + ) + + cond_tensors = [ + _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 1, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 2, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 4, [], tensor_type=_tfl_tensor_type.BOOL), + ] + cond_less = _build_operator(builder, 1, [0, 2], [3]) + cond_subgraph = _build_subgraph( + builder, + tensors=cond_tensors, + operators=[cond_less], + inputs=[0, 1], + outputs=[3], + ) + + body_tensors = [ + _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 1, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 3, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 4, [], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 5, [], tensor_type=_tfl_tensor_type.INT32), + ] + body_add_i = _build_operator(builder, 2, [0, 2], [3]) + body_add_acc = _build_operator(builder, 2, [1, 0], [4]) + body_subgraph = _build_subgraph( + builder, + tensors=body_tensors, + operators=[body_add_i, body_add_acc], + inputs=[0, 1], + outputs=[3, 4], + ) + + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("WHILE")), + _build_operator_code(builder, _get_builtin_operator("LESS")), + _build_operator_code(builder, _get_builtin_operator("ADD")), + ] + buffers = [ + _build_buffer(builder), + _build_buffer(builder), + _build_buffer(builder, three.tobytes()), + _build_buffer(builder, one.tobytes()), + _build_buffer(builder), + _build_buffer(builder), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[cond_subgraph, body_subgraph], + operator_codes=operator_codes, + buffers=buffers, + ) + + +def test_while_subgraphs_two_loop_vars(): + """Test WHILE tuple loop state with two loop-carried variables.""" + mod = _load_model_from_buffer(_build_tflite_two_var_while_model()) + + @I.ir_module + class Expected: + @R.function(private=True) + def tflite_while_cond_subgraph_1( + tvmgen_tensor_0: R.Tensor((), dtype="int32"), + tvmgen_tensor_1: R.Tensor((), dtype="int32"), + ) -> R.Tensor((), dtype="bool"): + with R.dataflow(): + gv: R.Tensor((), dtype="bool") = R.less(tvmgen_tensor_0, R.const(3, "int32")) + R.output(gv) + return gv + + @R.function(private=True) + def tflite_while_body_subgraph_2( + tvmgen_tensor_0: R.Tensor((), dtype="int32"), + tvmgen_tensor_1: R.Tensor((), dtype="int32"), + ) -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")): + with R.dataflow(): + gv: R.Tensor((), dtype="int32") = R.add(tvmgen_tensor_0, R.const(1, "int32")) + gv1: R.Tensor((), dtype="int32") = R.add(tvmgen_tensor_1, tvmgen_tensor_0) + gv2: R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")) = ( + gv, + gv1, + ) + R.output(gv2) + return gv2 + + @R.function(private=True) + def tflite_while_subgraph_1_2( + tvmgen_tensor_0: R.Tensor((), dtype="int32"), + tvmgen_tensor_1: R.Tensor((), dtype="int32"), + ) -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")): + cls = Expected + while_cond: R.Tensor((), dtype="bool") = cls.tflite_while_cond_subgraph_1( + tvmgen_tensor_0, tvmgen_tensor_1 + ) + if while_cond: + gv: R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")) = ( + cls.tflite_while_body_subgraph_2(tvmgen_tensor_0, tvmgen_tensor_1) + ) + gv1: R.Tensor((), dtype="int32") = gv[0] + gv2: R.Tensor((), dtype="int32") = gv[1] + gv3: R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")) = ( + cls.tflite_while_subgraph_1_2(gv1, gv2) + ) + cond_result: R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")) = gv3 + else: + cond_result: R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")) = ( + tvmgen_tensor_0, + tvmgen_tensor_1, + ) + return cond_result + + @R.function + def main( + tvmgen_tensor_0: R.Tensor((), dtype="int32"), + tvmgen_tensor_1: R.Tensor((), dtype="int32"), + ) -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")): + R.func_attr({"num_input": 2}) + cls = Expected + with R.dataflow(): + lv: R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")) = ( + cls.tflite_while_subgraph_1_2(tvmgen_tensor_0, tvmgen_tensor_1) + ) + lv1: R.Tensor((), dtype="int32") = lv[0] + lv2: R.Tensor((), dtype="int32") = lv[1] + gv: R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")) = ( + lv1, + lv2, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_while_subgraphs_non_bool_condition_unsupported(): + """Test WHILE rejects cond subgraphs that do not return scalar bool.""" + with pytest.raises(tvm.error.OpNotImplemented, match="WHILE requires a scalar bool condition"): + _load_model_from_buffer(_build_tflite_while_model(cond_output_type=_tfl_tensor_type.INT32)) + + +def test_while_subgraphs_invalid_index_unsupported(): + """Test WHILE rejects invalid cond/body subgraph indices before lowering.""" + with pytest.raises(tvm.error.OpNotImplemented, match="WHILE requires a valid subgraph index"): + _load_model_from_buffer(_build_tflite_while_model(cond_subgraph_index=3)) + + +def test_while_subgraphs_zero_loop_vars_unsupported(): + """Test WHILE rejects operators without loop-carried tensors.""" + with pytest.raises(tvm.error.OpNotImplemented, match="WHILE requires loop-carried inputs"): + _load_model_from_buffer(_build_tflite_zero_var_while_model()) + + +def test_while_subgraphs_loop_state_metadata_mismatch_unsupported(): + """Test WHILE rejects loop outputs whose metadata does not match loop inputs.""" + with pytest.raises( + tvm.error.OpNotImplemented, match="WHILE loop state tensor metadata mismatch" + ): + _load_model_from_buffer( + _build_tflite_while_model(main_output_type=_tfl_tensor_type.FLOAT32) + ) + + +def test_while_subgraphs_output_count_mismatch_unsupported(): + """Test WHILE rejects body subgraphs whose output arity does not match loop vars.""" + with pytest.raises(tvm.error.OpNotImplemented, match="WHILE subgraph output count mismatch"): + _load_model_from_buffer(_build_tflite_while_model(body_outputs=[])) + + +def test_while_subgraphs_input_metadata_mismatch_unsupported(): + """Test WHILE rejects cond subgraph inputs whose metadata does not match loop vars.""" + with pytest.raises( + tvm.error.OpNotImplemented, match="WHILE subgraph input tensor metadata mismatch" + ): + _load_model_from_buffer(_build_tflite_while_model(cond_input_type=_tfl_tensor_type.FLOAT32)) + + +def test_while_subgraphs_output_metadata_mismatch_unsupported(): + """Test WHILE rejects body outputs whose metadata does not match loop vars.""" + with pytest.raises( + tvm.error.OpNotImplemented, match="WHILE subgraph output tensor metadata mismatch" + ): + _load_model_from_buffer( + _build_tflite_while_model(body_output_type=_tfl_tensor_type.FLOAT32) + ) + + +def _build_tflite_call_once_model( + init_has_op=False, + init_subgraph_index=1, + call_once_inputs=None, + call_once_outputs=None, + init_inputs=None, + init_outputs=None, +): + """Build a TFLite model with CALL_ONCE and one pass-through output.""" + builder = flatbuffers.Builder(1024) + + call_once_inputs = [] if call_once_inputs is None else call_once_inputs + call_once_outputs = [] if call_once_outputs is None else call_once_outputs + init_inputs = [] if init_inputs is None else init_inputs + init_outputs = [] if init_outputs is None else init_outputs + + call_once_options = _build_call_once_options(builder, init_subgraph_index) + main_tensors = [_build_tensor(builder, 0, [2, 2])] + main_call_once = _build_operator( + builder, + 0, + call_once_inputs, + call_once_outputs, + builtin_options_type=_tfl_builtin_options.CallOnceOptions, + builtin_options=call_once_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=main_tensors, + operators=[main_call_once], + inputs=[0], + outputs=[0], + ) + + if init_has_op: + one = np.array(1.0, dtype=np.float32) + init_tensors = [ + _build_tensor(builder, 0, [2, 2]), + _build_tensor(builder, 1, []), + _build_tensor(builder, 2, [2, 2]), + ] + init_op = _build_operator(builder, 1, [0, 1], [2]) + buffers = [ + _build_buffer(builder), + _build_buffer(builder, one.tobytes()), + _build_buffer(builder), + ] + else: + init_tensors = ( + [_build_tensor(builder, 0, [2, 2])] + if len(init_inputs) != 0 or len(init_outputs) != 0 + else [] + ) + init_op = None + buffers = [_build_buffer(builder)] + + init_subgraph = _build_subgraph( + builder, + tensors=init_tensors, + operators=[] if init_op is None else [init_op], + inputs=init_inputs, + outputs=init_outputs, + ) + + operator_codes = [_build_operator_code(builder, _get_builtin_operator("CALL_ONCE"))] + if init_has_op: + operator_codes.append(_build_operator_code(builder, _get_builtin_operator("ADD"))) + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[init_subgraph], + operator_codes=operator_codes, + buffers=buffers, + ) + + +def test_call_once_empty_init_subgraph(): + """Test the no-op CALL_ONCE subset.""" + mod = _load_model_from_buffer(_build_tflite_call_once_model()) + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="float32") = tvmgen_tensor_0 + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_call_once_non_empty_init_subgraph_unsupported(): + """Test CALL_ONCE rejects init subgraphs with side-effect-like bodies.""" + with pytest.raises(tvm.error.OpNotImplemented, match="CALL_ONCE"): + _load_model_from_buffer(_build_tflite_call_once_model(init_has_op=True)) + + +def test_call_once_inputs_outputs_unsupported(): + """Test CALL_ONCE rejects operator inputs and outputs.""" + with pytest.raises(tvm.error.OpNotImplemented, match="CALL_ONCE with inputs or outputs"): + _load_model_from_buffer( + _build_tflite_call_once_model(call_once_inputs=[0], call_once_outputs=[0]) + ) + + +def test_call_once_init_subgraph_io_unsupported(): + """Test CALL_ONCE rejects init subgraphs with inputs or outputs.""" + with pytest.raises( + tvm.error.OpNotImplemented, match="CALL_ONCE with non-empty init subgraph I/O" + ): + _load_model_from_buffer(_build_tflite_call_once_model(init_inputs=[0], init_outputs=[0])) + + +def test_call_once_invalid_index_unsupported(): + """Test CALL_ONCE rejects invalid init subgraph indices before lowering.""" + with pytest.raises( + tvm.error.OpNotImplemented, match="CALL_ONCE requires a valid subgraph index" + ): + _load_model_from_buffer(_build_tflite_call_once_model(init_subgraph_index=2)) + + def _get_stablehlo_builtin_operator(builtin_name): if not hasattr(_tfl_builtin_operator, builtin_name): pytest.skip(f"TFLite schema does not provide BuiltinOperator.{builtin_name}")