@@ -266,10 +266,13 @@ def forward_train(self,
266
266
# bbox_targets is a tuple
267
267
roi_labels = bbox_results ['bbox_targets' ][0 ]
268
268
with torch .no_grad ():
269
+ cls_score = bbox_results ['cls_score' ]
270
+ if self .bbox_head [i ].custom_activation :
271
+ cls_score = self .bbox_head [i ].loss_cls .get_activation (
272
+ cls_score )
269
273
roi_labels = torch .where (
270
274
roi_labels == self .bbox_head [i ].num_classes ,
271
- bbox_results ['cls_score' ][:, :- 1 ].argmax (1 ),
272
- roi_labels )
275
+ cls_score [:, :- 1 ].argmax (1 ), roi_labels )
273
276
proposal_list = self .bbox_head [i ].refine_bboxes (
274
277
bbox_results ['rois' ], roi_labels ,
275
278
bbox_results ['bbox_pred' ], pos_is_gts , img_metas )
@@ -309,6 +312,11 @@ def simple_test(self, x, proposal_list, img_metas, rescale=False):
309
312
ms_scores .append (cls_score )
310
313
311
314
if i < self .num_stages - 1 :
315
+ if self .bbox_head [i ].custom_activation :
316
+ cls_score = [
317
+ self .bbox_head [i ].loss_cls .get_activation (s )
318
+ for s in cls_score
319
+ ]
312
320
bbox_label = [s [:, :- 1 ].argmax (dim = 1 ) for s in cls_score ]
313
321
rois = torch .cat ([
314
322
self .bbox_head [i ].regress_by_class (rois [j ], bbox_label [j ],
@@ -429,8 +437,11 @@ def aug_test(self, features, proposal_list, img_metas, rescale=False):
429
437
ms_scores .append (bbox_results ['cls_score' ])
430
438
431
439
if i < self .num_stages - 1 :
432
- bbox_label = bbox_results ['cls_score' ][:, :- 1 ].argmax (
433
- dim = 1 )
440
+ cls_score = bbox_results ['cls_score' ]
441
+ if self .bbox_head [i ].custom_activation :
442
+ cls_score = self .bbox_head [i ].loss_cls .get_activation (
443
+ cls_score )
444
+ bbox_label = cls_score [:, :- 1 ].argmax (dim = 1 )
434
445
rois = self .bbox_head [i ].regress_by_class (
435
446
rois , bbox_label , bbox_results ['bbox_pred' ],
436
447
img_meta [0 ])
0 commit comments