|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | """Tests for detection_generator.py."""
|
| 16 | +from unittest import mock |
| 17 | + |
16 | 18 | # Import libraries
|
17 | 19 |
|
18 | 20 | from absl.testing import parameterized
|
19 | 21 | import numpy as np
|
20 | 22 | import tensorflow as tf, tf_keras
|
21 | 23 |
|
| 24 | +from official.vision.configs import common |
22 | 25 | from official.vision.modeling.layers import detection_generator
|
23 | 26 | from official.vision.ops import anchor
|
24 | 27 |
|
@@ -327,6 +330,107 @@ def test_decode_multilevel_outputs_and_pre_nms_top_k(self):
|
327 | 330 | ]),
|
328 | 331 | ]))
|
329 | 332 |
|
| 333 | + def test_decode_multilevel_with_tflite_nms(self): |
| 334 | + config = common.TFLitePostProcessingConfig().as_dict() |
| 335 | + generator = detection_generator.MultilevelDetectionGenerator( |
| 336 | + apply_nms=True, |
| 337 | + nms_version='tflite', |
| 338 | + box_coder_weights=[9, 8, 7, 6], |
| 339 | + tflite_post_processing_config=config, |
| 340 | + ) |
| 341 | + raw_scores = { |
| 342 | + '4': tf.zeros(shape=[1, 8, 8, 3 * 2], dtype=tf.float32), |
| 343 | + '5': tf.zeros(shape=[1, 4, 4, 3 * 2], dtype=tf.float32), |
| 344 | + } |
| 345 | + raw_boxes = { |
| 346 | + '4': tf.zeros(shape=[1, 8, 8, 4 * 2], dtype=tf.float32), |
| 347 | + '5': tf.zeros(shape=[1, 4, 4, 4 * 2], dtype=tf.float32), |
| 348 | + } |
| 349 | + anchor_boxes = { |
| 350 | + '4': tf.zeros(shape=[1, 8, 8, 4 * 2], dtype=tf.float32), |
| 351 | + '5': tf.zeros(shape=[1, 4, 4, 4 * 2], dtype=tf.float32), |
| 352 | + } |
| 353 | + |
| 354 | + expected_signature = ( |
| 355 | + 'name: "TFLite_Detection_PostProcess" attr { key: "max_detections"' |
| 356 | + ' value { i: 200 } } attr { key: "max_classes_per_detection" value { i:' |
| 357 | + ' 5 } } attr { key: "detections_per_class" value { i: 5 } } attr { key:' |
| 358 | + ' "use_regular_nms" value { b: false } } attr { key:' |
| 359 | + ' "nms_score_threshold" value { f: 0.100000 } } attr { key:' |
| 360 | + ' "nms_iou_threshold" value { f: 0.500000 } } attr { key: "y_scale"' |
| 361 | + ' value { f: 9.000000 } } attr { key: "x_scale" value { f: 8.000000 } }' |
| 362 | + ' attr { key: "h_scale" value { f: 7.000000 } } attr { key: "w_scale"' |
| 363 | + ' value { f: 6.000000 } } attr { key: "num_classes" value { i: 3 } }' |
| 364 | + ) |
| 365 | + |
| 366 | + with mock.patch.object( |
| 367 | + tf, 'function', wraps=tf.function |
| 368 | + ) as mock_tf_function: |
| 369 | + test_output = generator( |
| 370 | + raw_boxes=raw_boxes, |
| 371 | + raw_scores=raw_scores, |
| 372 | + anchor_boxes=anchor_boxes, |
| 373 | + image_shape=tf.constant([], dtype=tf.int32), |
| 374 | + ) |
| 375 | + mock_tf_function.assert_called_once_with( |
| 376 | + experimental_implements=expected_signature |
| 377 | + ) |
| 378 | + |
| 379 | + self.assertEqual( |
| 380 | + test_output['num_detections'], tf.constant(0.0, dtype=tf.float32) |
| 381 | + ) |
| 382 | + self.assertEqual( |
| 383 | + test_output['detection_boxes'], tf.constant(0.0, dtype=tf.float32) |
| 384 | + ) |
| 385 | + self.assertEqual( |
| 386 | + test_output['detection_classes'], tf.constant(0.0, dtype=tf.float32) |
| 387 | + ) |
| 388 | + self.assertEqual( |
| 389 | + test_output['detection_scores'], tf.constant(0.0, dtype=tf.float32) |
| 390 | + ) |
| 391 | + |
| 392 | + def test_decode_multilevel_tflite_nms_error_on_wrong_boxes_shape(self): |
| 393 | + config = common.TFLitePostProcessingConfig().as_dict() |
| 394 | + generator = detection_generator.MultilevelDetectionGenerator( |
| 395 | + apply_nms=True, |
| 396 | + nms_version='tflite', |
| 397 | + tflite_post_processing_config=config, |
| 398 | + ) |
| 399 | + raw_scores = {'4': tf.zeros(shape=[1, 4, 4, 3 * 2], dtype=tf.float32)} |
| 400 | + raw_boxes = {'4': tf.zeros(shape=[1, 4, 4, 3], dtype=tf.float32)} |
| 401 | + anchor_boxes = {'4': tf.zeros(shape=[1, 4, 4, 4 * 2], dtype=tf.float32)} |
| 402 | + with self.assertRaisesRegex( |
| 403 | + ValueError, |
| 404 | + 'The last dimension of predicted boxes should be divisible by 4.', |
| 405 | + ): |
| 406 | + generator( |
| 407 | + raw_boxes=raw_boxes, |
| 408 | + raw_scores=raw_scores, |
| 409 | + anchor_boxes=anchor_boxes, |
| 410 | + image_shape=tf.constant([], dtype=tf.int32), |
| 411 | + ) |
| 412 | + |
| 413 | + def test_decode_multilevel_tflite_nms_error_on_wrong_scores_shape(self): |
| 414 | + config = common.TFLitePostProcessingConfig().as_dict() |
| 415 | + generator = detection_generator.MultilevelDetectionGenerator( |
| 416 | + apply_nms=True, |
| 417 | + nms_version='tflite', |
| 418 | + tflite_post_processing_config=config, |
| 419 | + ) |
| 420 | + raw_scores = {'4': tf.zeros(shape=[1, 4, 4, 7 * 3], dtype=tf.float32)} |
| 421 | + raw_boxes = {'4': tf.zeros(shape=[1, 4, 4, 4 * 5], dtype=tf.float32)} |
| 422 | + anchor_boxes = {'4': tf.zeros(shape=[1, 4, 4, 4 * 5], dtype=tf.float32)} |
| 423 | + with self.assertRaisesRegex( |
| 424 | + ValueError, |
| 425 | + 'The last dimension of predicted scores should be divisible by', |
| 426 | + ): |
| 427 | + generator( |
| 428 | + raw_boxes=raw_boxes, |
| 429 | + raw_scores=raw_scores, |
| 430 | + anchor_boxes=anchor_boxes, |
| 431 | + image_shape=tf.constant([], dtype=tf.int32), |
| 432 | + ) |
| 433 | + |
330 | 434 | def test_serialize_deserialize(self):
|
331 | 435 | tflite_post_processing_config = {
|
332 | 436 | 'max_detections': 100,
|
|
0 commit comments