Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def __init__(self, model, subgraph, exp_tab, ctx):
"TRANSPOSE_CONV": self.convert_transpose_conv,
"TRANSPOSE": self.convert_transpose,
"UNPACK": self.convert_unpack,
"UNIDIRECTIONAL_SEQUENCE_RNN": self.convert_unidirectional_sequence_rnn,
"UNSORTED_SEGMENT_MIN": functools.partial(
self._convert_segment_op, op_name="UNSORTED_SEGMENT_MIN", reduction="min"
),
Expand Down Expand Up @@ -4477,6 +4478,106 @@ def convert_unpack(self, op):

return squeezed

def convert_unidirectional_sequence_rnn(self, op):
"""Convert TFLite UNIDIRECTIONAL_SEQUENCE_RNN.

Inputs (5 tensors):
[0] input [batch, time, input_size] (or [time, batch, input_size] if time_major)
[1] input_weights [num_units, input_size]
[2] recurrent_weights [num_units, num_units]
[3] bias [num_units]
[4] hidden_state [batch, num_units] (variable, zero-initialised)

Output:
[0] output [batch, time, num_units]

Cell equation:
h_t = fused_activation(x_t @ W.T + h_{t-1} @ Wr.T + b)
"""
from tflite.BuiltinOptions import BuiltinOptions
from tflite.SequenceRNNOptions import SequenceRNNOptions

if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
"TFLite quantized UNIDIRECTIONAL_SEQUENCE_RNN is not supported yet."
)

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 5, "input tensors length should be 5"

input_tensor = input_tensors[0]
weights_tensor = input_tensors[1]
recurrent_tensor = input_tensors[2]
bias_tensor = input_tensors[3]
hidden_state_tensor = input_tensors[4]

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) >= 1, "output tensors length should be at least 1"

assert op.BuiltinOptionsType() == BuiltinOptions.SequenceRNNOptions
op_options = op.BuiltinOptions()
seq_rnn_options = SequenceRNNOptions()
seq_rnn_options.Init(op_options.Bytes, op_options.Pos)
time_major = seq_rnn_options.TimeMajor()
fused_activation_fn = seq_rnn_options.FusedActivationFunction()

# Constant weight/bias expressions.
weights_expr = self.get_tensor_expr(weights_tensor) # [num_units, input_size]
recurrent_expr = self.get_tensor_expr(recurrent_tensor) # [num_units, num_units]

# bias is optional (tensor_idx == -1 when absent); default to zeros.
if bias_tensor.tensor_idx != -1:
bias_expr = self.get_tensor_expr(bias_tensor) # [num_units]
else:
num_units = int(self.get_tensor_shape(weights_tensor)[0])
bias_dtype = self.get_tensor_type_str(weights_tensor.tensor.Type())
bias_expr = relax.op.zeros((num_units,), dtype=bias_dtype)

# Transpose to [input_size, num_units] and [num_units, num_units] for x @ W.T.
w_t = relax.op.permute_dims(weights_expr)
wr_t = relax.op.permute_dims(recurrent_expr)

# Resolve the input expression; normalise to batch-major [batch, time, input_size].
# Only the time dimension must be static (needed for unrolling); batch may be dynamic.
in_expr = self.get_tensor_expr(input_tensor)
in_shape = self.get_tensor_shape(input_tensor)
if time_major:
in_expr = relax.op.permute_dims(in_expr, [1, 0, 2])
num_steps = int(in_shape[0])
else:
num_steps = int(in_shape[1])

# Initial hidden state: use the model's tensor value when available (non-zero init or
# graph input), otherwise fall back to zeros for the common variable-tensor case.
h_dtype = self.get_tensor_type_str(hidden_state_tensor.tensor.Type())
if self.has_expr(hidden_state_tensor.tensor_idx) or (
hidden_state_tensor.buffer is not None and hidden_state_tensor.buffer.DataLength() > 0
):
h = self.get_tensor_expr(hidden_state_tensor)
else:
h_shape = tuple(to_int_list(self.get_tensor_shape(hidden_state_tensor)))
h = relax.op.zeros(h_shape, dtype=h_dtype)

# Unroll over the time axis.
# relax.op.split with 1 section returns the tensor directly; handle uniformly.
if num_steps == 1:
steps = [relax.op.squeeze(in_expr, axis=[1])]
else:
splits = relax.op.split(in_expr, num_steps, axis=1)
steps = [relax.op.squeeze(splits[i], axis=[1]) for i in range(num_steps)]

outputs = []
for x_t in steps: # x_t: [batch, input_size]
gates = relax.op.add(
relax.op.add(relax.op.matmul(x_t, w_t), relax.op.matmul(h, wr_t)),
bias_expr,
)
h = self.convert_fused_activation_function(gates, fused_activation_fn)
outputs.append(h)

# Stack timestep outputs: [batch, time, num_units].
return relax.op.stack(outputs, axis=1)

"""
def convert_unidirectional_sequence_lstm(self, op):
### Long Short Term Memory for TFLite implementation. ###
Expand Down
207 changes: 207 additions & 0 deletions tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3710,6 +3710,8 @@ def _get_tflite_schema_enum(enum_name):
_tfl_sparse_index_vector = _get_tflite_schema_enum("SparseIndexVector")
_tfl_tensor_type = _get_tflite_schema_enum("TensorType")

_tfl_sequence_rnn_options = _get_tflite_schema_module("SequenceRNNOptions")

_DENSIFY_TEST_VALUES = np.array([1.0, 2.0], dtype=np.float32)
_DENSIFY_TEST_DENSE = np.array([[1.0, 0.0], [0.0, 2.0]], dtype=np.float32)
_DENSIFY_ROW_PTRS = [0, 1, 2]
Expand Down Expand Up @@ -6731,5 +6733,210 @@ def main(
tvm.ir.assert_structural_equal(mod, Expected)


# ── UNIDIRECTIONAL_SEQUENCE_RNN ───────────────────────────────────────────────


def _build_unidirectional_sequence_rnn_model(
batch,
time,
input_size,
num_units,
weights,
recurrent_weights,
bias,
activation,
*,
time_major=False,
):
"""Build a minimal TFLite flatbuffer model containing one UNIDIRECTIONAL_SEQUENCE_RNN op.

Tensor layout (indices 0-5):
0 - input [batch, time, input_size] (or [time, batch, input_size] if time_major)
1 - input_weights [num_units, input_size] (constant)
2 - recurrent_wts [num_units, num_units] (constant)
3 - bias [num_units] (constant)
4 - hidden_state [batch, num_units] (variable, zero-initialised)
5 - output [batch, time, num_units]
"""
builder = flatbuffers.Builder(4096)

_tfl_sequence_rnn_options.SequenceRNNOptionsStart(builder)
_tfl_sequence_rnn_options.SequenceRNNOptionsAddTimeMajor(builder, time_major)
_tfl_sequence_rnn_options.SequenceRNNOptionsAddFusedActivationFunction(builder, activation)
rnn_opts = _tfl_sequence_rnn_options.SequenceRNNOptionsEnd(builder)

rnn_op_code = _build_operator_code(builder, _tfl_builtin_operator.UNIDIRECTIONAL_SEQUENCE_RNN)

input_shape = [time, batch, input_size] if time_major else [batch, time, input_size]

def _t(buf_idx, shape, is_variable=False):
shape_vec = _tflite_shape(builder, shape)
_tfl_tensor.TensorStart(builder)
_tfl_tensor.TensorAddBuffer(builder, buf_idx)
_tfl_tensor.TensorAddHasRank(builder, True)
_tfl_tensor.TensorAddIsVariable(builder, is_variable)
_tfl_tensor.TensorAddShape(builder, shape_vec)
_tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32)
return _tfl_tensor.TensorEnd(builder)

tensors = [
_t(0, input_shape),
_t(1, [num_units, input_size]),
_t(2, [num_units, num_units]),
_t(3, [num_units]),
_t(4, [batch, num_units], is_variable=True),
_t(5, [batch, time, num_units]),
]

rnn_op = _build_operator(
builder,
0,
[0, 1, 2, 3, 4],
[5],
builtin_options_type=_tfl_builtin_options.SequenceRNNOptions,
builtin_options=rnn_opts,
)

subgraph = _build_subgraph(
builder,
tensors=tensors,
operators=[rnn_op],
inputs=[0],
outputs=[5],
)

buffers = [
_build_buffer(builder),
_build_buffer(builder, weights.tobytes()),
_build_buffer(builder, recurrent_weights.tobytes()),
_build_buffer(builder, bias.tobytes()),
_build_buffer(builder),
_build_buffer(builder),
]

return _finish_tflite_model(
builder,
subgraph=subgraph,
operator_codes=[rnn_op_code],
buffers=buffers,
)


def test_unidirectional_sequence_rnn_none_activation():
"""UNIDIRECTIONAL_SEQUENCE_RNN with NONE activation, time=1, lowers to matmul/add/stack.

Cell equation: h_t = x_t @ W.T + h_{t-1} @ Wr.T + b (no activation for NONE)
"""
from tflite.ActivationFunctionType import ActivationFunctionType

batch, time, input_size, num_units = 2, 1, 2, 2
weights = np.eye(num_units, input_size, dtype=np.float32)
recurrent_weights = np.eye(num_units, dtype=np.float32)
bias = np.zeros(num_units, dtype=np.float32)

mod = _load_model_from_buffer(
_build_unidirectional_sequence_rnn_model(
batch,
time,
input_size,
num_units,
weights,
recurrent_weights,
bias,
ActivationFunctionType.NONE,
)
)

@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 1, 2), dtype="float32")) -> R.Tensor((2, 1, 2), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor((2, 2), dtype="float32") = R.squeeze(x, axis=[1])
lv1: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
R.const(np.eye(2, dtype=np.float32)), axes=None
)
lv2: R.Tensor((2, 2), dtype="float32") = R.matmul(lv, lv1, out_dtype="void")
lv3: R.Tensor((2, 2), dtype="float32") = R.zeros(R.shape([2, 2]), dtype="float32")
lv4: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
R.const(np.eye(2, dtype=np.float32)), axes=None
)
lv5: R.Tensor((2, 2), dtype="float32") = R.matmul(lv3, lv4, out_dtype="void")
lv6: R.Tensor((2, 2), dtype="float32") = R.add(lv2, lv5)
lv7: R.Tensor((2, 2), dtype="float32") = R.add(
lv6, R.const(np.zeros(2, dtype=np.float32))
)
gv: R.Tensor((2, 1, 2), dtype="float32") = R.stack((lv7,), axis=1)
R.output(gv)
return gv

tvm.ir.assert_structural_equal(mod, Expected)


def test_unidirectional_sequence_rnn_relu_activation():
"""UNIDIRECTIONAL_SEQUENCE_RNN with RELU activation and multiple time steps."""
from tflite.ActivationFunctionType import ActivationFunctionType

batch, time, input_size, num_units = 2, 3, 4, 8
np.random.seed(42)
weights = np.random.randn(num_units, input_size).astype(np.float32)
recurrent_weights = np.random.randn(num_units, num_units).astype(np.float32)
bias = np.random.randn(num_units).astype(np.float32)

mod = _load_model_from_buffer(
_build_unidirectional_sequence_rnn_model(
batch,
time,
input_size,
num_units,
weights,
recurrent_weights,
bias,
ActivationFunctionType.RELU,
)
)

fn = mod["main"]
assert len(fn.params) == 1, "only the sequence input should be a graph input"
in_shape = fn.params[0].struct_info.shape
assert tuple(int(d) for d in in_shape) == (batch, time, input_size)
out_shape = fn.ret_struct_info.shape
assert tuple(int(d) for d in out_shape) == (batch, time, num_units)


def test_unidirectional_sequence_rnn_time_major():
"""UNIDIRECTIONAL_SEQUENCE_RNN with time_major=True transposes before unrolling."""
from tflite.ActivationFunctionType import ActivationFunctionType

batch, time, input_size, num_units = 3, 4, 2, 5
np.random.seed(7)
weights = np.random.randn(num_units, input_size).astype(np.float32)
recurrent_weights = np.random.randn(num_units, num_units).astype(np.float32)
bias = np.zeros(num_units, dtype=np.float32)

mod = _load_model_from_buffer(
_build_unidirectional_sequence_rnn_model(
batch,
time,
input_size,
num_units,
weights,
recurrent_weights,
bias,
ActivationFunctionType.NONE,
time_major=True,
)
)

fn = mod["main"]
# Input to the graph is the raw time-major tensor [time, batch, input_size].
in_shape = fn.params[0].struct_info.shape
assert tuple(int(d) for d in in_shape) == (time, batch, input_size)
# Output is always batch-major [batch, time, num_units].
out_shape = fn.ret_struct_info.shape
assert tuple(int(d) for d in out_shape) == (batch, time, num_units)


if __name__ == "__main__":
pytest.main(["-s", __file__])
Loading