Skip to content

Commit 913e099

Browse files
authored
Support empty tensor input for some models. (open-mmlab#2280)
* support-empty-tensor * minor update * add unit test * fix unit test * add assert value equal * simplify some codes * simplify unit tests * distinguish x with x_empty and x_normal * ref only forward once * fix python3.5 ci error
1 parent 55d3b8c commit 913e099

File tree

10 files changed

+326
-22
lines changed

10 files changed

+326
-22
lines changed

mmdet/core/mask/mask_target.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ def mask_target(pos_proposals_list, pos_assigned_gt_inds_list, gt_masks_list,
88
cfg_list = [cfg for _ in range(len(pos_proposals_list))]
99
mask_targets = map(mask_target_single, pos_proposals_list,
1010
pos_assigned_gt_inds_list, gt_masks_list, cfg_list)
11-
mask_targets = torch.cat(list(mask_targets))
11+
mask_targets = list(mask_targets)
12+
if len(mask_targets) > 0:
13+
mask_targets = torch.cat(mask_targets)
1214
return mask_targets
1315

1416

mmdet/models/bbox_heads/bbox_head.py

+2
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ def loss(self,
138138
bbox_weights[pos_inds.type(torch.bool)],
139139
avg_factor=bbox_targets.size(0),
140140
reduction_override=reduction_override)
141+
else:
142+
losses['loss_bbox'] = bbox_pred.sum() * 0
141143
return losses
142144

143145
@force_fp32(apply_to=('cls_score', 'bbox_pred'))

mmdet/models/mask_heads/fcn_mask_head.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch.nn.modules.utils import _pair
66

77
from mmdet.core import auto_fp16, force_fp32, mask_target
8-
from mmdet.ops import ConvModule, build_upsample_layer
8+
from mmdet.ops import Conv2d, ConvModule, build_upsample_layer
99
from mmdet.ops.carafe import CARAFEPack
1010
from mmdet.ops.grid_sampler import grid_sample
1111
from ..builder import build_loss
@@ -98,7 +98,7 @@ def __init__(self,
9898
logits_in_channel = (
9999
self.conv_out_channels
100100
if self.upsample_method == 'deconv' else upsample_in_channels)
101-
self.conv_logits = nn.Conv2d(logits_in_channel, out_channels, 1)
101+
self.conv_logits = Conv2d(logits_in_channel, out_channels, 1)
102102
self.relu = nn.ReLU(inplace=True)
103103
self.debug_imgs = None
104104

@@ -136,11 +136,14 @@ def get_target(self, sampling_results, gt_masks, rcnn_train_cfg):
136136
@force_fp32(apply_to=('mask_pred', ))
137137
def loss(self, mask_pred, mask_targets, labels):
138138
loss = dict()
139-
if self.class_agnostic:
140-
loss_mask = self.loss_mask(mask_pred, mask_targets,
141-
torch.zeros_like(labels))
139+
if mask_pred.size(0) == 0:
140+
loss_mask = mask_pred.sum() * 0
142141
else:
143-
loss_mask = self.loss_mask(mask_pred, mask_targets, labels)
142+
if self.class_agnostic:
143+
loss_mask = self.loss_mask(mask_pred, mask_targets,
144+
torch.zeros_like(labels))
145+
else:
146+
loss_mask = self.loss_mask(mask_pred, mask_targets, labels)
144147
loss['loss_mask'] = loss_mask
145148
return loss
146149

mmdet/models/mask_heads/maskiou_head.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torch.nn.modules.utils import _pair
66

77
from mmdet.core import force_fp32
8+
from mmdet.ops import Conv2d, Linear, MaxPool2d
89
from ..builder import build_loss
910
from ..registry import HEADS
1011

@@ -41,7 +42,7 @@ def __init__(self,
4142
in_channels = self.conv_out_channels
4243
stride = 2 if i == num_convs - 1 else 1
4344
self.convs.append(
44-
nn.Conv2d(
45+
Conv2d(
4546
in_channels,
4647
self.conv_out_channels,
4748
3,
@@ -55,11 +56,11 @@ def __init__(self,
5556
in_channels = (
5657
self.conv_out_channels *
5758
pooled_area if i == 0 else self.fc_out_channels)
58-
self.fcs.append(nn.Linear(in_channels, self.fc_out_channels))
59+
self.fcs.append(Linear(in_channels, self.fc_out_channels))
5960

60-
self.fc_mask_iou = nn.Linear(self.fc_out_channels, self.num_classes)
61+
self.fc_mask_iou = Linear(self.fc_out_channels, self.num_classes)
6162
self.relu = nn.ReLU()
62-
self.max_pool = nn.MaxPool2d(2, 2)
63+
self.max_pool = MaxPool2d(2, 2)
6364
self.loss_iou = build_loss(loss_iou)
6465

6566
def init_weights(self):
@@ -82,7 +83,7 @@ def forward(self, mask_feat, mask_pred):
8283

8384
for conv in self.convs:
8485
x = self.relu(conv(x))
85-
x = x.view(x.size(0), -1)
86+
x = x.flatten(1)
8687
for fc in self.fcs:
8788
x = self.relu(fc(x))
8889
mask_iou = self.fc_mask_iou(x)
@@ -95,7 +96,7 @@ def loss(self, mask_iou_pred, mask_iou_targets):
9596
loss_mask_iou = self.loss_iou(mask_iou_pred[pos_inds],
9697
mask_iou_targets[pos_inds])
9798
else:
98-
loss_mask_iou = mask_iou_pred * 0
99+
loss_mask_iou = mask_iou_pred.sum() * 0
99100
return dict(loss_mask_iou=loss_mask_iou)
100101

101102
@force_fp32(apply_to=('mask_pred', ))

mmdet/models/roi_extractors/single_level.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,17 @@ def roi_rescale(self, rois, scale_factor):
8888

8989
@force_fp32(apply_to=('feats', ), out_fp16=True)
9090
def forward(self, feats, rois, roi_scale_factor=None):
91-
if len(feats) == 1:
92-
return self.roi_layers[0](feats[0], rois)
93-
9491
out_size = self.roi_layers[0].out_size
9592
num_levels = len(feats)
96-
target_lvls = self.map_roi_levels(rois, num_levels)
9793
roi_feats = feats[0].new_zeros(
9894
rois.size(0), self.out_channels, *out_size)
95+
96+
if num_levels == 1:
97+
if len(rois) == 0:
98+
return roi_feats
99+
return self.roi_layers[0](feats[0], rois)
100+
101+
target_lvls = self.map_roi_levels(rois, num_levels)
99102
if roi_scale_factor is not None:
100103
rois = self.roi_rescale(rois, roi_scale_factor)
101104
for i in range(num_levels):

mmdet/ops/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .sigmoid_focal_loss import SigmoidFocalLoss, sigmoid_focal_loss
1919
from .upsample import build_upsample_layer
2020
from .utils import get_compiler_version, get_compiling_cuda_version
21+
from .wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d
2122

2223
__all__ = [
2324
'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool',
@@ -28,5 +29,6 @@
2829
'MaskedConv2d', 'ContextBlock', 'GeneralizedAttention', 'NonLocal2D',
2930
'get_compiler_version', 'get_compiling_cuda_version', 'build_conv_layer',
3031
'ConvModule', 'ConvWS2d', 'conv_ws_2d', 'build_norm_layer', 'Scale',
31-
'build_upsample_layer', 'build_plugin_layer', 'batched_nms'
32+
'build_upsample_layer', 'build_plugin_layer', 'batched_nms', 'Conv2d',
33+
'ConvTranspose2d', 'MaxPool2d', 'Linear'
3234
]

mmdet/ops/conv.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from torch import nn as nn
2-
31
from .conv_ws import ConvWS2d
42
from .dcn import DeformConvPack, ModulatedDeformConvPack
3+
from .wrappers import Conv2d
54

65
conv_cfg = {
7-
'Conv': nn.Conv2d,
6+
'Conv': Conv2d,
87
'ConvWS': ConvWS2d,
98
'DCN': DeformConvPack,
109
'DCNv2': ModulatedDeformConvPack,

mmdet/ops/upsample.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from mmcv.cnn import xavier_init
44

55
from .carafe import CARAFEPack
6+
from .wrappers import ConvTranspose2d
67

78

89
class PixelShufflePack(nn.Module):
@@ -45,7 +46,7 @@ def forward(self, x):
4546
# layer_abbreviation: module
4647
'nearest': nn.Upsample,
4748
'bilinear': nn.Upsample,
48-
'deconv': nn.ConvTranspose2d,
49+
'deconv': ConvTranspose2d,
4950
'pixel_shuffle': PixelShufflePack,
5051
'carafe': CARAFEPack
5152
}

mmdet/ops/wrappers.py

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""
2+
Modified from https://github.com/facebookresearch/detectron2/blob/master
3+
/detectron2/layers/wrappers.py
4+
Wrap some nn modules to support empty tensor input.
5+
Currently, these wrappers are mainly used in mask heads like fcn_mask_head
6+
and maskiou_heads since mask heads are trained on only positive RoIs.
7+
"""
8+
import math
9+
10+
import torch
11+
import torch.nn as nn
12+
from torch.nn.modules.utils import _pair
13+
14+
15+
class NewEmptyTensorOp(torch.autograd.Function):
16+
17+
@staticmethod
18+
def forward(ctx, x, new_shape):
19+
ctx.shape = x.shape
20+
return x.new_empty(new_shape)
21+
22+
@staticmethod
23+
def backward(ctx, grad):
24+
shape = ctx.shape
25+
return NewEmptyTensorOp.apply(grad, shape), None
26+
27+
28+
class Conv2d(nn.Conv2d):
29+
30+
def forward(self, x):
31+
if x.numel() == 0 and torch.__version__ <= '1.4':
32+
out_shape = [x.shape[0], self.out_channels]
33+
for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
34+
self.padding, self.stride, self.dilation):
35+
o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
36+
out_shape.append(o)
37+
empty = NewEmptyTensorOp.apply(x, out_shape)
38+
if self.training:
39+
# produce dummy gradient to avoid DDP warning.
40+
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
41+
return empty + dummy
42+
else:
43+
return empty
44+
45+
return super().forward(x)
46+
47+
48+
class ConvTranspose2d(nn.ConvTranspose2d):
49+
50+
def forward(self, x):
51+
if x.numel() == 0 and torch.__version__ <= '1.4.0':
52+
out_shape = [x.shape[0], self.out_channels]
53+
for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
54+
self.padding, self.stride,
55+
self.dilation, self.output_padding):
56+
out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
57+
empty = NewEmptyTensorOp.apply(x, out_shape)
58+
if self.training:
59+
# produce dummy gradient to avoid DDP warning.
60+
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
61+
return empty + dummy
62+
else:
63+
return empty
64+
65+
return super(ConvTranspose2d, self).forward(x)
66+
67+
68+
class MaxPool2d(nn.MaxPool2d):
69+
70+
def forward(self, x):
71+
if x.numel() == 0 and torch.__version__ <= '1.4':
72+
out_shape = list(x.shape[:2])
73+
for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
74+
_pair(self.padding), _pair(self.stride),
75+
_pair(self.dilation)):
76+
o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
77+
o = math.ceil(o) if self.ceil_mode else math.floor(o)
78+
out_shape.append(o)
79+
empty = NewEmptyTensorOp.apply(x, out_shape)
80+
return empty
81+
82+
return super().forward(x)
83+
84+
85+
class Linear(torch.nn.Linear):
86+
87+
def forward(self, x):
88+
if x.numel() == 0:
89+
out_shape = [x.shape[0], self.out_features]
90+
empty = NewEmptyTensorOp.apply(x, out_shape)
91+
if self.training:
92+
# produce dummy gradient to avoid DDP warning.
93+
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
94+
return empty + dummy
95+
else:
96+
return empty
97+
98+
return super().forward(x)

0 commit comments

Comments
 (0)