[Relax][Frontend][TFLite] Add support for FFT/complex operators: REAL, IMAG, COMPLEX_ABS#19763
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds support for converting TFLite REAL, IMAG, and COMPLEX_ABS operators to TVM Relax, representing complex64 tensors as float32 tensors with an extra trailing dimension of size 2. It also adds a placeholder for RFFT2D which raises an unimplemented error. The review feedback suggests using negative indexing (axes=[-1] and axis=[-1]) instead of calculating the last axis dynamically using the tensor's ndim to prevent potential errors with dynamic or unknown ranks, along with corresponding updates to the test assertions.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
tlopex
left a comment
There was a problem hiding this comment.
Thanks for working on this. I think this needs a small fix before merge.
The complex64 -> float32[..., 2] lowering is currently applied in two different places, which can produce wrong input shapes when callers pass shape_dict or dtype_dict overrides. _input_type() already converts complex inputs to float32[..., 2], then from_tflite() applies user overrides and may append another trailing 2 if the overridden dtype is complex64.
For example:
dtype_dict={"x": "complex64"}can turn an already packed(2, 4, 2)shape into(2, 4, 2, 2).shape_dict={"x": (2, 4)}overwrites the packed default shape while the dtype remainsfloat32, soREAL/IMAGslice the wrong last axis.
I think the safer structure is to keep _input_type() returning the logical TFLite shape/dtype, apply user overrides, and then lower complex64 to float32[..., 2] exactly once.
A couple of related scope/test issues:
RFFT2Dis registered inconvert_map, butconvert_rfft2d()always raisesOpNotImplemented. The PR title/body says RFFT2D support and mentions an RFFT2D test, but the patch only adds tests forREAL,IMAG, andCOMPLEX_ABS. Please either implement it, or keep it clearly unsupported and add a negative test / update the PR description.- The new complex representation is only wired for model inputs. Constant or non-input complex tensors still go through the generic tensor path, which does not decode
TensorType.COMPLEX64. If input-only support is intentional for this PR, please make that scope explicit and cover it with tests.
|
Thanks for the feedback. Double lowering fix: Removed the Constant tensor support: Added RFFT2D: Removed from this PR entirely. An O(N²) matmul decomposition using existing Relax ops ( |
…, IMAG, COMPLEX_ABS (apache#19763) Part of apache#19519 This PR adds support for the FFT and complex operator family in the Relax TFLite frontend. **Key implementations:** - Registered `REAL`, `IMAG`, `COMPLEX_ABS`to the TFLite op map. - Implemented `convert_real` and `convert_imag` which extract the real and imaginary parts of a complex tensor via `strided_slice` + `squeeze` along the last axis. - Implemented `convert_complex_abs` which computes `sqrt(re^2 + im^2)` using elementwise Relax ops. - All three ops adopt a unified representation convention: TFLite `complex64` tensors (which have no native Relax dtype equivalent) are represented as `float32[..., 2]`, where the last axis holds `(real, imaginary)` interleaved.. **Out of scope:** - `RFFT2D` is not registered in this PR. An O(N²) matmul decomposition is feasible using existing Relax ops and will be contributed separately with benchmarks showing the performance gap versus a native FFT op. A native `relax.op.signal.rfft2d` is tracked in apache#19764 **Testing:** - Added structural equality tests for `REAL`, `IMAG`, and `COMPLEX_ABS` in `test_frontend_tflite.py` following the `verify(TestClass, Expected)` pattern. ```bash python3 -m pytest tests/python/relax/test_frontend_tflite.py -k "test_real or test_imag or test_complex_abs" ```
Import TFLite RFFT2D for the static no-padding/no-truncation path. The op dispatches to a naive O(N^4) DFT primfunc that the caller can override once a faster radix-2 / mixed-radix kernel is wired in. The output of the op is a real/imag pair tensor of shape [..., H, W//2+1, 2] in float32, matching the TFLite convention. The complex pair is appended on the trailing axis to stay consistent with the COMPLEX_ABS / REAL / IMAG imports in apache#19763. Add hand-built FlatBuffers coverage for the static pair-output lowering and the fft_length mismatch guard.
…19812) ## Summary This PR adds Relax TFLite frontend support for the TFLite builtin `RFFT2D` operator (issue #19519 item C — FFT / complex operators). It is the follow-up to upstream PR #19763, which already merged the `REAL` / `IMAG` / `COMPLEX_ABS` subset of item C; this PR completes the subset with `RFFT2D` itself. `RFFT2D` computes a 2D real FFT over the last two input axes and returns a real/imag pair tensor of shape `[..., H, W // 2 + 1, 2]`. Relax does not have a native complex64 dtype, so the pair output is represented as a `float32` tensor with a trailing axis of size 2, matching the convention PR #19763 established for `REAL` / `IMAG` / `COMPLEX_ABS`. `RFFT2D` accepts a real `float32` input and emits a `float32` pair output of the same dtype; the frontend does not need to materialize any COMPLEX64 in-memory representation of its own. The supported subset is the static no-padding / no-truncation path: - the input's last two dimensions must match `fft_length` - `fft_length` must be a length-2 integer constant (int32 or int64) - the output shape is `input_shape[:-2] + (H, W // 2 + 1, 2)` - sparse inputs are rejected ## Design ### Dispatch Between Two TIR Kernels `convert_rfft2d` selects one of two TIR primfuncs at lowering time, based on whether the spatial axes are powers of two: ```python if _is_power_of_2(height) and _is_power_of_2(width): prim_func = _build_tflite_rfft2d_fft_primfunc(input_shape, relax_output_shape) else: prim_func = _build_tflite_rfft2d_primfunc(input_shape, relax_output_shape) ``` Both kernels share the same `call_tir` contract, so downstream code is kernel-agnostic. #### DFT Reference Kernel (`_build_tflite_rfft2d_primfunc`) A naive O(B · H · W · H · W) DFT over the last two input axes. The outer (batch, out_y, out_x) iteration is structured as S-TIR spatial axes so a downstream `tir.schedule` pass can parallelize it. Trig and accumulation are in float32; the result agrees with `np.fft.rfft2` to about `1e-5` absolute tolerance for typical input sizes. This kernel is the fallback for odd or non-power-of-2 spatial sizes. #### Cooley-Tukey FFT Kernel (`_build_tflite_rfft2d_fft_primfunc`) An O(B · H · W · (log2(H) + log2(W))) radix-2 Cooley-Tukey FFT, dispatched when both `H` and `W` are positive powers of 2. The primfunc source is generated as a TIR string at construction time and registered in `linecache` so `tirx.parser` can resolve it. The bit-reversal permutation and butterfly stages are precomputed in Python and inlined as direct scratch-buffer assignments, so all loop bounds and twiddle factors are compile-time literals: 1. Copy the real input into `scratch_real`; initialize `scratch_imag` to 0. 2. For each batch and each row, run an in-place 1D FFT of length `W` along the width axis using scratch buffers. 3. For each batch and each column, run an in-place 1D FFT of length `H` along the height axis with stride `W`. 4. Write the first `W // 2 + 1` complex bins per row to the output pair representation. The bit-reversal swap pairs and butterfly stage bodies are precomputed at primfunc-construction time because `tirx.parser` does not currently accept runtime `T.serial` bounds, and twiddle factors must be `T.float32` literals to avoid `Undefined variable` errors. The fake linecache filename `<tflite_rfft2d_fft_primfunc H=8 W=8 outW=5>` is dimension-aware, so generated-source stack traces are readable. ### COMPLEX64 Pair Representation Helpers The frontend represents TFLite COMPLEX64 tensors as float32 real/imag pairs with a trailing axis of size 2 (since Relax has no native complex64 dtype). Four small helpers centralize the rule so future complex ops can plug in without re-implementing the pair-axis layout: - `_is_tflite_complex64_type` — checks whether a TFLite tensor type is `COMPLEX64`. - `_unwrap_tflite_tensor` — unwraps a `TensorWrapper` to the raw `tflite.Tensor`. - `_get_relax_tensor_dtype` — returns the Relax dtype used to represent a TFLite tensor (`"float32"` for COMPLEX64, otherwise the standard mapping). - `_get_relax_tensor_shape` — returns the Relax shape (TFLite shape with a trailing `(2,)` axis appended for COMPLEX64). The 3 callers that construct Relax parameters from TFLite metadata (`_get_static_tensor_shape_dtype` / `_set_subgraph_input_params` / `_get_tensor_param`) now go through these helpers. The pair-axis invariant is documented on `get_tensor_value` and `get_tensor_shape`, which both return the *raw TFLite* shape (no pair axis) for downstream callers that need to compare against the model's declared output shape. ### Boundary Validation `convert_rfft2d` validates rank, dtype, fft_length shape, integer-ness, positivity, fft_length == input spatial shape, output shape agreement, and the absence of sparse inputs before emitting the `call_tir`. Edge cases (sparse, non-integer fft_length, zero/negative fft_length, mismatched fft_length, 1×1 spatial size) each produce a targeted `OpNotImplemented` diagnostic. ## Operator Support | Operator | TFLite options | Relax lowering | Supported subset | |---|---|---|---| | `RFFT2D` | input `float32`, constant length-2 integer `fft_length`, `COMPLEX64` output | `call_tir` to a generated TIR kernel | static no-padding/no-truncation; `H`, `W` arbitrary; Cooley-Tukey dispatched when both are powers of 2 | ## Safety Checks - Non-float32 input raises `OpNotImplemented("RFFT2D input must be float32")`. - Non-COMPLEX64 output raises `OpNotImplemented("RFFT2D output must be COMPLEX64")`. - Sparse inputs raise `OpNotImplemented("RFFT2D does not support sparse inputs")`. - Non-constant `fft_length` raises `OpNotImplemented("RFFT2D requires a constant fft_length")`. - Non-integer `fft_length` raises `OpNotImplemented("RFFT2D fft_length must be an integer tensor")`. - Wrong-shape `fft_length` (not length 2) raises `OpNotImplemented("RFFT2D fft_length must be a length-2 tensor")`. - Non-positive `fft_length` raises `OpNotImplemented("RFFT2D fft_length must be positive")`. - `fft_length` not matching the input's last two dims raises `OpNotImplemented("RFFT2D currently supports fft_length matching the input spatial shape")`. - Mismatched output shape raises `OpNotImplemented("RFFT2D output shape does not match fft_length")`. - Input rank < 2 raises `OpNotImplemented("RFFT2D input rank must be at least 2")`. ## Not Included - `RFFT2D` with `fft_length` not matching the input's last two dimensions (padding / truncation path). - Other complex-data operators: `REAL` / `IMAG` / `COMPLEX_ABS` are already handled by upstream PR #19763 and are out of scope for this PR. - A frontend guard that rejects COMPLEX64 inputs flowing into non-complex ops. With PR #19763 providing the only other complex ops, models that contain only `RFFT2D` (with float32 input / COMPLEX64 output) and the three upstream ops are fully supported. A generic COMPLEX64 guard was intentionally not added here to keep this PR scoped to `RFFT2D`. - User-override validation for `shape_dict` / `dtype_dict` on COMPLEX64 inputs. After PR #19763, the frontend no longer auto-appends a trailing pair axis to user overrides; a user passing the natural TFLite shape without the pair axis will now fall through to the standard metadata mismatch path. - Higher-precision FFT kernel (e.g. SIMD). The float32 twiddle / float32 accumulation paths match `np.fft.rfft2` to `~1e-4` on the Cooley-Tukey path; large spatial dimensions may need a future optimized lowering or backend-specific implementation. ## Tests The tests manually build minimal TFLite flatbuffers, run the frontend, and compare against `np.fft.rfft2`. Edge cases raise `OpNotImplemented`. The DFT-path tests use `atol=1e-5`; the FFT-path tests use `atol=1e-4` because twiddle factors are float32 literals. | Test | Path | Shape | Coverage | |---|---|---|---| | `test_rfft2d_static_pair_output` | DFT | `[2, 4]` | Baseline 2D, even width; also asserts Relax script contains the `tflite_rfft2d` kernel name and pair-output struct-info | | `test_rfft2d_static_pair_output_with_batch` | DFT | `[2, 2, 4]` | Leading batch dims preserved | | `test_rfft2d_odd_width_pair_output` | DFT | `[3, 5]` | Odd width → `W // 2 + 1` output bins | | `test_rfft2d_int64_fft_length` | DFT | `[2, 4]` | INT64 fft_length constant (TFLite schema allows either int32 or int64) | | `test_rfft2d_4d_input_pair_output` | DFT | `[2, 3, 4, 5]` | 4D input with batch and odd width | | `test_rfft2d_minimal_1x1_pair_output` | DFT | `[1, 1]` | Edge case: trivial 1×1 FFT | | `test_rfft2d_mismatched_fft_length_unsupported` | — | `[2, 4]` (fft=`[4, 4]`) | fft_length != input spatial shape guard | | `test_rfft2d_dynamic_fft_length_unsupported` | — | `[2, 4]` | Dynamic/non-constant fft_length guard | | `test_rfft2d_fft_path_4x4` | **FFT** | `[4, 4]` | Smallest power-of-two (4×4) where both row and column FFTs do real work | | `test_rfft2d_fft_path_8x8` | **FFT** | `[8, 8]` | Square 8×8 power-of-two | | `test_rfft2d_fft_path_16x16` | **FFT** | `[16, 16]` | Larger FFT, kernel scaling check | | `test_rfft2d_fft_path_2x2x4x8` | **FFT** | `[2, 2, 4, 8]` | 4D power-of-two with batch | Local validation: ```bash python -m py_compile \ python/tvm/relax/frontend/tflite/tflite_frontend.py \ tests/python/relax/test_frontend_tflite.py python -m ruff check \ python/tvm/relax/frontend/tflite/tflite_frontend.py \ tests/python/relax/test_frontend_tflite.py python -m pytest \ tests/python/relax/test_frontend_tflite.py \ -k rfft2d -v ``` Result: ```text py_compile: passed ruff check: All checks passed pre-commit run --files: passed rfft2d tests: 12 passed ``` ## References - Issue [#19519](#19519) item C: FFT / complex operators (`RFFT2D`, `REAL`, `IMAG`, `COMPLEX_ABS`). - Upstream PR #19763: `[Relax][Frontend][TFLite] Add support for FFT/complex operators: REAL, IMAG, COMPLEX_ABS` (merge commit `9d6e1cf0`). This PR's RFFT2D output pair layout matches the COMPLEX64 pair representation PR #19763 established for `REAL` / `IMAG` / `COMPLEX_ABS`, so downstream ops from PR #19763 can consume RFFT2D output directly. - Tracking issue [#19764](#19764) for the longer-term native `relax.op.signal.rfft2d` / registered TOPI backend path. This PR is a frontend-local lowering for TFLite `RFFT2D`, not the native Relax signal op tracked there.
Part of #19519
This PR adds support for the FFT and complex operator family in the Relax TFLite frontend.
Key implementations:
REAL,IMAG,COMPLEX_ABSto the TFLite op map.convert_realandconvert_imagwhich extract the real and imaginary parts of a complex tensor viastrided_slice+squeezealong the last axis.convert_complex_abswhich computessqrt(re^2 + im^2)using elementwise Relax ops.complex64tensors (which have no native Relax dtype equivalent) are represented asfloat32[..., 2], where the last axis holds(real, imaginary)interleaved..Out of scope:
RFFT2Dis not registered in this PR. An O(N²) matmul decompositionis feasible using existing Relax ops and will be contributed separately
with benchmarks showing the performance gap versus a native FFT op.
A native
relax.op.signal.rfft2dis tracked in [Tracking Issue] [TOPI][Signal][Relax] Add native rfft2d op with C++ registered backend #19764Testing:
REAL,IMAG, andCOMPLEX_ABSintest_frontend_tflite.pyfollowing theverify(TestClass, Expected)pattern.python3 -m pytest tests/python/relax/test_frontend_tflite.py -k "test_real or test_imag or test_complex_abs"