Skip to content

Commit 3258f01

Browse files
committed
Add EfficientNet-Lite w/ ported weights, PyTorch trained MobileNetV3, improved weight initialization for DW convs
1 parent 75c8e95 commit 3258f01

File tree

6 files changed

+187
-9
lines changed

6 files changed

+187
-9
lines changed

README.md

+18-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ All models are implemented by GenEfficientNet or MobileNetV3 classes, with strin
66

77
## What's New
88

9+
### March 23, 2020
10+
* Add EfficientNet-Lite models w/ weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite)
11+
* Add PyTorch trained MobileNet-V3 Large weights trained from stratch with this code to 75.77% top-1
12+
* IMPORTANT CHANGE (if training from scratch) - weight init changed to better match Tensorflow impl, set `fix_group_fanout=False` in `initialize_weight_goog` for old behavior
13+
914
### Feb 12, 2020
1015
* Add EfficientNet-L2 and B0-B7 NoisyStudent weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet)
1116
* Port new EfficientNet-B8 (RandAugment) weights from TF TPU, these are different than the B8 AdvProp, different input normalization.
@@ -49,6 +54,7 @@ Implemented models include:
4954
* EfficientNet (B0-B8) (https://arxiv.org/abs/1905.11946)
5055
* EfficientNet-EdgeTPU (S, M, L) (https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html)
5156
* EfficientNet-CondConv (https://arxiv.org/abs/1904.04971)
57+
* EfficientNet-Lite (https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite)
5258
* MixNet (https://arxiv.org/abs/1907.09595)
5359
* MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626)
5460
* MobileNet-V3 (https://arxiv.org/abs/1905.02244)
@@ -76,6 +82,7 @@ I've managed to train several of the models to accuracies close to or above the
7682
| efficientnet_b0 | 77.698 (22.302) | 93.532 (6.468) | 5.3 | 390 | bicubic | 224 | 0.875 |
7783
| mixnet_m | 77.256 (22.744) | 93.418 (6.582) | 5.01 | 353 | bicubic | 224 | 0.875 |
7884
| mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13 | TBD | bicubic | 224 | 0.875 |
85+
| mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5M | bicubic | 224 | 0.875 |
7986
| mobilenetv3_rw | 75.634 (24.366) | 92.708 (7.292) | 5.5 | 219 | bicubic | 224 | 0.875 |
8087
| mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.9 | 312 | bicubic | 224 | 0.875 |
8188
| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6 | 385 | bilinear | 224 | 0.875 |
@@ -91,7 +98,7 @@ More pretrained models to come...
9198
The weights ported from Tensorflow checkpoints for the EfficientNet models do pretty much match accuracy in Tensorflow once a SAME convolution padding equivalent is added, and the same crop factors, image scaling, etc (see table) are used via cmd line args.
9299

93100
**IMPORTANT:**
94-
* Tensorflow ported weights for EfficientNet AdvProp (AP), EfficientNet EdgeTPU, EfficientNet-CondConv, and MobileNet-V3 models use Inception style (0.5, 0.5, 0.5) for mean and std.
101+
* Tensorflow ported weights for EfficientNet AdvProp (AP), EfficientNet EdgeTPU, EfficientNet-CondConv, EfficientNet-Lite, and MobileNet-V3 models use Inception style (0.5, 0.5, 0.5) for mean and std.
95102
* Enabling the Tensorflow preprocessing pipeline with `--tf-preprocessing` at validation time will improve scores by 0.1-0.5%, very close to original TF impl.
96103

97104
To run validation for tf_efficientnet_b5:
@@ -145,14 +152,18 @@ To run validation for a model with Inception preprocessing, ie EfficientNet-B8 A
145152
| tf_efficientnet_b3_ap | 81.828 (18.172) | 95.624 (4.376) | 12.23 | bicubic | 300 | 0.904 |
146153
| tf_efficientnet_b3 | 81.636 (18.364) | 95.718 (4.282) | 12.23 | bicubic | 300 | 0.904 |
147154
| tf_efficientnet_b3 *tfp | 81.576 (18.424) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A |
155+
| tf_efficientnet_lite4 | 81.528 (18.472) | 95.668 (4.332) | 13.00 | bilinear | 380 | 0.92 |
148156
| tf_efficientnet_b1_ns *tfp | 81.514 (18.486) | 95.776 (4.224) | 7.79 | bicubic | 240 | N/A |
157+
| tf_efficientnet_lite4 *tfp | 81.502 (18.498) | 95.676 (4.324) | 13.00 | bilinear | 380 | N/A |
149158
| tf_efficientnet_b1_ns | 81.388 (18.612) | 95.738 (4.262) | 7.79 | bicubic | 240 | 0.88 |
150159
| tf_efficientnet_el | 80.534 (19.466) | 95.190 (4.810) | 10.59 | bicubic | 300 | 0.904 |
151160
| tf_efficientnet_el *tfp | 80.476 (19.524) | 95.200 (4.800) | 10.59 | bicubic | 300 | N/A |
152161
| tf_efficientnet_b2_ap *tfp | 80.420 (19.580) | 95.040 (4.960) | 9.11 | bicubic | 260 | N/A |
153162
| tf_efficientnet_b2_ap | 80.306 (19.694) | 95.028 (4.972) | 9.11 | bicubic | 260 | 0.890 |
154163
| tf_efficientnet_b2 *tfp | 80.188 (19.812) | 94.974 (5.026) | 9.11 | bicubic | 260 | N/A |
155164
| tf_efficientnet_b2 | 80.086 (19.914) | 94.908 (5.092) | 9.11 | bicubic | 260 | 0.890 |
165+
| tf_efficientnet_lite3 | 79.812 (20.188) | 94.914 (5.086) | 8.20 | bilinear | 300 | 0.904 |
166+
| tf_efficientnet_lite3 *tfp | 79.734 (20.266) | 94.838 (5.162) | 8.20 | bilinear | 300 | N/A |
156167
| tf_efficientnet_b1_ap *tfp | 79.532 (20.468) | 94.378 (5.622) | 7.79 | bicubic | 240 | N/A |
157168
| tf_efficientnet_cc_b1_8e *tfp | 79.464 (20.536)| 94.492 (5.508) | 39.7 | bicubic | 240 | 0.88 |
158169
| tf_efficientnet_cc_b1_8e | 79.298 (20.702) | 94.364 (5.636) | 39.7 | bicubic | 240 | 0.88 |
@@ -170,17 +181,23 @@ To run validation for a model with Inception preprocessing, ie EfficientNet-B8 A
170181
| tf_efficientnet_cc_b0_4e *tfp | 77.746 (22.254) | 93.552 (6.448) | 13.3 | bicubic | 224 | 0.875 |
171182
| tf_efficientnet_cc_b0_4e | 77.304 (22.696) | 93.332 (6.668) | 13.3 | bicubic | 224 | 0.875 |
172183
| tf_efficientnet_es *tfp | 77.616 (22.384) | 93.750 (6.250) | 5.44 | bicubic | 224 | N/A |
184+
| tf_efficientnet_lite2 *tfp | 77.544 (22.456) | 93.800 (6.200) | 6.09 | bilinear | 260 | N/A |
185+
| tf_efficientnet_lite2 | 77.460 (22.540) | 93.746 (6.254) | 6.09 | bicubic | 260 | 0.89 |
173186
| tf_efficientnet_b0_ap *tfp | 77.514 (22.486) | 93.576 (6.424) | 5.29 | bicubic | 224 | N/A |
174187
| tf_efficientnet_es | 77.264 (22.736) | 93.600 (6.400) | 5.44 | bicubic | 224 | N/A |
175188
| tf_efficientnet_b0 *tfp | 77.258 (22.742) | 93.478 (6.522) | 5.29 | bicubic | 224 | N/A |
176189
| tf_efficientnet_b0_ap | 77.084 (22.916) | 93.254 (6.746) | 5.29 | bicubic | 224 | 0.875 |
177190
| tf_mixnet_m *tfp | 77.072 (22.928) | 93.368 (6.632) | 5.01 | bilinear | 224 | N/A |
178191
| tf_mixnet_m | 76.950 (23.050) | 93.156 (6.844) | 5.01 | bicubic | 224 | 0.875 |
179192
| tf_efficientnet_b0 | 76.848 (23.152) | 93.228 (6.772) | 5.29 | bicubic | 224 | 0.875 |
193+
| tf_efficientnet_lite1 *tfp | 76.764 (23.236) | 93.326 (6.674) | 5.42 | bilinear | 240 | N/A |
194+
| tf_efficientnet_lite1 | 76.638 (23.362) | 93.232 (6.768) | 5.42 | bicubic | 240 | 0.882 |
180195
| tf_mixnet_s *tfp | 75.800 (24.200) | 92.788 (7.212) | 4.13 | bilinear | 224 | N/A |
181196
| tf_mobilenetv3_large_100 *tfp | 75.768 (24.232) | 92.710 (7.290) | 5.48 | bilinear | 224 | N/A |
182197
| tf_mixnet_s | 75.648 (24.352) | 92.636 (7.364) | 4.13 | bicubic | 224 | 0.875 |
183198
| tf_mobilenetv3_large_100 | 75.516 (24.484) | 92.600 (7.400) | 5.48 | bilinear | 224 | 0.875 |
199+
| tf_efficientnet_lite0 *tfp | 75.074 (24.926) | 92.314 (7.686) | 4.65 | bilinear | 224 | N/A |
200+
| tf_efficientnet_lite0 | 74.842 (25.158) | 92.170 (7.830) | 4.65 | bicubic | 224 | 0.875 |
184201
| tf_mobilenetv3_large_075 *tfp | 73.730 (26.270) | 91.616 (8.384) | 3.99 | bilinear | 224 |N/A |
185202
| tf_mobilenetv3_large_075 | 73.442 (26.558) | 91.352 (8.648) | 3.99 | bilinear | 224 | 0.875 |
186203
| tf_mobilenetv3_large_minimal_100 *tfp | 72.678 (27.322) | 90.860 (9.140) | 3.92 | bilinear | 224 | N/A |

geffnet/efficientnet_builder.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@ def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='c
610610
return sa_scaled
611611

612612

613-
def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1):
613+
def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False):
614614
arch_args = []
615615
for stack_idx, block_strings in enumerate(arch_def):
616616
assert isinstance(block_strings, list)
@@ -623,22 +623,29 @@ def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_
623623
ba['num_experts'] *= experts_multiplier
624624
stack_args.append(ba)
625625
repeats.append(rep)
626-
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
626+
if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
627+
arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
628+
else:
629+
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
627630
return arch_args
628631

629632

630-
def initialize_weight_goog(m, n=''):
633+
def initialize_weight_goog(m, n='', fix_group_fanout=True):
631634
# weight init as per Tensorflow Official impl
632635
# https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
633636
if isinstance(m, CondConv2d):
634637
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
638+
if fix_group_fanout:
639+
fan_out //= m.groups
635640
init_weight_fn = get_condconv_initializer(
636641
lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
637642
init_weight_fn(m.weight)
638643
if m.bias is not None:
639644
m.bias.data.zero_()
640645
elif isinstance(m, nn.Conv2d):
641646
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
647+
if fix_group_fanout:
648+
fan_out //= m.groups
642649
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
643650
if m.bias is not None:
644651
m.bias.data.zero_()

geffnet/gen_efficientnet.py

+149-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
- Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665
99
- Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252
1010
11+
* EfficientNet-Lite
12+
1113
* MixNet (Small, Medium, and Large)
1214
- MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595
1315
@@ -36,6 +38,7 @@
3638
'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8',
3739
'efficientnet_l2', 'efficientnet_es', 'efficientnet_em', 'efficientnet_el',
3840
'efficientnet_cc_b0_4e', 'efficientnet_cc_b0_8e', 'efficientnet_cc_b1_8e',
41+
'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2', 'efficientnet_lite3', 'efficientnet_lite4',
3942
'tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3',
4043
'tf_efficientnet_b4', 'tf_efficientnet_b5', 'tf_efficientnet_b6', 'tf_efficientnet_b7', 'tf_efficientnet_b8',
4144
'tf_efficientnet_b0_ap', 'tf_efficientnet_b1_ap', 'tf_efficientnet_b2_ap', 'tf_efficientnet_b3_ap',
@@ -45,6 +48,8 @@
4548
'tf_efficientnet_b7_ns', 'tf_efficientnet_l2_ns', 'tf_efficientnet_l2_ns_475',
4649
'tf_efficientnet_es', 'tf_efficientnet_em', 'tf_efficientnet_el',
4750
'tf_efficientnet_cc_b0_4e', 'tf_efficientnet_cc_b0_8e', 'tf_efficientnet_cc_b1_8e',
51+
'tf_efficientnet_lite0', 'tf_efficientnet_lite1', 'tf_efficientnet_lite2', 'tf_efficientnet_lite3',
52+
'tf_efficientnet_lite4',
4853
'mixnet_s', 'mixnet_m', 'mixnet_l', 'mixnet_xl', 'tf_mixnet_s', 'tf_mixnet_m', 'tf_mixnet_l']
4954

5055

@@ -54,12 +59,14 @@
5459
'mnasnet_100':
5560
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth',
5661
'mnasnet_140': None,
62+
'mnasnet_small': None,
63+
5764
'semnasnet_050': None,
5865
'semnasnet_075': None,
5966
'semnasnet_100':
6067
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth',
6168
'semnasnet_140': None,
62-
'mnasnet_small': None,
69+
6370
'fbnetc_100':
6471
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth',
6572
'spnasnet_100':
@@ -89,6 +96,12 @@
8996
'efficientnet_cc_b0_8e': None,
9097
'efficientnet_cc_b1_8e': None,
9198

99+
'efficientnet_lite0': None,
100+
'efficientnet_lite1': None,
101+
'efficientnet_lite2': None,
102+
'efficientnet_lite3': None,
103+
'efficientnet_lite4': None,
104+
92105
'tf_efficientnet_b0':
93106
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
94107
'tf_efficientnet_b1':
@@ -162,6 +175,17 @@
162175
'tf_efficientnet_cc_b1_8e':
163176
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth',
164177

178+
'tf_efficientnet_lite0':
179+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth',
180+
'tf_efficientnet_lite1':
181+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth',
182+
'tf_efficientnet_lite2':
183+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth',
184+
'tf_efficientnet_lite3':
185+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth',
186+
'tf_efficientnet_lite4':
187+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth',
188+
165189
'mixnet_s': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth',
166190
'mixnet_m': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth',
167191
'mixnet_l': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth',
@@ -187,15 +211,16 @@ class GenEfficientNet(nn.Module):
187211
* Single-Path NAS Pixel1
188212
"""
189213

190-
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
214+
def __init__(self, block_args, num_classes=1000, in_chans=3, num_features=1280, stem_size=32, fix_stem=False,
191215
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
192216
pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
193217
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
194218
weight_init='goog'):
195219
super(GenEfficientNet, self).__init__()
196220
self.drop_rate = drop_rate
197221

198-
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
222+
if not fix_stem:
223+
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
199224
self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
200225
self.bn1 = norm_layer(stem_size, **norm_kwargs)
201226
self.act1 = act_layer(inplace=True)
@@ -521,6 +546,47 @@ def _gen_efficientnet_condconv(
521546
return model
522547

523548

549+
def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
550+
"""Creates an EfficientNet-Lite model.
551+
552+
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite
553+
Paper: https://arxiv.org/abs/1905.11946
554+
555+
EfficientNet params
556+
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
557+
'efficientnet-lite0': (1.0, 1.0, 224, 0.2),
558+
'efficientnet-lite1': (1.0, 1.1, 240, 0.2),
559+
'efficientnet-lite2': (1.1, 1.2, 260, 0.3),
560+
'efficientnet-lite3': (1.2, 1.4, 280, 0.3),
561+
'efficientnet-lite4': (1.4, 1.8, 300, 0.3),
562+
563+
Args:
564+
channel_multiplier: multiplier to number of channels per layer
565+
depth_multiplier: multiplier to number of repeats per stage
566+
"""
567+
arch_def = [
568+
['ds_r1_k3_s1_e1_c16'],
569+
['ir_r2_k3_s2_e6_c24'],
570+
['ir_r2_k5_s2_e6_c40'],
571+
['ir_r3_k3_s2_e6_c80'],
572+
['ir_r3_k5_s1_e6_c112'],
573+
['ir_r4_k5_s2_e6_c192'],
574+
['ir_r1_k3_s1_e6_c320'],
575+
]
576+
model_kwargs = dict(
577+
block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True),
578+
num_features=1280,
579+
stem_size=32,
580+
fix_stem=True,
581+
channel_multiplier=channel_multiplier,
582+
act_layer=nn.ReLU6,
583+
norm_kwargs=resolve_bn_args(kwargs),
584+
**kwargs,
585+
)
586+
model = _create_model(model_kwargs, variant, pretrained)
587+
return model
588+
589+
524590
def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
525591
"""Creates a MixNet Small model.
526592
@@ -795,6 +861,41 @@ def efficientnet_cc_b1_8e(pretrained=False, **kwargs):
795861
return model
796862

797863

864+
def efficientnet_lite0(pretrained=False, **kwargs):
865+
""" EfficientNet-Lite0 """
866+
model = _gen_efficientnet_lite(
867+
'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
868+
return model
869+
870+
871+
def efficientnet_lite1(pretrained=False, **kwargs):
872+
""" EfficientNet-Lite1 """
873+
model = _gen_efficientnet_lite(
874+
'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
875+
return model
876+
877+
878+
def efficientnet_lite2(pretrained=False, **kwargs):
879+
""" EfficientNet-Lite2 """
880+
model = _gen_efficientnet_lite(
881+
'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
882+
return model
883+
884+
885+
def efficientnet_lite3(pretrained=False, **kwargs):
886+
""" EfficientNet-Lite3 """
887+
model = _gen_efficientnet_lite(
888+
'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
889+
return model
890+
891+
892+
def efficientnet_lite4(pretrained=False, **kwargs):
893+
""" EfficientNet-Lite4 """
894+
model = _gen_efficientnet_lite(
895+
'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
896+
return model
897+
898+
798899
def tf_efficientnet_b0(pretrained=False, **kwargs):
799900
""" EfficientNet-B0 AutoAug. Tensorflow compatible variant """
800901
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
@@ -1148,6 +1249,51 @@ def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
11481249
return model
11491250

11501251

1252+
def tf_efficientnet_lite0(pretrained=False, **kwargs):
1253+
""" EfficientNet-Lite0. Tensorflow compatible variant """
1254+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1255+
kwargs['pad_type'] = 'same'
1256+
model = _gen_efficientnet_lite(
1257+
'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
1258+
return model
1259+
1260+
1261+
def tf_efficientnet_lite1(pretrained=False, **kwargs):
1262+
""" EfficientNet-Lite1. Tensorflow compatible variant """
1263+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1264+
kwargs['pad_type'] = 'same'
1265+
model = _gen_efficientnet_lite(
1266+
'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
1267+
return model
1268+
1269+
1270+
def tf_efficientnet_lite2(pretrained=False, **kwargs):
1271+
""" EfficientNet-Lite2. Tensorflow compatible variant """
1272+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1273+
kwargs['pad_type'] = 'same'
1274+
model = _gen_efficientnet_lite(
1275+
'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
1276+
return model
1277+
1278+
1279+
def tf_efficientnet_lite3(pretrained=False, **kwargs):
1280+
""" EfficientNet-Lite3. Tensorflow compatible variant """
1281+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1282+
kwargs['pad_type'] = 'same'
1283+
model = _gen_efficientnet_lite(
1284+
'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
1285+
return model
1286+
1287+
1288+
def tf_efficientnet_lite4(pretrained=False, **kwargs):
1289+
""" EfficientNet-Lite4. Tensorflow compatible variant """
1290+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1291+
kwargs['pad_type'] = 'same'
1292+
model = _gen_efficientnet_lite(
1293+
'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
1294+
return model
1295+
1296+
11511297
def mixnet_s(pretrained=False, **kwargs):
11521298
"""Creates a MixNet Small model.
11531299
"""

geffnet/mobilenetv3.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
'mobilenetv3_rw':
2222
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
2323
'mobilenetv3_large_075': None,
24-
'mobilenetv3_large_100': None,
24+
'mobilenetv3_large_100':
25+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth',
2526
'mobilenetv3_large_minimal_100': None,
2627
'mobilenetv3_small_075': None,
2728
'mobilenetv3_small_100': None,

geffnet/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.9.7'
1+
__version__ = '0.9.8'

0 commit comments

Comments
 (0)