Skip to content

Commit 7e4f9db

Browse files
authored
Merge pull request #174 from theabhirath/resnet-plus
2 parents 2b1fbd1 + 72cd4a9 commit 7e4f9db

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+2539
-2201
lines changed

.github/workflows/CI.yml

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,15 @@ jobs:
2727
- x64
2828
suite:
2929
- '["AlexNet", "VGG"]'
30-
- '["GoogLeNet", "SqueezeNet"]'
31-
- '["EfficientNet", "MobileNet"]'
32-
- '[r"/*/ResNet*", "ResNeXt"]'
33-
- 'r"/*/Inception/Inceptionv*"'
34-
- '["InceptionResNetv2", "Xception"]'
30+
- '["GoogLeNet", "SqueezeNet", "MobileNet"]'
31+
- '["EfficientNet"]'
32+
- 'r"/*/ResNet*"'
33+
- '[r"ResNeXt", r"SEResNet"]'
34+
- '"Inception"'
3535
- '"DenseNet"'
36-
- '"ConvNeXt"'
37-
- '"ConvMixer"'
38-
- '"ViT"'
39-
- '"Other"'
36+
- '["ConvNeXt", "ConvMixer"]'
37+
- 'r"ViTs"'
38+
- 'r"Mixers"'
4039
steps:
4140
- uses: actions/checkout@v2
4241
- uses: julia-actions/setup-julia@v1

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@ version = "0.8.0-DEV"
55
[deps]
66
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
77
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
8+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
9+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
810
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
911
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1012
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
1113
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1214
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
15+
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
16+
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
1317
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1418
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1519

docs/make.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Pkg
22

3-
Pkg.develop(path = "..")
3+
Pkg.develop(; path = "..")
44

55
using Publish
66
using Artifacts, LazyArtifacts
@@ -13,5 +13,5 @@ p = Publish.Project(Metalhead)
1313

1414
function build_and_deploy(label)
1515
rm(label; recursive = true, force = true)
16-
deploy(Metalhead; root = "/Metalhead.jl", label = label)
16+
return deploy(Metalhead; root = "/Metalhead.jl", label = label)
1717
end

docs/serve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Pkg
22

3-
Pkg.develop(path = "..")
3+
Pkg.develop(; path = "..")
44

55
using Revise
66
using Publish

src/Metalhead.jl

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using BSON
77
using Artifacts, LazyArtifacts
88
using Statistics
99
using MLUtils
10+
using PartialFunctions
1011
using Random
1112

1213
import Functors
@@ -20,38 +21,54 @@ using .Layers
2021
# CNN models
2122
include("convnets/alexnet.jl")
2223
include("convnets/vgg.jl")
23-
include("convnets/inception.jl")
24-
include("convnets/googlenet.jl")
25-
include("convnets/resnet.jl")
26-
include("convnets/resnext.jl")
24+
## ResNets
25+
include("convnets/resnets/core.jl")
26+
include("convnets/resnets/resnet.jl")
27+
include("convnets/resnets/resnext.jl")
28+
include("convnets/resnets/seresnet.jl")
29+
## Inceptions
30+
include("convnets/inception/googlenet.jl")
31+
include("convnets/inception/inceptionv3.jl")
32+
include("convnets/inception/inceptionv4.jl")
33+
include("convnets/inception/inceptionresnetv2.jl")
34+
include("convnets/inception/xception.jl")
35+
## MobileNets
36+
include("convnets/mobilenet/mobilenetv1.jl")
37+
include("convnets/mobilenet/mobilenetv2.jl")
38+
include("convnets/mobilenet/mobilenetv3.jl")
39+
## Others
2740
include("convnets/densenet.jl")
2841
include("convnets/squeezenet.jl")
29-
include("convnets/mobilenet.jl")
3042
include("convnets/efficientnet.jl")
3143
include("convnets/convnext.jl")
3244
include("convnets/convmixer.jl")
3345

34-
# Other models
35-
include("other/mlpmixer.jl")
46+
# Mixers
47+
include("mixers/core.jl")
48+
include("mixers/mlpmixer.jl")
49+
include("mixers/resmlp.jl")
50+
include("mixers/gmlp.jl")
3651

37-
# ViT-based models
52+
# ViTs
3853
include("vit-based/vit.jl")
3954

55+
# Load pretrained weights
4056
include("pretrain.jl")
4157

4258
export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
4359
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt,
4460
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
4561
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
4662
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet,
63+
WideResNet, SEResNet, SEResNeXt,
4764
MLPMixer, ResMLP, gMLP,
4865
ViT,
4966
ConvMixer, ConvNeXt
5067

5168
# use Flux._big_show to pretty print large models
52-
for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet,
69+
for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, :SEResNet, :SEResNeXt,
5370
:GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception,
54-
:SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3,
71+
:SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet,
5572
:MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt)
5673
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
5774
end

src/convnets/alexnet.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ function alexnet(; nclasses = 1000)
2424
Dropout(0.5),
2525
Dense(4096, 4096, relu),
2626
Dense(4096, nclasses)))
27-
2827
return layers
2928
end
3029

@@ -46,15 +45,16 @@ See also [`alexnet`](#).
4645
struct AlexNet
4746
layers::Any
4847
end
48+
@functor AlexNet
4949

5050
function AlexNet(; pretrain = false, nclasses = 1000)
5151
layers = alexnet(; nclasses = nclasses)
52-
pretrain && loadpretrain!(layers, "AlexNet")
52+
if pretrain
53+
loadpretrain!(layers, "AlexNet")
54+
end
5355
return AlexNet(layers)
5456
end
5557

56-
@functor AlexNet
57-
5858
(m::AlexNet)(x) = m.layers(x)
5959

6060
backbone(m::AlexNet) = m.layers[1]

src/convnets/convmixer.jl

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,34 @@ Creates a ConvMixer model.
99
1010
- `planes`: number of planes in the output of each block
1111
- `depth`: number of layers
12-
- `inchannels`: The number of channels in the input. The default value is 3.
12+
- `inchannels`: The number of channels in the input.
1313
- `kernel_size`: kernel size of the convolutional layers
1414
- `patch_size`: size of the patches
1515
- `activation`: activation function used after the convolutional layers
1616
- `nclasses`: number of classes in the output
1717
"""
1818
function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9),
1919
patch_size::Dims{2} = (7, 7), activation = gelu, nclasses = 1000)
20-
stem = conv_bn(patch_size, inchannels, planes, activation; preact = true,
21-
stride = patch_size[1])
22-
blocks = [Chain(SkipConnection(Chain(conv_bn(kernel_size, planes, planes, activation;
23-
preact = true, groups = planes,
24-
pad = SamePad())), +),
25-
conv_bn((1, 1), planes, planes, activation; preact = true)...)
20+
stem = conv_norm(patch_size, inchannels, planes, activation; preact = true,
21+
stride = patch_size[1])
22+
blocks = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation;
23+
preact = true, groups = planes,
24+
pad = SamePad())), +),
25+
conv_norm((1, 1), planes, planes, activation; preact = true)...)
2626
for _ in 1:depth]
2727
head = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(planes, nclasses))
2828
return Chain(Chain(stem..., Chain(blocks)), head)
2929
end
3030

31-
convmixer_config = Dict(:base => Dict(:planes => 1536, :depth => 20, :kernel_size => (9, 9),
32-
:patch_size => (7, 7)),
33-
:small => Dict(:planes => 768, :depth => 32, :kernel_size => (7, 7),
34-
:patch_size => (7, 7)),
35-
:large => Dict(:planes => 1024, :depth => 20,
36-
:kernel_size => (9, 9),
37-
:patch_size => (7, 7)))
31+
const CONVMIXER_CONFIGS = Dict(:base => Dict(:planes => 1536, :depth => 20,
32+
:kernel_size => (9, 9),
33+
:patch_size => (7, 7)),
34+
:small => Dict(:planes => 768, :depth => 32,
35+
:kernel_size => (7, 7),
36+
:patch_size => (7, 7)),
37+
:large => Dict(:planes => 1024, :depth => 20,
38+
:kernel_size => (9, 9),
39+
:patch_size => (7, 7)))
3840

3941
"""
4042
ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
@@ -45,26 +47,26 @@ Creates a ConvMixer model.
4547
# Arguments
4648
4749
- `mode`: the mode of the model, either `:base`, `:small` or `:large`
48-
- `inchannels`: The number of channels in the input. The default value is 3.
50+
- `inchannels`: The number of channels in the input.
4951
- `activation`: activation function used after the convolutional layers
5052
- `nclasses`: number of classes in the output
5153
"""
5254
struct ConvMixer
5355
layers::Any
5456
end
57+
@functor ConvMixer
5558

5659
function ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
57-
planes = convmixer_config[mode][:planes]
58-
depth = convmixer_config[mode][:depth]
59-
kernel_size = convmixer_config[mode][:kernel_size]
60-
patch_size = convmixer_config[mode][:patch_size]
60+
_checkconfig(mode, keys(CONVMIXER_CONFIGS))
61+
planes = CONVMIXER_CONFIGS[mode][:planes]
62+
depth = CONVMIXER_CONFIGS[mode][:depth]
63+
kernel_size = CONVMIXER_CONFIGS[mode][:kernel_size]
64+
patch_size = CONVMIXER_CONFIGS[mode][:patch_size]
6165
layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, activation,
6266
nclasses)
6367
return ConvMixer(layers)
6468
end
6569

66-
@functor ConvMixer
67-
6870
(m::ConvMixer)(x) = m.layers(x)
6971

7072
backbone(m::ConvMixer) = m.layers[1]

src/convnets/convnext.jl

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Creates a single block of ConvNeXt.
55
([reference](https://arxiv.org/abs/2201.03545))
66
7-
# Arguments:
7+
# Arguments
88
99
- `planes`: number of input channels.
1010
- `drop_path_rate`: Stochastic depth rate.
@@ -27,7 +27,7 @@ end
2727
Creates the layers for a ConvNeXt model.
2828
([reference](https://arxiv.org/abs/2201.03545))
2929
30-
# Arguments:
30+
# Arguments
3131
3232
- `inchannels`: number of input channels.
3333
- `depths`: list with configuration for depth of each block
@@ -39,60 +39,53 @@ Creates the layers for a ConvNeXt model.
3939
"""
4040
function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6,
4141
nclasses = 1000)
42-
@assert length(depths)==length(planes) "`planes` should have exactly one value for each block"
43-
42+
@assert length(depths) == length(planes)
43+
"`planes` should have exactly one value for each block"
4444
downsample_layers = []
4545
stem = Chain(Conv((4, 4), inchannels => planes[1]; stride = 4),
46-
ChannelLayerNorm(planes[1]; ϵ = 1.0f-6))
46+
ChannelLayerNorm(planes[1]))
4747
push!(downsample_layers, stem)
4848
for m in 1:(length(depths) - 1)
49-
downsample_layer = Chain(ChannelLayerNorm(planes[m]; ϵ = 1.0f-6),
49+
downsample_layer = Chain(ChannelLayerNorm(planes[m]),
5050
Conv((2, 2), planes[m] => planes[m + 1]; stride = 2))
5151
push!(downsample_layers, downsample_layer)
5252
end
53-
5453
stages = []
55-
dp_rates = LinRange{Float32}(0.0, drop_path_rate, sum(depths))
54+
dp_rates = linear_scheduler(drop_path_rate; depth = sum(depths))
5655
cur = 0
57-
for i in 1:length(depths)
56+
for i in eachindex(depths)
5857
push!(stages, [convnextblock(planes[i], dp_rates[cur + j], λ) for j in 1:depths[i]])
5958
cur += depths[i]
6059
end
61-
6260
backbone = collect(Iterators.flatten(Iterators.flatten(zip(downsample_layers, stages))))
6361
head = Chain(GlobalMeanPool(),
6462
MLUtils.flatten,
6563
LayerNorm(planes[end]),
6664
Dense(planes[end], nclasses))
67-
6865
return Chain(Chain(backbone), head)
6966
end
7067

7168
# Configurations for ConvNeXt models
72-
convnext_configs = Dict(:tiny => Dict(:depths => [3, 3, 9, 3],
73-
:planes => [96, 192, 384, 768]),
74-
:small => Dict(:depths => [3, 3, 27, 3],
75-
:planes => [96, 192, 384, 768]),
76-
:base => Dict(:depths => [3, 3, 27, 3],
77-
:planes => [128, 256, 512, 1024]),
78-
:large => Dict(:depths => [3, 3, 27, 3],
79-
:planes => [192, 384, 768, 1536]),
80-
:xlarge => Dict(:depths => [3, 3, 27, 3],
81-
:planes => [256, 512, 1024, 2048]))
69+
const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]),
70+
:small => ([3, 3, 27, 3], [96, 192, 384, 768]),
71+
:base => ([3, 3, 27, 3], [128, 256, 512, 1024]),
72+
:large => ([3, 3, 27, 3], [192, 384, 768, 1536]),
73+
:xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048]))
8274

8375
struct ConvNeXt
8476
layers::Any
8577
end
78+
@functor ConvNeXt
8679

8780
"""
8881
ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0., λ = 1f-6, nclasses = 1000)
8982
9083
Creates a ConvNeXt model.
9184
([reference](https://arxiv.org/abs/2201.03545))
9285
93-
# Arguments:
86+
# Arguments
9487
95-
- `inchannels`: The number of channels in the input. The default value is 3.
88+
- `inchannels`: The number of channels in the input.
9689
- `drop_path_rate`: Stochastic depth rate.
9790
- `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
9891
- `nclasses`: number of output classes
@@ -101,16 +94,12 @@ See also [`Metalhead.convnext`](#).
10194
"""
10295
function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6,
10396
nclasses = 1000)
104-
@assert mode in keys(convnext_configs) "`size` must be one of $(collect(keys(convnext_configs)))"
105-
depths = convnext_configs[mode][:depths]
106-
planes = convnext_configs[mode][:planes]
107-
layers = convnext(depths, planes; inchannels, drop_path_rate, λ, nclasses)
97+
_checkconfig(mode, keys(CONVNEXT_CONFIGS))
98+
layers = convnext(CONVNEXT_CONFIGS[mode]...; inchannels, drop_path_rate, λ, nclasses)
10899
return ConvNeXt(layers)
109100
end
110101

111102
(m::ConvNeXt)(x) = m.layers(x)
112103

113-
@functor ConvNeXt
114-
115104
backbone(m::ConvNeXt) = m.layers[1]
116105
classifier(m::ConvNeXt) = m.layers[2]

0 commit comments

Comments
 (0)