Skip to content

Commit 73df024

Browse files
committed
Add WideResNet
1 parent 73131bf commit 73df024

File tree

6 files changed

+56
-4
lines changed

6 files changed

+56
-4
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ jobs:
2929
- '["AlexNet", "VGG"]'
3030
- '["GoogLeNet", "SqueezeNet", "MobileNet"]'
3131
- '["EfficientNet"]'
32-
- '[r"ResNet", r"ResNeXt"]'
32+
- 'r"/^ResNet\z/"'
33+
- '[r"ResNeXt", r"SEResNet"]'
3334
- '"Inception"'
3435
- '"DenseNet"'
3536
- '["ConvNeXt", "ConvMixer"]'

src/Metalhead.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
6060
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
6161
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
6262
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet,
63-
SEResNet, SEResNeXt,
63+
WideResNet, SEResNet, SEResNeXt,
6464
MLPMixer, ResMLP, gMLP,
6565
ViT,
6666
ConvMixer, ConvNeXt

src/convnets/resnets/resnet.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,43 @@ end
3535

3636
backbone(m::ResNet) = m.layers[1]
3737
classifier(m::ResNet) = m.layers[2]
38+
39+
"""
40+
WideResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000)
41+
42+
Creates a Wide ResNet model with the specified depth. The model is the same as ResNet
43+
except for the bottleneck number of channels which is twice larger in every block.
44+
The number of channels in outer 1x1 convolutions is the same.
45+
((reference)[https://arxiv.org/abs/1605.07146])
46+
47+
# Arguments
48+
49+
- `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the Wide ResNet model.
50+
- `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet
51+
- `inchannels`: The number of input channels.
52+
- `nclasses`: the number of output classes
53+
54+
!!! warning
55+
56+
`WideResNet` does not currently support pretrained weights.
57+
58+
Advanced users who want more configuration options will be better served by using [`resnet`](#).
59+
"""
60+
struct WideResNet
61+
layers::Any
62+
end
63+
@functor WideResNet
64+
65+
function WideResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000)
66+
_checkconfig(depth, [50, 101])
67+
layers = resnet(RESNET_CONFIGS[depth]...; base_width = 128, inchannels, nclasses)
68+
if pretrain
69+
loadpretrain!(layers, string("WideResNet", depth))
70+
end
71+
return WideResNet(layers)
72+
end
73+
74+
(m::WideResNet)(x) = m.layers(x)
75+
76+
backbone(m::WideResNet) = m.layers[1]
77+
classifier(m::WideResNet) = m.layers[2]

src/layers/conv.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ function conv_norm(kernel_size, inplanes::Int, outplanes::Int, activation = relu
5252
return revnorm ? reverse(layers) : layers
5353
end
5454

55-
function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, activation = identity; kwargs...)
55+
function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, activation = identity;
56+
kwargs...)
5657
inplanes, outplanes = ch
5758
return conv_norm(kernel_size, inplanes, outplanes, activation; kwargs...)
5859
end

src/vit-based/vit.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ function transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout_rat
1616
layers = [Chain(SkipConnection(prenorm(planes,
1717
MHAttention(planes, nheads;
1818
attn_dropout_rate = dropout_rate,
19-
proj_dropout_rate = dropout_rate)), +),
19+
proj_dropout_rate = dropout_rate)),
20+
+),
2021
SkipConnection(prenorm(planes,
2122
mlp_block(planes, floor(Int, mlp_ratio * planes);
2223
dropout_rate)), +))

test/convnets.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@ end
5656
end
5757
end
5858
end
59+
60+
@testset "WideResNet" begin
61+
@testset "WideResNet($sz)" for sz in [50, 101]
62+
m = WideResNet(sz)
63+
@test size(m(x_224)) == (1000, 1)
64+
@test gradtest(m, x_224)
65+
_gc()
66+
end
67+
end
5968
end
6069

6170
@testset "ResNeXt" begin

0 commit comments

Comments
 (0)