@@ -113,6 +113,8 @@ def __init__(
113
113
deformable_groups = 4 ,
114
114
loc_filter_thr = 0.01 ,
115
115
background_label = None ,
116
+ train_cfg = None ,
117
+ test_cfg = None ,
116
118
loss_loc = dict (
117
119
type = 'FocalLoss' ,
118
120
use_sigmoid = True ,
@@ -176,6 +178,9 @@ def __init__(
176
178
self .loss_cls = build_loss (loss_cls )
177
179
self .loss_bbox = build_loss (loss_bbox )
178
180
181
+ self .train_cfg = train_cfg
182
+ self .test_cfg = test_cfg
183
+
179
184
self .fp16_enabled = False
180
185
181
186
self ._init_layers ()
@@ -418,7 +423,6 @@ def loss(self,
418
423
gt_bboxes ,
419
424
gt_labels ,
420
425
img_metas ,
421
- cfg ,
422
426
gt_bboxes_ignore = None ):
423
427
featmap_sizes = [featmap .size ()[- 2 :] for featmap in cls_scores ]
424
428
assert len (featmap_sizes ) == len (self .approx_generators )
@@ -431,26 +435,26 @@ def loss(self,
431
435
featmap_sizes ,
432
436
self .octave_base_scale ,
433
437
self .anchor_strides ,
434
- center_ratio = cfg .center_ratio ,
435
- ignore_ratio = cfg .ignore_ratio )
438
+ center_ratio = self . train_cfg .center_ratio ,
439
+ ignore_ratio = self . train_cfg .ignore_ratio )
436
440
437
441
# get sampled approxes
438
442
approxs_list , inside_flag_list = self .get_sampled_approxs (
439
- featmap_sizes , img_metas , cfg , device = device )
443
+ featmap_sizes , img_metas , self . train_cfg , device = device )
440
444
# get squares and guided anchors
441
445
squares_list , guided_anchors_list , _ = self .get_anchors (
442
446
featmap_sizes , shape_preds , loc_preds , img_metas , device = device )
443
447
444
448
# get shape targets
445
- sampling = False if not hasattr (cfg , 'ga_sampler' ) else True
449
+ sampling = False if not hasattr (self . train_cfg , 'ga_sampler' ) else True
446
450
shape_targets = ga_shape_target (
447
451
approxs_list ,
448
452
inside_flag_list ,
449
453
squares_list ,
450
454
gt_bboxes ,
451
455
img_metas ,
452
456
self .approxs_per_octave ,
453
- cfg ,
457
+ self . train_cfg ,
454
458
sampling = sampling )
455
459
if shape_targets is None :
456
460
return None
@@ -469,7 +473,7 @@ def loss(self,
469
473
img_metas ,
470
474
self .target_means ,
471
475
self .target_stds ,
472
- cfg ,
476
+ self . train_cfg ,
473
477
gt_bboxes_ignore_list = gt_bboxes_ignore ,
474
478
gt_labels_list = gt_labels ,
475
479
label_channels = label_channels ,
@@ -492,8 +496,7 @@ def loss(self,
492
496
label_weights_list ,
493
497
bbox_targets_list ,
494
498
bbox_weights_list ,
495
- num_total_samples = num_total_samples ,
496
- cfg = cfg )
499
+ num_total_samples = num_total_samples )
497
500
498
501
# get anchor location loss
499
502
losses_loc = []
@@ -503,7 +506,7 @@ def loss(self,
503
506
loc_targets [i ],
504
507
loc_weights [i ],
505
508
loc_avg_factor = loc_avg_factor ,
506
- cfg = cfg )
509
+ cfg = self . train_cfg )
507
510
losses_loc .append (loss_loc )
508
511
509
512
# get anchor shape loss
0 commit comments