@@ -276,12 +276,13 @@ class _BlockBuilder:
276
276
"""
277
277
278
278
def __init__ (self , channel_multiplier = 1.0 , channel_divisor = 8 , channel_min = None ,
279
- act_fn = None , se_gate_fn = torch .sigmoid , se_reduce_mid = False ,
279
+ drop_connect_rate = 0. , act_fn = None , se_gate_fn = torch .sigmoid , se_reduce_mid = False ,
280
280
bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
281
281
folded_bn = False , padding_same = False , verbose = False ):
282
282
self .channel_multiplier = channel_multiplier
283
283
self .channel_divisor = channel_divisor
284
284
self .channel_min = channel_min
285
+ self .drop_connect_rate = drop_connect_rate
285
286
self .act_fn = act_fn
286
287
self .se_gate_fn = se_gate_fn
287
288
self .se_reduce_mid = se_reduce_mid
@@ -310,10 +311,12 @@ def _make_block(self, ba):
310
311
print ('args:' , ba )
311
312
# could replace this if with lambdas or functools binding if variety increases
312
313
if bt == 'ir' :
314
+ ba ['drop_connect_rate' ] = self .drop_connect_rate
313
315
ba ['se_gate_fn' ] = self .se_gate_fn
314
316
ba ['se_reduce_mid' ] = self .se_reduce_mid
315
317
block = InvertedResidual (** ba )
316
318
elif bt == 'ds' or bt == 'dsa' :
319
+ ba ['drop_connect_rate' ] = self .drop_connect_rate
317
320
block = DepthwiseSeparableConv (** ba )
318
321
elif bt == 'ca' :
319
322
block = CascadeConv (** ba )
@@ -402,6 +405,19 @@ def hard_sigmoid(x):
402
405
return F .relu6 (x + 3. ) / 6.
403
406
404
407
408
+ def drop_connect (inputs , training = False , drop_connect_rate = 0. ):
409
+ """Apply drop connect."""
410
+ if not training :
411
+ return inputs
412
+
413
+ keep_prob = 1 - drop_connect_rate
414
+ random_tensor = keep_prob + torch .rand (
415
+ (inputs .size ()[0 ], 1 , 1 , 1 ), dtype = inputs .dtype , device = inputs .device )
416
+ random_tensor .floor_ () # binarize
417
+ output = inputs .div (keep_prob ) * random_tensor
418
+ return output
419
+
420
+
405
421
class ChannelShuffle (nn .Module ):
406
422
# FIXME haven't used yet
407
423
def __init__ (self , groups ):
@@ -474,13 +490,14 @@ def __init__(self, in_chs, out_chs, kernel_size,
474
490
stride = 1 , act_fn = F .relu , noskip = False , pw_act = False ,
475
491
se_ratio = 0. , se_gate_fn = torch .sigmoid ,
476
492
bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
477
- folded_bn = False , padding_same = False ):
493
+ folded_bn = False , padding_same = False , drop_connect_rate = 0. ):
478
494
super (DepthwiseSeparableConv , self ).__init__ ()
479
495
assert stride in [1 , 2 ]
480
496
self .has_se = se_ratio is not None and se_ratio > 0.
481
497
self .has_residual = (stride == 1 and in_chs == out_chs ) and not noskip
482
498
self .has_pw_act = pw_act # activation after point-wise conv
483
499
self .act_fn = act_fn
500
+ self .drop_connect_rate = drop_connect_rate
484
501
dw_padding = _padding_arg (kernel_size // 2 , padding_same )
485
502
pw_padding = _padding_arg (0 , padding_same )
486
503
@@ -515,7 +532,9 @@ def forward(self, x):
515
532
x = self .act_fn (x )
516
533
517
534
if self .has_residual :
518
- x += residual # FIXME add drop-connect
535
+ if self .drop_connect_rate > 0. :
536
+ x = drop_connect (x , self .training , self .drop_connect_rate )
537
+ x += residual
519
538
return x
520
539
521
540
@@ -557,12 +576,13 @@ def __init__(self, in_chs, out_chs, kernel_size,
557
576
se_ratio = 0. , se_reduce_mid = False , se_gate_fn = torch .sigmoid ,
558
577
shuffle_type = None , pw_group = 1 ,
559
578
bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
560
- folded_bn = False , padding_same = False ):
579
+ folded_bn = False , padding_same = False , drop_connect_rate = 0. ):
561
580
super (InvertedResidual , self ).__init__ ()
562
581
mid_chs = int (in_chs * exp_ratio )
563
582
self .has_se = se_ratio is not None and se_ratio > 0.
564
583
self .has_residual = (in_chs == out_chs and stride == 1 ) and not noskip
565
584
self .act_fn = act_fn
585
+ self .drop_connect_rate = drop_connect_rate
566
586
dw_padding = _padding_arg (kernel_size // 2 , padding_same )
567
587
pw_padding = _padding_arg (0 , padding_same )
568
588
@@ -619,7 +639,9 @@ def forward(self, x):
619
639
x = self .bn3 (x )
620
640
621
641
if self .has_residual :
622
- x += residual # FIXME add drop-connect
642
+ if self .drop_connect_rate > 0. :
643
+ x = drop_connect (x , self .training , self .drop_connect_rate )
644
+ x += residual
623
645
624
646
# NOTE maskrcnn_benchmark building blocks have an SE module defined here for some variants
625
647
@@ -643,12 +665,14 @@ class GenMobileNet(nn.Module):
643
665
def __init__ (self , block_args , num_classes = 1000 , in_chans = 3 , stem_size = 32 , num_features = 1280 ,
644
666
channel_multiplier = 1.0 , channel_divisor = 8 , channel_min = None ,
645
667
bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
646
- drop_rate = 0. , act_fn = F .relu , se_gate_fn = torch .sigmoid , se_reduce_mid = False ,
668
+ drop_rate = 0. , drop_connect_rate = 0. , act_fn = F .relu ,
669
+ se_gate_fn = torch .sigmoid , se_reduce_mid = False ,
647
670
global_pool = 'avg' , head_conv = 'default' , weight_init = 'goog' ,
648
- folded_bn = False , padding_same = False ):
671
+ folded_bn = False , padding_same = False , ):
649
672
super (GenMobileNet , self ).__init__ ()
650
673
self .num_classes = num_classes
651
674
self .drop_rate = drop_rate
675
+ self .drop_connect_rate = drop_connect_rate
652
676
self .act_fn = act_fn
653
677
self .num_features = num_features
654
678
@@ -661,7 +685,7 @@ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_f
661
685
662
686
builder = _BlockBuilder (
663
687
channel_multiplier , channel_divisor , channel_min ,
664
- act_fn , se_gate_fn , se_reduce_mid ,
688
+ drop_connect_rate , act_fn , se_gate_fn , se_reduce_mid ,
665
689
bn_momentum , bn_eps , folded_bn , padding_same , verbose = _DEBUG )
666
690
self .blocks = nn .Sequential (* builder (in_chs , block_args ))
667
691
in_chs = builder .in_chs
@@ -1090,7 +1114,7 @@ def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs):
1090
1114
1091
1115
1092
1116
def _gen_efficientnet (channel_multiplier = 1.0 , depth_multiplier = 1.0 , num_classes = 1000 , ** kwargs ):
1093
- """Creates a MobileNet-V3 model.
1117
+ """Creates an EfficientNet model.
1094
1118
1095
1119
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
1096
1120
Paper: https://arxiv.org/abs/1905.11946
@@ -1347,7 +1371,7 @@ def spnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs):
1347
1371
def efficientnet_b0 (num_classes , in_chans = 3 , pretrained = False , ** kwargs ):
1348
1372
""" EfficientNet """
1349
1373
default_cfg = default_cfgs ['efficientnet_b0' ]
1350
- # NOTE dropout should be 0.2 for train
1374
+ # NOTE for train, drop_rate should be 0.2
1351
1375
model = _gen_efficientnet (
1352
1376
channel_multiplier = 1.0 , depth_multiplier = 1.0 ,
1353
1377
num_classes = num_classes , in_chans = in_chans , ** kwargs )
@@ -1360,7 +1384,7 @@ def efficientnet_b0(num_classes, in_chans=3, pretrained=False, **kwargs):
1360
1384
def efficientnet_b1 (num_classes , in_chans = 3 , pretrained = False , ** kwargs ):
1361
1385
""" EfficientNet """
1362
1386
default_cfg = default_cfgs ['efficientnet_b1' ]
1363
- # NOTE dropout should be 0.2 for train
1387
+ # NOTE for train, drop_rate should be 0.2
1364
1388
model = _gen_efficientnet (
1365
1389
channel_multiplier = 1.0 , depth_multiplier = 1.1 ,
1366
1390
num_classes = num_classes , in_chans = in_chans , ** kwargs )
@@ -1373,7 +1397,7 @@ def efficientnet_b1(num_classes, in_chans=3, pretrained=False, **kwargs):
1373
1397
def efficientnet_b2 (num_classes , in_chans = 3 , pretrained = False , ** kwargs ):
1374
1398
""" EfficientNet """
1375
1399
default_cfg = default_cfgs ['efficientnet_b2' ]
1376
- # NOTE dropout should be 0.3 for train
1400
+ # NOTE for train, drop_rate should be 0.3
1377
1401
model = _gen_efficientnet (
1378
1402
channel_multiplier = 1.1 , depth_multiplier = 1.2 ,
1379
1403
num_classes = num_classes , in_chans = in_chans , ** kwargs )
@@ -1386,7 +1410,7 @@ def efficientnet_b2(num_classes, in_chans=3, pretrained=False, **kwargs):
1386
1410
def efficientnet_b3 (num_classes , in_chans = 3 , pretrained = False , ** kwargs ):
1387
1411
""" EfficientNet """
1388
1412
default_cfg = default_cfgs ['efficientnet_b3' ]
1389
- # NOTE dropout should be 0.3 for train
1413
+ # NOTE for train, drop_rate should be 0.3
1390
1414
model = _gen_efficientnet (
1391
1415
channel_multiplier = 1.2 , depth_multiplier = 1.4 ,
1392
1416
num_classes = num_classes , in_chans = in_chans , ** kwargs )
@@ -1399,7 +1423,7 @@ def efficientnet_b3(num_classes, in_chans=3, pretrained=False, **kwargs):
1399
1423
def efficientnet_b4 (num_classes , in_chans = 3 , pretrained = False , ** kwargs ):
1400
1424
""" EfficientNet """
1401
1425
default_cfg = default_cfgs ['efficientnet_b4' ]
1402
- # NOTE dropout should be 0.4 for train
1426
+ # NOTE for train, drop_rate should be 0.4
1403
1427
model = _gen_efficientnet (
1404
1428
channel_multiplier = 1.4 , depth_multiplier = 1.8 ,
1405
1429
num_classes = num_classes , in_chans = in_chans , ** kwargs )
0 commit comments