Skip to content

Commit 6aed712

Browse files
No public description
PiperOrigin-RevId: 597723134
1 parent 98684b0 commit 6aed712

File tree

2 files changed

+124
-4
lines changed

2 files changed

+124
-4
lines changed

official/vision/modeling/layers/detection_generator.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,7 @@ def _generate_detections_tflite(
794794
raw_scores: Mapping[str, tf.Tensor],
795795
anchor_boxes: Mapping[str, tf.Tensor],
796796
config: Dict[str, Any],
797+
box_coder_weights: List[float] | None = None,
797798
) -> Sequence[Any]:
798799
"""Generate detections for conversion to TFLite.
799800
@@ -817,7 +818,10 @@ def _generate_detections_tflite(
817818
features and value is a tensor denoting a level of anchors with shape
818819
[num_anchors, 4].
819820
config: A dictionary of configs defining parameters for TFLite NMS op.
820-
821+
box_coder_weights: An optional `list` of 4 positive floats to scale y, x, h,
822+
and w when encoding box coordinates. If set to None, does not perform
823+
scaling. For Faster RCNN, the open-source implementation recommends using
824+
[10.0, 10.0, 5.0, 5.0].
821825
Returns:
822826
A (dummy) tuple of (boxes, scores, classess, num_detections).
823827
@@ -839,15 +843,18 @@ def _generate_detections_tflite(
839843
raise ValueError(
840844
'The last dimension of predicted boxes should be divisible by 4.'
841845
)
846+
842847
num_anchors_per_locations = num_anchors_per_locations_times_4 // 4
843-
if num_anchors_per_locations_times_4 % 4 != 0:
848+
num_classes_times_anchors_per_location = (
849+
raw_scores[str(min_level)].get_shape().as_list()[-1]
850+
)
851+
if num_classes_times_anchors_per_location % num_anchors_per_locations != 0:
844852
raise ValueError(
845853
'The last dimension of predicted scores should be divisible by'
846854
f' {num_anchors_per_locations}.'
847855
)
848856
num_classes = (
849-
raw_scores[str(min_level)].get_shape().as_list()[-1]
850-
// num_anchors_per_locations
857+
num_classes_times_anchors_per_location // num_anchors_per_locations
851858
)
852859
config.update({'num_classes': num_classes})
853860

@@ -865,6 +872,14 @@ def _generate_detections_tflite(
865872
wa = anchors[..., 3] - anchors[..., 1]
866873
anchors = tf.stack([ycenter_a, xcenter_a, ha, wa], axis=-1)
867874

875+
if box_coder_weights:
876+
config.update({
877+
'y_scale': box_coder_weights[0],
878+
'x_scale': box_coder_weights[1],
879+
'h_scale': box_coder_weights[2],
880+
'w_scale': box_coder_weights[3],
881+
})
882+
868883
if config.get('normalize_anchor_coordinates', False):
869884
# TFLite's object detection APIs require normalized anchors.
870885
height, width = config['input_image_size']
@@ -1463,6 +1478,7 @@ def __call__(
14631478
raw_scores,
14641479
anchor_boxes,
14651480
self.get_config()['tflite_post_processing_config'],
1481+
self._config_dict['box_coder_weights'],
14661482
)
14671483
return {
14681484
'num_detections': num_detections,

official/vision/modeling/layers/detection_generator_test.py

+104
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313
# limitations under the License.
1414

1515
"""Tests for detection_generator.py."""
16+
from unittest import mock
17+
1618
# Import libraries
1719

1820
from absl.testing import parameterized
1921
import numpy as np
2022
import tensorflow as tf, tf_keras
2123

24+
from official.vision.configs import common
2225
from official.vision.modeling.layers import detection_generator
2326
from official.vision.ops import anchor
2427

@@ -327,6 +330,107 @@ def test_decode_multilevel_outputs_and_pre_nms_top_k(self):
327330
]),
328331
]))
329332

333+
def test_decode_multilevel_with_tflite_nms(self):
334+
config = common.TFLitePostProcessingConfig().as_dict()
335+
generator = detection_generator.MultilevelDetectionGenerator(
336+
apply_nms=True,
337+
nms_version='tflite',
338+
box_coder_weights=[9, 8, 7, 6],
339+
tflite_post_processing_config=config,
340+
)
341+
raw_scores = {
342+
'4': tf.zeros(shape=[1, 8, 8, 3 * 2], dtype=tf.float32),
343+
'5': tf.zeros(shape=[1, 4, 4, 3 * 2], dtype=tf.float32),
344+
}
345+
raw_boxes = {
346+
'4': tf.zeros(shape=[1, 8, 8, 4 * 2], dtype=tf.float32),
347+
'5': tf.zeros(shape=[1, 4, 4, 4 * 2], dtype=tf.float32),
348+
}
349+
anchor_boxes = {
350+
'4': tf.zeros(shape=[1, 8, 8, 4 * 2], dtype=tf.float32),
351+
'5': tf.zeros(shape=[1, 4, 4, 4 * 2], dtype=tf.float32),
352+
}
353+
354+
expected_signature = (
355+
'name: "TFLite_Detection_PostProcess" attr { key: "max_detections"'
356+
' value { i: 200 } } attr { key: "max_classes_per_detection" value { i:'
357+
' 5 } } attr { key: "detections_per_class" value { i: 5 } } attr { key:'
358+
' "use_regular_nms" value { b: false } } attr { key:'
359+
' "nms_score_threshold" value { f: 0.100000 } } attr { key:'
360+
' "nms_iou_threshold" value { f: 0.500000 } } attr { key: "y_scale"'
361+
' value { f: 9.000000 } } attr { key: "x_scale" value { f: 8.000000 } }'
362+
' attr { key: "h_scale" value { f: 7.000000 } } attr { key: "w_scale"'
363+
' value { f: 6.000000 } } attr { key: "num_classes" value { i: 3 } }'
364+
)
365+
366+
with mock.patch.object(
367+
tf, 'function', wraps=tf.function
368+
) as mock_tf_function:
369+
test_output = generator(
370+
raw_boxes=raw_boxes,
371+
raw_scores=raw_scores,
372+
anchor_boxes=anchor_boxes,
373+
image_shape=tf.constant([], dtype=tf.int32),
374+
)
375+
mock_tf_function.assert_called_once_with(
376+
experimental_implements=expected_signature
377+
)
378+
379+
self.assertEqual(
380+
test_output['num_detections'], tf.constant(0.0, dtype=tf.float32)
381+
)
382+
self.assertEqual(
383+
test_output['detection_boxes'], tf.constant(0.0, dtype=tf.float32)
384+
)
385+
self.assertEqual(
386+
test_output['detection_classes'], tf.constant(0.0, dtype=tf.float32)
387+
)
388+
self.assertEqual(
389+
test_output['detection_scores'], tf.constant(0.0, dtype=tf.float32)
390+
)
391+
392+
def test_decode_multilevel_tflite_nms_error_on_wrong_boxes_shape(self):
393+
config = common.TFLitePostProcessingConfig().as_dict()
394+
generator = detection_generator.MultilevelDetectionGenerator(
395+
apply_nms=True,
396+
nms_version='tflite',
397+
tflite_post_processing_config=config,
398+
)
399+
raw_scores = {'4': tf.zeros(shape=[1, 4, 4, 3 * 2], dtype=tf.float32)}
400+
raw_boxes = {'4': tf.zeros(shape=[1, 4, 4, 3], dtype=tf.float32)}
401+
anchor_boxes = {'4': tf.zeros(shape=[1, 4, 4, 4 * 2], dtype=tf.float32)}
402+
with self.assertRaisesRegex(
403+
ValueError,
404+
'The last dimension of predicted boxes should be divisible by 4.',
405+
):
406+
generator(
407+
raw_boxes=raw_boxes,
408+
raw_scores=raw_scores,
409+
anchor_boxes=anchor_boxes,
410+
image_shape=tf.constant([], dtype=tf.int32),
411+
)
412+
413+
def test_decode_multilevel_tflite_nms_error_on_wrong_scores_shape(self):
414+
config = common.TFLitePostProcessingConfig().as_dict()
415+
generator = detection_generator.MultilevelDetectionGenerator(
416+
apply_nms=True,
417+
nms_version='tflite',
418+
tflite_post_processing_config=config,
419+
)
420+
raw_scores = {'4': tf.zeros(shape=[1, 4, 4, 7 * 3], dtype=tf.float32)}
421+
raw_boxes = {'4': tf.zeros(shape=[1, 4, 4, 4 * 5], dtype=tf.float32)}
422+
anchor_boxes = {'4': tf.zeros(shape=[1, 4, 4, 4 * 5], dtype=tf.float32)}
423+
with self.assertRaisesRegex(
424+
ValueError,
425+
'The last dimension of predicted scores should be divisible by',
426+
):
427+
generator(
428+
raw_boxes=raw_boxes,
429+
raw_scores=raw_scores,
430+
anchor_boxes=anchor_boxes,
431+
image_shape=tf.constant([], dtype=tf.int32),
432+
)
433+
330434
def test_serialize_deserialize(self):
331435
tflite_post_processing_config = {
332436
'max_detections': 100,

0 commit comments

Comments
 (0)