diff --git a/python/tvm/relax/frontend/tflite/tflite_flexbuffer.py b/python/tvm/relax/frontend/tflite/tflite_flexbuffer.py index dc8ce1df21e2..5152b6996ecf 100644 --- a/python/tvm/relax/frontend/tflite/tflite_flexbuffer.py +++ b/python/tvm/relax/frontend/tflite/tflite_flexbuffer.py @@ -78,12 +78,7 @@ def __init__(self, buffer): def indirect_jump(self, offset, byte_width): """Helper function to read the offset value and jump""" - unpack_str = "" - if byte_width == 1: - unpack_str = "> 2) - value_bytes = self.buffer[end + i * byte_width : end + (i + 1) * byte_width] + value_type_packed = self.buffer[value_type_pos] + value_type = FlexBufferType(value_type_packed >> 2) + value_bit_width = BitWidth(value_type_packed & 3) + value_byte_width = 1 << value_bit_width + value_bytes = self.buffer[end + i * byte_width : end + i * byte_width + value_byte_width] if value_type == FlexBufferType.FBT_BOOL: value = bool(value_bytes[0]) elif value_type == FlexBufferType.FBT_INT: - value = struct.unpack(" 0.") + if max_detections <= 0: + raise ValueError("DETECTION_POSTPROCESS requires max_detections > 0.") + if detections_per_class <= 0: + raise ValueError("DETECTION_POSTPROCESS requires detections_per_class > 0.") + if not 0.0 <= iou_threshold <= 1.0: + raise ValueError("DETECTION_POSTPROCESS requires nms_iou_threshold in [0, 1].") + if x_scale <= 0.0 or y_scale <= 0.0 or w_scale <= 0.0 or h_scale <= 0.0: + raise ValueError("DETECTION_POSTPROCESS requires x/y/w/h_scale to be > 0.") inputs = self.get_input_tensors(op) assert len(inputs) == 3, "inputs length should be 3" @@ -3296,67 +3331,139 @@ def convert_detection_postprocess(self, op): # attributes for multibox_transform_loc multibox_transform_loc_attrs = {} multibox_transform_loc_attrs["clip"] = False - multibox_transform_loc_attrs["threshold"] = ( - 0.0 if use_regular_nms else custom_options["nms_score_threshold"] - ) + multibox_transform_loc_attrs["threshold"] = 0.0 if use_regular_nms else score_threshold multibox_transform_loc_attrs["variances"] = ( - 1 / custom_options["x_scale"], - 1 / custom_options["y_scale"], - 1 / custom_options["w_scale"], - 1 / custom_options["h_scale"], + 1 / x_scale, + 1 / y_scale, + 1 / w_scale, + 1 / h_scale, ) multibox_transform_loc_attrs["keep_background"] = use_regular_nms - ret = relax.op.vision.multibox_transform_loc( - # reshape cls_pred so it can be consumed by - # multibox_transform_loc - relax.op.permute_dims(cls_pred, [0, 2, 1]), - loc_prob, - anchor_expr, - **multibox_transform_loc_attrs, + multibox_res = self.bb.emit( + relax.op.vision.multibox_transform_loc( + # reshape cls_pred so it can be consumed by + # multibox_transform_loc + relax.op.permute_dims(cls_pred, [0, 2, 1]), + loc_prob, + anchor_expr, + **multibox_transform_loc_attrs, + ) + ) + transformed_boxes = self.bb.emit(relax.TupleGetItem(multibox_res, 0)) + transformed_scores = self.bb.emit(relax.TupleGetItem(multibox_res, 1)) + + if use_regular_nms: + nms_out = self.bb.emit( + relax.op.vision.all_class_non_max_suppression( + transformed_boxes, + transformed_scores, + relax.const(detections_per_class, "int64"), + relax.const(iou_threshold, "float32"), + relax.const(score_threshold, "float32"), + output_format="tensorflow", + ) + ) + selected_indices = self.bb.emit(relax.TupleGetItem(nms_out, 0)) + selected_scores = self.bb.emit(relax.TupleGetItem(nms_out, 1)) + num_detections = self.bb.emit(relax.TupleGetItem(nms_out, 2)) + class_id_from_score = None + else: + topk_res = self.bb.emit( + relax.op.topk(transformed_scores, k=1, axis=1, ret_type="both", largest=True) + ) + max_scores = self.bb.emit(relax.TupleGetItem(topk_res, 0)) + class_id_from_score = self.bb.emit(relax.TupleGetItem(topk_res, 1)) + nms_out = self.bb.emit( + relax.op.vision.all_class_non_max_suppression( + transformed_boxes, + max_scores, + relax.const(max_detections, "int64"), + relax.const(iou_threshold, "float32"), + relax.const(score_threshold, "float32"), + output_format="tensorflow", + ) + ) + selected_indices = self.bb.emit(relax.TupleGetItem(nms_out, 0)) + selected_scores = self.bb.emit(relax.TupleGetItem(nms_out, 1)) + num_detections = self.bb.emit(relax.TupleGetItem(nms_out, 2)) + class_id_from_score = relax.op.squeeze(class_id_from_score, axis=[1]) + + selected_score_slots = selected_scores.struct_info.shape.values[1] + selected_detection_positions = relax.op.expand_dims( + relax.op.arange(selected_score_slots, dtype="int64"), axis=0 + ) + selected_valid_detection_mask = relax.op.less( + selected_detection_positions, relax.op.expand_dims(num_detections, axis=1) + ) + masked_selected_scores = relax.op.where( + selected_valid_detection_mask, + selected_scores, + relax.const(-1.0, "float32"), + ) + topk_scores_res = self.bb.emit( + relax.op.topk( + masked_selected_scores, k=max_detections, axis=1, ret_type="both", largest=True + ) + ) + detection_scores = self.bb.emit(relax.TupleGetItem(topk_scores_res, 0)) + top_positions = self.bb.emit(relax.TupleGetItem(topk_scores_res, 1)) + num_detections = relax.op.minimum( + num_detections, relax.const([max_detections], dtype="int64") + ) + detection_positions = relax.op.expand_dims( + relax.op.arange(max_detections, dtype="int64"), axis=0 + ) + valid_detection_mask = relax.op.less( + detection_positions, relax.op.expand_dims(num_detections, axis=1) + ) + top_positions_expanded = relax.op.expand_dims(top_positions, axis=2) + top_positions_for_pairs = relax.op.repeat(top_positions_expanded, 2, axis=2) + top_index_pairs = relax.op.gather_elements( + selected_indices, top_positions_for_pairs, axis=1 + ) + top_box_ids = relax.op.squeeze( + relax.op.strided_slice(top_index_pairs, axes=[2], begin=[1], end=[2]), + axis=[2], + ) + top_box_ids_for_gather = relax.op.expand_dims(relax.op.astype(top_box_ids, "int64"), axis=2) + detection_boxes = relax.op.gather_nd( + transformed_boxes, top_box_ids_for_gather, batch_dims=1 ) if use_regular_nms: - # box coordinates need to be converted from ltrb to (ymin, xmin, ymax, xmax) - _, transformed_boxes = relax.op.split(ret[0], (2,), axis=2) - box_l, box_t, box_r, box_b = relax.op.split(transformed_boxes, 4, axis=2) - transformed_boxes = relax.op.concat([box_t, box_l, box_b, box_r], axis=2) - - return relax.op.vision.regular_non_max_suppression( - boxes=transformed_boxes, - scores=cls_pred, - max_detections_per_class=custom_options["detections_per_class"], - max_detections=custom_options["max_detections"], - num_classes=custom_options["num_classes"], - iou_threshold=custom_options["nms_iou_threshold"], - score_threshold=custom_options["nms_score_threshold"], + detection_classes = relax.op.squeeze( + relax.op.strided_slice(top_index_pairs, axes=[2], begin=[0], end=[1]), + axis=[2], + ) + detection_classes = relax.op.astype(detection_classes, "int32") + else: + top_box_ids_for_class = relax.op.expand_dims( + relax.op.astype(top_box_ids, "int64"), axis=2 + ) + detection_classes = relax.op.gather_nd( + class_id_from_score, top_box_ids_for_class, batch_dims=1 ) - # attributes for non_max_suppression - non_max_suppression_attrs = {} - non_max_suppression_attrs["return_indices"] = False - non_max_suppression_attrs["iou_threshold"] = custom_options["nms_iou_threshold"] - non_max_suppression_attrs["force_suppress"] = True - non_max_suppression_attrs["top_k"] = anchor_boxes - non_max_suppression_attrs["max_output_size"] = custom_options["max_detections"] - non_max_suppression_attrs["invalid_to_bottom"] = False - - ret = relax.op.vision.non_max_suppression( - ret[0], ret[1], ret[1], **non_max_suppression_attrs + detection_mask = relax.op.expand_dims(valid_detection_mask, axis=2) + detection_boxes = relax.op.where( + detection_mask, + detection_boxes, + relax.op.zeros((batch_size, max_detections, 4), dtype="float32"), + ) + detection_classes = relax.op.where( + valid_detection_mask, + detection_classes, + relax.op.zeros((batch_size, max_detections), dtype="int32"), ) - ret = relax.op.vision.get_valid_counts(ret, 0) - valid_count = ret[0] - # keep only the top 'max_detections' rows - ret = relax.op.strided_slice( - ret[1], [0, 0, 0], [batch_size, custom_options["max_detections"], 6] + detection_scores = relax.op.where( + valid_detection_mask, + detection_scores, + relax.op.zeros((batch_size, max_detections), dtype="float32"), ) - # the output needs some reshaping to match tflite - ret = relax.op.split(ret, 6, axis=2) - cls_ids = relax.op.reshape(ret[0], [batch_size, -1]) - scores = relax.op.reshape(ret[1], [batch_size, -1]) - boxes = relax.op.concat([ret[3], ret[2], ret[5], ret[4]], axis=2) - ret = relax.Tuple(relax.Tuple([boxes, cls_ids, scores, valid_count]), size=4) - return ret + detection_classes = relax.op.astype(detection_classes, "float32") + num_detections = relax.op.astype(num_detections, "float32") + return relax.Tuple([detection_boxes, detection_classes, detection_scores, num_detections]) def convert_nms_v5(self, op): """Convert TFLite NonMaxSuppressionV5""" diff --git a/python/tvm/relax/transform/legalize_ops/vision.py b/python/tvm/relax/transform/legalize_ops/vision.py index 7d8586ab5288..c515fc8fe81a 100644 --- a/python/tvm/relax/transform/legalize_ops/vision.py +++ b/python/tvm/relax/transform/legalize_ops/vision.py @@ -32,11 +32,15 @@ def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) -> E Returns ------- - result : Tuple[Tensor, Tensor] - A tuple of (trimmed_indices, num_total_detections) where: - - trimmed_indices: Tensor of shape (num_total_detections, 3) containing only - valid detection indices (batch_id, class_id, box_id) - - num_total_detections: Tensor of shape (1,) with the count of valid detections + result : Expr + The legalized NMS result. + + - For ONNX output format, returns a tuple of + `(trimmed_indices, num_total_detections)`, where `trimmed_indices` + contains only valid detection indices. + - For TensorFlow output format, returns the TOPI result directly to + preserve the `(selected_indices, selected_scores, num_detections)` + layout expected by the Relax op. """ boxes = call.args[0] scores = call.args[1] @@ -69,8 +73,9 @@ def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) -> E output_format, ) - # Dynamic output trimming using dynamic_strided_slice - # Extract selected_indices and num_total_detections from the NMS result + if output_format == "tensorflow": + return nms_result + selected_indices = block_builder.emit(TupleGetItem(nms_result, 0)) num_total_detections = block_builder.emit(TupleGetItem(nms_result, 1)) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 02282f3d41c9..c237d4db8f8a 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -27,6 +27,7 @@ from tensorflow.keras import applications as keras_app import tvm +import tvm.relax.frontend.tflite.tflite_frontend as tflite_frontend from tvm import relax from tvm.relax.frontend.tflite import from_tflite from tvm.script.parser import ir as I @@ -1082,6 +1083,142 @@ def func(self, boxes, scores): return mod, instance.func +class _StubDetectionPostprocessTensor: + def __init__(self, shape, name): + self._shape = list(shape) + self._name = name + + def Shape(self, index): + return self._shape[index] + + def Name(self): + return self._name + + def Type(self): + return 0 + + +class _StubDetectionPostprocessOp: + def __init__(self, custom_options): + self._custom_options = _encode_detection_postprocess_custom_options(custom_options) + + def CustomOptionsAsNumpy(self): + return np.frombuffer(self._custom_options, dtype="uint8") + + +_DETECTION_POSTPROCESS_ANCHORS = np.array( + [ + [0.5, 0.5, 1.0, 1.0], + [0.5, 0.2, 1.0, 1.0], + [0.1, 0.1, 0.5, 0.5], + [0.8, 0.8, 0.2, 0.2], + ], + dtype="float32", +) + + +def _encode_detection_postprocess_custom_options(custom_options): + from flatbuffers import flexbuffers + + builder = flexbuffers.Builder() + with builder.Map(): + for key, value in custom_options.items(): + if isinstance(value, bool): + builder.Bool(key, value) + elif isinstance(value, int): + builder.Int(key, value) + else: + builder.Float(key, float(value)) + return bytes(builder.Finish()) + + +def _make_detection_postprocess_tensor_wrapper(tensor_idx, shape, name): + return tflite_frontend.TensorWrapper( + tensor_idx, + _StubDetectionPostprocessTensor(shape, name), + None, + ) + + +def _build_detection_postprocess_mod( + *, + num_classes=1, + max_detections=4, + detections_per_class=4, + use_regular_nms=False, + nms_iou_threshold=0.5, + nms_score_threshold=0.3, + x_scale=10.0, + y_scale=10.0, + w_scale=5.0, + h_scale=5.0, + batch_size=2, + num_anchors=4, + input_num_classes=None, +): + custom_options = { + "num_classes": num_classes, + "max_detections": max_detections, + "detections_per_class": detections_per_class, + "nms_iou_threshold": nms_iou_threshold, + "nms_score_threshold": nms_score_threshold, + "x_scale": x_scale, + "y_scale": y_scale, + "w_scale": w_scale, + "h_scale": h_scale, + "use_regular_nms": use_regular_nms, + } + return _convert_detection_postprocess_with_options( + custom_options, + batch_size=batch_size, + num_anchors=num_anchors, + num_classes=num_classes, + input_num_classes=input_num_classes, + ) + + +def _convert_detection_postprocess_with_options( + custom_options, + *, + batch_size=2, + num_anchors=4, + num_classes=1, + input_num_classes=None, + build_module=True, +): + input_num_classes = num_classes if input_num_classes is None else input_num_classes + loc = relax.Var("loc", relax.TensorStructInfo((batch_size, num_anchors, 4), "float32")) + cls = relax.Var( + "cls", relax.TensorStructInfo((batch_size, num_anchors, input_num_classes), "float32") + ) + inputs = [ + _make_detection_postprocess_tensor_wrapper(0, (batch_size, num_anchors, 4), "loc"), + _make_detection_postprocess_tensor_wrapper( + 1, (batch_size, num_anchors, input_num_classes), "cls" + ), + _make_detection_postprocess_tensor_wrapper(2, (num_anchors, 4), "anchors"), + ] + converter = tflite_frontend.OperatorConverter.__new__(tflite_frontend.OperatorConverter) + converter.bb = relax.BlockBuilder() + converter.exp_tab = tflite_frontend.ExprTable() + converter.get_input_tensors = lambda op: inputs + converter.get_expr = lambda tensor_idx: {0: loc, 1: cls}[tensor_idx] + converter.get_tensor_value = ( + lambda tensor: _DETECTION_POSTPROCESS_ANCHORS if tensor.tensor_idx == 2 else None + ) + converter.get_tensor_type_str = lambda tensor_type: "float32" + op = _StubDetectionPostprocessOp(custom_options) + if not build_module: + return converter.convert_detection_postprocess(op) + bb = converter.bb + with bb.function("main", [loc, cls]): + with bb.dataflow(): + output = converter.convert_detection_postprocess(op) + gv = bb.emit_output(output) + bb.emit_func_output(gv) + return bb.get() + + def _make_valid_boxes(rng, n): """Generate n random boxes with y1<=y2, x1<=x2 using the given RNG.""" raw = rng.random((n, 4), dtype=np.float32) @@ -1207,6 +1344,137 @@ def test_nms_v5_ir(): assert f"R.Tensor(({max_output_size},)" in ir +_DETECTION_POSTPROCESS_SMOKE_CASES = [ + pytest.param( + { + "num_classes": 2, + "input_num_classes": 3, + "max_detections": 2, + "detections_per_class": 2, + "use_regular_nms": False, + "nms_iou_threshold": 0.5, + "nms_score_threshold": 0.5, + "batch_size": 1, + "num_anchors": 4, + }, + 2, + False, + id="basic_fast_nms", + ), + pytest.param( + { + "num_classes": 2, + "input_num_classes": 3, + "max_detections": 3, + "detections_per_class": 2, + "use_regular_nms": True, + "nms_iou_threshold": 0.45, + "nms_score_threshold": 0.25, + "batch_size": 2, + "num_anchors": 4, + }, + 1, + True, + id="regular_nms_multi_batch", + ), +] + + +_DETECTION_POSTPROCESS_SHAPE_CASES = [ + pytest.param( + { + "num_classes": 2, + "input_num_classes": 5, + "max_detections": 2, + "detections_per_class": 2, + "use_regular_nms": False, + "nms_iou_threshold": 0.5, + "nms_score_threshold": 0.5, + "batch_size": 1, + "num_anchors": 4, + }, + id="wider_input_classes", + ), + pytest.param( + { + "num_classes": 2, + "input_num_classes": 3, + "max_detections": 4, + "detections_per_class": 4, + "use_regular_nms": False, + "nms_iou_threshold": 0.5, + "nms_score_threshold": 0.5, + "batch_size": 1, + "num_anchors": 4, + }, + id="larger_max_detections", + ), +] + + +@pytest.mark.parametrize( + "build_kwargs,expected_topk_count,expected_keep_background", + _DETECTION_POSTPROCESS_SMOKE_CASES, +) +def test_detection_postprocess_smoke( + build_kwargs, expected_topk_count, expected_keep_background +): + mod = _build_detection_postprocess_mod(**build_kwargs) + ir = mod.script() + + assert "R.vision.multibox_transform_loc" in ir + assert "R.vision.all_class_non_max_suppression" in ir + assert 'output_format="tensorflow"' in ir + assert "R.where" in ir + assert "R.gather_elements" in ir + assert "R.gather_nd" in ir + assert ir.count("R.topk(") == expected_topk_count + assert f"keep_background={expected_keep_background}" in ir + expected_batch = build_kwargs["batch_size"] + expected_max_detections = build_kwargs["max_detections"] + tvm.ir.assert_structural_equal( + mod["main"].ret_struct_info, + relax.TupleStructInfo( + [ + relax.TensorStructInfo((expected_batch, expected_max_detections, 4), "float32"), + relax.TensorStructInfo((expected_batch, expected_max_detections), "float32"), + relax.TensorStructInfo((expected_batch, expected_max_detections), "float32"), + relax.TensorStructInfo((expected_batch,), "float32"), + ] + ), + ) + + legalized = relax.transform.LegalizeOps()(mod) + legalized_ir = legalized.script() + assert "R.vision.all_class_non_max_suppression(" not in legalized_ir + assert "R.call_tir(" in legalized_ir + tvm.ir.assert_structural_equal(legalized["main"].ret_struct_info, mod["main"].ret_struct_info) + + +@pytest.mark.parametrize("build_kwargs", _DETECTION_POSTPROCESS_SHAPE_CASES) +def test_detection_postprocess_shape_variations(build_kwargs): + mod = _build_detection_postprocess_mod(**build_kwargs) + batch_size = build_kwargs["batch_size"] + num_anchors = build_kwargs["num_anchors"] + input_num_classes = build_kwargs["input_num_classes"] + max_detections = build_kwargs["max_detections"] + + tvm.ir.assert_structural_equal( + mod["main"].params[1].struct_info, + relax.TensorStructInfo((batch_size, num_anchors, input_num_classes), "float32"), + ) + tvm.ir.assert_structural_equal( + mod["main"].ret_struct_info, + relax.TupleStructInfo( + [ + relax.TensorStructInfo((batch_size, max_detections, 4), "float32"), + relax.TensorStructInfo((batch_size, max_detections), "float32"), + relax.TensorStructInfo((batch_size, max_detections), "float32"), + relax.TensorStructInfo((batch_size,), "float32"), + ] + ), + ) + def _make_resize_expected( input_shape, output_size, method, coordinate_transformation_mode, rounding_method ):