Skip to content

Commit e91da70

Browse files
authored
Refactor anchor_generator and point_generator (open-mmlab#5349)
* add sparse priors * add mlvlpointsgenerator * revert __init__ of core * refactor reppoints * delete label channal * add docstr * fix typo * fix args * fix typo * fix doc * fix stride_h * add offset * add offset * fix docstr * new interface of single_proir * fix device * add unitest * add cuda unitest * add more cuda unintest * fix reppoints * fix device * add unintest for ssd and yolo and rename prior_idxs * add docstr for MlvlPointGenerator * add space * add num_base_priors
1 parent 269bb9e commit e91da70

File tree

6 files changed

+626
-49
lines changed

6 files changed

+626
-49
lines changed

mmdet/core/anchor/__init__.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from .anchor_generator import (AnchorGenerator, LegacyAnchorGenerator,
22
YOLOAnchorGenerator)
3-
from .builder import ANCHOR_GENERATORS, build_anchor_generator
4-
from .point_generator import PointGenerator
3+
from .builder import (ANCHOR_GENERATORS, PRIOR_GENERATORS,
4+
build_anchor_generator, build_prior_generator)
5+
from .point_generator import MlvlPointGenerator, PointGenerator
56
from .utils import anchor_inside_flags, calc_region, images_to_levels
67

78
__all__ = [
89
'AnchorGenerator', 'LegacyAnchorGenerator', 'anchor_inside_flags',
910
'PointGenerator', 'images_to_levels', 'calc_region',
10-
'build_anchor_generator', 'ANCHOR_GENERATORS', 'YOLOAnchorGenerator'
11+
'build_anchor_generator', 'ANCHOR_GENERATORS', 'YOLOAnchorGenerator',
12+
'build_prior_generator', 'PRIOR_GENERATORS', 'MlvlPointGenerator'
1113
]

mmdet/core/anchor/anchor_generator.py

+121-10
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
import warnings
2+
13
import mmcv
24
import numpy as np
35
import torch
46
from torch.nn.modules.utils import _pair
57

6-
from .builder import ANCHOR_GENERATORS
8+
from .builder import PRIOR_GENERATORS
79

810

9-
@ANCHOR_GENERATORS.register_module()
11+
@PRIOR_GENERATORS.register_module()
1012
class AnchorGenerator:
1113
"""Standard anchor generator for 2D anchor-based detectors.
1214
@@ -68,7 +70,7 @@ def __init__(self,
6870
# check center and center_offset
6971
if center_offset != 0:
7072
assert centers is None, 'center cannot be set when center_offset' \
71-
f'!=0, {centers} is given.'
73+
f'!=0, {centers} is given.'
7274
if not (0 <= center_offset <= 1):
7375
raise ValueError('center_offset should be in range [0, 1], '
7476
f'{center_offset} is given.')
@@ -87,7 +89,7 @@ def __init__(self,
8789

8890
# calculate scales of anchors
8991
assert ((octave_base_scale is not None
90-
and scales_per_octave is not None) ^ (scales is not None)), \
92+
and scales_per_octave is not None) ^ (scales is not None)), \
9193
'scales and octave_base_scale with scales_per_octave cannot' \
9294
' be set at the same time'
9395
if scales is not None:
@@ -112,6 +114,12 @@ def __init__(self,
112114
@property
113115
def num_base_anchors(self):
114116
"""list[int]: total number of base anchors in a feature grid"""
117+
return self.num_base_priors
118+
119+
@property
120+
def num_base_priors(self):
121+
"""list[int]: The number of priors (anchors) at a point
122+
on the feature grid"""
115123
return [base_anchors.size(0) for base_anchors in self.base_anchors]
116124

117125
@property
@@ -204,6 +212,99 @@ def _meshgrid(self, x, y, row_major=True):
204212
else:
205213
return yy, xx
206214

215+
def grid_priors(self, featmap_sizes, device='cuda'):
216+
"""Generate grid anchors in multiple feature levels.
217+
218+
Args:
219+
featmap_sizes (list[tuple]): List of feature map sizes in
220+
multiple feature levels.
221+
device (str): The device where the anchors will be put on.
222+
223+
Return:
224+
list[torch.Tensor]: Anchors in multiple feature levels. \
225+
The sizes of each tensor should be [N, 4], where \
226+
N = width * height * num_base_anchors, width and height \
227+
are the sizes of the corresponding feature level, \
228+
num_base_anchors is the number of anchors for that level.
229+
"""
230+
assert self.num_levels == len(featmap_sizes)
231+
multi_level_anchors = []
232+
for i in range(self.num_levels):
233+
anchors = self.single_level_grid_priors(
234+
featmap_sizes[i], level_idx=i, device=device)
235+
multi_level_anchors.append(anchors)
236+
return multi_level_anchors
237+
238+
def single_level_grid_priors(self, featmap_size, level_idx, device='cuda'):
239+
"""Generate grid anchors of a single level.
240+
241+
Note:
242+
This function is usually called by method ``self.grid_priors``.
243+
244+
Args:
245+
featmap_size (tuple[int]): Size of the feature maps.
246+
level_idx (int): The index of corresponding feature map level.
247+
device (str, optional): The device the tensor will be put on.
248+
Defaults to 'cuda'.
249+
250+
Returns:
251+
torch.Tensor: Anchors in the overall feature maps.
252+
"""
253+
254+
base_anchors = self.base_anchors[level_idx].to(device)
255+
feat_h, feat_w = featmap_size
256+
stride_w, stride_h = self.strides[level_idx]
257+
shift_x = torch.arange(0, feat_w, device=device) * stride_w
258+
shift_y = torch.arange(0, feat_h, device=device) * stride_h
259+
260+
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
261+
shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
262+
shifts = shifts.type_as(base_anchors)
263+
# first feat_w elements correspond to the first row of shifts
264+
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
265+
# shifted anchors (K, A, 4), reshape to (K*A, 4)
266+
267+
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
268+
all_anchors = all_anchors.view(-1, 4)
269+
# first A rows correspond to A anchors of (0, 0) in feature map,
270+
# then (0, 1), (0, 2), ...
271+
return all_anchors
272+
273+
def sparse_priors(self,
274+
prior_idxs,
275+
featmap_size,
276+
level_idx,
277+
dtype=torch.float32,
278+
device='cuda'):
279+
"""Generate sparse anchors according to the ``prior_idxs``.
280+
281+
Args:
282+
prior_idxs (Tensor): The index of corresponding anchors
283+
in the feature map.
284+
featmap_size (tuple[int]): feature map size arrange as (h, w).
285+
level_idx (int): The level index of corresponding feature
286+
map.
287+
dtype (obj:`torch.dtype`): Date type of points.Defaults to
288+
``torch.float32``.
289+
device (obj:`torch.device`): The device where the points is
290+
located.
291+
Returns:
292+
Tensor: Anchor with shape (N, 4), N should be equal to
293+
the length of ``prior_idxs``.
294+
"""
295+
296+
height, width = featmap_size
297+
num_base_anchors = self.num_base_anchors[level_idx]
298+
base_anchor_id = prior_idxs % num_base_anchors
299+
x = (prior_idxs //
300+
num_base_anchors) % width * self.strides[level_idx][0]
301+
y = (prior_idxs // width //
302+
num_base_anchors) % height * self.strides[level_idx][1]
303+
priors = torch.stack([x, y, x, y], 1).to(dtype).to(device) + \
304+
self.base_anchors[level_idx][base_anchor_id, :].to(device)
305+
306+
return priors
307+
207308
def grid_anchors(self, featmap_sizes, device='cuda'):
208309
"""Generate grid anchors in multiple feature levels.
209310
@@ -219,6 +320,9 @@ def grid_anchors(self, featmap_sizes, device='cuda'):
219320
are the sizes of the corresponding feature level, \
220321
num_base_anchors is the number of anchors for that level.
221322
"""
323+
warnings.warn('``grid_anchors`` would be deprecated soon. '
324+
'Please use ``grid_priors`` ')
325+
222326
assert self.num_levels == len(featmap_sizes)
223327
multi_level_anchors = []
224328
for i in range(self.num_levels):
@@ -251,7 +355,13 @@ def single_level_grid_anchors(self,
251355
Returns:
252356
torch.Tensor: Anchors in the overall feature maps.
253357
"""
254-
# keep as Tensor, so that we can covert to ONNX correctly
358+
359+
warnings.warn(
360+
'``single_level_grid_anchors`` would be deprecated soon. '
361+
'Please use ``single_level_grid_priors`` ')
362+
363+
# keep featmap_size as Tensor instead of int, so that we
364+
# can covert to ONNX correctly
255365
feat_h, feat_w = featmap_size
256366
shift_x = torch.arange(0, feat_w, device=device) * stride[0]
257367
shift_y = torch.arange(0, feat_h, device=device) * stride[1]
@@ -304,7 +414,8 @@ def single_level_valid_flags(self,
304414
"""Generate the valid flags of anchor in a single feature map.
305415
306416
Args:
307-
featmap_size (tuple[int]): The size of feature maps.
417+
featmap_size (tuple[int]): The size of feature maps, arrange
418+
as (h, w).
308419
valid_size (tuple[int]): The valid size of the feature maps.
309420
num_base_anchors (int): The number of base anchors.
310421
device (str, optional): Device where the flags will be put on.
@@ -346,7 +457,7 @@ def __repr__(self):
346457
return repr_str
347458

348459

349-
@ANCHOR_GENERATORS.register_module()
460+
@PRIOR_GENERATORS.register_module()
350461
class SSDAnchorGenerator(AnchorGenerator):
351462
"""Anchor generator for SSD.
352463
@@ -470,7 +581,7 @@ def __repr__(self):
470581
return repr_str
471582

472583

473-
@ANCHOR_GENERATORS.register_module()
584+
@PRIOR_GENERATORS.register_module()
474585
class LegacyAnchorGenerator(AnchorGenerator):
475586
"""Legacy anchor generator used in MMDetection V1.x.
476587
@@ -569,7 +680,7 @@ def gen_single_level_base_anchors(self,
569680
return base_anchors
570681

571682

572-
@ANCHOR_GENERATORS.register_module()
683+
@PRIOR_GENERATORS.register_module()
573684
class LegacySSDAnchorGenerator(SSDAnchorGenerator, LegacyAnchorGenerator):
574685
"""Legacy anchor generator used in MMDetection V1.x.
575686
@@ -591,7 +702,7 @@ def __init__(self,
591702
self.base_anchors = self.gen_base_anchors()
592703

593704

594-
@ANCHOR_GENERATORS.register_module()
705+
@PRIOR_GENERATORS.register_module()
595706
class YOLOAnchorGenerator(AnchorGenerator):
596707
"""Anchor generator for YOLO.
597708

mmdet/core/anchor/builder.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
1+
import warnings
2+
13
from mmcv.utils import Registry, build_from_cfg
24

3-
ANCHOR_GENERATORS = Registry('Anchor generator')
5+
PRIOR_GENERATORS = Registry('Generator for anchors and points')
6+
7+
ANCHOR_GENERATORS = PRIOR_GENERATORS
8+
9+
10+
def build_prior_generator(cfg, default_args=None):
11+
return build_from_cfg(cfg, PRIOR_GENERATORS, default_args)
412

513

614
def build_anchor_generator(cfg, default_args=None):
7-
return build_from_cfg(cfg, ANCHOR_GENERATORS, default_args)
15+
warnings.warn(
16+
'``build_anchor_generator`` would be deprecated soon, please use '
17+
'``build_prior_generator`` ')
18+
return build_prior_generator(cfg, default_args=default_args)

0 commit comments

Comments
 (0)