8
8
images_to_levels , multi_apply , multiclass_nms , unmap )
9
9
from ..builder import HEADS , build_loss
10
10
from .base_dense_head import BaseDenseHead
11
+ from .dense_test_mixins import BBoxTestMixin
11
12
12
13
13
14
@HEADS .register_module ()
14
- class AnchorHead (BaseDenseHead ):
15
+ class AnchorHead (BaseDenseHead , BBoxTestMixin ):
15
16
"""Anchor-based head (RPN, RetinaNet, SSD, etc.).
16
17
17
18
Args:
@@ -502,7 +503,8 @@ def get_bboxes(self,
502
503
bbox_preds ,
503
504
img_metas ,
504
505
cfg = None ,
505
- rescale = False ):
506
+ rescale = False ,
507
+ with_nms = True ):
506
508
"""Transform network output for a batch into bbox predictions.
507
509
508
510
Args:
@@ -516,6 +518,8 @@ def get_bboxes(self,
516
518
if None, test_cfg would be used
517
519
rescale (bool): If True, return boxes in original image space.
518
520
Default: False.
521
+ with_nms (bool): If True, do nms before return boxes.
522
+ Default: True.
519
523
520
524
Returns:
521
525
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
@@ -569,9 +573,18 @@ def get_bboxes(self,
569
573
]
570
574
img_shape = img_metas [img_id ]['img_shape' ]
571
575
scale_factor = img_metas [img_id ]['scale_factor' ]
572
- proposals = self ._get_bboxes_single (cls_score_list , bbox_pred_list ,
573
- mlvl_anchors , img_shape ,
574
- scale_factor , cfg , rescale )
576
+ if with_nms :
577
+ # some heads don't support with_nms argument
578
+ proposals = self ._get_bboxes_single (cls_score_list ,
579
+ bbox_pred_list ,
580
+ mlvl_anchors , img_shape ,
581
+ scale_factor , cfg , rescale )
582
+ else :
583
+ proposals = self ._get_bboxes_single (cls_score_list ,
584
+ bbox_pred_list ,
585
+ mlvl_anchors , img_shape ,
586
+ scale_factor , cfg , rescale ,
587
+ with_nms )
575
588
result_list .append (proposals )
576
589
return result_list
577
590
@@ -582,7 +595,8 @@ def _get_bboxes_single(self,
582
595
img_shape ,
583
596
scale_factor ,
584
597
cfg ,
585
- rescale = False ):
598
+ rescale = False ,
599
+ with_nms = True ):
586
600
"""Transform outputs for a single batch item into bbox predictions.
587
601
588
602
Args:
@@ -599,6 +613,9 @@ def _get_bboxes_single(self,
599
613
cfg (mmcv.Config): Test / postprocessing configuration,
600
614
if None, test_cfg would be used.
601
615
rescale (bool): If True, return boxes in original image space.
616
+ Default: False.
617
+ with_nms (bool): If True, do nms before return boxes.
618
+ Default: True.
602
619
603
620
Returns:
604
621
Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
@@ -647,7 +664,29 @@ def _get_bboxes_single(self,
647
664
# BG cat_id: num_class
648
665
padding = mlvl_scores .new_zeros (mlvl_scores .shape [0 ], 1 )
649
666
mlvl_scores = torch .cat ([mlvl_scores , padding ], dim = 1 )
650
- det_bboxes , det_labels = multiclass_nms (mlvl_bboxes , mlvl_scores ,
651
- cfg .score_thr , cfg .nms ,
652
- cfg .max_per_img )
653
- return det_bboxes , det_labels
667
+
668
+ if with_nms :
669
+ det_bboxes , det_labels = multiclass_nms (mlvl_bboxes , mlvl_scores ,
670
+ cfg .score_thr , cfg .nms ,
671
+ cfg .max_per_img )
672
+ return det_bboxes , det_labels
673
+ else :
674
+ return mlvl_bboxes , mlvl_scores
675
+
676
+ def aug_test (self , feats , img_metas , rescale = False ):
677
+ """Test function with test time augmentation.
678
+
679
+ Args:
680
+ feats (list[Tensor]): the outer list indicates test-time
681
+ augmentations and inner Tensor should have a shape NxCxHxW,
682
+ which contains features for all images in the batch.
683
+ img_metas (list[list[dict]]): the outer list indicates test-time
684
+ augs (multiscale, flip, etc.) and the inner list indicates
685
+ images in a batch. each dict has image information.
686
+ rescale (bool, optional): Whether to rescale the results.
687
+ Defaults to False.
688
+
689
+ Returns:
690
+ list[ndarray]: bbox results of each class
691
+ """
692
+ return self .aug_test_bboxes (feats , img_metas , rescale = rescale )
0 commit comments