Skip to content

Commit 099c1a5

Browse files
authored
Merge pull request #171 from darsnack/efficient-net
2 parents 07ad654 + a3f44c8 commit 099c1a5

File tree

5 files changed

+185
-6
lines changed

5 files changed

+185
-6
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
| [MobileNetv1](https://arxiv.org/abs/1704.04861) | [`MobileNetv1`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MobileNetv1.html) | N |
2929
| [MobileNetv2](https://arxiv.org/abs/1801.04381) | [`MobileNetv2`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MobileNetv2.html) | N |
3030
| [MobileNetv3](https://arxiv.org/abs/1905.02244) | [`MobileNetv3`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MobileNetv3.html) | N |
31+
| [EfficientNet](https://arxiv.org/abs/1905.11946) | [`EfficientNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.EfficientNet.html) | N |
3132
| [MLPMixer](https://arxiv.org/pdf/2105.01601) | [`MLPMixer`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MLPMixer.html) | N |
3233
| [ResMLP](https://arxiv.org/abs/2105.03404) | [`ResMLP`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ResMLP.html) | N |
3334
| [gMLP](https://arxiv.org/abs/2105.08050) | [`gMLP`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.gMLP.html) | N |

src/Metalhead.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ include("convnets/resnext.jl")
2727
include("convnets/densenet.jl")
2828
include("convnets/squeezenet.jl")
2929
include("convnets/mobilenet.jl")
30+
include("convnets/efficientnet.jl")
3031
include("convnets/convnext.jl")
3132
include("convnets/convmixer.jl")
3233

@@ -42,7 +43,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
4243
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt,
4344
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
4445
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
45-
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3,
46+
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet,
4647
MLPMixer, ResMLP, gMLP,
4748
ViT,
4849
ConvMixer, ConvNeXt

src/convnets/efficientnet.jl

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
"""
2+
efficientnet(scalings, block_config;
3+
inchannels = 3, nclasses = 1000, max_width = 1280)
4+
5+
Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)).
6+
7+
# Arguments
8+
9+
- `scalings`: global width and depth scaling (given as a tuple)
10+
- `block_config`: configuration for each inverted residual block,
11+
given as a vector of tuples with elements:
12+
- `n`: number of block repetitions (will be scaled by global depth scaling)
13+
- `k`: kernel size
14+
- `s`: kernel stride
15+
- `e`: expansion ratio
16+
- `i`: block input channels (will be scaled by global width scaling)
17+
- `o`: block output channels (will be scaled by global width scaling)
18+
- `inchannels`: number of input channels
19+
- `nclasses`: number of output classes
20+
- `max_width`: maximum number of output channels before the fully connected
21+
classification blocks
22+
"""
23+
function efficientnet(scalings, block_config;
24+
inchannels = 3, nclasses = 1000, max_width = 1280)
25+
wscale, dscale = scalings
26+
scalew(w) = wscale 1 ? w : ceil(Int64, wscale * w)
27+
scaled(d) = dscale 1 ? d : ceil(Int64, dscale * d)
28+
29+
out_channels = _round_channels(scalew(32), 8)
30+
stem = conv_bn((3, 3), inchannels, out_channels, swish;
31+
bias = false, stride = 2, pad = SamePad())
32+
33+
blocks = []
34+
for (n, k, s, e, i, o) in block_config
35+
in_channels = _round_channels(scalew(i), 8)
36+
out_channels = _round_channels(scalew(o), 8)
37+
repeats = scaled(n)
38+
39+
push!(blocks,
40+
invertedresidual(k, in_channels, in_channels * e, out_channels, swish;
41+
stride = s, reduction = 4))
42+
for _ in 1:(repeats - 1)
43+
push!(blocks,
44+
invertedresidual(k, out_channels, out_channels * e, out_channels, swish;
45+
stride = 1, reduction = 4))
46+
end
47+
end
48+
blocks = Chain(blocks...)
49+
50+
head_out_channels = _round_channels(max_width, 8)
51+
head = conv_bn((1, 1), out_channels, head_out_channels, swish;
52+
bias = false, pad = SamePad())
53+
54+
top = Dense(head_out_channels, nclasses)
55+
56+
return Chain(Chain([stem..., blocks, head...]),
57+
Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, top))
58+
end
59+
60+
# n: # of block repetitions
61+
# k: kernel size k x k
62+
# s: stride
63+
# e: expantion ratio
64+
# i: block input channels
65+
# o: block output channels
66+
const efficientnet_block_configs = [
67+
# (n, k, s, e, i, o)
68+
(1, 3, 1, 1, 32, 16),
69+
(2, 3, 2, 6, 16, 24),
70+
(2, 5, 2, 6, 24, 40),
71+
(3, 3, 2, 6, 40, 80),
72+
(3, 5, 1, 6, 80, 112),
73+
(4, 5, 2, 6, 112, 192),
74+
(1, 3, 1, 6, 192, 320)
75+
]
76+
77+
# w: width scaling
78+
# d: depth scaling
79+
# r: image resolution
80+
const efficientnet_global_configs = Dict(
81+
# ( r, ( w, d))
82+
:b0 => (224, (1.0, 1.0)),
83+
:b1 => (240, (1.0, 1.1)),
84+
:b2 => (260, (1.1, 1.2)),
85+
:b3 => (300, (1.2, 1.4)),
86+
:b4 => (380, (1.4, 1.8)),
87+
:b5 => (456, (1.6, 2.2)),
88+
:b6 => (528, (1.8, 2.6)),
89+
:b7 => (600, (2.0, 3.1)),
90+
:b8 => (672, (2.2, 3.6))
91+
)
92+
93+
struct EfficientNet
94+
layers::Any
95+
end
96+
97+
"""
98+
EfficientNet(scalings, block_config;
99+
inchannels = 3, nclasses = 1000, max_width = 1280)
100+
101+
Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)).
102+
See also [`efficientnet`](#).
103+
104+
# Arguments
105+
106+
- `scalings`: global width and depth scaling (given as a tuple)
107+
- `block_config`: configuration for each inverted residual block,
108+
given as a vector of tuples with elements:
109+
- `n`: number of block repetitions (will be scaled by global depth scaling)
110+
- `k`: kernel size
111+
- `s`: kernel stride
112+
- `e`: expansion ratio
113+
- `i`: block input channels (will be scaled by global width scaling)
114+
- `o`: block output channels (will be scaled by global width scaling)
115+
- `inchannels`: number of input channels
116+
- `nclasses`: number of output classes
117+
- `max_width`: maximum number of output channels before the fully connected
118+
classification blocks
119+
"""
120+
function EfficientNet(scalings, block_config;
121+
inchannels = 3, nclasses = 1000, max_width = 1280)
122+
layers = efficientnet(scalings, block_config;
123+
inchannels = inchannels,
124+
nclasses = nclasses,
125+
max_width = max_width)
126+
return EfficientNet(layers)
127+
end
128+
129+
@functor EfficientNet
130+
131+
(m::EfficientNet)(x) = m.layers(x)
132+
133+
backbone(m::EfficientNet) = m.layers[1]
134+
classifier(m::EfficientNet) = m.layers[2]
135+
136+
"""
137+
EfficientNet(name::Symbol; pretrain = false)
138+
139+
Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)).
140+
See also [`efficientnet`](#).
141+
142+
# Arguments
143+
144+
- `name`: name of default configuration
145+
(can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`)
146+
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
147+
"""
148+
function EfficientNet(name::Symbol; pretrain = false)
149+
@assert name in keys(efficientnet_global_configs)
150+
"`name` must be one of $(sort(collect(keys(efficientnet_global_configs))))"
151+
152+
model = EfficientNet(efficientnet_global_configs[name][2], efficientnet_block_configs)
153+
pretrain && loadpretrain!(model, string("efficientnet-", name))
154+
155+
return model
156+
end

test/convnets.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,27 @@ end
7070
GC.safepoint()
7171
GC.gc()
7272

73+
@testset "EfficientNet" begin
74+
@testset "EfficientNet($name)" for name in [:b0, :b1, :b2, :b3, :b4] #, :b5, :b6, :b7, :b8]
75+
# preferred image resolution scaling
76+
r = Metalhead.efficientnet_global_configs[name][1]
77+
x = rand(Float32, r, r, 3, 1)
78+
m = EfficientNet(name)
79+
@test size(m(x)) == (1000, 1)
80+
if (EfficientNet, name) in PRETRAINED_MODELS
81+
@test acctest(EfficientNet(name, pretrain = true))
82+
else
83+
@test_throws ArgumentError EfficientNet(name, pretrain = true)
84+
end
85+
@test gradtest(m, x)
86+
GC.safepoint()
87+
GC.gc()
88+
end
89+
end
90+
91+
GC.safepoint()
92+
GC.gc()
93+
7394
@testset "GoogLeNet" begin
7495
m = GoogLeNet()
7596
@test size(m(x_224)) == (1000, 1)
@@ -215,7 +236,7 @@ GC.safepoint()
215236
GC.gc()
216237

217238
@testset "ConvNeXt" verbose = true begin
218-
@testset for mode in [:small, :base, :large] # :tiny, #, :xlarge]
239+
@testset for mode in [:small, :base] #, :large # :tiny, #, :xlarge]
219240
@testset for drop_path_rate in [0.0, 0.5]
220241
m = ConvNeXt(mode; drop_path_rate)
221242
@test size(m(x_224)) == (1000, 1)
@@ -230,7 +251,7 @@ GC.safepoint()
230251
GC.gc()
231252

232253
@testset "ConvMixer" verbose = true begin
233-
@testset for mode in [:small, :base, :large]
254+
@testset for mode in [:small, :base] #, :large]
234255
m = ConvMixer(mode)
235256

236257
@test size(m(x_224)) == (1000, 1)

test/other.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testset "MLPMixer" begin
2-
@testset for mode in [:small, :base, :large] # :huge]
2+
@testset for mode in [:small, :base] # :large, # :huge]
33
@testset for drop_path_rate in [0.0, 0.5]
44
m = MLPMixer(mode; drop_path_rate)
55
@test size(m(x_224)) == (1000, 1)
@@ -11,7 +11,7 @@
1111
end
1212

1313
@testset "ResMLP" begin
14-
@testset for mode in [:small, :base, :large] # :huge]
14+
@testset for mode in [:small, :base] # :large, # :huge]
1515
@testset for drop_path_rate in [0.0, 0.5]
1616
m = ResMLP(mode; drop_path_rate)
1717
@test size(m(x_224)) == (1000, 1)
@@ -23,7 +23,7 @@ end
2323
end
2424

2525
@testset "gMLP" begin
26-
@testset for mode in [:small, :base, :large] # :huge]
26+
@testset for mode in [:small, :base] # :large, # :huge]
2727
@testset for drop_path_rate in [0.0, 0.5]
2828
m = gMLP(mode; drop_path_rate)
2929
@test size(m(x_224)) == (1000, 1)

0 commit comments

Comments
 (0)