Skip to content

Commit fc74aa1

Browse files
theabhirathdarsnack
andcommitted
Cleanup - docs and code
Co-Authored-By: Kyle Daruwalla <[email protected]>
1 parent b143b95 commit fc74aa1

File tree

9 files changed

+93
-96
lines changed

9 files changed

+93
-96
lines changed

src/convnets/densenet.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ Create a Densenet bottleneck layer
1313
function dense_bottleneck(inplanes, outplanes)
1414
inner_channels = 4 * outplanes
1515
return SkipConnection(Chain(conv_norm((1, 1), inplanes, inner_channels; bias = false,
16-
prenorm = true)...,
16+
revnorm = true)...,
1717
conv_norm((3, 3), inner_channels, outplanes; pad = 1,
18-
bias = false, prenorm = true)...),
18+
bias = false, revnorm = true)...),
1919
cat_channels)
2020
end
2121

@@ -31,7 +31,7 @@ Create a DenseNet transition sequence
3131
- `outplanes`: number of output feature maps
3232
"""
3333
function transition(inplanes, outplanes)
34-
return Chain(conv_norm((1, 1), inplanes, outplanes; bias = false, prenorm = true)...,
34+
return Chain(conv_norm((1, 1), inplanes, outplanes; bias = false, revnorm = true)...,
3535
MeanPool((2, 2)))
3636
end
3737

src/convnets/inception/xception.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1,
3434
end
3535
push!(layers, relu)
3636
append!(layers,
37-
depthwise_sep_conv_bn((3, 3), inc, outc; pad = 1, bias = false,
37+
depthwise_sep_conv_norm((3, 3), inc, outc; pad = 1, bias = false,
3838
use_bn = (false, false)))
3939
push!(layers, BatchNorm(outc))
4040
end
@@ -63,8 +63,8 @@ function xception(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000)
6363
xception_block(256, 728, 2; stride = 2),
6464
[xception_block(728, 728, 3) for _ in 1:8]...,
6565
xception_block(728, 1024, 2; stride = 2, grow_at_start = false),
66-
depthwise_sep_conv_bn((3, 3), 1024, 1536; pad = 1)...,
67-
depthwise_sep_conv_bn((3, 3), 1536, 2048; pad = 1)...)
66+
depthwise_sep_conv_norm((3, 3), 1024, 1536; pad = 1)...,
67+
depthwise_sep_conv_norm((3, 3), 1536, 2048; pad = 1)...)
6868
head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate),
6969
Dense(2048, nclasses))
7070
return Chain(body, head)

src/convnets/mobilenet/mobilenetv1.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function mobilenetv1(width_mult, config;
3030
outch = Int(outch * width_mult)
3131
for _ in 1:nrepeats
3232
layer = dw ?
33-
depthwise_sep_conv_bn((3, 3), inchannels, outch, activation;
33+
depthwise_sep_conv_norm((3, 3), inchannels, outch, activation;
3434
stride = stride, pad = 1, bias = false) :
3535
conv_norm((3, 3), inchannels, outch, activation; stride, pad = 1,
3636
bias = false)

src/convnets/resnets/core.jl

Lines changed: 59 additions & 65 deletions
Large diffs are not rendered by default.

src/convnets/resnets/seresnet.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ end
2727
function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000)
2828
_checkconfig(depth, keys(resnet_configs))
2929
layers = resnet(resnet_configs[depth]...; inchannels, nclasses,
30-
attn_fn = planes -> squeeze_excite(planes))
30+
attn_fn = squeeze_excite)
3131
if pretrain
3232
loadpretrain!(layers, string("SEResNet", depth))
3333
end
@@ -70,7 +70,7 @@ function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_widt
7070
inchannels = 3, nclasses = 1000)
7171
_checkconfig(depth, sort(collect(keys(resnet_configs)))[3:end])
7272
layers = resnet(resnet_configs[depth]...; inchannels, nclasses, cardinality, base_width,
73-
attn_fn = planes -> squeeze_excite(planes))
73+
attn_fn = squeeze_excite)
7474
if pretrain
7575
loadpretrain!(layers, string("SEResNeXt", depth, "_", cardinality, "x", base_width))
7676
end

src/layers/Layers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ include("normalise.jl")
2626
export prenorm, ChannelLayerNorm
2727

2828
include("conv.jl")
29-
export conv_norm, depthwise_sep_conv_bn, invertedresidual
29+
export conv_norm, depthwise_sep_conv_norm, invertedresidual
3030

3131
include("drop.jl")
3232
export DropBlock, DropPath, droppath_rates

src/layers/conv.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
conv_norm(kernel_size, inplanes::Int, outplanes::Int, activation = relu;
3-
norm_layer = BatchNorm, prenorm = false, preact = false, use_bn = true,
3+
norm_layer = BatchNorm, revnorm = false, preact = false, use_bn = true,
44
stride = 1, pad = 0, dilation = 1, groups = 1, [bias, weight, init])
55
66
Create a convolution + batch normalization pair with activation.
@@ -12,44 +12,44 @@ Create a convolution + batch normalization pair with activation.
1212
- `outplanes`: number of output feature maps
1313
- `activation`: the activation function for the final layer
1414
- `norm_layer`: the normalization layer used
15-
- `prenorm`: set to `true` to place the batch norm before the convolution
15+
- `revnorm`: set to `true` to place the batch norm before the convolution
1616
- `preact`: set to `true` to place the activation function before the batch norm
17-
(only compatible with `prenorm = false`)
17+
(only compatible with `revnorm = false`)
1818
- `use_bn`: set to `false` to disable batch normalization
19-
(only compatible with `prenorm = false` and `preact = false`)
19+
(only compatible with `revnorm = false` and `preact = false`)
2020
- `stride`: stride of the convolution kernel
2121
- `pad`: padding of the convolution kernel
2222
- `dilation`: dilation of the convolution kernel
2323
- `groups`: groups for the convolution kernel
2424
- `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#))
2525
"""
2626
function conv_norm(kernel_size, inplanes::Int, outplanes::Int, activation = relu;
27-
norm_layer = BatchNorm, prenorm = false, preact = false, use_bn = true,
27+
norm_layer = BatchNorm, revnorm = false, preact = false, use_bn = true,
2828
kwargs...)
2929
if !use_bn
30-
if (preact || prenorm)
30+
if (preact || revnorm)
3131
throw(ArgumentError("`preact` only supported with `use_bn = true`"))
3232
else
3333
return [Conv(kernel_size, inplanes => outplanes, activation; kwargs...)]
3434
end
3535
end
36-
if prenorm
36+
if revnorm
3737
activations = (conv = activation, bn = identity)
3838
bnplanes = inplanes
3939
else
4040
activations = (conv = identity, bn = activation)
4141
bnplanes = outplanes
4242
end
4343
if preact
44-
if prenorm
45-
throw(ArgumentError("`preact` and `prenorm` cannot be set at the same time"))
44+
if revnorm
45+
throw(ArgumentError("`preact` and `revnorm` cannot be set at the same time"))
4646
else
4747
activations = (conv = activation, bn = identity)
4848
end
4949
end
5050
layers = [Conv(kernel_size, inplanes => outplanes, activations.conv; kwargs...),
5151
norm_layer(bnplanes, activations.bn)]
52-
return prenorm ? reverse(layers) : layers
52+
return revnorm ? reverse(layers) : layers
5353
end
5454

5555
function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, outplanes,
@@ -60,7 +60,7 @@ end
6060

6161
"""
6262
depthwise_sep_conv_bn(kernel_size, inplanes, outplanes, activation = relu;
63-
prenorm = false, use_bn = (true, true),
63+
revnorm = false, use_bn = (true, true),
6464
stride = 1, pad = 0, dilation = 1, [bias, weight, init])
6565
6666
Create a depthwise separable convolution chain as used in MobileNetv1.
@@ -79,20 +79,20 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1).
7979
- `inplanes`: number of input feature maps
8080
- `outplanes`: number of output feature maps
8181
- `activation`: the activation function for the final layer
82-
- `prenorm`: set to `true` to place the batch norm before the convolution
83-
- `use_bn`: a tuple of two booleans to specify whether to use batch normalization for the first and second convolution
82+
- `revnorm`: set to `true` to place the batch norm before the convolution
83+
- `use_bn`: a tuple of two booleans to specify whether to use normalization for the first and second convolution
8484
- `stride`: stride of the first convolution kernel
8585
- `pad`: padding of the first convolution kernel
8686
- `dilation`: dilation of the first convolution kernel
8787
- `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#))
8888
"""
89-
function depthwise_sep_conv_bn(kernel_size, inplanes, outplanes, activation = relu;
90-
prenorm = false, use_bn = (true, true),
91-
stride = 1, kwargs...)
89+
function depthwise_sep_conv_norm(kernel_size, inplanes, outplanes, activation = relu;
90+
norm_layer = BatchNorm, revnorm = false, use_norm = (true, true),
91+
stride = 1, kwargs...)
9292
return vcat(conv_norm(kernel_size, inplanes, inplanes, activation;
93-
prenorm, use_bn = use_bn[1], stride, groups = inplanes,
93+
norm_layerm, revnorm, use_bn = use_bn[1], stride, groups = inplanes,
9494
kwargs...),
95-
conv_norm((1, 1), inplanes, outplanes, activation; prenorm,
95+
conv_norm((1, 1), inplanes, outplanes, activation; norm_layer, revnorm,
9696
use_bn = use_bn[2]))
9797
end
9898

src/layers/pool.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ produce a single output. Note that this is equivalent to
1010
- `output_size`: The size of the output after pooling.
1111
- `connection`: The connection type to use.
1212
"""
13-
function AdaptiveMeanMaxPool(output_size = (1, 1); connection = +)
13+
function AdaptiveMeanMaxPool(connection, output_size = (1, 1))
1414
return Parallel(connection, AdaptiveMeanPool(output_size), AdaptiveMaxPool(output_size))
1515
end
16+
AdaptiveMeanMaxPool(output_size::Tuple = (1, 1)) = AdaptiveMeanMaxPool(+, output_size)

test/convnets.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@ end
2323
@testset "ResNet" begin
2424
# Tests for pretrained ResNets
2525
## TODO: find a way to port pretrained models to the new ResNet API
26-
# @testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152]
26+
@testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152]
27+
m = ResNet(sz)
28+
@test size(m(x_224)) == (1000, 1)
2729
# if (ResNet, sz) in PRETRAINED_MODELS
2830
# @test acctest(ResNet(sz, pretrain = true))
2931
# else
3032
# @test_throws ArgumentError ResNet(sz, pretrain = true)
3133
# end
32-
# end
34+
end
3335

3436
@testset "resnet" begin
3537
@testset for block_fn in [:basicblock, :bottleneck]

0 commit comments

Comments
 (0)