Skip to content

Commit af6e1f8

Browse files
authored
V2.0 speedup rpn (open-mmlab#2420)
* speed inference of rpn * nms return sorted inds * add comment * minor perfect * rename idxs to inds
1 parent 3ed272a commit af6e1f8

File tree

6 files changed

+80
-60
lines changed

6 files changed

+80
-60
lines changed

mmdet/core/post_processing/bbox_nms.py

+3-27
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22

3-
from mmdet.ops.nms import nms_wrapper
3+
from mmdet.ops.nms import batched_nms
44

55

66
def multiclass_nms(multi_bboxes,
@@ -48,29 +48,5 @@ def multiclass_nms(multi_bboxes,
4848
labels = multi_bboxes.new_zeros((0, ), dtype=torch.long)
4949
return bboxes, labels
5050

51-
# Modified from https://github.com/pytorch/vision/blob
52-
# /505cd6957711af790211896d32b40291bea1bc21/torchvision/ops/boxes.py#L39.
53-
# strategy: in order to perform NMS independently per class.
54-
# we add an offset to all the boxes. The offset is dependent
55-
# only on the class idx, and is large enough so that boxes
56-
# from different classes do not overlap
57-
max_coordinate = bboxes.max()
58-
offsets = labels.to(bboxes) * (max_coordinate + 1)
59-
bboxes_for_nms = bboxes + offsets[:, None]
60-
nms_cfg_ = nms_cfg.copy()
61-
nms_type = nms_cfg_.pop('type', 'nms')
62-
nms_op = getattr(nms_wrapper, nms_type)
63-
dets, keep = nms_op(
64-
torch.cat([bboxes_for_nms, scores[:, None]], 1), **nms_cfg_)
65-
bboxes = bboxes[keep]
66-
scores = dets[:, -1] # soft_nms will modify scores
67-
labels = labels[keep]
68-
69-
if keep.size(0) > max_num:
70-
_, inds = scores.sort(descending=True)
71-
inds = inds[:max_num]
72-
bboxes = bboxes[inds]
73-
scores = scores[inds]
74-
labels = labels[inds]
75-
76-
return torch.cat([bboxes, scores[:, None]], 1), labels
51+
dets, keep = batched_nms(bboxes, scores, labels, nms_cfg)
52+
return dets[:max_num], labels[keep[:max_num]]

mmdet/models/anchor_heads/rpn_head.py

+37-25
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from mmcv.cnn import normal_init
55

66
from mmdet.core import delta2bbox
7-
from mmdet.ops import nms
7+
from mmdet.ops import batched_nms
88
from ..registry import HEADS
99
from .anchor_head import AnchorHead
1010

@@ -61,7 +61,12 @@ def get_bboxes_single(self,
6161
scale_factor,
6262
cfg,
6363
rescale=False):
64-
mlvl_proposals = []
64+
# bboxes from different level should be independent during NMS,
65+
# level_ids are used as labels for batched NMS to separate them
66+
level_ids = []
67+
mlvl_scores = []
68+
mlvl_bbox_preds = []
69+
mlvl_valid_anchors = []
6570
for idx in range(len(cls_scores)):
6671
rpn_cls_score = cls_scores[idx]
6772
rpn_bbox_pred = bbox_preds[idx]
@@ -79,30 +84,37 @@ def get_bboxes_single(self,
7984
rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)
8085
anchors = mlvl_anchors[idx]
8186
if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
82-
_, topk_inds = scores.topk(cfg.nms_pre)
87+
# sort is faster than topk
88+
# _, topk_inds = scores.topk(cfg.nms_pre)
89+
ranked_scores, rank_inds = scores.sort(descending=True)
90+
topk_inds = rank_inds[:cfg.nms_pre]
91+
scores = ranked_scores[:cfg.nms_pre]
8392
rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
8493
anchors = anchors[topk_inds, :]
85-
scores = scores[topk_inds]
86-
proposals = delta2bbox(anchors, rpn_bbox_pred, self.target_means,
87-
self.target_stds, img_shape)
88-
if cfg.min_bbox_size > 0:
89-
w = proposals[:, 2] - proposals[:, 0]
90-
h = proposals[:, 3] - proposals[:, 1]
91-
valid_inds = torch.nonzero((w >= cfg.min_bbox_size) &
92-
(h >= cfg.min_bbox_size)).squeeze()
94+
mlvl_scores.append(scores)
95+
mlvl_bbox_preds.append(rpn_bbox_pred)
96+
mlvl_valid_anchors.append(anchors)
97+
level_ids.append(
98+
scores.new_full((scores.size(0), ), idx, dtype=torch.long))
99+
100+
scores = torch.cat(mlvl_scores)
101+
anchors = torch.cat(mlvl_valid_anchors)
102+
rpn_bbox_pred = torch.cat(mlvl_bbox_preds)
103+
proposals = delta2bbox(anchors, rpn_bbox_pred, self.target_means,
104+
self.target_stds, img_shape)
105+
ids = torch.cat(level_ids)
106+
107+
if cfg.min_bbox_size > 0:
108+
w = proposals[:, 2] - proposals[:, 0]
109+
h = proposals[:, 3] - proposals[:, 1]
110+
valid_inds = torch.nonzero((w >= cfg.min_bbox_size)
111+
& (h >= cfg.min_bbox_size)).squeeze()
112+
if valid_inds.sum().item() != len(proposals):
93113
proposals = proposals[valid_inds, :]
94114
scores = scores[valid_inds]
95-
proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1)
96-
proposals, _ = nms(proposals, cfg.nms_thr)
97-
proposals = proposals[:cfg.nms_post, :]
98-
mlvl_proposals.append(proposals)
99-
proposals = torch.cat(mlvl_proposals, 0)
100-
if cfg.nms_across_levels:
101-
proposals, _ = nms(proposals, cfg.nms_thr)
102-
proposals = proposals[:cfg.max_num, :]
103-
else:
104-
scores = proposals[:, 4]
105-
num = min(cfg.max_num, proposals.shape[0])
106-
_, topk_inds = scores.topk(num)
107-
proposals = proposals[topk_inds, :]
108-
return proposals
115+
ids = ids[valid_inds]
116+
117+
# TODO: remove the hard coded nms type
118+
nms_cfg = dict(type='nms', iou_thr=cfg.nms_thr)
119+
dets, keep = batched_nms(proposals, scores, ids, nms_cfg)
120+
return dets[:cfg.nms_post]

mmdet/ops/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
deform_conv, deform_roi_pooling, modulated_deform_conv)
99
from .generalized_attention import GeneralizedAttention
1010
from .masked_conv import MaskedConv2d
11-
from .nms import nms, soft_nms
11+
from .nms import batched_nms, nms, soft_nms
1212
from .non_local import NonLocal2D
1313
from .norm import build_norm_layer
1414
from .plugin import build_plugin_layer
@@ -28,5 +28,5 @@
2828
'MaskedConv2d', 'ContextBlock', 'GeneralizedAttention', 'NonLocal2D',
2929
'get_compiler_version', 'get_compiling_cuda_version', 'build_conv_layer',
3030
'ConvModule', 'ConvWS2d', 'conv_ws_2d', 'build_norm_layer', 'Scale',
31-
'build_upsample_layer', 'build_plugin_layer'
31+
'build_upsample_layer', 'build_plugin_layer', 'batched_nms'
3232
]

mmdet/ops/nms/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .nms_wrapper import nms, soft_nms
1+
from .nms_wrapper import batched_nms, nms, soft_nms
22

3-
__all__ = ['nms', 'soft_nms']
3+
__all__ = ['nms', 'soft_nms', 'batched_nms']

mmdet/ops/nms/nms_wrapper.py

+33
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,36 @@ def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3):
116116
else:
117117
return new_dets.numpy().astype(dets.dtype), inds.numpy().astype(
118118
np.int64)
119+
120+
121+
def batched_nms(bboxes, scores, inds, nms_cfg):
122+
"""Performs non-maximum suppression in a batched fashion.
123+
124+
Modified from https://github.com/pytorch/vision/blob
125+
/505cd6957711af790211896d32b40291bea1bc21/torchvision/ops/boxes.py#L39.
126+
In order to perform NMS independently per class, we add an offset to all
127+
the boxes. The offset is dependent only on the class idx, and is large
128+
enough so that boxes from different classes do not overlap.
129+
130+
Arguments:
131+
bboxes (torch.Tensor): bboxes in shape (N, 4).
132+
scores (torch.Tensor): scores in shape (N, ).
133+
inds (torch.Tensor): each index value correspond to a bbox cluster,
134+
and NMS will not be applied between elements of different inds,
135+
shape (N, ).
136+
nms_cfg (dict): specify nms type and other parameters like iou_thr.
137+
138+
Returns:
139+
tuple: kept bboxes and indice.
140+
"""
141+
max_coordinate = bboxes.max()
142+
offsets = inds.to(bboxes) * (max_coordinate + 1)
143+
bboxes_for_nms = bboxes + offsets[:, None]
144+
nms_cfg_ = nms_cfg.copy()
145+
nms_type = nms_cfg_.pop('type', 'nms')
146+
nms_op = eval(nms_type)
147+
dets, keep = nms_op(
148+
torch.cat([bboxes_for_nms, scores[:, None]], -1), **nms_cfg_)
149+
bboxes = bboxes[keep]
150+
scores = dets[:, -1]
151+
return torch.cat([bboxes, scores[:, None]], -1), keep

mmdet/ops/nms/src/cuda/nms_kernel.cu

+3-4
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,7 @@ at::Tensor nms_cuda_forward(const at::Tensor boxes, float nms_overlap_thresh) {
132132

133133
THCudaFree(state, mask_dev);
134134
// TODO improve this part
135-
return std::get<0>(order_t.index({
136-
keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(
137-
order_t.device(), keep.scalar_type())
138-
}).sort(0, false));
135+
return order_t.index({
136+
keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(
137+
order_t.device(), keep.scalar_type())});
139138
}

0 commit comments

Comments
 (0)