Skip to content

Commit 9c95543

Browse files
authored
Support TTA of RetinaNet and GFL (open-mmlab#3638)
* Move RepPoints TTA to mixin class for reuse * Support TTA of RetinaNet * Support TTA of GFL * Update to use BBoxTestMixin in dense_heads * Update for v2.4.0 inference
1 parent 868f7e5 commit 9c95543

7 files changed

+195
-99
lines changed

mmdet/models/dense_heads/anchor_free_head.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
from mmdet.core import multi_apply
99
from ..builder import HEADS, build_loss
1010
from .base_dense_head import BaseDenseHead
11+
from .dense_test_mixins import BBoxTestMixin
1112

1213

1314
@HEADS.register_module()
14-
class AnchorFreeHead(BaseDenseHead):
15+
class AnchorFreeHead(BaseDenseHead, BBoxTestMixin):
1516
"""Anchor-free head (FCOS, Fovea, RepPoints, etc.).
1617
1718
Args:
@@ -328,3 +329,21 @@ def get_points(self, featmap_sizes, dtype, device, flatten=False):
328329
self._get_points_single(featmap_sizes[i], self.strides[i],
329330
dtype, device, flatten))
330331
return mlvl_points
332+
333+
def aug_test(self, feats, img_metas, rescale=False):
334+
"""Test function with test time augmentation.
335+
336+
Args:
337+
feats (list[Tensor]): the outer list indicates test-time
338+
augmentations and inner Tensor should have a shape NxCxHxW,
339+
which contains features for all images in the batch.
340+
img_metas (list[list[dict]]): the outer list indicates test-time
341+
augs (multiscale, flip, etc.) and the inner list indicates
342+
images in a batch. each dict has image information.
343+
rescale (bool, optional): Whether to rescale the results.
344+
Defaults to False.
345+
346+
Returns:
347+
list[ndarray]: bbox results of each class
348+
"""
349+
return self.aug_test_bboxes(feats, img_metas, rescale=rescale)

mmdet/models/dense_heads/anchor_head.py

+49-10
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
images_to_levels, multi_apply, multiclass_nms, unmap)
99
from ..builder import HEADS, build_loss
1010
from .base_dense_head import BaseDenseHead
11+
from .dense_test_mixins import BBoxTestMixin
1112

1213

1314
@HEADS.register_module()
14-
class AnchorHead(BaseDenseHead):
15+
class AnchorHead(BaseDenseHead, BBoxTestMixin):
1516
"""Anchor-based head (RPN, RetinaNet, SSD, etc.).
1617
1718
Args:
@@ -502,7 +503,8 @@ def get_bboxes(self,
502503
bbox_preds,
503504
img_metas,
504505
cfg=None,
505-
rescale=False):
506+
rescale=False,
507+
with_nms=True):
506508
"""Transform network output for a batch into bbox predictions.
507509
508510
Args:
@@ -516,6 +518,8 @@ def get_bboxes(self,
516518
if None, test_cfg would be used
517519
rescale (bool): If True, return boxes in original image space.
518520
Default: False.
521+
with_nms (bool): If True, do nms before return boxes.
522+
Default: True.
519523
520524
Returns:
521525
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
@@ -569,9 +573,18 @@ def get_bboxes(self,
569573
]
570574
img_shape = img_metas[img_id]['img_shape']
571575
scale_factor = img_metas[img_id]['scale_factor']
572-
proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list,
573-
mlvl_anchors, img_shape,
574-
scale_factor, cfg, rescale)
576+
if with_nms:
577+
# some heads don't support with_nms argument
578+
proposals = self._get_bboxes_single(cls_score_list,
579+
bbox_pred_list,
580+
mlvl_anchors, img_shape,
581+
scale_factor, cfg, rescale)
582+
else:
583+
proposals = self._get_bboxes_single(cls_score_list,
584+
bbox_pred_list,
585+
mlvl_anchors, img_shape,
586+
scale_factor, cfg, rescale,
587+
with_nms)
575588
result_list.append(proposals)
576589
return result_list
577590

@@ -582,7 +595,8 @@ def _get_bboxes_single(self,
582595
img_shape,
583596
scale_factor,
584597
cfg,
585-
rescale=False):
598+
rescale=False,
599+
with_nms=True):
586600
"""Transform outputs for a single batch item into bbox predictions.
587601
588602
Args:
@@ -599,6 +613,9 @@ def _get_bboxes_single(self,
599613
cfg (mmcv.Config): Test / postprocessing configuration,
600614
if None, test_cfg would be used.
601615
rescale (bool): If True, return boxes in original image space.
616+
Default: False.
617+
with_nms (bool): If True, do nms before return boxes.
618+
Default: True.
602619
603620
Returns:
604621
Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
@@ -647,7 +664,29 @@ def _get_bboxes_single(self,
647664
# BG cat_id: num_class
648665
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
649666
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
650-
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
651-
cfg.score_thr, cfg.nms,
652-
cfg.max_per_img)
653-
return det_bboxes, det_labels
667+
668+
if with_nms:
669+
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
670+
cfg.score_thr, cfg.nms,
671+
cfg.max_per_img)
672+
return det_bboxes, det_labels
673+
else:
674+
return mlvl_bboxes, mlvl_scores
675+
676+
def aug_test(self, feats, img_metas, rescale=False):
677+
"""Test function with test time augmentation.
678+
679+
Args:
680+
feats (list[Tensor]): the outer list indicates test-time
681+
augmentations and inner Tensor should have a shape NxCxHxW,
682+
which contains features for all images in the batch.
683+
img_metas (list[list[dict]]): the outer list indicates test-time
684+
augs (multiscale, flip, etc.) and the inner list indicates
685+
images in a batch. each dict has image information.
686+
rescale (bool, optional): Whether to rescale the results.
687+
Defaults to False.
688+
689+
Returns:
690+
list[ndarray]: bbox results of each class
691+
"""
692+
return self.aug_test_bboxes(feats, img_metas, rescale=rescale)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from inspect import signature
2+
3+
import torch
4+
5+
from mmdet.core import bbox2result, bbox_mapping_back, multiclass_nms
6+
7+
8+
class BBoxTestMixin(object):
9+
"""Mixin class for test time augmentation of bboxes."""
10+
11+
def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas):
12+
"""Merge augmented detection bboxes and scores.
13+
14+
Args:
15+
aug_bboxes (list[Tensor]): shape (n, 4*#class)
16+
aug_scores (list[Tensor] or None): shape (n, #class)
17+
img_shapes (list[Tensor]): shape (3, ).
18+
19+
Returns:
20+
tuple: (bboxes, scores)
21+
"""
22+
recovered_bboxes = []
23+
for bboxes, img_info in zip(aug_bboxes, img_metas):
24+
img_shape = img_info[0]['img_shape']
25+
scale_factor = img_info[0]['scale_factor']
26+
flip = img_info[0]['flip']
27+
flip_direction = img_info[0]['flip_direction']
28+
bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip,
29+
flip_direction)
30+
recovered_bboxes.append(bboxes)
31+
bboxes = torch.cat(recovered_bboxes, dim=0)
32+
if aug_scores is None:
33+
return bboxes
34+
else:
35+
scores = torch.cat(aug_scores, dim=0)
36+
return bboxes, scores
37+
38+
def aug_test_bboxes(self, feats, img_metas, rescale=False):
39+
"""Test det bboxes with test time augmentation.
40+
41+
Args:
42+
feats (list[Tensor]): the outer list indicates test-time
43+
augmentations and inner Tensor should have a shape NxCxHxW,
44+
which contains features for all images in the batch.
45+
img_metas (list[list[dict]]): the outer list indicates test-time
46+
augs (multiscale, flip, etc.) and the inner list indicates
47+
images in a batch. each dict has image information.
48+
rescale (bool, optional): Whether to rescale the results.
49+
Defaults to False.
50+
51+
Returns:
52+
list[ndarray]: bbox results of each class
53+
"""
54+
# check with_nms argument
55+
gb_sig = signature(self.get_bboxes)
56+
gb_args = [p.name for p in gb_sig.parameters.values()]
57+
gbs_sig = signature(self._get_bboxes_single)
58+
gbs_args = [p.name for p in gbs_sig.parameters.values()]
59+
assert ('with_nms' in gb_args) and ('with_nms' in gbs_args), \
60+
f'{self.__class__.__name__}' \
61+
' does not support test-time augmentation'
62+
63+
aug_bboxes = []
64+
aug_scores = []
65+
for x, img_meta in zip(feats, img_metas):
66+
# only one image in the batch
67+
outs = self.forward(x)
68+
bbox_inputs = outs + (img_meta, self.test_cfg, False, False)
69+
det_bboxes, det_scores = self.get_bboxes(*bbox_inputs)[0]
70+
aug_bboxes.append(det_bboxes)
71+
aug_scores.append(det_scores)
72+
73+
# after merging, bboxes will be rescaled to the original image size
74+
merged_bboxes, merged_scores = self.merge_aug_bboxes(
75+
aug_bboxes, aug_scores, img_metas)
76+
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
77+
self.test_cfg.score_thr,
78+
self.test_cfg.nms,
79+
self.test_cfg.max_per_img)
80+
81+
if rescale:
82+
_det_bboxes = det_bboxes
83+
else:
84+
_det_bboxes = det_bboxes.clone()
85+
_det_bboxes[:, :4] *= det_bboxes.new_tensor(
86+
img_metas[0][0]['scale_factor'])
87+
bbox_results = bbox2result(_det_bboxes, det_labels, self.num_classes)
88+
return bbox_results

mmdet/models/dense_heads/gfl_head.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,8 @@ def _get_bboxes_single(self,
382382
img_shape,
383383
scale_factor,
384384
cfg,
385-
rescale=False):
385+
rescale=False,
386+
with_nms=True):
386387
"""Transform outputs for a single batch item into labeled boxes.
387388
388389
Args:
@@ -401,6 +402,8 @@ def _get_bboxes_single(self,
401402
if None, test_cfg would be used.
402403
rescale (bool): If True, return boxes in original image space.
403404
Default: False.
405+
with_nms (bool): If True, do nms before return boxes.
406+
Default: True.
404407
405408
Returns:
406409
tuple(Tensor):
@@ -450,10 +453,13 @@ def _get_bboxes_single(self,
450453
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
451454
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
452455

453-
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
454-
cfg.score_thr, cfg.nms,
455-
cfg.max_per_img)
456-
return det_bboxes, det_labels
456+
if with_nms:
457+
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
458+
cfg.score_thr, cfg.nms,
459+
cfg.max_per_img)
460+
return det_bboxes, det_labels
461+
else:
462+
return mlvl_bboxes, mlvl_scores
457463

458464
def get_targets(self,
459465
anchor_list,

mmdet/models/dense_heads/reppoints_head.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ def get_bboxes(self,
664664
img_metas,
665665
cfg=None,
666666
rescale=False,
667-
nms=True):
667+
with_nms=True):
668668
assert len(cls_scores) == len(pts_preds_refine)
669669
bbox_preds_refine = [
670670
self.points2bbox(pts_pred_refine)
@@ -690,7 +690,7 @@ def get_bboxes(self,
690690
proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list,
691691
mlvl_points, img_shape,
692692
scale_factor, cfg, rescale,
693-
nms)
693+
with_nms)
694694
result_list.append(proposals)
695695
return result_list
696696

@@ -702,7 +702,7 @@ def _get_bboxes_single(self,
702702
scale_factor,
703703
cfg,
704704
rescale=False,
705-
nms=True):
705+
with_nms=True):
706706
cfg = self.test_cfg if cfg is None else cfg
707707
assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
708708
mlvl_bboxes = []
@@ -749,7 +749,7 @@ def _get_bboxes_single(self,
749749
# BG cat_id: num_class
750750
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
751751
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
752-
if nms:
752+
if with_nms:
753753
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
754754
cfg.score_thr, cfg.nms,
755755
cfg.max_per_img)
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
import torch
2-
3-
from mmdet.core import bbox2result, bbox_mapping_back, multiclass_nms
41
from ..builder import DETECTORS
52
from .single_stage import SingleStageDetector
63

@@ -23,77 +20,3 @@ def __init__(self,
2320
super(RepPointsDetector,
2421
self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
2522
pretrained)
26-
27-
def merge_aug_results(self, aug_bboxes, aug_scores, img_metas):
28-
"""Merge augmented detection bboxes and scores.
29-
30-
Args:
31-
aug_bboxes (list[Tensor]): shape (n, 4*#class)
32-
aug_scores (list[Tensor] or None): shape (n, #class)
33-
img_shapes (list[Tensor]): shape (3, ).
34-
35-
Returns:
36-
tuple: (bboxes, scores)
37-
"""
38-
recovered_bboxes = []
39-
for bboxes, img_info in zip(aug_bboxes, img_metas):
40-
img_shape = img_info[0]['img_shape']
41-
scale_factor = img_info[0]['scale_factor']
42-
flip = img_info[0]['flip']
43-
flip_direction = img_info[0]['flip_direction']
44-
bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip,
45-
flip_direction)
46-
recovered_bboxes.append(bboxes)
47-
bboxes = torch.cat(recovered_bboxes, dim=0)
48-
if aug_scores is None:
49-
return bboxes
50-
else:
51-
scores = torch.cat(aug_scores, dim=0)
52-
return bboxes, scores
53-
54-
def aug_test(self, imgs, img_metas, rescale=False):
55-
"""Test function with test time augmentation.
56-
57-
Args:
58-
imgs (list[Tensor]): the outer list indicates test-time
59-
augmentations and inner Tensor should have a shape NxCxHxW,
60-
which contains all images in the batch.
61-
img_metas (list[list[dict]]): the outer list indicates test-time
62-
augs (multiscale, flip, etc.) and the inner list indicates
63-
images in a batch. each dict has image information.
64-
rescale (bool, optional): Whether to rescale the results.
65-
Defaults to False.
66-
67-
Returns:
68-
list[ndarray]: bbox results of each class
69-
"""
70-
# recompute feats to save memory
71-
feats = self.extract_feats(imgs)
72-
73-
aug_bboxes = []
74-
aug_scores = []
75-
for x, img_meta in zip(feats, img_metas):
76-
# only one image in the batch
77-
outs = self.bbox_head(x)
78-
bbox_inputs = outs + (img_meta, self.test_cfg, False, False)
79-
det_bboxes, det_scores = self.bbox_head.get_bboxes(*bbox_inputs)[0]
80-
aug_bboxes.append(det_bboxes)
81-
aug_scores.append(det_scores)
82-
83-
# after merging, bboxes will be rescaled to the original image size
84-
merged_bboxes, merged_scores = self.merge_aug_results(
85-
aug_bboxes, aug_scores, img_metas)
86-
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
87-
self.test_cfg.score_thr,
88-
self.test_cfg.nms,
89-
self.test_cfg.max_per_img)
90-
91-
if rescale:
92-
_det_bboxes = det_bboxes
93-
else:
94-
_det_bboxes = det_bboxes.clone()
95-
_det_bboxes[:, :4] *= det_bboxes.new_tensor(
96-
img_metas[0][0]['scale_factor'])
97-
bbox_results = bbox2result(_det_bboxes, det_labels,
98-
self.bbox_head.num_classes)
99-
return bbox_results

0 commit comments

Comments
 (0)