Skip to content

Commit ae496b4

Browse files
committed
Add ResNeXt back
Also add tests. A lot of tests
1 parent 3be1d81 commit ae496b4

File tree

5 files changed

+66
-44
lines changed

5 files changed

+66
-44
lines changed

src/Metalhead.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ include("vit-based/vit.jl")
3838
include("pretrain.jl")
3939

4040
export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
41-
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, # ResNeXt,
41+
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt,
4242
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
4343
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
4444
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet,
@@ -47,7 +47,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
4747
ConvMixer, ConvNeXt
4848

4949
# use Flux._big_show to pretty print large models
50-
for T in (:AlexNet, :VGG, :DenseNet, :ResNet, # :ResNeXt,
50+
for T in (:AlexNet, :VGG, :DenseNet, :ResNet, :ResNeXt,
5151
:GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception,
5252
:SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3,
5353
:MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt)

src/convnets/resne(x)t.jl

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,18 +173,18 @@ function _drop_blocks(drop_block_prob = 0.0)
173173
]
174174
end
175175

176-
function resnet(block, layers; nclasses = 1000, inchannels = 3, output_stride = 32,
176+
function resnet(block_fn, layers; nclasses = 1000, inchannels = 3, output_stride = 32,
177177
stem = first(resnet_stem(; inchannels)), inplanes = 64,
178178
downsample_fn = downsample_conv,
179179
drop_rates::NamedTuple = (drop_rate = 0.0, drop_path_rate = 0.0,
180180
drop_block_rate = 0.0),
181181
block_args::NamedTuple = NamedTuple())
182182
# Feature Blocks
183183
channels = [64, 128, 256, 512]
184-
stage_blocks = _make_blocks(block, channels, layers, inplanes;
184+
stage_blocks = _make_blocks(block_fn, channels, layers, inplanes;
185185
output_stride, downsample_fn, drop_rates, block_args)
186186
# Head (Pooling and Classifier)
187-
expansion = expansion_factor(block)
187+
expansion = expansion_factor(block_fn)
188188
num_features = 512 * expansion
189189
classifier = Chain(GlobalMeanPool(), Dropout(drop_rates.drop_rate), MLUtils.flatten,
190190
Dense(num_features, nclasses))
@@ -201,11 +201,27 @@ struct ResNet
201201
end
202202
@functor ResNet
203203

204-
function ResNet(depth::Integer; pretrain = false, nclasses = 1000, kwargs...)
205-
@assert depth in [18, 34, 50, 101, 152] "Invalid depth. Must be one of [18, 34, 50, 101, 152]"
206-
model = resnet(resnet_config[depth]...; nclasses, kwargs...)
204+
function ResNet(depth::Integer; pretrain = false, nclasses = 1000)
205+
@assert depth in [18, 34, 50, 101, 152]
206+
"Invalid depth. Must be one of [18, 34, 50, 101, 152]"
207+
model = resnet(resnet_config[depth]...; nclasses)
207208
if pretrain
208209
loadpretrain!(model, string("resnet", depth))
209210
end
210211
return model
211212
end
213+
214+
struct ResNeXt
215+
layers::Any
216+
end
217+
@functor ResNeXt
218+
219+
function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, nclasses = 1000)
220+
@assert depth in [50, 101, 152]
221+
"Invalid depth. Must be one of [50, 101, 152]"
222+
model = resnet(bottleneck, [3, 4, 6, 3]; nclasses, block_args = (; cardinality, base_width))
223+
if pretrain
224+
loadpretrain!(model, string("resnext", depth, "_", cardinality, "x", base_width))
225+
end
226+
return model
227+
end

src/layers/Layers.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ module Layers
22

33
using Flux
44
using CUDA
5-
using NNlib
6-
using NNlibCUDA
5+
using NNlib, NNlibCUDA
76
using Functors
87
using ChainRulesCore
98
using Statistics

src/utilities.jl

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,6 @@ function _round_channels(channels, divisor, min_value = divisor)
99
return (new_channels < 0.9 * channels) ? new_channels + divisor : new_channels
1010
end
1111

12-
"""
13-
addrelu(x, y)
14-
15-
Convenience function for `(x, y) -> @. relu(x + y)`.
16-
Useful as the `connection` argument for [`resnet`](#).
17-
See also [`reluadd`](#).
18-
"""
19-
addrelu(x, y) = @. relu(x + y)
20-
21-
"""
22-
reluadd(x, y)
23-
24-
Convenience function for `(x, y) -> @. relu(x) + relu(y)`.
25-
Useful as the `connection` argument for [`resnet`](#).
26-
See also [`addrelu`](#).
27-
"""
28-
reluadd(x, y) = @. relu(x) + relu(y)
29-
3012
"""
3113
cat_channels(x, y, zs...)
3214

test/convnets.jl

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,39 @@ GC.safepoint()
2727
GC.gc()
2828

2929
@testset "ResNet" begin
30-
@testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152]
31-
m = ResNet(sz)
32-
@test size(m(x_256)) == (1000, 1)
33-
## TODO: find a way to port pretrained models to the new ResNet API
30+
# Tests for pretrained ResNets
31+
## TODO: find a way to port pretrained models to the new ResNet API
32+
# @testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152]
3433
# if (ResNet, sz) in PRETRAINED_MODELS
3534
# @test acctest(ResNet(sz, pretrain = true))
3635
# else
3736
# @test_throws ArgumentError ResNet(sz, pretrain = true)
3837
# end
39-
@test gradtest(m, x_256)
40-
GC.safepoint()
41-
GC.gc()
38+
# end
39+
40+
@testset "resnet" begin
41+
@testset for block_fn in [Metalhead.basicblock, Metalhead.bottleneck]
42+
layer_list = [
43+
[2, 2, 2, 2],
44+
[3, 4, 6, 3],
45+
[3, 4, 23, 3],
46+
[3, 8, 36, 3]
47+
]
48+
@testset for layers in layer_list
49+
drop_list = [
50+
(drop_rate = 0.1, drop_path_rate = 0.1, drop_block_rate = 0.1),
51+
(drop_rate = 0.5, drop_path_rate = 0.5, drop_block_rate = 0.5),
52+
(drop_rate = 0.9, drop_path_rate = 0.9, drop_block_rate = 0.9),
53+
]
54+
@testset for drop_rates in drop_list
55+
m = Metalhead.resnet(block_fn, layers; drop_rates)
56+
@test size(m(x_224)) == (1000, 1)
57+
@test gradtest(m, x_224)
58+
GC.safepoint()
59+
GC.gc()
60+
end
61+
end
62+
end
4263
end
4364
end
4465

@@ -47,16 +68,20 @@ GC.gc()
4768

4869
@testset "ResNeXt" begin
4970
@testset for depth in [50, 101, 152]
50-
m = ResNeXt(depth)
51-
@test size(m(x_224)) == (1000, 1)
52-
if ResNeXt in PRETRAINED_MODELS
53-
@test acctest(ResNeXt(depth, pretrain = true))
54-
else
55-
@test_throws ArgumentError ResNeXt(depth, pretrain = true)
71+
@testset for cardinality in [32, 64]
72+
@testset for base_width in [4, 8]
73+
m = ResNeXt(depth; cardinality, base_width)
74+
@test size(m(x_224)) == (1000, 1)
75+
if string("resnext", depth, "_", cardinality, "x", base_width) in PRETRAINED_MODELS
76+
@test acctest(ResNeXt(depth, pretrain = true))
77+
else
78+
@test_throws ArgumentError ResNeXt(depth, pretrain = true)
79+
end
80+
@test gradtest(m, x_224)
81+
GC.safepoint()
82+
GC.gc()
83+
end
5684
end
57-
@test gradtest(m, x_224)
58-
GC.safepoint()
59-
GC.gc()
6085
end
6186
end
6287

0 commit comments

Comments
 (0)