Skip to content

Commit 99eb25a

Browse files
committed
Make all config dicts const and capitalise
Also misc. formatting and cleanup
1 parent fc74aa1 commit 99eb25a

20 files changed

+95
-99
lines changed

src/convnets/convmixer.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9),
2828
return Chain(Chain(stem..., Chain(blocks)), head)
2929
end
3030

31-
convmixer_configs = Dict(:base => Dict(:planes => 1536, :depth => 20,
32-
:kernel_size => (9, 9),
33-
:patch_size => (7, 7)),
34-
:small => Dict(:planes => 768, :depth => 32,
35-
:kernel_size => (7, 7),
36-
:patch_size => (7, 7)),
37-
:large => Dict(:planes => 1024, :depth => 20,
38-
:kernel_size => (9, 9),
39-
:patch_size => (7, 7)))
31+
const CONVMIXER_CONFIGS = Dict(:base => Dict(:planes => 1536, :depth => 20,
32+
:kernel_size => (9, 9),
33+
:patch_size => (7, 7)),
34+
:small => Dict(:planes => 768, :depth => 32,
35+
:kernel_size => (7, 7),
36+
:patch_size => (7, 7)),
37+
:large => Dict(:planes => 1024, :depth => 20,
38+
:kernel_size => (9, 9),
39+
:patch_size => (7, 7)))
4040

4141
"""
4242
ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
@@ -57,11 +57,11 @@ end
5757
@functor ConvMixer
5858

5959
function ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
60-
_checkconfig(mode, keys(convmixer_configs))
61-
planes = convmixer_configs[mode][:planes]
62-
depth = convmixer_configs[mode][:depth]
63-
kernel_size = convmixer_configs[mode][:kernel_size]
64-
patch_size = convmixer_configs[mode][:patch_size]
60+
_checkconfig(mode, keys(CONVMIXER_CONFIGS))
61+
planes = CONVMIXER_CONFIGS[mode][:planes]
62+
depth = CONVMIXER_CONFIGS[mode][:depth]
63+
kernel_size = CONVMIXER_CONFIGS[mode][:kernel_size]
64+
patch_size = CONVMIXER_CONFIGS[mode][:patch_size]
6565
layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, activation,
6666
nclasses)
6767
return ConvMixer(layers)

src/convnets/convnext.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,11 @@ function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0
6666
end
6767

6868
# Configurations for ConvNeXt models
69-
convnext_configs = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]),
70-
:small => ([3, 3, 27, 3], [96, 192, 384, 768]),
71-
:base => ([3, 3, 27, 3], [128, 256, 512, 1024]),
72-
:large => ([3, 3, 27, 3], [192, 384, 768, 1536]),
73-
:xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048]))
69+
const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]),
70+
:small => ([3, 3, 27, 3], [96, 192, 384, 768]),
71+
:base => ([3, 3, 27, 3], [128, 256, 512, 1024]),
72+
:large => ([3, 3, 27, 3], [192, 384, 768, 1536]),
73+
:xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048]))
7474

7575
struct ConvNeXt
7676
layers::Any
@@ -94,8 +94,8 @@ See also [`Metalhead.convnext`](#).
9494
"""
9595
function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6,
9696
nclasses = 1000)
97-
_checkconfig(mode, keys(convnext_configs))
98-
layers = convnext(convnext_configs[mode]...; inchannels, drop_path_rate, λ, nclasses)
97+
_checkconfig(mode, keys(CONVNEXT_CONFIGS))
98+
layers = convnext(CONVNEXT_CONFIGS[mode]...; inchannels, drop_path_rate, λ, nclasses)
9999
return ConvNeXt(layers)
100100
end
101101

src/convnets/densenet.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ end
140140
backbone(m::DenseNet) = m.layers[1]
141141
classifier(m::DenseNet) = m.layers[2]
142142

143-
const densenet_configs = Dict(121 => (6, 12, 24, 16),
143+
const DENSENET_CONFIGS = Dict(121 => (6, 12, 24, 16),
144144
161 => (6, 12, 36, 24),
145145
169 => (6, 12, 32, 32),
146146
201 => (6, 12, 48, 32))
@@ -160,8 +160,8 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet.
160160
See also [`Metalhead.densenet`](#).
161161
"""
162162
function DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000)
163-
_checkconfig(config, keys(densenet_configs))
164-
model = DenseNet(densenet_configs[config]; nclasses = nclasses)
163+
_checkconfig(config, keys(DENSENET_CONFIGS))
164+
model = DenseNet(DENSENET_CONFIGS[config]; nclasses = nclasses)
165165
if pretrain
166166
loadpretrain!(model, string("DenseNet", config))
167167
end

src/convnets/efficientnet.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ end
5959
# e: expantion ratio
6060
# i: block input channels
6161
# o: block output channels
62-
const efficientnet_block_configs = [
62+
const EFFICIENTNET_BLOCK_CONFIGS = [
6363
# (n, k, s, e, i, o)
6464
(1, 3, 1, 1, 32, 16),
6565
(2, 3, 2, 6, 16, 24),
@@ -73,7 +73,7 @@ const efficientnet_block_configs = [
7373
# w: width scaling
7474
# d: depth scaling
7575
# r: image resolution
76-
const efficientnet_global_configs = Dict(:b0 => (224, (1.0, 1.0)),
76+
const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)),
7777
:b1 => (240, (1.0, 1.1)),
7878
:b2 => (260, (1.1, 1.2)),
7979
:b3 => (300, (1.2, 1.4)),
@@ -137,8 +137,8 @@ See also [`efficientnet`](#).
137137
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
138138
"""
139139
function EfficientNet(name::Symbol; pretrain = false)
140-
_checkconfig(name, keys(efficientnet_global_configs))
141-
model = EfficientNet(efficientnet_global_configs[name][2], efficientnet_block_configs)
140+
_checkconfig(name, keys(EFFICIENTNET_GLOBAL_CONFIGS))
141+
model = EfficientNet(EFFICIENTNET_GLOBAL_CONFIGS[name][2], EFFICIENTNET_BLOCK_CONFIGS)
142142
pretrain && loadpretrain!(model, string("efficientnet-", name))
143143
return model
144144
end

src/convnets/inception/xception.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1,
3535
push!(layers, relu)
3636
append!(layers,
3737
depthwise_sep_conv_norm((3, 3), inc, outc; pad = 1, bias = false,
38-
use_bn = (false, false)))
38+
use_norm = (false, false)))
3939
push!(layers, BatchNorm(outc))
4040
end
4141
layers = start_with_relu ? layers : layers[2:end]

src/convnets/mobilenet/mobilenetv1.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function mobilenetv1(width_mult, config;
3131
for _ in 1:nrepeats
3232
layer = dw ?
3333
depthwise_sep_conv_norm((3, 3), inchannels, outch, activation;
34-
stride = stride, pad = 1, bias = false) :
34+
stride = stride, pad = 1, bias = false) :
3535
conv_norm((3, 3), inchannels, outch, activation; stride, pad = 1,
3636
bias = false)
3737
append!(layers, layer)
@@ -45,7 +45,7 @@ function mobilenetv1(width_mult, config;
4545
Dense(inchannels, nclasses)))
4646
end
4747

48-
const mobilenetv1_configs = [
48+
const MOBILENETV1_CONFIGS = [
4949
# dw, c, s, r
5050
(false, 32, 2, 1),
5151
(true, 64, 1, 1),
@@ -84,7 +84,7 @@ end
8484

8585
function MobileNetv1(width_mult::Number = 1; inchannels = 3, pretrain = false,
8686
nclasses = 1000)
87-
layers = mobilenetv1(width_mult, mobilenetv1_configs; inchannels, nclasses)
87+
layers = mobilenetv1(width_mult, MOBILENETV1_CONFIGS; inchannels, nclasses)
8888
if pretrain
8989
loadpretrain!(layers, string("MobileNetv1"))
9090
end

src/convnets/mobilenet/mobilenetv2.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, ncla
4646
end
4747

4848
# Layer configurations for MobileNetv2
49-
const mobilenetv2_configs = [
49+
const MOBILENETV2_CONFIGS = [
5050
# t, c, n, s, a
5151
(1, 16, 1, 1, relu6),
5252
(6, 24, 2, 2, relu6),
@@ -83,7 +83,7 @@ See also [`Metalhead.mobilenetv2`](#).
8383
"""
8484
function MobileNetv2(width_mult::Number = 1; inchannels = 3, pretrain = false,
8585
nclasses = 1000)
86-
layers = mobilenetv2(width_mult, mobilenetv2_configs; inchannels, nclasses)
86+
layers = mobilenetv2(width_mult, MOBILENETV2_CONFIGS; inchannels, nclasses)
8787
pretrain && loadpretrain!(layers, string("MobileNetv2"))
8888
if pretrain
8989
loadpretrain!(layers, string("MobileNetv2"))

src/convnets/mobilenet/mobilenetv3.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ function mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, ncla
5353
end
5454

5555
# Configurations for small and large mode for MobileNetv3
56-
mobilenetv3_configs = Dict(:small => [
56+
MOBILENETV3_CONFIGS = Dict(:small => [
5757
# k, t, c, SE, a, s
5858
(3, 1, 16, 4, relu, 2),
5959
(3, 4.5, 24, nothing, relu, 2),
@@ -115,7 +115,7 @@ function MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels =
115115
pretrain = false, nclasses = 1000)
116116
@assert mode in [:large, :small] "`mode` has to be either :large or :small"
117117
max_width = (mode == :large) ? 1280 : 1024
118-
layers = mobilenetv3(width_mult, mobilenetv3_configs[mode]; inchannels, max_width,
118+
layers = mobilenetv3(width_mult, MOBILENETV3_CONFIGS[mode]; inchannels, max_width,
119119
nclasses)
120120
if pretrain
121121
loadpretrain!(layers, string("MobileNetv3", mode))

src/convnets/resnets/core.jl

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1,
2525
drop_block = identity, drop_path = identity,
2626
attn_fn = planes -> identity)
2727
first_planes = planes ÷ reduction_factor
28-
outplanes = planes * expansion_factor(basicblock)
28+
outplanes = planes
2929
conv_bn1 = conv_norm((3, 3), inplanes => first_planes, identity; norm_layer, revnorm,
3030
stride, pad = 1, bias = false)
3131
conv_bn2 = conv_norm((3, 3), first_planes => outplanes, identity; norm_layer, revnorm,
@@ -67,7 +67,7 @@ function bottleneck(inplanes::Integer, planes::Integer; stride::Integer,
6767
attn_fn = planes -> identity)
6868
width = floor(Int, planes * (base_width / 64)) * cardinality
6969
first_planes = width ÷ reduction_factor
70-
outplanes = planes * expansion_factor(bottleneck)
70+
outplanes = planes * 4
7171
conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, revnorm,
7272
bias = false)
7373
conv_bn2 = conv_norm((3, 3), first_planes => width, identity; norm_layer, revnorm,
@@ -215,7 +215,7 @@ function basicblock_builder(block_repeats::Vector{<:Integer}; inplanes::Integer
215215
drop_block = DropBlock(blockschedule[schedule_idx])
216216
block = basicblock(inplanes, planes; stride, reduction_factor, activation,
217217
norm_layer, revnorm, attn_fn, drop_path, drop_block)
218-
downsample = downsample_fn(inplanes, planes * expansion; stride)
218+
downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, revnorm)
219219
# inplanes increases by expansion after each block
220220
inplanes = planes * expansion
221221
return block, downsample
@@ -248,18 +248,14 @@ function bottleneck_builder(block_repeats::Vector{<:Integer}; inplanes::Integer
248248
block = bottleneck(inplanes, planes; stride, cardinality, base_width,
249249
reduction_factor, activation, norm_layer, revnorm,
250250
attn_fn, drop_path, drop_block)
251-
downsample = downsample_fn(inplanes, planes * expansion; stride)
251+
downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, revnorm)
252252
# inplanes increases by expansion after each block
253253
inplanes = planes * expansion
254254
return block, downsample
255255
end
256256
return get_layers
257257
end
258258

259-
# Makes the main stages of the ResNet model. This is an internal function and should not be
260-
# used by end-users. `block_fn` is a function that returns a single block of the ResNet.
261-
# See `basicblock` and `bottleneck` for examples. A block must define a function
262-
# `expansion(::typeof(block))` that returns the expansion factor of the block.
263259
function resnet_stages(get_layers, block_repeats::Vector{<:Integer}, connection)
264260
# Construct each stage
265261
stages = []
@@ -316,15 +312,15 @@ function resnet(block_type::Symbol, block_repeats::Vector{<:Integer};
316312
end
317313
classifier_fn = nfeatures -> create_classifier(nfeatures, nclasses; dropout_rate,
318314
pool_layer, use_conv)
319-
return resnet((imsize..., inchannels), stem, connection$activation, get_layers,
320-
block_repeats, classifier_fn)
315+
return resnet((imsize..., inchannels), stem, get_layers, block_repeats,
316+
connection$activation, classifier_fn)
321317
end
322318
function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...)
323319
return resnet(block_fn, block_repeats, shortcut_dict[downsample_opt]; kwargs...)
324320
end
325321

326322
# block-layer configurations for ResNet-like models
327-
const resnet_configs = Dict(18 => (:basicblock, [2, 2, 2, 2]),
323+
const RESNET_CONFIGS = Dict(18 => (:basicblock, [2, 2, 2, 2]),
328324
34 => (:basicblock, [3, 4, 6, 3]),
329325
50 => (:bottleneck, [3, 4, 6, 3]),
330326
101 => (:bottleneck, [3, 4, 23, 3]),

src/convnets/resnets/resnet.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ end
2323
@functor ResNet
2424

2525
function ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000)
26-
_checkconfig(depth, keys(resnet_configs))
27-
layers = resnet(resnet_configs[depth]...; inchannels, nclasses)
26+
_checkconfig(depth, keys(RESNET_CONFIGS))
27+
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses)
2828
if pretrain
2929
loadpretrain!(layers, string("ResNet", depth))
3030
end

src/convnets/resnets/resnext.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ end
2929

3030
function ResNeXt(depth::Integer; pretrain = false, cardinality = 32,
3131
base_width = 4, inchannels = 3, nclasses = 1000)
32-
_checkconfig(depth, sort(collect(keys(resnet_configs)))[3:end])
33-
layers = resnet(resnet_configs[depth]...; inchannels, nclasses, cardinality, base_width)
32+
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
33+
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width)
3434
if pretrain
3535
loadpretrain!(layers, string("ResNeXt", depth, "_", cardinality, "x", base_width))
3636
end

src/convnets/resnets/seresnet.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ end
2525
(m::SEResNet)(x) = m.layers(x)
2626

2727
function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000)
28-
_checkconfig(depth, keys(resnet_configs))
29-
layers = resnet(resnet_configs[depth]...; inchannels, nclasses,
28+
_checkconfig(depth, keys(RESNET_CONFIGS))
29+
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses,
3030
attn_fn = squeeze_excite)
3131
if pretrain
3232
loadpretrain!(layers, string("SEResNet", depth))
@@ -68,8 +68,8 @@ end
6868

6969
function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4,
7070
inchannels = 3, nclasses = 1000)
71-
_checkconfig(depth, sort(collect(keys(resnet_configs)))[3:end])
72-
layers = resnet(resnet_configs[depth]...; inchannels, nclasses, cardinality, base_width,
71+
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
72+
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width,
7373
attn_fn = squeeze_excite)
7474
if pretrain
7575
loadpretrain!(layers, string("SEResNeXt", depth, "_", cardinality, "x", base_width))

src/convnets/vgg.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,12 @@ function vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dr
9999
return Chain(Chain(conv), class)
100100
end
101101

102-
const vgg_conv_configs = Dict(:A => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)],
102+
const VGG_CONV_CONFIGS = Dict(:A => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)],
103103
:B => [(64, 2), (128, 2), (256, 2), (512, 2), (512, 2)],
104104
:D => [(64, 2), (128, 2), (256, 3), (512, 3), (512, 3)],
105105
:E => [(64, 2), (128, 2), (256, 4), (512, 4), (512, 4)])
106106

107-
const vgg_configs = Dict(11 => :A,
107+
const VGG_CONFIGS = Dict(11 => :A,
108108
13 => :B,
109109
16 => :D,
110110
19 => :E)
@@ -153,8 +153,8 @@ See also [`VGG`](#).
153153
- `pretrain`: set to `true` to load pre-trained model weights for ImageNet
154154
"""
155155
function VGG(depth::Integer = 16; pretrain = false, batchnorm = false, nclasses = 1000)
156-
_checkconfig(depth, keys(vgg_configs))
157-
model = VGG((224, 224); config = vgg_conv_configs[vgg_configs[depth]],
156+
_checkconfig(depth, keys(VGG_CONFIGS))
157+
model = VGG((224, 224); config = VGG_CONV_CONFIGS[VGG_CONFIGS[depth]],
158158
inchannels = 3,
159159
batchnorm = batchnorm,
160160
nclasses = nclasses,

0 commit comments

Comments
 (0)