Skip to content

Commit 102f379

Browse files
authored
Some loss (Seesaw loss..) may have custom activation to get the predicted labels. (open-mmlab#5428)
* Some loss (Seesaw loss..) may have custom activation to get the predicted labels. * change custom_cls_channels to custom_activation.
1 parent e91da70 commit 102f379

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

mmdet/models/roi_heads/cascade_roi_head.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,13 @@ def forward_train(self,
266266
# bbox_targets is a tuple
267267
roi_labels = bbox_results['bbox_targets'][0]
268268
with torch.no_grad():
269+
cls_score = bbox_results['cls_score']
270+
if self.bbox_head[i].custom_activation:
271+
cls_score = self.bbox_head[i].loss_cls.get_activation(
272+
cls_score)
269273
roi_labels = torch.where(
270274
roi_labels == self.bbox_head[i].num_classes,
271-
bbox_results['cls_score'][:, :-1].argmax(1),
272-
roi_labels)
275+
cls_score[:, :-1].argmax(1), roi_labels)
273276
proposal_list = self.bbox_head[i].refine_bboxes(
274277
bbox_results['rois'], roi_labels,
275278
bbox_results['bbox_pred'], pos_is_gts, img_metas)
@@ -309,6 +312,11 @@ def simple_test(self, x, proposal_list, img_metas, rescale=False):
309312
ms_scores.append(cls_score)
310313

311314
if i < self.num_stages - 1:
315+
if self.bbox_head[i].custom_activation:
316+
cls_score = [
317+
self.bbox_head[i].loss_cls.get_activation(s)
318+
for s in cls_score
319+
]
312320
bbox_label = [s[:, :-1].argmax(dim=1) for s in cls_score]
313321
rois = torch.cat([
314322
self.bbox_head[i].regress_by_class(rois[j], bbox_label[j],
@@ -429,8 +437,11 @@ def aug_test(self, features, proposal_list, img_metas, rescale=False):
429437
ms_scores.append(bbox_results['cls_score'])
430438

431439
if i < self.num_stages - 1:
432-
bbox_label = bbox_results['cls_score'][:, :-1].argmax(
433-
dim=1)
440+
cls_score = bbox_results['cls_score']
441+
if self.bbox_head[i].custom_activation:
442+
cls_score = self.bbox_head[i].loss_cls.get_activation(
443+
cls_score)
444+
bbox_label = cls_score[:, :-1].argmax(dim=1)
434445
rois = self.bbox_head[i].regress_by_class(
435446
rois, bbox_label, bbox_results['bbox_pred'],
436447
img_meta[0])

0 commit comments

Comments
 (0)