Skip to content

Commit 4efecfd

Browse files
committed
Add drop_connect impl to try during training, fix a few comments
1 parent 0fc4cca commit 4efecfd

File tree

1 file changed

+38
-14
lines changed

1 file changed

+38
-14
lines changed

models/genmobilenet.py

+38-14
Original file line numberDiff line numberDiff line change
@@ -276,12 +276,13 @@ class _BlockBuilder:
276276
"""
277277

278278
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
279-
act_fn=None, se_gate_fn=torch.sigmoid, se_reduce_mid=False,
279+
drop_connect_rate=0., act_fn=None, se_gate_fn=torch.sigmoid, se_reduce_mid=False,
280280
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
281281
folded_bn=False, padding_same=False, verbose=False):
282282
self.channel_multiplier = channel_multiplier
283283
self.channel_divisor = channel_divisor
284284
self.channel_min = channel_min
285+
self.drop_connect_rate = drop_connect_rate
285286
self.act_fn = act_fn
286287
self.se_gate_fn = se_gate_fn
287288
self.se_reduce_mid = se_reduce_mid
@@ -310,10 +311,12 @@ def _make_block(self, ba):
310311
print('args:', ba)
311312
# could replace this if with lambdas or functools binding if variety increases
312313
if bt == 'ir':
314+
ba['drop_connect_rate'] = self.drop_connect_rate
313315
ba['se_gate_fn'] = self.se_gate_fn
314316
ba['se_reduce_mid'] = self.se_reduce_mid
315317
block = InvertedResidual(**ba)
316318
elif bt == 'ds' or bt == 'dsa':
319+
ba['drop_connect_rate'] = self.drop_connect_rate
317320
block = DepthwiseSeparableConv(**ba)
318321
elif bt == 'ca':
319322
block = CascadeConv(**ba)
@@ -402,6 +405,19 @@ def hard_sigmoid(x):
402405
return F.relu6(x + 3.) / 6.
403406

404407

408+
def drop_connect(inputs, training=False, drop_connect_rate=0.):
409+
"""Apply drop connect."""
410+
if not training:
411+
return inputs
412+
413+
keep_prob = 1 - drop_connect_rate
414+
random_tensor = keep_prob + torch.rand(
415+
(inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device)
416+
random_tensor.floor_() # binarize
417+
output = inputs.div(keep_prob) * random_tensor
418+
return output
419+
420+
405421
class ChannelShuffle(nn.Module):
406422
# FIXME haven't used yet
407423
def __init__(self, groups):
@@ -474,13 +490,14 @@ def __init__(self, in_chs, out_chs, kernel_size,
474490
stride=1, act_fn=F.relu, noskip=False, pw_act=False,
475491
se_ratio=0., se_gate_fn=torch.sigmoid,
476492
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
477-
folded_bn=False, padding_same=False):
493+
folded_bn=False, padding_same=False, drop_connect_rate=0.):
478494
super(DepthwiseSeparableConv, self).__init__()
479495
assert stride in [1, 2]
480496
self.has_se = se_ratio is not None and se_ratio > 0.
481497
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
482498
self.has_pw_act = pw_act # activation after point-wise conv
483499
self.act_fn = act_fn
500+
self.drop_connect_rate = drop_connect_rate
484501
dw_padding = _padding_arg(kernel_size // 2, padding_same)
485502
pw_padding = _padding_arg(0, padding_same)
486503

@@ -515,7 +532,9 @@ def forward(self, x):
515532
x = self.act_fn(x)
516533

517534
if self.has_residual:
518-
x += residual # FIXME add drop-connect
535+
if self.drop_connect_rate > 0.:
536+
x = drop_connect(x, self.training, self.drop_connect_rate)
537+
x += residual
519538
return x
520539

521540

@@ -557,12 +576,13 @@ def __init__(self, in_chs, out_chs, kernel_size,
557576
se_ratio=0., se_reduce_mid=False, se_gate_fn=torch.sigmoid,
558577
shuffle_type=None, pw_group=1,
559578
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
560-
folded_bn=False, padding_same=False):
579+
folded_bn=False, padding_same=False, drop_connect_rate=0.):
561580
super(InvertedResidual, self).__init__()
562581
mid_chs = int(in_chs * exp_ratio)
563582
self.has_se = se_ratio is not None and se_ratio > 0.
564583
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
565584
self.act_fn = act_fn
585+
self.drop_connect_rate = drop_connect_rate
566586
dw_padding = _padding_arg(kernel_size // 2, padding_same)
567587
pw_padding = _padding_arg(0, padding_same)
568588

@@ -619,7 +639,9 @@ def forward(self, x):
619639
x = self.bn3(x)
620640

621641
if self.has_residual:
622-
x += residual # FIXME add drop-connect
642+
if self.drop_connect_rate > 0.:
643+
x = drop_connect(x, self.training, self.drop_connect_rate)
644+
x += residual
623645

624646
# NOTE maskrcnn_benchmark building blocks have an SE module defined here for some variants
625647

@@ -643,12 +665,14 @@ class GenMobileNet(nn.Module):
643665
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
644666
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
645667
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
646-
drop_rate=0., act_fn=F.relu, se_gate_fn=torch.sigmoid, se_reduce_mid=False,
668+
drop_rate=0., drop_connect_rate=0., act_fn=F.relu,
669+
se_gate_fn=torch.sigmoid, se_reduce_mid=False,
647670
global_pool='avg', head_conv='default', weight_init='goog',
648-
folded_bn=False, padding_same=False):
671+
folded_bn=False, padding_same=False,):
649672
super(GenMobileNet, self).__init__()
650673
self.num_classes = num_classes
651674
self.drop_rate = drop_rate
675+
self.drop_connect_rate = drop_connect_rate
652676
self.act_fn = act_fn
653677
self.num_features = num_features
654678

@@ -661,7 +685,7 @@ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_f
661685

662686
builder = _BlockBuilder(
663687
channel_multiplier, channel_divisor, channel_min,
664-
act_fn, se_gate_fn, se_reduce_mid,
688+
drop_connect_rate, act_fn, se_gate_fn, se_reduce_mid,
665689
bn_momentum, bn_eps, folded_bn, padding_same, verbose=_DEBUG)
666690
self.blocks = nn.Sequential(*builder(in_chs, block_args))
667691
in_chs = builder.in_chs
@@ -1090,7 +1114,7 @@ def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs):
10901114

10911115

10921116
def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs):
1093-
"""Creates a MobileNet-V3 model.
1117+
"""Creates an EfficientNet model.
10941118
10951119
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
10961120
Paper: https://arxiv.org/abs/1905.11946
@@ -1347,7 +1371,7 @@ def spnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs):
13471371
def efficientnet_b0(num_classes, in_chans=3, pretrained=False, **kwargs):
13481372
""" EfficientNet """
13491373
default_cfg = default_cfgs['efficientnet_b0']
1350-
# NOTE dropout should be 0.2 for train
1374+
# NOTE for train, drop_rate should be 0.2
13511375
model = _gen_efficientnet(
13521376
channel_multiplier=1.0, depth_multiplier=1.0,
13531377
num_classes=num_classes, in_chans=in_chans, **kwargs)
@@ -1360,7 +1384,7 @@ def efficientnet_b0(num_classes, in_chans=3, pretrained=False, **kwargs):
13601384
def efficientnet_b1(num_classes, in_chans=3, pretrained=False, **kwargs):
13611385
""" EfficientNet """
13621386
default_cfg = default_cfgs['efficientnet_b1']
1363-
# NOTE dropout should be 0.2 for train
1387+
# NOTE for train, drop_rate should be 0.2
13641388
model = _gen_efficientnet(
13651389
channel_multiplier=1.0, depth_multiplier=1.1,
13661390
num_classes=num_classes, in_chans=in_chans, **kwargs)
@@ -1373,7 +1397,7 @@ def efficientnet_b1(num_classes, in_chans=3, pretrained=False, **kwargs):
13731397
def efficientnet_b2(num_classes, in_chans=3, pretrained=False, **kwargs):
13741398
""" EfficientNet """
13751399
default_cfg = default_cfgs['efficientnet_b2']
1376-
# NOTE dropout should be 0.3 for train
1400+
# NOTE for train, drop_rate should be 0.3
13771401
model = _gen_efficientnet(
13781402
channel_multiplier=1.1, depth_multiplier=1.2,
13791403
num_classes=num_classes, in_chans=in_chans, **kwargs)
@@ -1386,7 +1410,7 @@ def efficientnet_b2(num_classes, in_chans=3, pretrained=False, **kwargs):
13861410
def efficientnet_b3(num_classes, in_chans=3, pretrained=False, **kwargs):
13871411
""" EfficientNet """
13881412
default_cfg = default_cfgs['efficientnet_b3']
1389-
# NOTE dropout should be 0.3 for train
1413+
# NOTE for train, drop_rate should be 0.3
13901414
model = _gen_efficientnet(
13911415
channel_multiplier=1.2, depth_multiplier=1.4,
13921416
num_classes=num_classes, in_chans=in_chans, **kwargs)
@@ -1399,7 +1423,7 @@ def efficientnet_b3(num_classes, in_chans=3, pretrained=False, **kwargs):
13991423
def efficientnet_b4(num_classes, in_chans=3, pretrained=False, **kwargs):
14001424
""" EfficientNet """
14011425
default_cfg = default_cfgs['efficientnet_b4']
1402-
# NOTE dropout should be 0.4 for train
1426+
# NOTE for train, drop_rate should be 0.4
14031427
model = _gen_efficientnet(
14041428
channel_multiplier=1.4, depth_multiplier=1.8,
14051429
num_classes=num_classes, in_chans=in_chans, **kwargs)

0 commit comments

Comments
 (0)