Skip to content

[Relax][Frontend][TFLite] Add support for FFT/complex operators: REAL, IMAG, COMPLEX_ABS#19763

Merged
tlopex merged 3 commits into
apache:mainfrom
fnhirwa:fft
Jun 16, 2026
Merged

[Relax][Frontend][TFLite] Add support for FFT/complex operators: REAL, IMAG, COMPLEX_ABS#19763
tlopex merged 3 commits into
apache:mainfrom
fnhirwa:fft

Conversation

@fnhirwa

@fnhirwa fnhirwa commented Jun 14, 2026

Copy link
Copy Markdown
Contributor

Part of #19519

This PR adds support for the FFT and complex operator family in the Relax TFLite frontend.

Key implementations:

  • Registered REAL, IMAG, COMPLEX_ABSto the TFLite op map.
  • Implemented convert_real and convert_imag which extract the real and imaginary parts of a complex tensor via strided_slice + squeeze along the last axis.
  • Implemented convert_complex_abs which computes sqrt(re^2 + im^2) using elementwise Relax ops.
  • All three ops adopt a unified representation convention: TFLite complex64 tensors (which have no native Relax dtype equivalent) are represented as float32[..., 2], where the last axis holds (real, imaginary) interleaved..

Out of scope:

Testing:

  • Added structural equality tests for REAL, IMAG, and COMPLEX_ABS in test_frontend_tflite.py following the verify(TestClass, Expected) pattern.
python3 -m pytest tests/python/relax/test_frontend_tflite.py -k "test_real or test_imag or test_complex_abs"

@fnhirwa fnhirwa marked this pull request as ready for review June 14, 2026 13:56

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for converting TFLite REAL, IMAG, and COMPLEX_ABS operators to TVM Relax, representing complex64 tensors as float32 tensors with an extra trailing dimension of size 2. It also adds a placeholder for RFFT2D which raises an unimplemented error. The review feedback suggests using negative indexing (axes=[-1] and axis=[-1]) instead of calculating the last axis dynamically using the tensor's ndim to prevent potential errors with dynamic or unknown ranks, along with corresponding updates to the test assertions.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread python/tvm/relax/frontend/tflite/tflite_frontend.py Outdated
Comment thread python/tvm/relax/frontend/tflite/tflite_frontend.py Outdated
Comment thread python/tvm/relax/frontend/tflite/tflite_frontend.py Outdated
Comment thread tests/python/relax/test_frontend_tflite.py Outdated
Comment thread tests/python/relax/test_frontend_tflite.py Outdated
Comment thread tests/python/relax/test_frontend_tflite.py Outdated

@tlopex tlopex left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this. I think this needs a small fix before merge.

The complex64 -> float32[..., 2] lowering is currently applied in two different places, which can produce wrong input shapes when callers pass shape_dict or dtype_dict overrides. _input_type() already converts complex inputs to float32[..., 2], then from_tflite() applies user overrides and may append another trailing 2 if the overridden dtype is complex64.

For example:

  • dtype_dict={"x": "complex64"} can turn an already packed (2, 4, 2) shape into (2, 4, 2, 2).
  • shape_dict={"x": (2, 4)} overwrites the packed default shape while the dtype remains float32, so REAL/IMAG slice the wrong last axis.

I think the safer structure is to keep _input_type() returning the logical TFLite shape/dtype, apply user overrides, and then lower complex64 to float32[..., 2] exactly once.

A couple of related scope/test issues:

  • RFFT2D is registered in convert_map, but convert_rfft2d() always raises OpNotImplemented. The PR title/body says RFFT2D support and mentions an RFFT2D test, but the patch only adds tests for REAL, IMAG, and COMPLEX_ABS. Please either implement it, or keep it clearly unsupported and add a negative test / update the PR description.
  • The new complex representation is only wired for model inputs. Constant or non-input complex tensors still go through the generic tensor path, which does not decode TensorType.COMPLEX64. If input-only support is intentional for this PR, please make that scope explicit and cover it with tests.

@fnhirwa fnhirwa changed the title [Relax][Frontend][TFLite] Add support for FFT/complex operators: REAL, IMAG, COMPLEX_ABS, RFFT2D [Relax][Frontend][TFLite] Add support for FFT/complex operators: REAL, IMAG, COMPLEX_ABS Jun 16, 2026
@fnhirwa

fnhirwa commented Jun 16, 2026

Copy link
Copy Markdown
Contributor Author

Thanks for the feedback.

Double lowering fix: Removed the complex64 -> float32[..., 2] transformation from _input_type(). The lowering now happens exactly once in from_tflite() after user overrides are merged. Verified that shape_dict and dtype_dict overrides no longer produce double-packed shapes.

Constant tensor support: Added COMPLEX64 to get_tensor_type_as_numpy, get_tensor_type_str, and get_tensor_expr. Constant complex tensors are now reinterpreted as float32[..., 2] by reading the buffer as np.complex64 and calling .view(np.float32).reshape(shape + (2,)).

RFFT2D: Removed from this PR entirely. An O(N²) matmul decomposition using existing Relax ops (matmul, permute_dims, concat) is feasible as an interim implementation and will be contributed in a follow-up PR. The correct long-term fix is a native relax.op.signal.rfft2d with a C++ registered backend, tracked in #19764.

@fnhirwa fnhirwa requested a review from tlopex June 16, 2026 18:25

@tlopex tlopex left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM Thanks!

@tlopex tlopex merged commit 9d6e1cf into apache:main Jun 16, 2026
11 of 12 checks passed
MasterJH5574 pushed a commit to MasterJH5574/tvm that referenced this pull request Jun 16, 2026
…, IMAG, COMPLEX_ABS (apache#19763)

Part of apache#19519

This PR adds support for the FFT and complex operator family in the
Relax TFLite frontend.

**Key implementations:**

- Registered `REAL`, `IMAG`, `COMPLEX_ABS`to the TFLite op map.
- Implemented `convert_real` and `convert_imag` which extract the real
and imaginary parts of a complex tensor via `strided_slice` + `squeeze`
along the last axis.
- Implemented `convert_complex_abs` which computes `sqrt(re^2 + im^2)`
using elementwise Relax ops.
- All three ops adopt a unified representation convention: TFLite
`complex64` tensors (which have no native Relax dtype equivalent) are
represented as `float32[..., 2]`, where the last axis holds `(real,
imaginary)` interleaved..

**Out of scope:**

- `RFFT2D` is not registered in this PR. An O(N²) matmul decomposition
is feasible using existing Relax ops and will be contributed separately
  with benchmarks showing the performance gap versus a native FFT op.
A native `relax.op.signal.rfft2d` is tracked in
apache#19764

**Testing:**

- Added structural equality tests for `REAL`, `IMAG`, and `COMPLEX_ABS`
in `test_frontend_tflite.py` following the `verify(TestClass, Expected)`
pattern.

```bash
python3 -m pytest tests/python/relax/test_frontend_tflite.py -k "test_real or test_imag or test_complex_abs"
```
Aharrypotter added a commit to Aharrypotter/tvm that referenced this pull request Jun 17, 2026
Import TFLite RFFT2D for the static no-padding/no-truncation path. The
op dispatches to a naive O(N^4) DFT primfunc that the caller can override
once a faster radix-2 / mixed-radix kernel is wired in.

The output of the op is a real/imag pair tensor of shape [..., H, W//2+1, 2]
in float32, matching the TFLite convention. The complex pair is appended
on the trailing axis to stay consistent with the COMPLEX_ABS / REAL / IMAG
imports in apache#19763.

Add hand-built FlatBuffers coverage for the static pair-output lowering
and the fft_length mismatch guard.
tlopex pushed a commit that referenced this pull request Jun 17, 2026
…19812)

## Summary

This PR adds Relax TFLite frontend support for the TFLite builtin
`RFFT2D`
operator (issue #19519 item C — FFT / complex operators). It is the
follow-up to upstream PR #19763, which already merged the `REAL` /
`IMAG` /
`COMPLEX_ABS` subset of item C; this PR completes the subset with
`RFFT2D`
itself.

`RFFT2D` computes a 2D real FFT over the last two input axes and returns
a
real/imag pair tensor of shape `[..., H, W // 2 + 1, 2]`. Relax does not
have
a native complex64 dtype, so the pair output is represented as a
`float32`
tensor with a trailing axis of size 2, matching the convention PR #19763
established for `REAL` / `IMAG` / `COMPLEX_ABS`. `RFFT2D` accepts a real
`float32` input and emits a `float32` pair output of the same dtype; the
frontend does not need to materialize any COMPLEX64 in-memory
representation
of its own.

The supported subset is the static no-padding / no-truncation path:

- the input's last two dimensions must match `fft_length`
- `fft_length` must be a length-2 integer constant (int32 or int64)
- the output shape is `input_shape[:-2] + (H, W // 2 + 1, 2)`
- sparse inputs are rejected

## Design

### Dispatch Between Two TIR Kernels

`convert_rfft2d` selects one of two TIR primfuncs at lowering time,
based on
whether the spatial axes are powers of two:

```python
if _is_power_of_2(height) and _is_power_of_2(width):
    prim_func = _build_tflite_rfft2d_fft_primfunc(input_shape, relax_output_shape)
else:
    prim_func = _build_tflite_rfft2d_primfunc(input_shape, relax_output_shape)
```

Both kernels share the same `call_tir` contract, so downstream code is
kernel-agnostic.

#### DFT Reference Kernel (`_build_tflite_rfft2d_primfunc`)

A naive O(B · H · W · H · W) DFT over the last two input axes. The outer
(batch, out_y, out_x) iteration is structured as S-TIR spatial axes so a
downstream `tir.schedule` pass can parallelize it. Trig and accumulation
are
in float32; the result agrees with `np.fft.rfft2` to about `1e-5`
absolute
tolerance for typical input sizes. This kernel is the fallback for odd
or
non-power-of-2 spatial sizes.

#### Cooley-Tukey FFT Kernel (`_build_tflite_rfft2d_fft_primfunc`)

An O(B · H · W · (log2(H) + log2(W))) radix-2 Cooley-Tukey FFT,
dispatched
when both `H` and `W` are positive powers of 2. The primfunc source is
generated as a TIR string at construction time and registered in
`linecache`
so `tirx.parser` can resolve it. The bit-reversal permutation and
butterfly
stages are precomputed in Python and inlined as direct scratch-buffer
assignments, so all loop bounds and twiddle factors are compile-time
literals:

1. Copy the real input into `scratch_real`; initialize `scratch_imag` to
0.
2. For each batch and each row, run an in-place 1D FFT of length `W`
along
   the width axis using scratch buffers.
3. For each batch and each column, run an in-place 1D FFT of length `H`
   along the height axis with stride `W`.
4. Write the first `W // 2 + 1` complex bins per row to the output pair
   representation.

The bit-reversal swap pairs and butterfly stage bodies are precomputed
at
primfunc-construction time because `tirx.parser` does not currently
accept
runtime `T.serial` bounds, and twiddle factors must be `T.float32`
literals
to avoid `Undefined variable` errors. The fake linecache filename
`<tflite_rfft2d_fft_primfunc H=8 W=8 outW=5>` is dimension-aware, so
generated-source stack traces are readable.

### COMPLEX64 Pair Representation Helpers

The frontend represents TFLite COMPLEX64 tensors as float32 real/imag
pairs
with a trailing axis of size 2 (since Relax has no native complex64
dtype).
Four small helpers centralize the rule so future complex ops can plug in
without re-implementing the pair-axis layout:

- `_is_tflite_complex64_type` — checks whether a TFLite tensor type is
  `COMPLEX64`.
- `_unwrap_tflite_tensor` — unwraps a `TensorWrapper` to the raw
  `tflite.Tensor`.
- `_get_relax_tensor_dtype` — returns the Relax dtype used to represent
a
TFLite tensor (`"float32"` for COMPLEX64, otherwise the standard
mapping).
- `_get_relax_tensor_shape` — returns the Relax shape (TFLite shape with
a
  trailing `(2,)` axis appended for COMPLEX64).

The 3 callers that construct Relax parameters from TFLite metadata
(`_get_static_tensor_shape_dtype` / `_set_subgraph_input_params` /
`_get_tensor_param`) now go through these helpers. The pair-axis
invariant
is documented on `get_tensor_value` and `get_tensor_shape`, which both
return the *raw TFLite* shape (no pair axis) for downstream callers that
need to compare against the model's declared output shape.

### Boundary Validation

`convert_rfft2d` validates rank, dtype, fft_length shape, integer-ness,
positivity, fft_length == input spatial shape, output shape agreement,
and
the absence of sparse inputs before emitting the `call_tir`. Edge cases
(sparse, non-integer fft_length, zero/negative fft_length, mismatched
fft_length, 1×1 spatial size) each produce a targeted `OpNotImplemented`
diagnostic.

## Operator Support

| Operator | TFLite options | Relax lowering | Supported subset |
|---|---|---|---|
| `RFFT2D` | input `float32`, constant length-2 integer `fft_length`,
`COMPLEX64` output | `call_tir` to a generated TIR kernel | static
no-padding/no-truncation; `H`, `W` arbitrary; Cooley-Tukey dispatched
when both are powers of 2 |

## Safety Checks

- Non-float32 input raises `OpNotImplemented("RFFT2D input must be
float32")`.
- Non-COMPLEX64 output raises `OpNotImplemented("RFFT2D output must be
COMPLEX64")`.
- Sparse inputs raise `OpNotImplemented("RFFT2D does not support sparse
inputs")`.
- Non-constant `fft_length` raises `OpNotImplemented("RFFT2D requires a
constant fft_length")`.
- Non-integer `fft_length` raises `OpNotImplemented("RFFT2D fft_length
must be an integer tensor")`.
- Wrong-shape `fft_length` (not length 2) raises
`OpNotImplemented("RFFT2D fft_length must be a length-2 tensor")`.
- Non-positive `fft_length` raises `OpNotImplemented("RFFT2D fft_length
must be positive")`.
- `fft_length` not matching the input's last two dims raises
`OpNotImplemented("RFFT2D currently supports fft_length matching the
input spatial shape")`.
- Mismatched output shape raises `OpNotImplemented("RFFT2D output shape
does not match fft_length")`.
- Input rank < 2 raises `OpNotImplemented("RFFT2D input rank must be at
least 2")`.

## Not Included

- `RFFT2D` with `fft_length` not matching the input's last two
dimensions
  (padding / truncation path).
- Other complex-data operators: `REAL` / `IMAG` / `COMPLEX_ABS` are
already
  handled by upstream PR #19763 and are out of scope for this PR.
- A frontend guard that rejects COMPLEX64 inputs flowing into
non-complex
  ops. With PR #19763 providing the only other complex ops, models that
  contain only `RFFT2D` (with float32 input / COMPLEX64 output) and the
  three upstream ops are fully supported. A generic COMPLEX64 guard was
  intentionally not added here to keep this PR scoped to `RFFT2D`.
- User-override validation for `shape_dict` / `dtype_dict` on COMPLEX64
  inputs. After PR #19763, the frontend no longer auto-appends a
trailing pair axis to user overrides; a user passing the natural TFLite
  shape without the pair axis will now fall through to the standard
  metadata mismatch path.
- Higher-precision FFT kernel (e.g. SIMD). The float32 twiddle / float32
  accumulation paths match `np.fft.rfft2` to `~1e-4` on the Cooley-Tukey
  path; large spatial dimensions may need a future optimized lowering or
  backend-specific implementation.

## Tests

The tests manually build minimal TFLite flatbuffers, run the frontend,
and
compare against `np.fft.rfft2`. Edge cases raise `OpNotImplemented`. The
DFT-path tests use `atol=1e-5`; the FFT-path tests use `atol=1e-4`
because
twiddle factors are float32 literals.

| Test | Path | Shape | Coverage |
|---|---|---|---|
| `test_rfft2d_static_pair_output` | DFT | `[2, 4]` | Baseline 2D, even
width; also asserts Relax script contains the `tflite_rfft2d` kernel
name and pair-output struct-info |
| `test_rfft2d_static_pair_output_with_batch` | DFT | `[2, 2, 4]` |
Leading batch dims preserved |
| `test_rfft2d_odd_width_pair_output` | DFT | `[3, 5]` | Odd width → `W
// 2 + 1` output bins |
| `test_rfft2d_int64_fft_length` | DFT | `[2, 4]` | INT64 fft_length
constant (TFLite schema allows either int32 or int64) |
| `test_rfft2d_4d_input_pair_output` | DFT | `[2, 3, 4, 5]` | 4D input
with batch and odd width |
| `test_rfft2d_minimal_1x1_pair_output` | DFT | `[1, 1]` | Edge case:
trivial 1×1 FFT |
| `test_rfft2d_mismatched_fft_length_unsupported` | — | `[2, 4]`
(fft=`[4, 4]`) | fft_length != input spatial shape guard |
| `test_rfft2d_dynamic_fft_length_unsupported` | — | `[2, 4]` |
Dynamic/non-constant fft_length guard |
| `test_rfft2d_fft_path_4x4` | **FFT** | `[4, 4]` | Smallest
power-of-two (4×4) where both row and column FFTs do real work |
| `test_rfft2d_fft_path_8x8` | **FFT** | `[8, 8]` | Square 8×8
power-of-two |
| `test_rfft2d_fft_path_16x16` | **FFT** | `[16, 16]` | Larger FFT,
kernel scaling check |
| `test_rfft2d_fft_path_2x2x4x8` | **FFT** | `[2, 2, 4, 8]` | 4D
power-of-two with batch |

Local validation:

```bash
python -m py_compile \
  python/tvm/relax/frontend/tflite/tflite_frontend.py \
  tests/python/relax/test_frontend_tflite.py

python -m ruff check \
  python/tvm/relax/frontend/tflite/tflite_frontend.py \
  tests/python/relax/test_frontend_tflite.py

python -m pytest \
  tests/python/relax/test_frontend_tflite.py \
  -k rfft2d -v
```

Result:

```text
py_compile: passed
ruff check: All checks passed
pre-commit run --files: passed
rfft2d tests: 12 passed
```

## References

- Issue [#19519](#19519) item C: FFT
/
  complex operators (`RFFT2D`, `REAL`, `IMAG`, `COMPLEX_ABS`).
- Upstream PR #19763: `[Relax][Frontend][TFLite] Add support for
FFT/complex operators: REAL, IMAG, COMPLEX_ABS` (merge commit
`9d6e1cf0`).
  This PR's RFFT2D output pair layout matches the COMPLEX64 pair
representation PR #19763 established for `REAL` / `IMAG` /
`COMPLEX_ABS`,
  so downstream ops from PR #19763 can consume RFFT2D output directly.
- Tracking issue [#19764](#19764)
for
the longer-term native `relax.op.signal.rfft2d` / registered TOPI
backend
path. This PR is a frontend-local lowering for TFLite `RFFT2D`, not the
  native Relax signal op tracked there.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants