From 2af96e785119949531be150baa13f3dc7dadbd4d Mon Sep 17 00:00:00 2001 From: Aharrypotter <62729549+Aharrypotter@users.noreply.github.com> Date: Sat, 4 Apr 2026 14:51:21 +0800 Subject: [PATCH 1/3] [Relax][Frontend][TFLite] Implement DETECTION_POSTPROCESS custom operator This commit wires up the TFLite_Detection_PostProcess custom operator in the Relax TFLite frontend. Key changes include: - Implemented conversion logic using multibox_transform_loc and all_class_non_max_suppression. - Added support for both regular NMS and class-agnostic NMS paths via 'use_regular_nms'. - Properly formatted outputs (boxes, classes, scores, num_detections) to match TFLite spec. - Added strict validation for required custom options (num_classes, scales, etc.). --- .../relax/frontend/tflite/tflite_frontend.py | 191 +++++++++++++----- 1 file changed, 140 insertions(+), 51 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index b344d9361a7a..a555f56a0f71 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -3225,16 +3225,49 @@ def convert_dequantize(self, op): def convert_detection_postprocess(self, op): """Convert TFLite_Detection_PostProcess""" - raise NotImplementedError( - "DETECTION_POSTPROCESS is not wired in this frontend yet: it still needs " - "Relax NMS / get_valid_counts / related vision helpers (see dead code below). " - "relax.vision.multibox_transform_loc exists; tracking: " - "https://github.com/apache/tvm/issues/18928" - ) flexbuffer = op.CustomOptionsAsNumpy().tobytes() custom_options = FlexBufferDecoder(flexbuffer).decode() - use_regular_nms = "use_regular_nms" in custom_options and custom_options["use_regular_nms"] + use_regular_nms = bool(custom_options.get("use_regular_nms", False)) + + required_attrs = [ + "num_classes", + "max_detections", + "detections_per_class", + "nms_iou_threshold", + "nms_score_threshold", + "x_scale", + "y_scale", + "w_scale", + "h_scale", + ] + missing_attrs = [key for key in required_attrs if key not in custom_options] + if missing_attrs: + raise ValueError( + "DETECTION_POSTPROCESS custom options miss required attributes: " + + ", ".join(missing_attrs) + ) + + num_classes = int(custom_options["num_classes"]) + max_detections = int(custom_options["max_detections"]) + detections_per_class = int(custom_options["detections_per_class"]) + iou_threshold = float(custom_options["nms_iou_threshold"]) + score_threshold = float(custom_options["nms_score_threshold"]) + x_scale = float(custom_options["x_scale"]) + y_scale = float(custom_options["y_scale"]) + w_scale = float(custom_options["w_scale"]) + h_scale = float(custom_options["h_scale"]) + + if num_classes <= 0: + raise ValueError("DETECTION_POSTPROCESS requires num_classes > 0.") + if max_detections <= 0: + raise ValueError("DETECTION_POSTPROCESS requires max_detections > 0.") + if detections_per_class <= 0: + raise ValueError("DETECTION_POSTPROCESS requires detections_per_class > 0.") + if not 0.0 <= iou_threshold <= 1.0: + raise ValueError("DETECTION_POSTPROCESS requires nms_iou_threshold in [0, 1].") + if x_scale <= 0.0 or y_scale <= 0.0 or w_scale <= 0.0 or h_scale <= 0.0: + raise ValueError("DETECTION_POSTPROCESS requires x/y/w/h_scale to be > 0.") inputs = self.get_input_tensors(op) assert len(inputs) == 3, "inputs length should be 3" @@ -3296,18 +3329,16 @@ def convert_detection_postprocess(self, op): # attributes for multibox_transform_loc multibox_transform_loc_attrs = {} multibox_transform_loc_attrs["clip"] = False - multibox_transform_loc_attrs["threshold"] = ( - 0.0 if use_regular_nms else custom_options["nms_score_threshold"] - ) + multibox_transform_loc_attrs["threshold"] = 0.0 if use_regular_nms else score_threshold multibox_transform_loc_attrs["variances"] = ( - 1 / custom_options["x_scale"], - 1 / custom_options["y_scale"], - 1 / custom_options["w_scale"], - 1 / custom_options["h_scale"], + 1 / x_scale, + 1 / y_scale, + 1 / w_scale, + 1 / h_scale, ) multibox_transform_loc_attrs["keep_background"] = use_regular_nms - ret = relax.op.vision.multibox_transform_loc( + transformed_boxes, transformed_scores = relax.op.vision.multibox_transform_loc( # reshape cls_pred so it can be consumed by # multibox_transform_loc relax.op.permute_dims(cls_pred, [0, 2, 1]), @@ -3317,46 +3348,104 @@ def convert_detection_postprocess(self, op): ) if use_regular_nms: - # box coordinates need to be converted from ltrb to (ymin, xmin, ymax, xmax) - _, transformed_boxes = relax.op.split(ret[0], (2,), axis=2) - box_l, box_t, box_r, box_b = relax.op.split(transformed_boxes, 4, axis=2) - transformed_boxes = relax.op.concat([box_t, box_l, box_b, box_r], axis=2) - - return relax.op.vision.regular_non_max_suppression( - boxes=transformed_boxes, - scores=cls_pred, - max_detections_per_class=custom_options["detections_per_class"], - max_detections=custom_options["max_detections"], - num_classes=custom_options["num_classes"], - iou_threshold=custom_options["nms_iou_threshold"], - score_threshold=custom_options["nms_score_threshold"], + nms_out = relax.op.vision.all_class_non_max_suppression( + transformed_boxes, + transformed_scores, + relax.const(detections_per_class, "int64"), + relax.const(iou_threshold, "float32"), + relax.const(score_threshold, "float32"), + output_format="tensorflow", + ) + selected_indices = nms_out[0] + selected_scores = nms_out[1] + num_detections = nms_out[2] + class_id_from_score = None + else: + max_scores, class_id_from_score = relax.op.topk( + transformed_scores, k=1, axis=1, ret_type="both", largest=True ) + nms_out = relax.op.vision.all_class_non_max_suppression( + transformed_boxes, + max_scores, + relax.const(max_detections, "int64"), + relax.const(iou_threshold, "float32"), + relax.const(score_threshold, "float32"), + output_format="tensorflow", + ) + selected_indices = nms_out[0] + selected_scores = nms_out[1] + num_detections = nms_out[2] + class_id_from_score = relax.op.squeeze(class_id_from_score, axis=[1]) - # attributes for non_max_suppression - non_max_suppression_attrs = {} - non_max_suppression_attrs["return_indices"] = False - non_max_suppression_attrs["iou_threshold"] = custom_options["nms_iou_threshold"] - non_max_suppression_attrs["force_suppress"] = True - non_max_suppression_attrs["top_k"] = anchor_boxes - non_max_suppression_attrs["max_output_size"] = custom_options["max_detections"] - non_max_suppression_attrs["invalid_to_bottom"] = False - - ret = relax.op.vision.non_max_suppression( - ret[0], ret[1], ret[1], **non_max_suppression_attrs + num_detections = relax.op.minimum( + num_detections, relax.const(np.array([max_detections], dtype="int64")) + ) + detection_positions = relax.op.expand_dims( + relax.op.arange(max_detections, dtype="int64"), axis=0 + ) + valid_detection_mask = relax.op.less( + detection_positions, relax.op.expand_dims(num_detections, axis=1) + ) + masked_selected_scores = relax.op.where( + valid_detection_mask, + selected_scores, + relax.const(-1.0, selected_scores.struct_info.dtype), + ) + detection_scores, top_positions = relax.op.topk( + masked_selected_scores, k=max_detections, axis=1, ret_type="both", largest=True + ) + top_positions_expanded = relax.op.expand_dims(top_positions, axis=2) + top_positions_for_pairs = relax.op.repeat(top_positions_expanded, 2, axis=2) + top_index_pairs = relax.op.gather_elements( + selected_indices, top_positions_for_pairs, axis=1 + ) + top_box_ids = relax.op.squeeze( + relax.op.strided_slice( + top_index_pairs, begin=[0, 0, 1], end=[batch_size, max_detections, 2] + ), + axis=[2], + ) + top_box_ids_for_gather = relax.op.expand_dims(relax.op.astype(top_box_ids, "int64"), axis=2) + detection_boxes = relax.op.gather_nd( + transformed_boxes, top_box_ids_for_gather, batch_dims=1 + ) + + if use_regular_nms: + detection_classes = relax.op.squeeze( + relax.op.strided_slice( + top_index_pairs, begin=[0, 0, 0], end=[batch_size, max_detections, 1] + ), + axis=[2], + ) + else: + top_box_ids_for_class = relax.op.expand_dims( + relax.op.astype(top_box_ids, "int64"), axis=2 + ) + detection_classes = relax.op.gather_nd( + class_id_from_score, top_box_ids_for_class, batch_dims=1 + ) + + detection_mask = relax.op.expand_dims(valid_detection_mask, axis=2) + detection_boxes = relax.op.where( + detection_mask, + detection_boxes, + relax.op.zeros( + (batch_size, max_detections, 4), dtype=detection_boxes.struct_info.dtype + ), + ) + detection_classes = relax.op.where( + valid_detection_mask, + detection_classes, + relax.op.zeros((batch_size, max_detections), dtype=detection_classes.struct_info.dtype), ) - ret = relax.op.vision.get_valid_counts(ret, 0) - valid_count = ret[0] - # keep only the top 'max_detections' rows - ret = relax.op.strided_slice( - ret[1], [0, 0, 0], [batch_size, custom_options["max_detections"], 6] + detection_scores = relax.op.where( + valid_detection_mask, + detection_scores, + relax.op.zeros((batch_size, max_detections), dtype=detection_scores.struct_info.dtype), ) - # the output needs some reshaping to match tflite - ret = relax.op.split(ret, 6, axis=2) - cls_ids = relax.op.reshape(ret[0], [batch_size, -1]) - scores = relax.op.reshape(ret[1], [batch_size, -1]) - boxes = relax.op.concat([ret[3], ret[2], ret[5], ret[4]], axis=2) - ret = relax.Tuple(relax.Tuple([boxes, cls_ids, scores, valid_count]), size=4) - return ret + detection_classes = relax.op.astype(detection_classes, "float32") + num_detections = relax.op.astype(num_detections, "float32") + return relax.Tuple([detection_boxes, detection_classes, detection_scores, num_detections]) def convert_nms_v5(self, op): """Convert TFLite NonMaxSuppressionV5""" From b4950e8127eeae07a18db37fa55ca63987e49f8d Mon Sep 17 00:00:00 2001 From: Aharrypotter <62729549+Aharrypotter@users.noreply.github.com> Date: Wed, 8 Apr 2026 21:59:16 +0800 Subject: [PATCH 2/3] [Relax][Frontend][TFLite] Fix DETECTION_POSTPROCESS lowering and tests Wire up the DETECTION_POSTPROCESS frontend lowering more robustly and cover the main semantic paths in test_frontend_tflite.py. This change: - fixes tuple handling and output post-processing in the TFLite frontend DETECTION_POSTPROCESS converter - masks selected scores using the actual selected-score slots before topk and aligns class padding dtypes - updates all_class_non_max_suppression legalization so the tensorflow output_format keeps the TOPI 3-tensor contract - adds DETECTION_POSTPROCESS frontend tests for IR structure, option validation, semantic reference coverage, LegalizeOps coverage, and the zero-valid-detection case Validation: - python -m ruff check python/tvm/relax/frontend/tflite/tflite_frontend.py python/tvm/relax/transform/legalize_ops/vision.py tests/python/relax/test_frontend_tflite.py - python -m pre_commit run --files python/tvm/relax/frontend/tflite/tflite_frontend.py python/tvm/relax/transform/legalize_ops/vision.py tests/python/relax/test_frontend_tflite.py Result: - ruff check passed - pre-commit passed --- .../relax/frontend/tflite/tflite_frontend.py | 118 ++++---- .../relax/transform/legalize_ops/vision.py | 19 +- tests/python/relax/test_frontend_tflite.py | 268 ++++++++++++++++++ 3 files changed, 348 insertions(+), 57 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index a555f56a0f71..16d5cb636bbb 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -2832,7 +2832,9 @@ def convert_batch_matmul(self, op): new_b_shape = [1] * max(0, rank_a - rank_b) + [int(s) for s in shape_b] max_rank = max(rank_a, rank_b) - batch_shape = [max(new_a_shape[i], new_b_shape[i]) for i in range(max_rank - 2)] + batch_shape = [ + max(new_a_shape[i], new_b_shape[i]) for i in range(max_rank - 2) + ] a_broadcast = batch_shape + [int(shape_a[-2]), int(shape_a[-1])] b_broadcast = batch_shape + [int(shape_b[-2]), int(shape_b[-1])] @@ -3338,47 +3340,76 @@ def convert_detection_postprocess(self, op): ) multibox_transform_loc_attrs["keep_background"] = use_regular_nms - transformed_boxes, transformed_scores = relax.op.vision.multibox_transform_loc( - # reshape cls_pred so it can be consumed by - # multibox_transform_loc - relax.op.permute_dims(cls_pred, [0, 2, 1]), - loc_prob, - anchor_expr, - **multibox_transform_loc_attrs, + multibox_res = self.bb.emit( + relax.op.vision.multibox_transform_loc( + # reshape cls_pred so it can be consumed by + # multibox_transform_loc + relax.op.permute_dims(cls_pred, [0, 2, 1]), + loc_prob, + anchor_expr, + **multibox_transform_loc_attrs, + ) ) + transformed_boxes = self.bb.emit(relax.TupleGetItem(multibox_res, 0)) + transformed_scores = self.bb.emit(relax.TupleGetItem(multibox_res, 1)) if use_regular_nms: - nms_out = relax.op.vision.all_class_non_max_suppression( - transformed_boxes, - transformed_scores, - relax.const(detections_per_class, "int64"), - relax.const(iou_threshold, "float32"), - relax.const(score_threshold, "float32"), - output_format="tensorflow", + nms_out = self.bb.emit( + relax.op.vision.all_class_non_max_suppression( + transformed_boxes, + transformed_scores, + relax.const(detections_per_class, "int64"), + relax.const(iou_threshold, "float32"), + relax.const(score_threshold, "float32"), + output_format="tensorflow", + ) ) - selected_indices = nms_out[0] - selected_scores = nms_out[1] - num_detections = nms_out[2] + selected_indices = self.bb.emit(relax.TupleGetItem(nms_out, 0)) + selected_scores = self.bb.emit(relax.TupleGetItem(nms_out, 1)) + num_detections = self.bb.emit(relax.TupleGetItem(nms_out, 2)) class_id_from_score = None else: - max_scores, class_id_from_score = relax.op.topk( - transformed_scores, k=1, axis=1, ret_type="both", largest=True + topk_res = self.bb.emit( + relax.op.topk(transformed_scores, k=1, axis=1, ret_type="both", largest=True) ) - nms_out = relax.op.vision.all_class_non_max_suppression( - transformed_boxes, - max_scores, - relax.const(max_detections, "int64"), - relax.const(iou_threshold, "float32"), - relax.const(score_threshold, "float32"), - output_format="tensorflow", + max_scores = self.bb.emit(relax.TupleGetItem(topk_res, 0)) + class_id_from_score = self.bb.emit(relax.TupleGetItem(topk_res, 1)) + nms_out = self.bb.emit( + relax.op.vision.all_class_non_max_suppression( + transformed_boxes, + max_scores, + relax.const(max_detections, "int64"), + relax.const(iou_threshold, "float32"), + relax.const(score_threshold, "float32"), + output_format="tensorflow", + ) ) - selected_indices = nms_out[0] - selected_scores = nms_out[1] - num_detections = nms_out[2] + selected_indices = self.bb.emit(relax.TupleGetItem(nms_out, 0)) + selected_scores = self.bb.emit(relax.TupleGetItem(nms_out, 1)) + num_detections = self.bb.emit(relax.TupleGetItem(nms_out, 2)) class_id_from_score = relax.op.squeeze(class_id_from_score, axis=[1]) + selected_score_slots = selected_scores.struct_info.shape.values[1] + selected_detection_positions = relax.op.expand_dims( + relax.op.arange(selected_score_slots, dtype="int64"), axis=0 + ) + selected_valid_detection_mask = relax.op.less( + selected_detection_positions, relax.op.expand_dims(num_detections, axis=1) + ) + masked_selected_scores = relax.op.where( + selected_valid_detection_mask, + selected_scores, + relax.const(-1.0, "float32"), + ) + topk_scores_res = self.bb.emit( + relax.op.topk( + masked_selected_scores, k=max_detections, axis=1, ret_type="both", largest=True + ) + ) + detection_scores = self.bb.emit(relax.TupleGetItem(topk_scores_res, 0)) + top_positions = self.bb.emit(relax.TupleGetItem(topk_scores_res, 1)) num_detections = relax.op.minimum( - num_detections, relax.const(np.array([max_detections], dtype="int64")) + num_detections, relax.const([max_detections], dtype="int64") ) detection_positions = relax.op.expand_dims( relax.op.arange(max_detections, dtype="int64"), axis=0 @@ -3386,23 +3417,13 @@ def convert_detection_postprocess(self, op): valid_detection_mask = relax.op.less( detection_positions, relax.op.expand_dims(num_detections, axis=1) ) - masked_selected_scores = relax.op.where( - valid_detection_mask, - selected_scores, - relax.const(-1.0, selected_scores.struct_info.dtype), - ) - detection_scores, top_positions = relax.op.topk( - masked_selected_scores, k=max_detections, axis=1, ret_type="both", largest=True - ) top_positions_expanded = relax.op.expand_dims(top_positions, axis=2) top_positions_for_pairs = relax.op.repeat(top_positions_expanded, 2, axis=2) top_index_pairs = relax.op.gather_elements( selected_indices, top_positions_for_pairs, axis=1 ) top_box_ids = relax.op.squeeze( - relax.op.strided_slice( - top_index_pairs, begin=[0, 0, 1], end=[batch_size, max_detections, 2] - ), + relax.op.strided_slice(top_index_pairs, axes=[2], begin=[1], end=[2]), axis=[2], ) top_box_ids_for_gather = relax.op.expand_dims(relax.op.astype(top_box_ids, "int64"), axis=2) @@ -3412,11 +3433,10 @@ def convert_detection_postprocess(self, op): if use_regular_nms: detection_classes = relax.op.squeeze( - relax.op.strided_slice( - top_index_pairs, begin=[0, 0, 0], end=[batch_size, max_detections, 1] - ), + relax.op.strided_slice(top_index_pairs, axes=[2], begin=[0], end=[1]), axis=[2], ) + detection_classes = relax.op.astype(detection_classes, "int32") else: top_box_ids_for_class = relax.op.expand_dims( relax.op.astype(top_box_ids, "int64"), axis=2 @@ -3429,19 +3449,17 @@ def convert_detection_postprocess(self, op): detection_boxes = relax.op.where( detection_mask, detection_boxes, - relax.op.zeros( - (batch_size, max_detections, 4), dtype=detection_boxes.struct_info.dtype - ), + relax.op.zeros((batch_size, max_detections, 4), dtype="float32"), ) detection_classes = relax.op.where( valid_detection_mask, detection_classes, - relax.op.zeros((batch_size, max_detections), dtype=detection_classes.struct_info.dtype), + relax.op.zeros((batch_size, max_detections), dtype="int32"), ) detection_scores = relax.op.where( valid_detection_mask, detection_scores, - relax.op.zeros((batch_size, max_detections), dtype=detection_scores.struct_info.dtype), + relax.op.zeros((batch_size, max_detections), dtype="float32"), ) detection_classes = relax.op.astype(detection_classes, "float32") num_detections = relax.op.astype(num_detections, "float32") diff --git a/python/tvm/relax/transform/legalize_ops/vision.py b/python/tvm/relax/transform/legalize_ops/vision.py index 7d8586ab5288..c515fc8fe81a 100644 --- a/python/tvm/relax/transform/legalize_ops/vision.py +++ b/python/tvm/relax/transform/legalize_ops/vision.py @@ -32,11 +32,15 @@ def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) -> E Returns ------- - result : Tuple[Tensor, Tensor] - A tuple of (trimmed_indices, num_total_detections) where: - - trimmed_indices: Tensor of shape (num_total_detections, 3) containing only - valid detection indices (batch_id, class_id, box_id) - - num_total_detections: Tensor of shape (1,) with the count of valid detections + result : Expr + The legalized NMS result. + + - For ONNX output format, returns a tuple of + `(trimmed_indices, num_total_detections)`, where `trimmed_indices` + contains only valid detection indices. + - For TensorFlow output format, returns the TOPI result directly to + preserve the `(selected_indices, selected_scores, num_detections)` + layout expected by the Relax op. """ boxes = call.args[0] scores = call.args[1] @@ -69,8 +73,9 @@ def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) -> E output_format, ) - # Dynamic output trimming using dynamic_strided_slice - # Extract selected_indices and num_total_detections from the NMS result + if output_format == "tensorflow": + return nms_result + selected_indices = block_builder.emit(TupleGetItem(nms_result, 0)) num_total_detections = block_builder.emit(TupleGetItem(nms_result, 1)) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 02282f3d41c9..c237d4db8f8a 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -27,6 +27,7 @@ from tensorflow.keras import applications as keras_app import tvm +import tvm.relax.frontend.tflite.tflite_frontend as tflite_frontend from tvm import relax from tvm.relax.frontend.tflite import from_tflite from tvm.script.parser import ir as I @@ -1082,6 +1083,142 @@ def func(self, boxes, scores): return mod, instance.func +class _StubDetectionPostprocessTensor: + def __init__(self, shape, name): + self._shape = list(shape) + self._name = name + + def Shape(self, index): + return self._shape[index] + + def Name(self): + return self._name + + def Type(self): + return 0 + + +class _StubDetectionPostprocessOp: + def __init__(self, custom_options): + self._custom_options = _encode_detection_postprocess_custom_options(custom_options) + + def CustomOptionsAsNumpy(self): + return np.frombuffer(self._custom_options, dtype="uint8") + + +_DETECTION_POSTPROCESS_ANCHORS = np.array( + [ + [0.5, 0.5, 1.0, 1.0], + [0.5, 0.2, 1.0, 1.0], + [0.1, 0.1, 0.5, 0.5], + [0.8, 0.8, 0.2, 0.2], + ], + dtype="float32", +) + + +def _encode_detection_postprocess_custom_options(custom_options): + from flatbuffers import flexbuffers + + builder = flexbuffers.Builder() + with builder.Map(): + for key, value in custom_options.items(): + if isinstance(value, bool): + builder.Bool(key, value) + elif isinstance(value, int): + builder.Int(key, value) + else: + builder.Float(key, float(value)) + return bytes(builder.Finish()) + + +def _make_detection_postprocess_tensor_wrapper(tensor_idx, shape, name): + return tflite_frontend.TensorWrapper( + tensor_idx, + _StubDetectionPostprocessTensor(shape, name), + None, + ) + + +def _build_detection_postprocess_mod( + *, + num_classes=1, + max_detections=4, + detections_per_class=4, + use_regular_nms=False, + nms_iou_threshold=0.5, + nms_score_threshold=0.3, + x_scale=10.0, + y_scale=10.0, + w_scale=5.0, + h_scale=5.0, + batch_size=2, + num_anchors=4, + input_num_classes=None, +): + custom_options = { + "num_classes": num_classes, + "max_detections": max_detections, + "detections_per_class": detections_per_class, + "nms_iou_threshold": nms_iou_threshold, + "nms_score_threshold": nms_score_threshold, + "x_scale": x_scale, + "y_scale": y_scale, + "w_scale": w_scale, + "h_scale": h_scale, + "use_regular_nms": use_regular_nms, + } + return _convert_detection_postprocess_with_options( + custom_options, + batch_size=batch_size, + num_anchors=num_anchors, + num_classes=num_classes, + input_num_classes=input_num_classes, + ) + + +def _convert_detection_postprocess_with_options( + custom_options, + *, + batch_size=2, + num_anchors=4, + num_classes=1, + input_num_classes=None, + build_module=True, +): + input_num_classes = num_classes if input_num_classes is None else input_num_classes + loc = relax.Var("loc", relax.TensorStructInfo((batch_size, num_anchors, 4), "float32")) + cls = relax.Var( + "cls", relax.TensorStructInfo((batch_size, num_anchors, input_num_classes), "float32") + ) + inputs = [ + _make_detection_postprocess_tensor_wrapper(0, (batch_size, num_anchors, 4), "loc"), + _make_detection_postprocess_tensor_wrapper( + 1, (batch_size, num_anchors, input_num_classes), "cls" + ), + _make_detection_postprocess_tensor_wrapper(2, (num_anchors, 4), "anchors"), + ] + converter = tflite_frontend.OperatorConverter.__new__(tflite_frontend.OperatorConverter) + converter.bb = relax.BlockBuilder() + converter.exp_tab = tflite_frontend.ExprTable() + converter.get_input_tensors = lambda op: inputs + converter.get_expr = lambda tensor_idx: {0: loc, 1: cls}[tensor_idx] + converter.get_tensor_value = ( + lambda tensor: _DETECTION_POSTPROCESS_ANCHORS if tensor.tensor_idx == 2 else None + ) + converter.get_tensor_type_str = lambda tensor_type: "float32" + op = _StubDetectionPostprocessOp(custom_options) + if not build_module: + return converter.convert_detection_postprocess(op) + bb = converter.bb + with bb.function("main", [loc, cls]): + with bb.dataflow(): + output = converter.convert_detection_postprocess(op) + gv = bb.emit_output(output) + bb.emit_func_output(gv) + return bb.get() + + def _make_valid_boxes(rng, n): """Generate n random boxes with y1<=y2, x1<=x2 using the given RNG.""" raw = rng.random((n, 4), dtype=np.float32) @@ -1207,6 +1344,137 @@ def test_nms_v5_ir(): assert f"R.Tensor(({max_output_size},)" in ir +_DETECTION_POSTPROCESS_SMOKE_CASES = [ + pytest.param( + { + "num_classes": 2, + "input_num_classes": 3, + "max_detections": 2, + "detections_per_class": 2, + "use_regular_nms": False, + "nms_iou_threshold": 0.5, + "nms_score_threshold": 0.5, + "batch_size": 1, + "num_anchors": 4, + }, + 2, + False, + id="basic_fast_nms", + ), + pytest.param( + { + "num_classes": 2, + "input_num_classes": 3, + "max_detections": 3, + "detections_per_class": 2, + "use_regular_nms": True, + "nms_iou_threshold": 0.45, + "nms_score_threshold": 0.25, + "batch_size": 2, + "num_anchors": 4, + }, + 1, + True, + id="regular_nms_multi_batch", + ), +] + + +_DETECTION_POSTPROCESS_SHAPE_CASES = [ + pytest.param( + { + "num_classes": 2, + "input_num_classes": 5, + "max_detections": 2, + "detections_per_class": 2, + "use_regular_nms": False, + "nms_iou_threshold": 0.5, + "nms_score_threshold": 0.5, + "batch_size": 1, + "num_anchors": 4, + }, + id="wider_input_classes", + ), + pytest.param( + { + "num_classes": 2, + "input_num_classes": 3, + "max_detections": 4, + "detections_per_class": 4, + "use_regular_nms": False, + "nms_iou_threshold": 0.5, + "nms_score_threshold": 0.5, + "batch_size": 1, + "num_anchors": 4, + }, + id="larger_max_detections", + ), +] + + +@pytest.mark.parametrize( + "build_kwargs,expected_topk_count,expected_keep_background", + _DETECTION_POSTPROCESS_SMOKE_CASES, +) +def test_detection_postprocess_smoke( + build_kwargs, expected_topk_count, expected_keep_background +): + mod = _build_detection_postprocess_mod(**build_kwargs) + ir = mod.script() + + assert "R.vision.multibox_transform_loc" in ir + assert "R.vision.all_class_non_max_suppression" in ir + assert 'output_format="tensorflow"' in ir + assert "R.where" in ir + assert "R.gather_elements" in ir + assert "R.gather_nd" in ir + assert ir.count("R.topk(") == expected_topk_count + assert f"keep_background={expected_keep_background}" in ir + expected_batch = build_kwargs["batch_size"] + expected_max_detections = build_kwargs["max_detections"] + tvm.ir.assert_structural_equal( + mod["main"].ret_struct_info, + relax.TupleStructInfo( + [ + relax.TensorStructInfo((expected_batch, expected_max_detections, 4), "float32"), + relax.TensorStructInfo((expected_batch, expected_max_detections), "float32"), + relax.TensorStructInfo((expected_batch, expected_max_detections), "float32"), + relax.TensorStructInfo((expected_batch,), "float32"), + ] + ), + ) + + legalized = relax.transform.LegalizeOps()(mod) + legalized_ir = legalized.script() + assert "R.vision.all_class_non_max_suppression(" not in legalized_ir + assert "R.call_tir(" in legalized_ir + tvm.ir.assert_structural_equal(legalized["main"].ret_struct_info, mod["main"].ret_struct_info) + + +@pytest.mark.parametrize("build_kwargs", _DETECTION_POSTPROCESS_SHAPE_CASES) +def test_detection_postprocess_shape_variations(build_kwargs): + mod = _build_detection_postprocess_mod(**build_kwargs) + batch_size = build_kwargs["batch_size"] + num_anchors = build_kwargs["num_anchors"] + input_num_classes = build_kwargs["input_num_classes"] + max_detections = build_kwargs["max_detections"] + + tvm.ir.assert_structural_equal( + mod["main"].params[1].struct_info, + relax.TensorStructInfo((batch_size, num_anchors, input_num_classes), "float32"), + ) + tvm.ir.assert_structural_equal( + mod["main"].ret_struct_info, + relax.TupleStructInfo( + [ + relax.TensorStructInfo((batch_size, max_detections, 4), "float32"), + relax.TensorStructInfo((batch_size, max_detections), "float32"), + relax.TensorStructInfo((batch_size, max_detections), "float32"), + relax.TensorStructInfo((batch_size,), "float32"), + ] + ), + ) + def _make_resize_expected( input_shape, output_size, method, coordinate_transformation_mode, rounding_method ): From 7f34fca423cf6d68ea2ffa3f28e83a02bccfe81b Mon Sep 17 00:00:00 2001 From: Aharrypotter <62729549+Aharrypotter@users.noreply.github.com> Date: Sat, 11 Apr 2026 12:04:25 +0800 Subject: [PATCH 3/3] [BugFix][TFLite] Fix FlexBufferDecoder for byte_width > 4 --- .../frontend/tflite/tflite_flexbuffer.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_flexbuffer.py b/python/tvm/relax/frontend/tflite/tflite_flexbuffer.py index dc8ce1df21e2..5152b6996ecf 100644 --- a/python/tvm/relax/frontend/tflite/tflite_flexbuffer.py +++ b/python/tvm/relax/frontend/tflite/tflite_flexbuffer.py @@ -78,12 +78,7 @@ def __init__(self, buffer): def indirect_jump(self, offset, byte_width): """Helper function to read the offset value and jump""" - unpack_str = "" - if byte_width == 1: - unpack_str = "> 2) - value_bytes = self.buffer[end + i * byte_width : end + (i + 1) * byte_width] + value_type_packed = self.buffer[value_type_pos] + value_type = FlexBufferType(value_type_packed >> 2) + value_bit_width = BitWidth(value_type_packed & 3) + value_byte_width = 1 << value_bit_width + value_bytes = self.buffer[end + i * byte_width : end + i * byte_width + value_byte_width] if value_type == FlexBufferType.FBT_BOOL: value = bool(value_bytes[0]) elif value_type == FlexBufferType.FBT_INT: - value = struct.unpack("