Skip to content

Commit 3174d69

Browse files
authored
Add train_cfg/test_cfg in AnchorHead (open-mmlab#2422)
* add cfg for anchor head * add copy in two stage rpn
1 parent af6e1f8 commit 3174d69

14 files changed

+66
-51
lines changed

mmdet/models/anchor_heads/anchor_head.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def __init__(self,
4949
use_sigmoid=True,
5050
loss_weight=1.0),
5151
loss_bbox=dict(
52-
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)):
52+
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
53+
train_cfg=None,
54+
test_cfg=None):
5355
super(AnchorHead, self).__init__()
5456
self.in_channels = in_channels
5557
self.num_classes = num_classes
@@ -80,6 +82,8 @@ def __init__(self,
8082

8183
self.loss_cls = build_loss(loss_cls)
8284
self.loss_bbox = build_loss(loss_bbox)
85+
self.train_cfg = train_cfg
86+
self.test_cfg = test_cfg
8387
self.fp16_enabled = False
8488

8589
self.anchor_generators = []
@@ -149,7 +153,7 @@ def get_anchors(self, featmap_sizes, img_metas, device='cuda'):
149153
return anchor_list, valid_flag_list
150154

151155
def loss_single(self, cls_score, bbox_pred, labels, label_weights,
152-
bbox_targets, bbox_weights, num_total_samples, cfg):
156+
bbox_targets, bbox_weights, num_total_samples):
153157
# classification loss
154158
labels = labels.reshape(-1)
155159
label_weights = label_weights.reshape(-1)
@@ -175,7 +179,6 @@ def loss(self,
175179
gt_bboxes,
176180
gt_labels,
177181
img_metas,
178-
cfg,
179182
gt_bboxes_ignore=None):
180183
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
181184
assert len(featmap_sizes) == len(self.anchor_generators)
@@ -192,7 +195,7 @@ def loss(self,
192195
img_metas,
193196
self.target_means,
194197
self.target_stds,
195-
cfg,
198+
self.train_cfg,
196199
gt_bboxes_ignore_list=gt_bboxes_ignore,
197200
gt_labels_list=gt_labels,
198201
label_channels=label_channels,
@@ -212,8 +215,7 @@ def loss(self,
212215
label_weights_list,
213216
bbox_targets_list,
214217
bbox_weights_list,
215-
num_total_samples=num_total_samples,
216-
cfg=cfg)
218+
num_total_samples=num_total_samples)
217219
return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
218220

219221
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))

mmdet/models/anchor_heads/atss_head.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def forward_single(self, x, scale):
123123
return cls_score, bbox_pred, centerness
124124

125125
def loss_single(self, anchors, cls_score, bbox_pred, centerness, labels,
126-
label_weights, bbox_targets, num_total_samples, cfg):
126+
label_weights, bbox_targets, num_total_samples):
127127

128128
anchors = anchors.reshape(-1, 4)
129129
cls_score = cls_score.permute(0, 2, 3,
@@ -186,7 +186,6 @@ def loss(self,
186186
gt_bboxes,
187187
gt_labels,
188188
img_metas,
189-
cfg,
190189
gt_bboxes_ignore=None):
191190

192191
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
@@ -202,7 +201,6 @@ def loss(self,
202201
valid_flag_list,
203202
gt_bboxes,
204203
img_metas,
205-
cfg,
206204
gt_bboxes_ignore_list=gt_bboxes_ignore,
207205
gt_labels_list=gt_labels,
208206
label_channels=label_channels)
@@ -226,8 +224,7 @@ def loss(self,
226224
labels_list,
227225
label_weights_list,
228226
bbox_targets_list,
229-
num_total_samples=num_total_samples,
230-
cfg=cfg)
227+
num_total_samples=num_total_samples)
231228

232229
bbox_avg_factor = sum(bbox_avg_factor)
233230
bbox_avg_factor = reduce_mean(bbox_avg_factor).item()

mmdet/models/anchor_heads/fcos_head.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ def __init__(self,
5353
use_sigmoid=True,
5454
loss_weight=1.0),
5555
conv_cfg=None,
56-
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)):
56+
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
57+
train_cfg=None,
58+
test_cfg=None):
5759
super(FCOSHead, self).__init__()
5860
self.num_classes = num_classes
5961
self.cls_out_channels = num_classes
@@ -65,6 +67,8 @@ def __init__(self,
6567
self.loss_cls = build_loss(loss_cls)
6668
self.loss_bbox = build_loss(loss_bbox)
6769
self.loss_centerness = build_loss(loss_centerness)
70+
self.train_cfg = train_cfg
71+
self.test_cfg = test_cfg
6872
self.conv_cfg = conv_cfg
6973
self.norm_cfg = norm_cfg
7074
self.fp16_enabled = False
@@ -147,7 +151,6 @@ def loss(self,
147151
gt_bboxes,
148152
gt_labels,
149153
img_metas,
150-
cfg,
151154
gt_bboxes_ignore=None):
152155
assert len(cls_scores) == len(bbox_preds) == len(centernesses)
153156
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]

mmdet/models/anchor_heads/fovea_head.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def __init__(self,
6262
loss_cls=None,
6363
loss_bbox=None,
6464
conv_cfg=None,
65-
norm_cfg=None):
65+
norm_cfg=None,
66+
train_cfg=None,
67+
test_cfg=None):
6668
super(FoveaHead, self).__init__()
6769
self.num_classes = num_classes
6870
self.cls_out_channels = num_classes
@@ -84,6 +86,8 @@ def __init__(self,
8486
self.loss_bbox = build_loss(loss_bbox)
8587
self.conv_cfg = conv_cfg
8688
self.norm_cfg = norm_cfg
89+
self.train_cfg = train_cfg
90+
self.test_cfg = test_cfg
8791
self._init_layers()
8892

8993
def _init_layers(self):
@@ -195,7 +199,6 @@ def loss(self,
195199
gt_bbox_list,
196200
gt_label_list,
197201
img_metas,
198-
cfg,
199202
gt_bboxes_ignore=None):
200203
assert len(cls_scores) == len(bbox_preds)
201204

mmdet/models/anchor_heads/free_anchor_retina_head.py

-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def loss(self,
3838
gt_bboxes,
3939
gt_labels,
4040
img_metas,
41-
cfg,
4241
gt_bboxes_ignore=None):
4342
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
4443
assert len(featmap_sizes) == len(self.anchor_generators)

mmdet/models/anchor_heads/ga_rpn_head.py

-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def loss(self,
4040
loc_preds,
4141
gt_bboxes,
4242
img_metas,
43-
cfg,
4443
gt_bboxes_ignore=None):
4544
losses = super(GARPNHead, self).loss(
4645
cls_scores,
@@ -50,7 +49,6 @@ def loss(self,
5049
gt_bboxes,
5150
None,
5251
img_metas,
53-
cfg,
5452
gt_bboxes_ignore=gt_bboxes_ignore)
5553
return dict(
5654
loss_rpn_cls=losses['loss_cls'],

mmdet/models/anchor_heads/guided_anchor_head.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ def __init__(
113113
deformable_groups=4,
114114
loc_filter_thr=0.01,
115115
background_label=None,
116+
train_cfg=None,
117+
test_cfg=None,
116118
loss_loc=dict(
117119
type='FocalLoss',
118120
use_sigmoid=True,
@@ -176,6 +178,9 @@ def __init__(
176178
self.loss_cls = build_loss(loss_cls)
177179
self.loss_bbox = build_loss(loss_bbox)
178180

181+
self.train_cfg = train_cfg
182+
self.test_cfg = test_cfg
183+
179184
self.fp16_enabled = False
180185

181186
self._init_layers()
@@ -418,7 +423,6 @@ def loss(self,
418423
gt_bboxes,
419424
gt_labels,
420425
img_metas,
421-
cfg,
422426
gt_bboxes_ignore=None):
423427
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
424428
assert len(featmap_sizes) == len(self.approx_generators)
@@ -431,26 +435,26 @@ def loss(self,
431435
featmap_sizes,
432436
self.octave_base_scale,
433437
self.anchor_strides,
434-
center_ratio=cfg.center_ratio,
435-
ignore_ratio=cfg.ignore_ratio)
438+
center_ratio=self.train_cfg.center_ratio,
439+
ignore_ratio=self.train_cfg.ignore_ratio)
436440

437441
# get sampled approxes
438442
approxs_list, inside_flag_list = self.get_sampled_approxs(
439-
featmap_sizes, img_metas, cfg, device=device)
443+
featmap_sizes, img_metas, self.train_cfg, device=device)
440444
# get squares and guided anchors
441445
squares_list, guided_anchors_list, _ = self.get_anchors(
442446
featmap_sizes, shape_preds, loc_preds, img_metas, device=device)
443447

444448
# get shape targets
445-
sampling = False if not hasattr(cfg, 'ga_sampler') else True
449+
sampling = False if not hasattr(self.train_cfg, 'ga_sampler') else True
446450
shape_targets = ga_shape_target(
447451
approxs_list,
448452
inside_flag_list,
449453
squares_list,
450454
gt_bboxes,
451455
img_metas,
452456
self.approxs_per_octave,
453-
cfg,
457+
self.train_cfg,
454458
sampling=sampling)
455459
if shape_targets is None:
456460
return None
@@ -469,7 +473,7 @@ def loss(self,
469473
img_metas,
470474
self.target_means,
471475
self.target_stds,
472-
cfg,
476+
self.train_cfg,
473477
gt_bboxes_ignore_list=gt_bboxes_ignore,
474478
gt_labels_list=gt_labels,
475479
label_channels=label_channels,
@@ -492,8 +496,7 @@ def loss(self,
492496
label_weights_list,
493497
bbox_targets_list,
494498
bbox_weights_list,
495-
num_total_samples=num_total_samples,
496-
cfg=cfg)
499+
num_total_samples=num_total_samples)
497500

498501
# get anchor location loss
499502
losses_loc = []
@@ -503,7 +506,7 @@ def loss(self,
503506
loc_targets[i],
504507
loc_weights[i],
505508
loc_avg_factor=loc_avg_factor,
506-
cfg=cfg)
509+
cfg=self.train_cfg)
507510
losses_loc.append(loss_loc)
508511

509512
# get anchor shape loss

mmdet/models/anchor_heads/reppoints_head.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def __init__(self,
6464
use_grid_points=False,
6565
center_init=True,
6666
transform_method='moment',
67-
moment_mul=0.01):
67+
moment_mul=0.01,
68+
train_cfg=None,
69+
test_cfg=None):
6870
super(RepPointsHead, self).__init__()
6971
self.in_channels = in_channels
7072
self.num_classes = num_classes
@@ -77,6 +79,8 @@ def __init__(self,
7779
self.point_strides = point_strides
7880
self.conv_cfg = conv_cfg
7981
self.norm_cfg = norm_cfg
82+
self.train_cfg = train_cfg
83+
self.test_cfg = test_cfg
8084

8185
self.background_label = (
8286
num_classes if background_label is None else background_label)
@@ -423,7 +427,6 @@ def loss(self,
423427
gt_bboxes,
424428
gt_labels,
425429
img_metas,
426-
cfg,
427430
gt_bboxes_ignore=None):
428431
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
429432
assert len(featmap_sizes) == len(self.point_generators)
@@ -434,7 +437,7 @@ def loss(self,
434437
img_metas)
435438
pts_coordinate_preds_init = self.offset_to_pts(center_list,
436439
pts_preds_init)
437-
if cfg.init.assigner['type'] == 'PointAssigner':
440+
if self.train_cfg.init.assigner['type'] == 'PointAssigner':
438441
# Assign target for center list
439442
candidate_list = center_list
440443
else:
@@ -447,7 +450,7 @@ def loss(self,
447450
valid_flag_list,
448451
gt_bboxes,
449452
img_metas,
450-
cfg.init,
453+
self.train_cfg.init,
451454
gt_bboxes_ignore_list=gt_bboxes_ignore,
452455
gt_labels_list=gt_labels,
453456
label_channels=label_channels,
@@ -481,7 +484,7 @@ def loss(self,
481484
valid_flag_list,
482485
gt_bboxes,
483486
img_metas,
484-
cfg.refine,
487+
self.train_cfg.refine,
485488
gt_bboxes_ignore_list=gt_bboxes_ignore,
486489
gt_labels_list=gt_labels,
487490
label_channels=label_channels,

mmdet/models/anchor_heads/rpn_head.py

-2
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,13 @@ def loss(self,
4040
bbox_preds,
4141
gt_bboxes,
4242
img_metas,
43-
cfg,
4443
gt_bboxes_ignore=None):
4544
losses = super(RPNHead, self).loss(
4645
cls_scores,
4746
bbox_preds,
4847
gt_bboxes,
4948
None,
5049
img_metas,
51-
cfg,
5250
gt_bboxes_ignore=gt_bboxes_ignore)
5351
return dict(
5452
loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox'])

0 commit comments

Comments
 (0)