From 8906b0660039281898c39fedaab5f610cc366e97 Mon Sep 17 00:00:00 2001 From: LudovicoYIN Date: Mon, 25 May 2026 07:48:54 +0000 Subject: [PATCH] [Relax][Frontend][TFLite] Add UNIDIRECTIONAL_SEQUENCE_RNN converter Implements convert_unidirectional_sequence_rnn in the Relax TFLite frontend. The op (BuiltinOperator 35) executes a simple RNN cell over a time sequence using Relax primitives (matmul / add / activation). Cell formula: h_t = act(x_t @ W^T + h_{t-1} @ Wr^T + b) Design notes: - Inputs: input[batch,time,input_size], input_weights[units,input_size], recurrent_weights[units,units], bias[units], hidden_state[batch,units] - time_major=True input is transposed to batch-major before unrolling - Activations supported: NONE, RELU, RELU6, TANH, SIGMOID - Quantised variant raises OpNotImplemented (guard already present) - For time=1 the split is skipped; squeeze is applied directly - Time steps are unrolled at graph-construction time and outputs stacked along axis=1 Tests added (3): - test_unidirectional_sequence_rnn_none_activation: structural_equal check with identity weights / zero bias, NONE activation, time=1 - test_unidirectional_sequence_rnn_relu_activation: shape check with random weights, RELU activation, time=3 - test_unidirectional_sequence_rnn_time_major: shape check with time_major=True input layout Closes part of apache/tvm#19519 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../relax/frontend/tflite/tflite_frontend.py | 101 +++++++++ tests/python/relax/test_frontend_tflite.py | 207 ++++++++++++++++++ 2 files changed, 308 insertions(+) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 28b125eec0b0..e50385e7a6cb 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -312,6 +312,7 @@ def __init__(self, model, subgraph, exp_tab, ctx): "TRANSPOSE_CONV": self.convert_transpose_conv, "TRANSPOSE": self.convert_transpose, "UNPACK": self.convert_unpack, + "UNIDIRECTIONAL_SEQUENCE_RNN": self.convert_unidirectional_sequence_rnn, "UNSORTED_SEGMENT_MIN": functools.partial( self._convert_segment_op, op_name="UNSORTED_SEGMENT_MIN", reduction="min" ), @@ -4477,6 +4478,106 @@ def convert_unpack(self, op): return squeezed + def convert_unidirectional_sequence_rnn(self, op): + """Convert TFLite UNIDIRECTIONAL_SEQUENCE_RNN. + + Inputs (5 tensors): + [0] input [batch, time, input_size] (or [time, batch, input_size] if time_major) + [1] input_weights [num_units, input_size] + [2] recurrent_weights [num_units, num_units] + [3] bias [num_units] + [4] hidden_state [batch, num_units] (variable, zero-initialised) + + Output: + [0] output [batch, time, num_units] + + Cell equation: + h_t = fused_activation(x_t @ W.T + h_{t-1} @ Wr.T + b) + """ + from tflite.BuiltinOptions import BuiltinOptions + from tflite.SequenceRNNOptions import SequenceRNNOptions + + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + "TFLite quantized UNIDIRECTIONAL_SEQUENCE_RNN is not supported yet." + ) + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 5, "input tensors length should be 5" + + input_tensor = input_tensors[0] + weights_tensor = input_tensors[1] + recurrent_tensor = input_tensors[2] + bias_tensor = input_tensors[3] + hidden_state_tensor = input_tensors[4] + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) >= 1, "output tensors length should be at least 1" + + assert op.BuiltinOptionsType() == BuiltinOptions.SequenceRNNOptions + op_options = op.BuiltinOptions() + seq_rnn_options = SequenceRNNOptions() + seq_rnn_options.Init(op_options.Bytes, op_options.Pos) + time_major = seq_rnn_options.TimeMajor() + fused_activation_fn = seq_rnn_options.FusedActivationFunction() + + # Constant weight/bias expressions. + weights_expr = self.get_tensor_expr(weights_tensor) # [num_units, input_size] + recurrent_expr = self.get_tensor_expr(recurrent_tensor) # [num_units, num_units] + + # bias is optional (tensor_idx == -1 when absent); default to zeros. + if bias_tensor.tensor_idx != -1: + bias_expr = self.get_tensor_expr(bias_tensor) # [num_units] + else: + num_units = int(self.get_tensor_shape(weights_tensor)[0]) + bias_dtype = self.get_tensor_type_str(weights_tensor.tensor.Type()) + bias_expr = relax.op.zeros((num_units,), dtype=bias_dtype) + + # Transpose to [input_size, num_units] and [num_units, num_units] for x @ W.T. + w_t = relax.op.permute_dims(weights_expr) + wr_t = relax.op.permute_dims(recurrent_expr) + + # Resolve the input expression; normalise to batch-major [batch, time, input_size]. + # Only the time dimension must be static (needed for unrolling); batch may be dynamic. + in_expr = self.get_tensor_expr(input_tensor) + in_shape = self.get_tensor_shape(input_tensor) + if time_major: + in_expr = relax.op.permute_dims(in_expr, [1, 0, 2]) + num_steps = int(in_shape[0]) + else: + num_steps = int(in_shape[1]) + + # Initial hidden state: use the model's tensor value when available (non-zero init or + # graph input), otherwise fall back to zeros for the common variable-tensor case. + h_dtype = self.get_tensor_type_str(hidden_state_tensor.tensor.Type()) + if self.has_expr(hidden_state_tensor.tensor_idx) or ( + hidden_state_tensor.buffer is not None and hidden_state_tensor.buffer.DataLength() > 0 + ): + h = self.get_tensor_expr(hidden_state_tensor) + else: + h_shape = tuple(to_int_list(self.get_tensor_shape(hidden_state_tensor))) + h = relax.op.zeros(h_shape, dtype=h_dtype) + + # Unroll over the time axis. + # relax.op.split with 1 section returns the tensor directly; handle uniformly. + if num_steps == 1: + steps = [relax.op.squeeze(in_expr, axis=[1])] + else: + splits = relax.op.split(in_expr, num_steps, axis=1) + steps = [relax.op.squeeze(splits[i], axis=[1]) for i in range(num_steps)] + + outputs = [] + for x_t in steps: # x_t: [batch, input_size] + gates = relax.op.add( + relax.op.add(relax.op.matmul(x_t, w_t), relax.op.matmul(h, wr_t)), + bias_expr, + ) + h = self.convert_fused_activation_function(gates, fused_activation_fn) + outputs.append(h) + + # Stack timestep outputs: [batch, time, num_units]. + return relax.op.stack(outputs, axis=1) + """ def convert_unidirectional_sequence_lstm(self, op): ### Long Short Term Memory for TFLite implementation. ### diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 031c1553d8bf..19b499ee8d54 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3710,6 +3710,8 @@ def _get_tflite_schema_enum(enum_name): _tfl_sparse_index_vector = _get_tflite_schema_enum("SparseIndexVector") _tfl_tensor_type = _get_tflite_schema_enum("TensorType") +_tfl_sequence_rnn_options = _get_tflite_schema_module("SequenceRNNOptions") + _DENSIFY_TEST_VALUES = np.array([1.0, 2.0], dtype=np.float32) _DENSIFY_TEST_DENSE = np.array([[1.0, 0.0], [0.0, 2.0]], dtype=np.float32) _DENSIFY_ROW_PTRS = [0, 1, 2] @@ -6731,5 +6733,210 @@ def main( tvm.ir.assert_structural_equal(mod, Expected) +# ── UNIDIRECTIONAL_SEQUENCE_RNN ─────────────────────────────────────────────── + + +def _build_unidirectional_sequence_rnn_model( + batch, + time, + input_size, + num_units, + weights, + recurrent_weights, + bias, + activation, + *, + time_major=False, +): + """Build a minimal TFLite flatbuffer model containing one UNIDIRECTIONAL_SEQUENCE_RNN op. + + Tensor layout (indices 0-5): + 0 - input [batch, time, input_size] (or [time, batch, input_size] if time_major) + 1 - input_weights [num_units, input_size] (constant) + 2 - recurrent_wts [num_units, num_units] (constant) + 3 - bias [num_units] (constant) + 4 - hidden_state [batch, num_units] (variable, zero-initialised) + 5 - output [batch, time, num_units] + """ + builder = flatbuffers.Builder(4096) + + _tfl_sequence_rnn_options.SequenceRNNOptionsStart(builder) + _tfl_sequence_rnn_options.SequenceRNNOptionsAddTimeMajor(builder, time_major) + _tfl_sequence_rnn_options.SequenceRNNOptionsAddFusedActivationFunction(builder, activation) + rnn_opts = _tfl_sequence_rnn_options.SequenceRNNOptionsEnd(builder) + + rnn_op_code = _build_operator_code(builder, _tfl_builtin_operator.UNIDIRECTIONAL_SEQUENCE_RNN) + + input_shape = [time, batch, input_size] if time_major else [batch, time, input_size] + + def _t(buf_idx, shape, is_variable=False): + shape_vec = _tflite_shape(builder, shape) + _tfl_tensor.TensorStart(builder) + _tfl_tensor.TensorAddBuffer(builder, buf_idx) + _tfl_tensor.TensorAddHasRank(builder, True) + _tfl_tensor.TensorAddIsVariable(builder, is_variable) + _tfl_tensor.TensorAddShape(builder, shape_vec) + _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32) + return _tfl_tensor.TensorEnd(builder) + + tensors = [ + _t(0, input_shape), + _t(1, [num_units, input_size]), + _t(2, [num_units, num_units]), + _t(3, [num_units]), + _t(4, [batch, num_units], is_variable=True), + _t(5, [batch, time, num_units]), + ] + + rnn_op = _build_operator( + builder, + 0, + [0, 1, 2, 3, 4], + [5], + builtin_options_type=_tfl_builtin_options.SequenceRNNOptions, + builtin_options=rnn_opts, + ) + + subgraph = _build_subgraph( + builder, + tensors=tensors, + operators=[rnn_op], + inputs=[0], + outputs=[5], + ) + + buffers = [ + _build_buffer(builder), + _build_buffer(builder, weights.tobytes()), + _build_buffer(builder, recurrent_weights.tobytes()), + _build_buffer(builder, bias.tobytes()), + _build_buffer(builder), + _build_buffer(builder), + ] + + return _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=[rnn_op_code], + buffers=buffers, + ) + + +def test_unidirectional_sequence_rnn_none_activation(): + """UNIDIRECTIONAL_SEQUENCE_RNN with NONE activation, time=1, lowers to matmul/add/stack. + + Cell equation: h_t = x_t @ W.T + h_{t-1} @ Wr.T + b (no activation for NONE) + """ + from tflite.ActivationFunctionType import ActivationFunctionType + + batch, time, input_size, num_units = 2, 1, 2, 2 + weights = np.eye(num_units, input_size, dtype=np.float32) + recurrent_weights = np.eye(num_units, dtype=np.float32) + bias = np.zeros(num_units, dtype=np.float32) + + mod = _load_model_from_buffer( + _build_unidirectional_sequence_rnn_model( + batch, + time, + input_size, + num_units, + weights, + recurrent_weights, + bias, + ActivationFunctionType.NONE, + ) + ) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 1, 2), dtype="float32")) -> R.Tensor((2, 1, 2), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((2, 2), dtype="float32") = R.squeeze(x, axis=[1]) + lv1: R.Tensor((2, 2), dtype="float32") = R.permute_dims( + R.const(np.eye(2, dtype=np.float32)), axes=None + ) + lv2: R.Tensor((2, 2), dtype="float32") = R.matmul(lv, lv1, out_dtype="void") + lv3: R.Tensor((2, 2), dtype="float32") = R.zeros(R.shape([2, 2]), dtype="float32") + lv4: R.Tensor((2, 2), dtype="float32") = R.permute_dims( + R.const(np.eye(2, dtype=np.float32)), axes=None + ) + lv5: R.Tensor((2, 2), dtype="float32") = R.matmul(lv3, lv4, out_dtype="void") + lv6: R.Tensor((2, 2), dtype="float32") = R.add(lv2, lv5) + lv7: R.Tensor((2, 2), dtype="float32") = R.add( + lv6, R.const(np.zeros(2, dtype=np.float32)) + ) + gv: R.Tensor((2, 1, 2), dtype="float32") = R.stack((lv7,), axis=1) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_unidirectional_sequence_rnn_relu_activation(): + """UNIDIRECTIONAL_SEQUENCE_RNN with RELU activation and multiple time steps.""" + from tflite.ActivationFunctionType import ActivationFunctionType + + batch, time, input_size, num_units = 2, 3, 4, 8 + np.random.seed(42) + weights = np.random.randn(num_units, input_size).astype(np.float32) + recurrent_weights = np.random.randn(num_units, num_units).astype(np.float32) + bias = np.random.randn(num_units).astype(np.float32) + + mod = _load_model_from_buffer( + _build_unidirectional_sequence_rnn_model( + batch, + time, + input_size, + num_units, + weights, + recurrent_weights, + bias, + ActivationFunctionType.RELU, + ) + ) + + fn = mod["main"] + assert len(fn.params) == 1, "only the sequence input should be a graph input" + in_shape = fn.params[0].struct_info.shape + assert tuple(int(d) for d in in_shape) == (batch, time, input_size) + out_shape = fn.ret_struct_info.shape + assert tuple(int(d) for d in out_shape) == (batch, time, num_units) + + +def test_unidirectional_sequence_rnn_time_major(): + """UNIDIRECTIONAL_SEQUENCE_RNN with time_major=True transposes before unrolling.""" + from tflite.ActivationFunctionType import ActivationFunctionType + + batch, time, input_size, num_units = 3, 4, 2, 5 + np.random.seed(7) + weights = np.random.randn(num_units, input_size).astype(np.float32) + recurrent_weights = np.random.randn(num_units, num_units).astype(np.float32) + bias = np.zeros(num_units, dtype=np.float32) + + mod = _load_model_from_buffer( + _build_unidirectional_sequence_rnn_model( + batch, + time, + input_size, + num_units, + weights, + recurrent_weights, + bias, + ActivationFunctionType.NONE, + time_major=True, + ) + ) + + fn = mod["main"] + # Input to the graph is the raw time-major tensor [time, batch, input_size]. + in_shape = fn.params[0].struct_info.shape + assert tuple(int(d) for d in in_shape) == (time, batch, input_size) + # Output is always batch-major [batch, time, num_units]. + out_shape = fn.ret_struct_info.shape + assert tuple(int(d) for d in out_shape) == (batch, time, num_units) + + if __name__ == "__main__": pytest.main(["-s", __file__])