1
+ import warnings
2
+
1
3
import mmcv
2
4
import numpy as np
3
5
import torch
4
6
from torch .nn .modules .utils import _pair
5
7
6
- from .builder import ANCHOR_GENERATORS
8
+ from .builder import PRIOR_GENERATORS
7
9
8
10
9
- @ANCHOR_GENERATORS .register_module ()
11
+ @PRIOR_GENERATORS .register_module ()
10
12
class AnchorGenerator :
11
13
"""Standard anchor generator for 2D anchor-based detectors.
12
14
@@ -68,7 +70,7 @@ def __init__(self,
68
70
# check center and center_offset
69
71
if center_offset != 0 :
70
72
assert centers is None , 'center cannot be set when center_offset' \
71
- f'!=0, { centers } is given.'
73
+ f'!=0, { centers } is given.'
72
74
if not (0 <= center_offset <= 1 ):
73
75
raise ValueError ('center_offset should be in range [0, 1], '
74
76
f'{ center_offset } is given.' )
@@ -87,7 +89,7 @@ def __init__(self,
87
89
88
90
# calculate scales of anchors
89
91
assert ((octave_base_scale is not None
90
- and scales_per_octave is not None ) ^ (scales is not None )), \
92
+ and scales_per_octave is not None ) ^ (scales is not None )), \
91
93
'scales and octave_base_scale with scales_per_octave cannot' \
92
94
' be set at the same time'
93
95
if scales is not None :
@@ -112,6 +114,12 @@ def __init__(self,
112
114
@property
113
115
def num_base_anchors (self ):
114
116
"""list[int]: total number of base anchors in a feature grid"""
117
+ return self .num_base_priors
118
+
119
+ @property
120
+ def num_base_priors (self ):
121
+ """list[int]: The number of priors (anchors) at a point
122
+ on the feature grid"""
115
123
return [base_anchors .size (0 ) for base_anchors in self .base_anchors ]
116
124
117
125
@property
@@ -204,6 +212,99 @@ def _meshgrid(self, x, y, row_major=True):
204
212
else :
205
213
return yy , xx
206
214
215
+ def grid_priors (self , featmap_sizes , device = 'cuda' ):
216
+ """Generate grid anchors in multiple feature levels.
217
+
218
+ Args:
219
+ featmap_sizes (list[tuple]): List of feature map sizes in
220
+ multiple feature levels.
221
+ device (str): The device where the anchors will be put on.
222
+
223
+ Return:
224
+ list[torch.Tensor]: Anchors in multiple feature levels. \
225
+ The sizes of each tensor should be [N, 4], where \
226
+ N = width * height * num_base_anchors, width and height \
227
+ are the sizes of the corresponding feature level, \
228
+ num_base_anchors is the number of anchors for that level.
229
+ """
230
+ assert self .num_levels == len (featmap_sizes )
231
+ multi_level_anchors = []
232
+ for i in range (self .num_levels ):
233
+ anchors = self .single_level_grid_priors (
234
+ featmap_sizes [i ], level_idx = i , device = device )
235
+ multi_level_anchors .append (anchors )
236
+ return multi_level_anchors
237
+
238
+ def single_level_grid_priors (self , featmap_size , level_idx , device = 'cuda' ):
239
+ """Generate grid anchors of a single level.
240
+
241
+ Note:
242
+ This function is usually called by method ``self.grid_priors``.
243
+
244
+ Args:
245
+ featmap_size (tuple[int]): Size of the feature maps.
246
+ level_idx (int): The index of corresponding feature map level.
247
+ device (str, optional): The device the tensor will be put on.
248
+ Defaults to 'cuda'.
249
+
250
+ Returns:
251
+ torch.Tensor: Anchors in the overall feature maps.
252
+ """
253
+
254
+ base_anchors = self .base_anchors [level_idx ].to (device )
255
+ feat_h , feat_w = featmap_size
256
+ stride_w , stride_h = self .strides [level_idx ]
257
+ shift_x = torch .arange (0 , feat_w , device = device ) * stride_w
258
+ shift_y = torch .arange (0 , feat_h , device = device ) * stride_h
259
+
260
+ shift_xx , shift_yy = self ._meshgrid (shift_x , shift_y )
261
+ shifts = torch .stack ([shift_xx , shift_yy , shift_xx , shift_yy ], dim = - 1 )
262
+ shifts = shifts .type_as (base_anchors )
263
+ # first feat_w elements correspond to the first row of shifts
264
+ # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
265
+ # shifted anchors (K, A, 4), reshape to (K*A, 4)
266
+
267
+ all_anchors = base_anchors [None , :, :] + shifts [:, None , :]
268
+ all_anchors = all_anchors .view (- 1 , 4 )
269
+ # first A rows correspond to A anchors of (0, 0) in feature map,
270
+ # then (0, 1), (0, 2), ...
271
+ return all_anchors
272
+
273
+ def sparse_priors (self ,
274
+ prior_idxs ,
275
+ featmap_size ,
276
+ level_idx ,
277
+ dtype = torch .float32 ,
278
+ device = 'cuda' ):
279
+ """Generate sparse anchors according to the ``prior_idxs``.
280
+
281
+ Args:
282
+ prior_idxs (Tensor): The index of corresponding anchors
283
+ in the feature map.
284
+ featmap_size (tuple[int]): feature map size arrange as (h, w).
285
+ level_idx (int): The level index of corresponding feature
286
+ map.
287
+ dtype (obj:`torch.dtype`): Date type of points.Defaults to
288
+ ``torch.float32``.
289
+ device (obj:`torch.device`): The device where the points is
290
+ located.
291
+ Returns:
292
+ Tensor: Anchor with shape (N, 4), N should be equal to
293
+ the length of ``prior_idxs``.
294
+ """
295
+
296
+ height , width = featmap_size
297
+ num_base_anchors = self .num_base_anchors [level_idx ]
298
+ base_anchor_id = prior_idxs % num_base_anchors
299
+ x = (prior_idxs //
300
+ num_base_anchors ) % width * self .strides [level_idx ][0 ]
301
+ y = (prior_idxs // width //
302
+ num_base_anchors ) % height * self .strides [level_idx ][1 ]
303
+ priors = torch .stack ([x , y , x , y ], 1 ).to (dtype ).to (device ) + \
304
+ self .base_anchors [level_idx ][base_anchor_id , :].to (device )
305
+
306
+ return priors
307
+
207
308
def grid_anchors (self , featmap_sizes , device = 'cuda' ):
208
309
"""Generate grid anchors in multiple feature levels.
209
310
@@ -219,6 +320,9 @@ def grid_anchors(self, featmap_sizes, device='cuda'):
219
320
are the sizes of the corresponding feature level, \
220
321
num_base_anchors is the number of anchors for that level.
221
322
"""
323
+ warnings .warn ('``grid_anchors`` would be deprecated soon. '
324
+ 'Please use ``grid_priors`` ' )
325
+
222
326
assert self .num_levels == len (featmap_sizes )
223
327
multi_level_anchors = []
224
328
for i in range (self .num_levels ):
@@ -251,7 +355,13 @@ def single_level_grid_anchors(self,
251
355
Returns:
252
356
torch.Tensor: Anchors in the overall feature maps.
253
357
"""
254
- # keep as Tensor, so that we can covert to ONNX correctly
358
+
359
+ warnings .warn (
360
+ '``single_level_grid_anchors`` would be deprecated soon. '
361
+ 'Please use ``single_level_grid_priors`` ' )
362
+
363
+ # keep featmap_size as Tensor instead of int, so that we
364
+ # can covert to ONNX correctly
255
365
feat_h , feat_w = featmap_size
256
366
shift_x = torch .arange (0 , feat_w , device = device ) * stride [0 ]
257
367
shift_y = torch .arange (0 , feat_h , device = device ) * stride [1 ]
@@ -304,7 +414,8 @@ def single_level_valid_flags(self,
304
414
"""Generate the valid flags of anchor in a single feature map.
305
415
306
416
Args:
307
- featmap_size (tuple[int]): The size of feature maps.
417
+ featmap_size (tuple[int]): The size of feature maps, arrange
418
+ as (h, w).
308
419
valid_size (tuple[int]): The valid size of the feature maps.
309
420
num_base_anchors (int): The number of base anchors.
310
421
device (str, optional): Device where the flags will be put on.
@@ -346,7 +457,7 @@ def __repr__(self):
346
457
return repr_str
347
458
348
459
349
- @ANCHOR_GENERATORS .register_module ()
460
+ @PRIOR_GENERATORS .register_module ()
350
461
class SSDAnchorGenerator (AnchorGenerator ):
351
462
"""Anchor generator for SSD.
352
463
@@ -470,7 +581,7 @@ def __repr__(self):
470
581
return repr_str
471
582
472
583
473
- @ANCHOR_GENERATORS .register_module ()
584
+ @PRIOR_GENERATORS .register_module ()
474
585
class LegacyAnchorGenerator (AnchorGenerator ):
475
586
"""Legacy anchor generator used in MMDetection V1.x.
476
587
@@ -569,7 +680,7 @@ def gen_single_level_base_anchors(self,
569
680
return base_anchors
570
681
571
682
572
- @ANCHOR_GENERATORS .register_module ()
683
+ @PRIOR_GENERATORS .register_module ()
573
684
class LegacySSDAnchorGenerator (SSDAnchorGenerator , LegacyAnchorGenerator ):
574
685
"""Legacy anchor generator used in MMDetection V1.x.
575
686
@@ -591,7 +702,7 @@ def __init__(self,
591
702
self .base_anchors = self .gen_base_anchors ()
592
703
593
704
594
- @ANCHOR_GENERATORS .register_module ()
705
+ @PRIOR_GENERATORS .register_module ()
595
706
class YOLOAnchorGenerator (AnchorGenerator ):
596
707
"""Anchor generator for YOLO.
597
708
0 commit comments