Skip to content

[Relax][Frontend][TFLite] Implement DETECTION_POSTPROCESS tflite operator#19345

Merged
tlopex merged 3 commits into
apache:mainfrom
Aharrypotter:relax-onnx-detection-postprocess
Apr 11, 2026
Merged

[Relax][Frontend][TFLite] Implement DETECTION_POSTPROCESS tflite operator#19345
tlopex merged 3 commits into
apache:mainfrom
Aharrypotter:relax-onnx-detection-postprocess

Conversation

@Aharrypotter

Copy link
Copy Markdown
Contributor

Summary

Changes

  • Operator Registration: Implemented convert_detection_postprocess in python/tvm/relax/frontend/tflite/tflite_frontend.py.
  • Core Logic:
    • Integrated multibox_transform_loc for coordinate decoding and variance scaling.
    • Supported use_regular_nms attribute to switch between all-class NMS and class-agnostic NMS paths.
    • Leveraged all_class_non_max_suppression for efficient box filtering.
  • Output Alignment: Used topk, gather_nd, and where operators to ensure the output tensors (boxes, classes, scores, num_detections) match the TFLite specification in terms of shape and layout.
  • Attribute Validation: Added strict validation for required custom options such as num_classes, max_detections, and scaling factors.

Validation

Verified with linting and pre-commit hooks:

# Lint check
python -m ruff check python/tvm/relax/frontend/tflite/tflite_frontend.py

# Pre-commit checks
python -m pre_commit run --files python/tvm/relax/frontend/tflite/tflite_frontend.py

Result:

  • Passed: All static checks and style guidelines are met.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the DETECTION_POSTPROCESS operator in the TFLite frontend for Relax, including attribute validation and NMS logic. Feedback highlights several critical issues: invalid Python unpacking of relax.Call objects (e.g., from topk), premature access to struct_info.dtype before normalization, a potential shape mismatch in mask generation, and robustness concerns regarding dynamic batch sizes in slicing operations.

Comment thread python/tvm/relax/frontend/tflite/tflite_frontend.py Outdated
Comment thread python/tvm/relax/frontend/tflite/tflite_frontend.py Outdated
Comment thread python/tvm/relax/frontend/tflite/tflite_frontend.py Outdated
Comment thread python/tvm/relax/frontend/tflite/tflite_frontend.py Outdated
Comment thread python/tvm/relax/frontend/tflite/tflite_frontend.py Outdated
Comment thread python/tvm/relax/frontend/tflite/tflite_frontend.py
Comment thread python/tvm/relax/frontend/tflite/tflite_frontend.py Outdated
@tlopex

tlopex commented Apr 4, 2026

Copy link
Copy Markdown
Member

Several issues that need to be addressed:

  1. relax.Call cannot be unpacked directly

Relax ops return relax.Call objects, not Python tuples, so direct unpacking will fail at graph construction time. This applies to at least these places:

  • relax.op.vision.multibox_transform_loc(...)
  • both relax.op.topk(...) calls

Please emit the call first, then extract the fields with relax.TupleGetItem, e.g.

ret = bb.emit(relax.op.vision.multibox_transform_loc(...))
transformed_boxes = relax.TupleGetItem(ret, 0)
transformed_scores = relax.TupleGetItem(ret, 1)
  1. Control-flow / indentation issue

The all_class_non_max_suppression block and the assignments after if use_regular_nms: appear to be indented between the if and else branches. As written, this does not look like valid Python control flow. Please restructure this section so the if and else branches are each self-contained.

  1. strided_slice with dynamic batch_size

Here:

relax.op.strided_slice(top_index_pairs, begin=[0, 0, 1], end=[batch_size, max_detections, 2])
if batch_size is symbolic, embedding it directly in a Python list will not work correctly. Please switch to a symbolic-friendly form, or avoid slicing on the batch dimension directly.

  1. Unrelated formatting change

The single-line reformatting of batch_shape in convert_batch_matmul appears unrelated to this PR. Please revert it to keep the diff focused.

  1. Missing test coverage

Please add an end-to-end correctness test against a TFLite model that uses DETECTION_POSTPROCESS (for example, an SSD MobileNet variant).

@Aharrypotter

Copy link
Copy Markdown
Contributor Author

Currently thinking about how to write appropriate test cases.

@tlopex tlopex left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, please have a look at those gemini reviews, which I think maybe helpful

@tlopex

tlopex commented Apr 10, 2026

Copy link
Copy Markdown
Member

@Aharrypotter Could you have a look and fix the issue?

@Aharrypotter

Copy link
Copy Markdown
Contributor Author

Sure, will push the fix today.

@Aharrypotter Aharrypotter force-pushed the relax-onnx-detection-postprocess branch from 488d252 to a0dc24b Compare April 10, 2026 10:59
@Aharrypotter

Aharrypotter commented Apr 10, 2026

Copy link
Copy Markdown
Contributor Author

I also fixed a bug in vision.py: the legalization path was not preserving the expected TensorFlow-style output format for DETECTION_POSTPROCESS. That created an implementation mismatch rather than just test noise, so I fixed the legalization logic and kept the updated tests aligned with the corrected behavior.

Please take a look @tlopex

@Aharrypotter Aharrypotter force-pushed the relax-onnx-detection-postprocess branch from a0dc24b to d52742c Compare April 10, 2026 11:52
@Aharrypotter

Copy link
Copy Markdown
Contributor Author

I will debug the error later.

…ator

This commit wires up the TFLite_Detection_PostProcess custom operator in the
Relax TFLite frontend. Key changes include:
- Implemented conversion logic using multibox_transform_loc and all_class_non_max_suppression.
- Added support for both regular NMS and class-agnostic NMS paths via 'use_regular_nms'.
- Properly formatted outputs (boxes, classes, scores, num_detections) to match TFLite spec.
- Added strict validation for required custom options (num_classes, scales, etc.).
Wire up the DETECTION_POSTPROCESS frontend lowering more robustly and
cover the main semantic paths in test_frontend_tflite.py.

This change:
- fixes tuple handling and output post-processing in the TFLite frontend
  DETECTION_POSTPROCESS converter
- masks selected scores using the actual selected-score slots before topk
  and aligns class padding dtypes
- updates all_class_non_max_suppression legalization so the tensorflow
  output_format keeps the TOPI 3-tensor contract
- adds DETECTION_POSTPROCESS frontend tests for IR structure, option
  validation, semantic reference coverage, LegalizeOps coverage, and the
  zero-valid-detection case

Validation:
- python -m ruff check python/tvm/relax/frontend/tflite/tflite_frontend.py python/tvm/relax/transform/legalize_ops/vision.py tests/python/relax/test_frontend_tflite.py
- python -m pre_commit run --files python/tvm/relax/frontend/tflite/tflite_frontend.py python/tvm/relax/transform/legalize_ops/vision.py tests/python/relax/test_frontend_tflite.py

Result:
- ruff check passed
- pre-commit passed
@Aharrypotter Aharrypotter force-pushed the relax-onnx-detection-postprocess branch from d52742c to 7f34fca Compare April 11, 2026 04:04
@Aharrypotter

Copy link
Copy Markdown
Contributor Author

The CI test test_detection_postprocess_smoke[regular_nms_multi_batch] failed with:
struct.error: unpack requires a buffer of 4 bytes

Root Cause: FlexBufferDecoder cannot correctly handle FlexBuffer data with byte_width=8.

When use_regular_nms=True, the flexbuffers.Builder generates data using 8-byte width. However, the decoder had hardcoded assumptions in three places:

Location Issue Fix
indirect_jump Only supported byte_width=1,4; assertion failed for 8-byte Support all widths {1,2,4,8}
decode_vector Only extracted type from type byte, ignoring lower 2 bits (bit_width); always used <i (4-byte) format for unpacking Extract both (type >> 2) and (type & 3), select format based on actual byte width
decode_map Unpacked map_size with <i (4-byte), but slice length is byte_width (may be 8) Dynamically select format based on byte_width

Verification: After fix, both test cases decode correctly:

  • basic_fast_nms (byte_width=4): ✓
  • regular_nms_multi_batch (byte_width=8): ✓

@tlopex tlopex left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you

@tlopex tlopex merged commit b14b023 into apache:main Apr 11, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants