Skip to content

Commit 4ac9207

Browse files
theabhirathdarsnack
andcommitted
Hardcode large ResNet model dict for block configs
Also misc. cleanup Co-Authored-By: Kyle Daruwalla <[email protected]>
1 parent 61e55f9 commit 4ac9207

File tree

5 files changed

+20
-12
lines changed

5 files changed

+20
-12
lines changed

src/convnets/resnets/core.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ function downsample_identity(inplanes::Integer, outplanes::Integer; kwargs...)
116116
end
117117

118118
# Shortcut configurations for the ResNet models
119-
const SHORTCUT_DICT = Dict(:A => (downsample_identity, downsample_identity),
119+
const RESNET_SHORTCUTS = Dict(:A => (downsample_identity, downsample_identity),
120120
:B => (downsample_conv, downsample_identity),
121121
:C => (downsample_conv, downsample_conv),
122122
:D => (downsample_pool, downsample_identity))
@@ -343,7 +343,7 @@ function resnet(block_type::Symbol, block_repeats::AbstractVector{<:Integer},
343343
connection$activation, classifier_fn)
344344
end
345345
function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...)
346-
return resnet(block_fn, block_repeats, SHORTCUT_DICT[downsample_opt]; kwargs...)
346+
return resnet(block_fn, block_repeats, RESNET_SHORTCUTS[downsample_opt]; kwargs...)
347347
end
348348

349349
# block-layer configurations for ResNet-like models
@@ -352,3 +352,7 @@ const RESNET_CONFIGS = Dict(18 => (:basicblock, [2, 2, 2, 2]),
352352
50 => (:bottleneck, [3, 4, 6, 3]),
353353
101 => (:bottleneck, [3, 4, 23, 3]),
354354
152 => (:bottleneck, [3, 8, 36, 3]))
355+
356+
const LRESNET_CONFIGS = Dict(50 => (:bottleneck, [3, 4, 6, 3]),
357+
101 => (:bottleneck, [3, 4, 23, 3]),
358+
152 => (:bottleneck, [3, 8, 36, 3]))

src/convnets/resnets/res2net.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ function bottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1,
3333
for _ in 1:max(1, scale - 1)]
3434
reslayer = is_first ? Parallel(cat_channels, pool, conv_bns...) :
3535
Parallel(cat_channels, identity, Chain(PairwiseFusion(+, conv_bns...)))
36-
tuplify(x) = is_first ? tuple(x...) : tuple(x[1], tuple(x[2:end]...))
36+
if is_first
37+
tuplify(x) = tuple(x...)
38+
else
39+
tuplify(x) = tuple(x[1], tuple(x[2:end]...))
40+
end
3741
layers = [conv_norm((1, 1), inplanes => width * scale, activation;
3842
norm_layer, revnorm, bias = false)...,
3943
chunk$(; size = width, dims = 3), tuplify, reslayer,
@@ -138,8 +142,8 @@ end
138142
function Res2NeXt(depth::Integer; pretrain::Bool = false, scale::Integer = 4,
139143
base_width::Integer = 4, cardinality::Integer = 8,
140144
inchannels::Integer = 3, nclasses::Integer = 1000)
141-
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
142-
layers = resnet(:bottle2neck, RESNET_CONFIGS[depth][2]; base_width, scale,
145+
_checkconfig(depth, keys(LRESNET_CONFIGS))
146+
layers = resnet(:bottle2neck, LRESNET_CONFIGS[depth][2]; base_width, scale,
143147
cardinality, inchannels, nclasses)
144148
if pretrain
145149
loadpretrain!(layers,
@@ -152,4 +156,4 @@ end
152156
(m::Res2NeXt)(x) = m.layers(x)
153157

154158
backbone(m::Res2NeXt) = m.layers[1]
155-
classifier(m::Res2NeXt) = m.layers[2]
159+
classifier(m::Res2NeXt) = m.layers[2]

src/convnets/resnets/resnet.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ end
6565

6666
function WideResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3,
6767
nclasses::Integer = 1000)
68-
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
69-
layers = resnet(RESNET_CONFIGS[depth]...; base_width = 128, inchannels, nclasses)
68+
_checkconfig(depth, keys(LRESNET_CONFIGS))
69+
layers = resnet(LRESNET_CONFIGS[depth]...; base_width = 128, inchannels, nclasses)
7070
if pretrain
7171
loadpretrain!(layers, string("WideResNet", depth))
7272
end

src/convnets/resnets/resnext.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ end
2929

3030
function ResNeXt(depth::Integer; pretrain::Bool = false, cardinality = 32,
3131
base_width = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
32-
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
33-
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width)
32+
_checkconfig(depth, keys(LRESNET_CONFIGS))
33+
layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width)
3434
if pretrain
3535
loadpretrain!(layers, string("ResNeXt", depth, "_", cardinality, "x", base_width))
3636
end

src/convnets/resnets/seresnet.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ end
6969

7070
function SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality = 32, base_width = 4,
7171
inchannels::Integer = 3, nclasses::Integer = 1000)
72-
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
73-
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width,
72+
_checkconfig(depth, keys(LRESNET_CONFIGS))
73+
layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width,
7474
attn_fn = squeeze_excite)
7575
if pretrain
7676
loadpretrain!(layers, string("SEResNeXt", depth, "_", cardinality, "x", base_width))

0 commit comments

Comments
 (0)