Skip to content

Commit 7449985

Browse files
authored
Merge pull request #190 from theabhirath/refine
Expose a uniform API at the highest level for models
2 parents 7e4f9db + 59e1ef4 commit 7449985

39 files changed

+678
-723
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ jobs:
3434
- '"Inception"'
3535
- '"DenseNet"'
3636
- '["ConvNeXt", "ConvMixer"]'
37-
- 'r"ViTs"'
3837
- 'r"Mixers"'
38+
- 'r"ViTs"'
3939
steps:
4040
- uses: actions/checkout@v2
4141
- uses: julia-actions/setup-julia@v1

Project.toml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,13 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2020
[compat]
2121
BSON = "0.3.2"
2222
Flux = "0.13"
23-
Functors = "0.2"
24-
MLUtils = "0.2.6"
25-
NNlib = "0.7.34, 0.8"
23+
Functors = "0.2, 0.3"
24+
CUDA = "3"
25+
ChainRulesCore = "1"
26+
PartialFunctions = "1"
27+
MLUtils = "0.2.10"
28+
NNlib = "0.8"
29+
NNlibCUDA = "0.2"
2630
julia = "1.6"
2731

2832
[publish]

src/Metalhead.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,12 @@ include("vit-based/vit.jl")
5656
include("pretrain.jl")
5757

5858
export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
59-
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt,
59+
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152,
60+
WideResNet, ResNeXt, SEResNet, SEResNeXt,
6061
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
6162
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
6263
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet,
63-
WideResNet, SEResNet, SEResNeXt,
64-
MLPMixer, ResMLP, gMLP,
65-
ViT,
66-
ConvMixer, ConvNeXt
64+
MLPMixer, ResMLP, gMLP, ViT, ConvMixer, ConvNeXt
6765

6866
# use Flux._big_show to pretty print large models
6967
for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, :SEResNet, :SEResNeXt,

src/convnets/alexnet.jl

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,59 @@
11
"""
2-
alexnet(; nclasses = 1000)
2+
alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
33
44
Create an AlexNet model
55
([reference](https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)).
66
77
# Arguments
88
9+
- `inchannels`: The number of input channels.
910
- `nclasses`: the number of output classes
1011
"""
11-
function alexnet(; nclasses = 1000)
12-
layers = Chain(Chain(Conv((11, 11), 3 => 64, relu; stride = (4, 4), pad = (2, 2)),
13-
MaxPool((3, 3); stride = (2, 2)),
14-
Conv((5, 5), 64 => 192, relu; pad = (2, 2)),
15-
MaxPool((3, 3); stride = (2, 2)),
16-
Conv((3, 3), 192 => 384, relu; pad = (1, 1)),
17-
Conv((3, 3), 384 => 256, relu; pad = (1, 1)),
18-
Conv((3, 3), 256 => 256, relu; pad = (1, 1)),
19-
MaxPool((3, 3); stride = (2, 2)),
20-
AdaptiveMeanPool((6, 6))),
21-
Chain(MLUtils.flatten,
22-
Dropout(0.5),
23-
Dense(256 * 6 * 6, 4096, relu),
24-
Dropout(0.5),
25-
Dense(4096, 4096, relu),
26-
Dense(4096, nclasses)))
27-
return layers
12+
function alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
13+
backbone = Chain(Conv((11, 11), inchannels => 64, relu; stride = 4, pad = 2),
14+
MaxPool((3, 3); stride = 2),
15+
Conv((5, 5), 64 => 192, relu; pad = 2),
16+
MaxPool((3, 3); stride = 2),
17+
Conv((3, 3), 192 => 384, relu; pad = 1),
18+
Conv((3, 3), 384 => 256, relu; pad = 1),
19+
Conv((3, 3), 256 => 256, relu; pad = 1),
20+
MaxPool((3, 3); stride = 2))
21+
classifier = Chain(AdaptiveMeanPool((6, 6)), MLUtils.flatten,
22+
Dropout(0.5),
23+
Dense(256 * 6 * 6, 4096, relu),
24+
Dropout(0.5),
25+
Dense(4096, 4096, relu),
26+
Dense(4096, nclasses))
27+
return Chain(backbone, classifier)
2828
end
2929

3030
"""
31-
AlexNet(; pretrain = false, nclasses = 1000)
31+
AlexNet(; pretrain::Bool = false, inchannels::Integer = 3,
32+
nclasses::Integer = 1000)
3233
3334
Create a `AlexNet`.
34-
See also [`alexnet`](#).
35-
36-
!!! warning
37-
38-
`AlexNet` does not currently support pretrained weights.
35+
([reference](https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)).
3936
4037
# Arguments
4138
4239
- `pretrain`: set to `true` to load pre-trained weights for ImageNet
40+
- `inchannels`: The number of input channels.
4341
- `nclasses`: the number of output classes
42+
43+
!!! warning
44+
45+
`AlexNet` does not currently support pretrained weights.
46+
47+
See also [`alexnet`](#).
4448
"""
4549
struct AlexNet
4650
layers::Any
4751
end
4852
@functor AlexNet
4953

50-
function AlexNet(; pretrain = false, nclasses = 1000)
51-
layers = alexnet(; nclasses = nclasses)
54+
function AlexNet(; pretrain::Bool = false, inchannels::Integer = 3,
55+
nclasses::Integer = 1000)
56+
layers = alexnet(; inchannels, nclasses)
5257
if pretrain
5358
loadpretrain!(layers, "AlexNet")
5459
end

src/convnets/convmixer.jl

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
2-
convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9), patch_size::Dims{2} = 7,
3-
activation = gelu, nclasses = 1000)
2+
convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9),
3+
patch_size::Dims{2} = (7, 7), activation = gelu,
4+
inchannels::Integer = 3, nclasses::Integer = 1000)
45
56
Creates a ConvMixer model.
67
([reference](https://arxiv.org/abs/2201.09792))
@@ -9,61 +10,56 @@ Creates a ConvMixer model.
910
1011
- `planes`: number of planes in the output of each block
1112
- `depth`: number of layers
12-
- `inchannels`: The number of channels in the input.
1313
- `kernel_size`: kernel size of the convolutional layers
1414
- `patch_size`: size of the patches
1515
- `activation`: activation function used after the convolutional layers
16+
- `inchannels`: The number of channels in the input.
1617
- `nclasses`: number of classes in the output
1718
"""
18-
function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9),
19-
patch_size::Dims{2} = (7, 7), activation = gelu, nclasses = 1000)
19+
function convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9),
20+
patch_size::Dims{2} = (7, 7), activation = gelu,
21+
inchannels::Integer = 3, nclasses::Integer = 1000)
2022
stem = conv_norm(patch_size, inchannels, planes, activation; preact = true,
2123
stride = patch_size[1])
2224
blocks = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation;
2325
preact = true, groups = planes,
2426
pad = SamePad())), +),
2527
conv_norm((1, 1), planes, planes, activation; preact = true)...)
2628
for _ in 1:depth]
27-
head = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(planes, nclasses))
28-
return Chain(Chain(stem..., Chain(blocks)), head)
29+
return Chain(Chain(stem..., Chain(blocks...)), create_classifier(planes, nclasses))
2930
end
3031

31-
const CONVMIXER_CONFIGS = Dict(:base => Dict(:planes => 1536, :depth => 20,
32-
:kernel_size => (9, 9),
33-
:patch_size => (7, 7)),
34-
:small => Dict(:planes => 768, :depth => 32,
35-
:kernel_size => (7, 7),
36-
:patch_size => (7, 7)),
37-
:large => Dict(:planes => 1024, :depth => 20,
38-
:kernel_size => (9, 9),
39-
:patch_size => (7, 7)))
32+
const CONVMIXER_CONFIGS = Dict(:base => ((1536, 20),
33+
(kernel_size = (9, 9),
34+
patch_size = (7, 7))),
35+
:small => ((768, 32),
36+
(kernel_size = (7, 7),
37+
patch_size = (7, 7))),
38+
:large => ((1024, 20),
39+
(kernel_size = (9, 9),
40+
patch_size = (7, 7))))
4041

4142
"""
42-
ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
43+
ConvMixer(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
4344
4445
Creates a ConvMixer model.
4546
([reference](https://arxiv.org/abs/2201.09792))
4647
4748
# Arguments
4849
49-
- `mode`: the mode of the model, either `:base`, `:small` or `:large`
50+
- `config`: the size of the model, either `:base`, `:small` or `:large`
5051
- `inchannels`: The number of channels in the input.
51-
- `activation`: activation function used after the convolutional layers
5252
- `nclasses`: number of classes in the output
5353
"""
5454
struct ConvMixer
5555
layers::Any
5656
end
5757
@functor ConvMixer
5858

59-
function ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
60-
_checkconfig(mode, keys(CONVMIXER_CONFIGS))
61-
planes = CONVMIXER_CONFIGS[mode][:planes]
62-
depth = CONVMIXER_CONFIGS[mode][:depth]
63-
kernel_size = CONVMIXER_CONFIGS[mode][:kernel_size]
64-
patch_size = CONVMIXER_CONFIGS[mode][:patch_size]
65-
layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, activation,
66-
nclasses)
59+
function ConvMixer(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
60+
_checkconfig(config, keys(CONVMIXER_CONFIGS))
61+
layers = convmixer(CONVMIXER_CONFIGS[config][1]...; CONVMIXER_CONFIGS[config][2]...,
62+
inchannels, nclasses)
6763
return ConvMixer(layers)
6864
end
6965

src/convnets/convnext.jl

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
convnextblock(planes, drop_path_rate = 0., λ = 1f-6)
2+
convnextblock(planes::Integer, drop_path_rate = 0.0, layerscale_init = 1.0f-6)
33
44
Creates a single block of ConvNeXt.
55
([reference](https://arxiv.org/abs/2201.03545))
@@ -8,61 +8,64 @@ Creates a single block of ConvNeXt.
88
99
- `planes`: number of input channels.
1010
- `drop_path_rate`: Stochastic depth rate.
11-
- `λ`: Initial value for [`LayerScale`](#)
11+
- `layerscale_init`: Initial value for [`LayerScale`](#)
1212
"""
13-
function convnextblock(planes, drop_path_rate = 0.0, λ = 1.0f-6)
13+
function convnextblock(planes::Integer, drop_path_rate = 0.0, layerscale_init = 1.0f-6)
1414
layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3),
1515
swapdims((3, 1, 2, 4)),
1616
LayerNorm(planes; ϵ = 1.0f-6),
1717
mlp_block(planes, 4 * planes),
18-
LayerScale(planes, λ),
18+
LayerScale(planes, layerscale_init),
1919
swapdims((2, 3, 1, 4)),
2020
DropPath(drop_path_rate)), +)
2121
return layers
2222
end
2323

2424
"""
25-
convnext(depths, planes; inchannels = 3, drop_path_rate = 0., λ = 1f-6, nclasses = 1000)
25+
convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:Integer};
26+
drop_path_rate = 0.0, layerscale_init = 1.0f-6, inchannels::Integer = 3,
27+
nclasses::Integer = 1000)
2628
2729
Creates the layers for a ConvNeXt model.
2830
([reference](https://arxiv.org/abs/2201.03545))
2931
3032
# Arguments
3133
32-
- `inchannels`: number of input channels.
3334
- `depths`: list with configuration for depth of each block
3435
- `planes`: list with configuration for number of output channels in each block
3536
- `drop_path_rate`: Stochastic depth rate.
36-
- `λ`: Initial value for [`LayerScale`](#)
37+
- `layerscale_init`: Initial value for [`LayerScale`](#)
3738
([reference](https://arxiv.org/abs/2103.17239))
39+
- `inchannels`: number of input channels.
3840
- `nclasses`: number of output classes
3941
"""
40-
function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6,
41-
nclasses = 1000)
42+
function convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:Integer};
43+
drop_path_rate = 0.0, layerscale_init = 1.0f-6, inchannels::Integer = 3,
44+
nclasses::Integer = 1000)
4245
@assert length(depths) == length(planes)
4346
"`planes` should have exactly one value for each block"
4447
downsample_layers = []
45-
stem = Chain(Conv((4, 4), inchannels => planes[1]; stride = 4),
46-
ChannelLayerNorm(planes[1]))
47-
push!(downsample_layers, stem)
48+
push!(downsample_layers,
49+
Chain(conv_norm((4, 4), inchannels => planes[1]; stride = 4,
50+
norm_layer = ChannelLayerNorm)...))
4851
for m in 1:(length(depths) - 1)
49-
downsample_layer = Chain(ChannelLayerNorm(planes[m]),
50-
Conv((2, 2), planes[m] => planes[m + 1]; stride = 2))
51-
push!(downsample_layers, downsample_layer)
52+
push!(downsample_layers,
53+
Chain(conv_norm((2, 2), planes[m] => planes[m + 1]; stride = 2,
54+
norm_layer = ChannelLayerNorm, revnorm = true)...))
5255
end
5356
stages = []
5457
dp_rates = linear_scheduler(drop_path_rate; depth = sum(depths))
5558
cur = 0
5659
for i in eachindex(depths)
57-
push!(stages, [convnextblock(planes[i], dp_rates[cur + j], λ) for j in 1:depths[i]])
60+
push!(stages,
61+
[convnextblock(planes[i], dp_rates[cur + j], layerscale_init)
62+
for j in 1:depths[i]])
5863
cur += depths[i]
5964
end
6065
backbone = collect(Iterators.flatten(Iterators.flatten(zip(downsample_layers, stages))))
61-
head = Chain(GlobalMeanPool(),
62-
MLUtils.flatten,
63-
LayerNorm(planes[end]),
64-
Dense(planes[end], nclasses))
65-
return Chain(Chain(backbone), head)
66+
classifier = Chain(GlobalMeanPool(), MLUtils.flatten,
67+
LayerNorm(planes[end]), Dense(planes[end], nclasses))
68+
return Chain(Chain(backbone...), classifier)
6669
end
6770

6871
# Configurations for ConvNeXt models
@@ -72,30 +75,28 @@ const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]),
7275
:large => ([3, 3, 27, 3], [192, 384, 768, 1536]),
7376
:xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048]))
7477

75-
struct ConvNeXt
76-
layers::Any
77-
end
78-
@functor ConvNeXt
79-
8078
"""
81-
ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0., λ = 1f-6, nclasses = 1000)
79+
ConvNeXt(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
8280
8381
Creates a ConvNeXt model.
8482
([reference](https://arxiv.org/abs/2201.03545))
8583
8684
# Arguments
8785
86+
- `config`: The size of the model, one of `tiny`, `small`, `base`, `large` or `xlarge`.
8887
- `inchannels`: The number of channels in the input.
89-
- `drop_path_rate`: Stochastic depth rate.
90-
- `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
9188
- `nclasses`: number of output classes
9289
9390
See also [`Metalhead.convnext`](#).
9491
"""
95-
function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6,
96-
nclasses = 1000)
97-
_checkconfig(mode, keys(CONVNEXT_CONFIGS))
98-
layers = convnext(CONVNEXT_CONFIGS[mode]...; inchannels, drop_path_rate, λ, nclasses)
92+
struct ConvNeXt
93+
layers::Any
94+
end
95+
@functor ConvNeXt
96+
97+
function ConvNeXt(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
98+
_checkconfig(config, keys(CONVNEXT_CONFIGS))
99+
layers = convnext(CONVNEXT_CONFIGS[config]...; inchannels, nclasses)
99100
return ConvNeXt(layers)
100101
end
101102

0 commit comments

Comments
 (0)