4
4
from mmcv .cnn import normal_init
5
5
6
6
from mmdet .core import delta2bbox
7
- from mmdet .ops import nms
7
+ from mmdet .ops import batched_nms
8
8
from ..registry import HEADS
9
9
from .anchor_head import AnchorHead
10
10
@@ -61,7 +61,12 @@ def get_bboxes_single(self,
61
61
scale_factor ,
62
62
cfg ,
63
63
rescale = False ):
64
- mlvl_proposals = []
64
+ # bboxes from different level should be independent during NMS,
65
+ # level_ids are used as labels for batched NMS to separate them
66
+ level_ids = []
67
+ mlvl_scores = []
68
+ mlvl_bbox_preds = []
69
+ mlvl_valid_anchors = []
65
70
for idx in range (len (cls_scores )):
66
71
rpn_cls_score = cls_scores [idx ]
67
72
rpn_bbox_pred = bbox_preds [idx ]
@@ -79,30 +84,37 @@ def get_bboxes_single(self,
79
84
rpn_bbox_pred = rpn_bbox_pred .permute (1 , 2 , 0 ).reshape (- 1 , 4 )
80
85
anchors = mlvl_anchors [idx ]
81
86
if cfg .nms_pre > 0 and scores .shape [0 ] > cfg .nms_pre :
82
- _ , topk_inds = scores .topk (cfg .nms_pre )
87
+ # sort is faster than topk
88
+ # _, topk_inds = scores.topk(cfg.nms_pre)
89
+ ranked_scores , rank_inds = scores .sort (descending = True )
90
+ topk_inds = rank_inds [:cfg .nms_pre ]
91
+ scores = ranked_scores [:cfg .nms_pre ]
83
92
rpn_bbox_pred = rpn_bbox_pred [topk_inds , :]
84
93
anchors = anchors [topk_inds , :]
85
- scores = scores [topk_inds ]
86
- proposals = delta2bbox (anchors , rpn_bbox_pred , self .target_means ,
87
- self .target_stds , img_shape )
88
- if cfg .min_bbox_size > 0 :
89
- w = proposals [:, 2 ] - proposals [:, 0 ]
90
- h = proposals [:, 3 ] - proposals [:, 1 ]
91
- valid_inds = torch .nonzero ((w >= cfg .min_bbox_size ) &
92
- (h >= cfg .min_bbox_size )).squeeze ()
94
+ mlvl_scores .append (scores )
95
+ mlvl_bbox_preds .append (rpn_bbox_pred )
96
+ mlvl_valid_anchors .append (anchors )
97
+ level_ids .append (
98
+ scores .new_full ((scores .size (0 ), ), idx , dtype = torch .long ))
99
+
100
+ scores = torch .cat (mlvl_scores )
101
+ anchors = torch .cat (mlvl_valid_anchors )
102
+ rpn_bbox_pred = torch .cat (mlvl_bbox_preds )
103
+ proposals = delta2bbox (anchors , rpn_bbox_pred , self .target_means ,
104
+ self .target_stds , img_shape )
105
+ ids = torch .cat (level_ids )
106
+
107
+ if cfg .min_bbox_size > 0 :
108
+ w = proposals [:, 2 ] - proposals [:, 0 ]
109
+ h = proposals [:, 3 ] - proposals [:, 1 ]
110
+ valid_inds = torch .nonzero ((w >= cfg .min_bbox_size )
111
+ & (h >= cfg .min_bbox_size )).squeeze ()
112
+ if valid_inds .sum ().item () != len (proposals ):
93
113
proposals = proposals [valid_inds , :]
94
114
scores = scores [valid_inds ]
95
- proposals = torch .cat ([proposals , scores .unsqueeze (- 1 )], dim = - 1 )
96
- proposals , _ = nms (proposals , cfg .nms_thr )
97
- proposals = proposals [:cfg .nms_post , :]
98
- mlvl_proposals .append (proposals )
99
- proposals = torch .cat (mlvl_proposals , 0 )
100
- if cfg .nms_across_levels :
101
- proposals , _ = nms (proposals , cfg .nms_thr )
102
- proposals = proposals [:cfg .max_num , :]
103
- else :
104
- scores = proposals [:, 4 ]
105
- num = min (cfg .max_num , proposals .shape [0 ])
106
- _ , topk_inds = scores .topk (num )
107
- proposals = proposals [topk_inds , :]
108
- return proposals
115
+ ids = ids [valid_inds ]
116
+
117
+ # TODO: remove the hard coded nms type
118
+ nms_cfg = dict (type = 'nms' , iou_thr = cfg .nms_thr )
119
+ dets , keep = batched_nms (proposals , scores , ids , nms_cfg )
120
+ return dets [:cfg .nms_post ]
0 commit comments