Skip to content

Commit da5321d

Browse files
committed
Make all config dicts const and capitalise
Also misc. formatting
1 parent 2aa3459 commit da5321d

20 files changed

+78
-76
lines changed

src/convnets/convmixer.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9),
2828
return Chain(Chain(stem..., Chain(blocks)), head)
2929
end
3030

31-
convmixer_configs = Dict(:base => Dict(:planes => 1536, :depth => 20,
32-
:kernel_size => (9, 9),
33-
:patch_size => (7, 7)),
34-
:small => Dict(:planes => 768, :depth => 32,
35-
:kernel_size => (7, 7),
36-
:patch_size => (7, 7)),
37-
:large => Dict(:planes => 1024, :depth => 20,
38-
:kernel_size => (9, 9),
39-
:patch_size => (7, 7)))
31+
const CONVMIXER_CONFIGS = Dict(:base => Dict(:planes => 1536, :depth => 20,
32+
:kernel_size => (9, 9),
33+
:patch_size => (7, 7)),
34+
:small => Dict(:planes => 768, :depth => 32,
35+
:kernel_size => (7, 7),
36+
:patch_size => (7, 7)),
37+
:large => Dict(:planes => 1024, :depth => 20,
38+
:kernel_size => (9, 9),
39+
:patch_size => (7, 7)))
4040

4141
"""
4242
ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
@@ -57,11 +57,11 @@ end
5757
@functor ConvMixer
5858

5959
function ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
60-
_checkconfig(mode, keys(convmixer_configs))
61-
planes = convmixer_configs[mode][:planes]
62-
depth = convmixer_configs[mode][:depth]
63-
kernel_size = convmixer_configs[mode][:kernel_size]
64-
patch_size = convmixer_configs[mode][:patch_size]
60+
_checkconfig(mode, keys(CONVMIXER_CONFIGS))
61+
planes = CONVMIXER_CONFIGS[mode][:planes]
62+
depth = CONVMIXER_CONFIGS[mode][:depth]
63+
kernel_size = CONVMIXER_CONFIGS[mode][:kernel_size]
64+
patch_size = CONVMIXER_CONFIGS[mode][:patch_size]
6565
layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, activation,
6666
nclasses)
6767
return ConvMixer(layers)

src/convnets/convnext.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,11 @@ function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0
6666
end
6767

6868
# Configurations for ConvNeXt models
69-
convnext_configs = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]),
70-
:small => ([3, 3, 27, 3], [96, 192, 384, 768]),
71-
:base => ([3, 3, 27, 3], [128, 256, 512, 1024]),
72-
:large => ([3, 3, 27, 3], [192, 384, 768, 1536]),
73-
:xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048]))
69+
const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]),
70+
:small => ([3, 3, 27, 3], [96, 192, 384, 768]),
71+
:base => ([3, 3, 27, 3], [128, 256, 512, 1024]),
72+
:large => ([3, 3, 27, 3], [192, 384, 768, 1536]),
73+
:xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048]))
7474

7575
struct ConvNeXt
7676
layers::Any
@@ -94,8 +94,8 @@ See also [`Metalhead.convnext`](#).
9494
"""
9595
function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6,
9696
nclasses = 1000)
97-
_checkconfig(mode, keys(convnext_configs))
98-
layers = convnext(convnext_configs[mode]...; inchannels, drop_path_rate, λ, nclasses)
97+
_checkconfig(mode, keys(CONVNEXT_CONFIGS))
98+
layers = convnext(CONVNEXT_CONFIGS[mode]...; inchannels, drop_path_rate, λ, nclasses)
9999
return ConvNeXt(layers)
100100
end
101101

src/convnets/densenet.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ end
140140
backbone(m::DenseNet) = m.layers[1]
141141
classifier(m::DenseNet) = m.layers[2]
142142

143-
const densenet_configs = Dict(121 => (6, 12, 24, 16),
143+
const DENSENET_CONFIGS = Dict(121 => (6, 12, 24, 16),
144144
161 => (6, 12, 36, 24),
145145
169 => (6, 12, 32, 32),
146146
201 => (6, 12, 48, 32))
@@ -160,8 +160,8 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet.
160160
See also [`Metalhead.densenet`](#).
161161
"""
162162
function DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000)
163-
_checkconfig(config, keys(densenet_configs))
164-
model = DenseNet(densenet_configs[config]; nclasses = nclasses)
163+
_checkconfig(config, keys(DENSENET_CONFIGS))
164+
model = DenseNet(DENSENET_CONFIGS[config]; nclasses = nclasses)
165165
if pretrain
166166
loadpretrain!(model, string("DenseNet", config))
167167
end

src/convnets/efficientnet.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ end
5959
# e: expantion ratio
6060
# i: block input channels
6161
# o: block output channels
62-
const efficientnet_block_configs = [
62+
const EFFICIENTNET_BLOCK_CONFIGS = [
6363
# (n, k, s, e, i, o)
6464
(1, 3, 1, 1, 32, 16),
6565
(2, 3, 2, 6, 16, 24),
@@ -73,7 +73,7 @@ const efficientnet_block_configs = [
7373
# w: width scaling
7474
# d: depth scaling
7575
# r: image resolution
76-
const efficientnet_global_configs = Dict(:b0 => (224, (1.0, 1.0)),
76+
const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)),
7777
:b1 => (240, (1.0, 1.1)),
7878
:b2 => (260, (1.1, 1.2)),
7979
:b3 => (300, (1.2, 1.4)),
@@ -137,8 +137,8 @@ See also [`efficientnet`](#).
137137
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
138138
"""
139139
function EfficientNet(name::Symbol; pretrain = false)
140-
_checkconfig(name, keys(efficientnet_global_configs))
141-
model = EfficientNet(efficientnet_global_configs[name][2], efficientnet_block_configs)
140+
_checkconfig(name, keys(EFFICIENTNET_GLOBAL_CONFIGS))
141+
model = EfficientNet(EFFICIENTNET_GLOBAL_CONFIGS[name][2], EFFICIENTNET_BLOCK_CONFIGS)
142142
pretrain && loadpretrain!(model, string("efficientnet-", name))
143143
return model
144144
end

src/convnets/inception/xception.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1,
3535
push!(layers, relu)
3636
append!(layers,
3737
depthwise_sep_conv_norm((3, 3), inc, outc; pad = 1, bias = false,
38-
use_bn = (false, false)))
38+
use_bn = (false, false)))
3939
push!(layers, BatchNorm(outc))
4040
end
4141
layers = start_with_relu ? layers : layers[2:end]

src/convnets/mobilenet/mobilenetv1.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function mobilenetv1(width_mult, config;
3131
for _ in 1:nrepeats
3232
layer = dw ?
3333
depthwise_sep_conv_norm((3, 3), inchannels, outch, activation;
34-
stride = stride, pad = 1, bias = false) :
34+
stride = stride, pad = 1, bias = false) :
3535
conv_norm((3, 3), inchannels, outch, activation; stride, pad = 1,
3636
bias = false)
3737
append!(layers, layer)
@@ -45,7 +45,7 @@ function mobilenetv1(width_mult, config;
4545
Dense(inchannels, nclasses)))
4646
end
4747

48-
const mobilenetv1_configs = [
48+
const MOBILENETV1_CONFIGS = [
4949
# dw, c, s, r
5050
(false, 32, 2, 1),
5151
(true, 64, 1, 1),
@@ -84,7 +84,7 @@ end
8484

8585
function MobileNetv1(width_mult::Number = 1; inchannels = 3, pretrain = false,
8686
nclasses = 1000)
87-
layers = mobilenetv1(width_mult, mobilenetv1_configs; inchannels, nclasses)
87+
layers = mobilenetv1(width_mult, MOBILENETV1_CONFIGS; inchannels, nclasses)
8888
if pretrain
8989
loadpretrain!(layers, string("MobileNetv1"))
9090
end

src/convnets/mobilenet/mobilenetv2.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, ncla
4646
end
4747

4848
# Layer configurations for MobileNetv2
49-
const mobilenetv2_configs = [
49+
const MOBILENETV2_CONFIGS = [
5050
# t, c, n, s, a
5151
(1, 16, 1, 1, relu6),
5252
(6, 24, 2, 2, relu6),
@@ -83,7 +83,7 @@ See also [`Metalhead.mobilenetv2`](#).
8383
"""
8484
function MobileNetv2(width_mult::Number = 1; inchannels = 3, pretrain = false,
8585
nclasses = 1000)
86-
layers = mobilenetv2(width_mult, mobilenetv2_configs; inchannels, nclasses)
86+
layers = mobilenetv2(width_mult, MOBILENETV2_CONFIGS; inchannels, nclasses)
8787
pretrain && loadpretrain!(layers, string("MobileNetv2"))
8888
if pretrain
8989
loadpretrain!(layers, string("MobileNetv2"))

src/convnets/mobilenet/mobilenetv3.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ function mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, ncla
5353
end
5454

5555
# Configurations for small and large mode for MobileNetv3
56-
mobilenetv3_configs = Dict(:small => [
56+
MOBILENETV3_CONFIGS = Dict(:small => [
5757
# k, t, c, SE, a, s
5858
(3, 1, 16, 4, relu, 2),
5959
(3, 4.5, 24, nothing, relu, 2),
@@ -115,7 +115,7 @@ function MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels =
115115
pretrain = false, nclasses = 1000)
116116
@assert mode in [:large, :small] "`mode` has to be either :large or :small"
117117
max_width = (mode == :large) ? 1280 : 1024
118-
layers = mobilenetv3(width_mult, mobilenetv3_configs[mode]; inchannels, max_width,
118+
layers = mobilenetv3(width_mult, MOBILENETV3_CONFIGS[mode]; inchannels, max_width,
119119
nclasses)
120120
if pretrain
121121
loadpretrain!(layers, string("MobileNetv3", mode))

src/convnets/resnets/core.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...)
324324
end
325325

326326
# block-layer configurations for ResNet-like models
327-
const resnet_configs = Dict(18 => (:basicblock, [2, 2, 2, 2]),
327+
const RESNET_CONFIGS = Dict(18 => (:basicblock, [2, 2, 2, 2]),
328328
34 => (:basicblock, [3, 4, 6, 3]),
329329
50 => (:bottleneck, [3, 4, 6, 3]),
330330
101 => (:bottleneck, [3, 4, 23, 3]),

src/convnets/resnets/resnet.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ end
2323
@functor ResNet
2424

2525
function ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000)
26-
_checkconfig(depth, keys(resnet_configs))
27-
layers = resnet(resnet_configs[depth]...; inchannels, nclasses)
26+
_checkconfig(depth, keys(RESNET_CONFIGS))
27+
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses)
2828
if pretrain
2929
loadpretrain!(layers, string("ResNet", depth))
3030
end

src/convnets/resnets/resnext.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ end
2929

3030
function ResNeXt(depth::Integer; pretrain = false, cardinality = 32,
3131
base_width = 4, inchannels = 3, nclasses = 1000)
32-
_checkconfig(depth, sort(collect(keys(resnet_configs)))[3:end])
33-
layers = resnet(resnet_configs[depth]...; inchannels, nclasses, cardinality, base_width)
32+
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
33+
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width)
3434
if pretrain
3535
loadpretrain!(layers, string("ResNeXt", depth, "_", cardinality, "x", base_width))
3636
end

src/convnets/resnets/seresnet.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ end
2525
(m::SEResNet)(x) = m.layers(x)
2626

2727
function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000)
28-
_checkconfig(depth, keys(resnet_configs))
29-
layers = resnet(resnet_configs[depth]...; inchannels, nclasses,
28+
_checkconfig(depth, keys(RESNET_CONFIGS))
29+
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses,
3030
attn_fn = squeeze_excite)
3131
if pretrain
3232
loadpretrain!(layers, string("SEResNet", depth))
@@ -68,8 +68,8 @@ end
6868

6969
function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4,
7070
inchannels = 3, nclasses = 1000)
71-
_checkconfig(depth, sort(collect(keys(resnet_configs)))[3:end])
72-
layers = resnet(resnet_configs[depth]...; inchannels, nclasses, cardinality, base_width,
71+
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
72+
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width,
7373
attn_fn = squeeze_excite)
7474
if pretrain
7575
loadpretrain!(layers, string("SEResNeXt", depth, "_", cardinality, "x", base_width))

src/convnets/vgg.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,12 @@ function vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dr
9999
return Chain(Chain(conv), class)
100100
end
101101

102-
const vgg_conv_configs = Dict(:A => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)],
102+
const VGG_CONV_CONFIGS = Dict(:A => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)],
103103
:B => [(64, 2), (128, 2), (256, 2), (512, 2), (512, 2)],
104104
:D => [(64, 2), (128, 2), (256, 3), (512, 3), (512, 3)],
105105
:E => [(64, 2), (128, 2), (256, 4), (512, 4), (512, 4)])
106106

107-
const vgg_configs = Dict(11 => :A,
107+
const VGG_CONFIGS = Dict(11 => :A,
108108
13 => :B,
109109
16 => :D,
110110
19 => :E)
@@ -153,8 +153,8 @@ See also [`VGG`](#).
153153
- `pretrain`: set to `true` to load pre-trained model weights for ImageNet
154154
"""
155155
function VGG(depth::Integer = 16; pretrain = false, batchnorm = false, nclasses = 1000)
156-
_checkconfig(depth, keys(vgg_configs))
157-
model = VGG((224, 224); config = vgg_conv_configs[vgg_configs[depth]],
156+
_checkconfig(depth, keys(VGG_CONFIGS))
157+
model = VGG((224, 224); config = VGG_CONV_CONFIGS[VGG_CONFIGS[depth]],
158158
inchannels = 3,
159159
batchnorm = batchnorm,
160160
nclasses = nclasses,

src/layers/conv.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,12 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1).
8787
- `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#))
8888
"""
8989
function depthwise_sep_conv_norm(kernel_size, inplanes, outplanes, activation = relu;
90-
norm_layer = BatchNorm, revnorm = false, use_norm = (true, true),
90+
norm_layer = BatchNorm, revnorm = false,
91+
use_norm = (true, true),
9192
stride = 1, kwargs...)
9293
return vcat(conv_norm(kernel_size, inplanes, inplanes, activation;
93-
norm_layerm, revnorm, use_bn = use_bn[1], stride, groups = inplanes,
94+
norm_layerm, revnorm, use_bn = use_bn[1], stride,
95+
groups = inplanes,
9496
kwargs...),
9597
conv_norm((1, 1), inplanes, outplanes, activation; norm_layer, revnorm,
9698
use_bn = use_bn[2]))

src/mixers/core.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ function mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3,
3737
end
3838

3939
# Configurations for MLPMixer models
40-
mixer_configs = Dict(:small => Dict(:depth => 8, :planes => 512),
41-
:base => Dict(:depth => 12, :planes => 768),
42-
:large => Dict(:depth => 24, :planes => 1024),
43-
:huge => Dict(:depth => 32, :planes => 1280))
40+
const MIXER_CONFIGS = Dict(:small => Dict(:depth => 8, :planes => 512),
41+
:base => Dict(:depth => 12, :planes => 768),
42+
:large => Dict(:depth => 24, :planes => 1024),
43+
:huge => Dict(:depth => 32, :planes => 1280))

src/mixers/gmlp.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ See also [`Metalhead.mlpmixer`](#).
9696
"""
9797
function gMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16),
9898
imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000)
99-
_checkconfig(size, keys(mixer_configs))
100-
depth = mixer_configs[size][:depth]
101-
embedplanes = mixer_configs[size][:planes]
99+
_checkconfig(size, keys(MIXER_CONFIGS))
100+
depth = MIXER_CONFIGS[size][:depth]
101+
embedplanes = MIXER_CONFIGS[size][:planes]
102102
layers = mlpmixer(spatial_gating_block, imsize; mlp_layer = gated_mlp_block,
103103
patch_size, embedplanes, drop_path_rate, depth, nclasses)
104104
return gMLP(layers)

src/mixers/mlpmixer.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ See also [`Metalhead.mlpmixer`](#).
5555
"""
5656
function MLPMixer(size::Symbol = :base; patch_size::Dims{2} = (16, 16),
5757
imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000)
58-
_checkconfig(size, keys(mixer_configs))
59-
depth = mixer_configs[size][:depth]
60-
embedplanes = mixer_configs[size][:planes]
58+
_checkconfig(size, keys(MIXER_CONFIGS))
59+
depth = MIXER_CONFIGS[size][:depth]
60+
embedplanes = MIXER_CONFIGS[size][:planes]
6161
layers = mlpmixer(mixerblock, imsize; patch_size, embedplanes, depth, drop_path_rate,
6262
nclasses)
6363
return MLPMixer(layers)

src/mixers/resmlp.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ See also [`Metalhead.mlpmixer`](#).
5858
"""
5959
function ResMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16),
6060
imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000)
61-
_checkconfig(size, keys(mixer_configs))
62-
depth = mixer_configs[size][:depth]
63-
embedplanes = mixer_configs[size][:planes]
61+
_checkconfig(size, keys(MIXER_CONFIGS))
62+
depth = MIXER_CONFIGS[size][:depth]
63+
embedplanes = MIXER_CONFIGS[size][:planes]
6464
layers = mlpmixer(resmixerblock, imsize; mlp_ratio = 4.0, patch_size, embedplanes,
6565
drop_path_rate, depth, nclasses)
6666
return ResMLP(layers)

src/vit-based/vit.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,15 @@ function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} =
6262
Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast)))
6363
end
6464

65-
vit_configs = Dict(:tiny => (depth = 12, embedplanes = 192, nheads = 3),
66-
:small => (depth = 12, embedplanes = 384, nheads = 6),
67-
:base => (depth = 12, embedplanes = 768, nheads = 12),
68-
:large => (depth = 24, embedplanes = 1024, nheads = 16),
69-
:huge => (depth = 32, embedplanes = 1280, nheads = 16),
70-
:giant => (depth = 40, embedplanes = 1408, nheads = 16,
71-
mlp_ratio = 48 // 11),
72-
:gigantic => (depth = 48, embedplanes = 1664, nheads = 16,
73-
mlp_ratio = 64 // 13))
65+
const VIT_CONFIGS = Dict(:tiny => (depth = 12, embedplanes = 192, nheads = 3),
66+
:small => (depth = 12, embedplanes = 384, nheads = 6),
67+
:base => (depth = 12, embedplanes = 768, nheads = 12),
68+
:large => (depth = 24, embedplanes = 1024, nheads = 16),
69+
:huge => (depth = 32, embedplanes = 1280, nheads = 16),
70+
:giant => (depth = 40, embedplanes = 1408, nheads = 16,
71+
mlp_ratio = 48 // 11),
72+
:gigantic => (depth = 48, embedplanes = 1664, nheads = 16,
73+
mlp_ratio = 64 // 13))
7474

7575
"""
7676
ViT(mode::Symbol = base; imsize::Dims{2} = (256, 256), inchannels = 3,
@@ -98,8 +98,8 @@ end
9898

9999
function ViT(mode::Symbol = :base; imsize::Dims{2} = (256, 256), inchannels = 3,
100100
patch_size::Dims{2} = (16, 16), pool = :class, nclasses = 1000)
101-
_checkconfig(mode, keys(vit_configs))
102-
kwargs = vit_configs[mode]
101+
_checkconfig(mode, keys(VIT_CONFIGS))
102+
kwargs = VIT_CONFIGS[mode]
103103
layers = vit(imsize; inchannels, patch_size, nclasses, pool, kwargs...)
104104
return ViT(layers)
105105
end

test/convnets.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ end
111111
@testset "EfficientNet" begin
112112
@testset "EfficientNet($name)" for name in [:b0, :b1, :b2, :b3, :b4, :b5] #:b6, :b7, :b8]
113113
# preferred image resolution scaling
114-
r = Metalhead.efficientnet_global_configs[name][1]
114+
r = Metalhead.EFFICIENTNET_GLOBAL_CONFIGS[name][1]
115115
x = rand(Float32, r, r, 3, 1)
116116
m = EfficientNet(name)
117117
@test size(m(x)) == (1000, 1)

0 commit comments

Comments
 (0)