Skip to content

Commit b1e6b42

Browse files
committed
Fifth refactor is a charm
Also, we aren't using the skips anymore
1 parent aa2a9ef commit b1e6b42

File tree

3 files changed

+114
-145
lines changed

3 files changed

+114
-145
lines changed

src/convnets/resnets/core.jl

Lines changed: 113 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
"""
2-
basicblock(inplanes, planes; stride = 1, downsample = identity,
3-
reduction_factor = 1, dilation = 1, first_dilation = dilation,
4-
activation = relu, connection = addact\$activation,
5-
norm_layer = BatchNorm, drop_block = identity, drop_path = identity,
6-
attn_fn = planes -> identity)
2+
basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activation = relu,
3+
norm_layer = BatchNorm, prenorm = false,
4+
drop_block = identity, drop_path = identity,
5+
attn_fn = planes -> identity)
76
87
Creates a basic ResNet block.
98
@@ -12,24 +11,19 @@ Creates a basic ResNet block.
1211
- `inplanes`: number of input feature maps
1312
- `planes`: number of feature maps for the block
1413
- `stride`: the stride of the block
15-
- `downsample`: the downsampling function to use
1614
- `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first
1715
convolution.
18-
- `dilation`: the dilation of the second convolution.
19-
- `first_dilation`: the dilation of the first convolution.
2016
- `activation`: the activation function to use.
21-
- `connection`: the function applied to the output of residual and skip paths in
22-
a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses
23-
PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`.
2417
- `norm_layer`: the normalization layer to use.
2518
- `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks`
2619
function and passed in.
2720
- `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks`
2821
function and passed in.
2922
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
3023
"""
31-
function basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activation = relu,
32-
norm_layer = BatchNorm, prenorm = false,
24+
function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1,
25+
reduction_factor::Integer = 1, activation = relu,
26+
norm_layer = BatchNorm, prenorm::Bool = false,
3327
drop_block = identity, drop_path = identity,
3428
attn_fn = planes -> identity)
3529
first_planes = planes ÷ reduction_factor
@@ -45,11 +39,11 @@ end
4539
expansion_factor(::typeof(basicblock)) = 1
4640

4741
"""
48-
bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1,
49-
base_width = 64, reduction_factor = 1, first_dilation = 1,
50-
activation = relu, connection = addact\$activation,
51-
norm_layer = BatchNorm, drop_block = identity, drop_path = identity,
52-
attn_fn = planes -> identity)
42+
bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = 64,
43+
reduction_factor = 1, activation = relu,
44+
norm_layer = BatchNorm, prenorm = false,
45+
drop_block = identity, drop_path = identity,
46+
attn_fn = planes -> identity)
5347
5448
Creates a bottleneck ResNet block.
5549
@@ -58,26 +52,22 @@ Creates a bottleneck ResNet block.
5852
- `inplanes`: number of input feature maps
5953
- `planes`: number of feature maps for the block
6054
- `stride`: the stride of the block
61-
- `downsample`: the downsampling function to use
6255
- `cardinality`: the number of groups in the convolution.
6356
- `base_width`: the number of output feature maps for each convolutional group.
6457
- `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first
6558
convolution.
66-
- `first_dilation`: the dilation of the 3x3 convolution.
6759
- `activation`: the activation function to use.
68-
- `connection`: the function applied to the output of residual and skip paths in
69-
a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses
70-
PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`.
7160
- `norm_layer`: the normalization layer to use.
7261
- `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks`
7362
function and passed in.
7463
- `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks`
7564
function and passed in.
7665
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
7766
"""
78-
function bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = 64,
79-
reduction_factor = 1, activation = relu,
80-
norm_layer = BatchNorm, prenorm = false,
67+
function bottleneck(inplanes::Integer, planes::Integer; stride::Integer,
68+
cardinality::Integer = 1, base_width::Integer = 64,
69+
reduction_factor::Integer = 1, activation = relu,
70+
norm_layer = BatchNorm, prenorm::Bool = false,
8171
drop_block = identity, drop_path = identity,
8272
attn_fn = planes -> identity)
8373
width = floor(Int, planes * (base_width / 64)) * cardinality
@@ -113,6 +103,7 @@ end
113103

114104
# Downsample layer which is an identity projection. Uses max pooling
115105
# when the output size is more than the input size.
106+
# TODO - figure out how to make this work when outplanes < inplanes
116107
function downsample_identity(inplanes::Integer, outplanes::Integer; kwargs...)
117108
if outplanes > inplanes
118109
return Chain(MaxPool((1, 1); stride = 2),
@@ -174,8 +165,8 @@ on how to use this function.
174165
- `activation`: The activation function used in the stem.
175166
"""
176167
function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3,
177-
replace_pool::Bool = false, norm_layer = BatchNorm, prenorm = false,
178-
activation = relu)
168+
replace_pool::Bool = false, activation = relu,
169+
norm_layer = BatchNorm, prenorm::Bool = false)
179170
@assert stem_type in [:default, :deep, :deep_tiered]
180171
"Stem type must be one of [:default, :deep, :deep_tiered]"
181172
# Main stem
@@ -203,65 +194,70 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3,
203194
prenorm,
204195
stride = 2, pad = 1, bias = false)...) :
205196
MaxPool((3, 3); stride = 2, pad = 1)
206-
return Chain(conv1, bn1, stempool), inplanes
207-
end
208-
209-
# Templating builders for the blocks and the downsampling layers
210-
function template_builder(block_fn; kwargs...)
211-
function (inplanes, planes; _kwargs...)
212-
return block_fn(inplanes, planes; kwargs..., _kwargs...)
213-
end
214-
end
215-
216-
function template_builder(::typeof(basicblock); reduction_factor::Integer = 1,
217-
activation = relu, norm_layer = BatchNorm, prenorm::Bool = false,
218-
attn_fn = planes -> identity, kargs...)
219-
return (args...; kwargs...) -> basicblock(args...; kwargs..., reduction_factor,
220-
activation, norm_layer, prenorm, attn_fn)
197+
return Chain(conv1, bn1, stempool)
221198
end
222199

223-
function template_builder(::typeof(bottleneck); cardinality::Integer = 1,
224-
base_width::Integer = 64,
225-
reduction_factor::Integer = 1, activation = relu,
226-
norm_layer = BatchNorm, prenorm::Bool = false,
227-
attn_fn = planes -> identity, kargs...)
228-
return (args...; kwargs...) -> bottleneck(args...; kwargs..., cardinality, base_width,
229-
reduction_factor, activation,
230-
norm_layer, prenorm, attn_fn)
231-
end
200+
resnet_planes(stage_idx::Integer) = 64 * 2^(stage_idx - 1)
232201

233-
function template_builder(downsample_fn::Union{typeof(downsample_conv),
234-
typeof(downsample_pool),
235-
typeof(downsample_identity)};
236-
norm_layer = BatchNorm, prenorm = false)
237-
return (args...; kwargs...) -> downsample_fn(args...; kwargs..., norm_layer, prenorm)
202+
function basicblock_builder(block_repeats::Vector{<:Integer}; inplanes::Integer = 64,
203+
reduction_factor::Integer = 1, expansion::Integer = 1,
204+
norm_layer = BatchNorm, prenorm::Bool = false,
205+
activation = relu, attn_fn = planes -> identity,
206+
drop_block_rate = 0.0, drop_path_rate = 0.0,
207+
stride_fn = get_stride, planes_fn = resnet_planes,
208+
downsample_tuple = (downsample_conv, downsample_identity))
209+
pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats))
210+
blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats))
211+
# closure over `idxs`
212+
function get_layers(stage_idx::Integer, block_idx::Integer)
213+
planes = planes_fn(stage_idx)
214+
# `get_stride` is a callback that the user can tweak to change the stride of the
215+
# blocks. It defaults to the standard behaviour as in the paper
216+
stride = stride_fn(stage_idx, block_idx)
217+
downsample_fn = (stride != 1 || inplanes != planes * expansion) ?
218+
downsample_tuple[1] : downsample_tuple[2]
219+
# DropBlock, DropPath both take in rates based on a linear scaling schedule
220+
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
221+
drop_path = DropPath(pathschedule[schedule_idx])
222+
drop_block = DropBlock(blockschedule[schedule_idx])
223+
block = basicblock(inplanes, planes; stride, reduction_factor, activation,
224+
norm_layer, prenorm, attn_fn, drop_path, drop_block)
225+
downsample = downsample_fn(inplanes, planes * expansion; stride)
226+
# inplanes increases by expansion after each block
227+
inplanes = planes * expansion
228+
return block, downsample
229+
end
230+
return get_layers
238231
end
239232

240-
resnet_planes(stage_idx::Integer) = 64 * 2^(stage_idx - 1)
241-
242-
function configure_resnet_block(block_template, expansion, block_repeats::Vector{<:Integer};
243-
stride_fn = get_stride, plane_fn = resnet_planes,
244-
downsample_templates::NTuple{2, Any},
245-
inplanes::Integer = 64,
246-
drop_path_rate = 0.0, drop_block_rate = 0.0, kwargs...)
233+
function bottleneck_builder(block_repeats::Vector{<:Integer}; inplanes::Integer = 64,
234+
cardinality::Integer = 1, base_width::Integer = 64,
235+
reduction_factor::Integer = 1, expansion::Integer = 4,
236+
norm_layer = BatchNorm, prenorm::Bool = false,
237+
activation = relu, attn_fn = planes -> identity,
238+
drop_block_rate = 0.0, drop_path_rate = 0.0,
239+
stride_fn = get_stride, planes_fn = resnet_planes,
240+
downsample_tuple = (downsample_conv, downsample_identity))
247241
pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats))
248242
blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats))
249243
# closure over `idxs`
250244
function get_layers(stage_idx::Integer, block_idx::Integer)
251-
planes = plane_fn(stage_idx)
245+
planes = planes_fn(stage_idx)
252246
# `get_stride` is a callback that the user can tweak to change the stride of the
253247
# blocks. It defaults to the standard behaviour as in the paper
254248
stride = stride_fn(stage_idx, block_idx)
255-
downsample_template = (stride != 1 || inplanes != planes * expansion) ?
256-
downsample_templates[1] : downsample_templates[2]
249+
downsample_fn = (stride != 1 || inplanes != planes * expansion) ?
250+
downsample_tuple[1] : downsample_tuple[2]
257251
# DropBlock, DropPath both take in rates based on a linear scaling schedule
258252
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
259253
drop_path = DropPath(pathschedule[schedule_idx])
260254
drop_block = DropBlock(blockschedule[schedule_idx])
261-
block = block_template(inplanes, planes; stride, drop_path, drop_block)
262-
downsample = downsample_template(inplanes, planes * expansion; stride)
255+
block = bottleneck(inplanes, planes; stride, cardinality, base_width,
256+
reduction_factor, activation, norm_layer, prenorm,
257+
attn_fn, drop_path, drop_block)
258+
downsample = downsample_fn(inplanes, planes * expansion; stride)
263259
# inplanes increases by expansion after each block
264-
inplanes = (planes * expansion)
260+
inplanes = planes * expansion
265261
return block, downsample
266262
end
267263
return get_layers
@@ -283,41 +279,59 @@ function resnet_stages(get_layers, block_repeats::Vector{<:Integer}, connection)
283279
return Chain(stages...)
284280
end
285281

286-
function resnet(connection, get_layers, block_repeats::Vector{<:Integer}, stem, classifier)
287-
stage_blocks = resnet_stages(get_layers, block_repeats, connection)
288-
return Chain(Chain(stem, stage_blocks), classifier)
282+
function resnet(block_type::Symbol, block_repeats::Vector{<:Integer};
283+
downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity),
284+
cardinality::Integer = 1, base_width::Integer = 64, inplanes::Integer = 64,
285+
reduction_factor::Integer = 1, imsize::Dims{2} = (256, 256),
286+
inchannels::Integer = 3, stem_fn = resnet_stem,
287+
connection = addact, activation = relu, norm_layer = BatchNorm,
288+
prenorm::Bool = false, attn_fn = planes -> identity,
289+
pool_layer = AdaptiveMeanPool((1, 1)), use_conv::Bool = false,
290+
drop_block_rate = 0.0, drop_path_rate = 0.0, dropout_rate = 0.0,
291+
nclasses::Integer = 1000)
292+
# Build stem
293+
stem = stem_fn(; inchannels)
294+
# Block builder
295+
if block_type == :basicblock
296+
@assert cardinality==1 "Cardinality must be 1 for `basicblock`"
297+
@assert base_width==64 "Base width must be 64 for `basicblock`"
298+
get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor,
299+
activation, norm_layer, prenorm, attn_fn,
300+
drop_block_rate, drop_path_rate,
301+
stride_fn = get_stride, planes_fn = resnet_planes,
302+
downsample_tuple = downsample_opt)
303+
elseif block_type == :bottleneck
304+
get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width,
305+
reduction_factor, activation, norm_layer,
306+
prenorm, attn_fn, drop_block_rate, drop_path_rate,
307+
stride_fn = get_stride, planes_fn = resnet_planes,
308+
downsample_tuple = downsample_opt)
309+
else
310+
throw(ArgumentError("Unknown block type $block_type"))
311+
end
312+
classifier_fn = nfeatures -> create_classifier(nfeatures, nclasses; dropout_rate,
313+
pool_layer, use_conv)
314+
return resnet((imsize..., inchannels), stem, connection$activation, get_layers,
315+
block_repeats, classifier_fn)
316+
end
317+
function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...)
318+
return resnet(block_fn, block_repeats, shortcut_dict[downsample_opt]; kwargs...)
289319
end
290320

291-
function resnet(block_fn, block_repeats::Vector{<:Integer},
292-
downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity);
293-
imsize::Dims{2} = (256, 256), inchannels::Integer = 3,
294-
stem = first(resnet_stem(; inchannels)), inplanes::Integer = 64,
295-
connection = addact, activation = relu,
296-
pool_layer = AdaptiveMeanPool((1, 1)), use_conv::Bool = false,
297-
dropout_rate = 0.0, nclasses::Integer = 1000, kwargs...)
298-
# Configure downsample templates
299-
downsample_templates = map(template_builder, downsample_opt)
300-
# Configure block templates
301-
block_template = template_builder(block_fn; kwargs...)
302-
get_layers = configure_resnet_block(block_template, expansion_factor(block_fn),
303-
block_repeats; inplanes, downsample_templates,
304-
kwargs...)
321+
function resnet(img_dims, stem, connection, get_layers, block_repeats::Vector{<:Integer},
322+
classifier_fn)
305323
# Build stages of the ResNet
306-
stage_blocks = resnet_stages(get_layers, block_repeats, connection$activation)
324+
stage_blocks = resnet_stages(get_layers, block_repeats, connection)
307325
backbone = Chain(stem, stage_blocks)
308326
# Build the classifier head
309-
nfeaturemaps = Flux.outputsize(backbone, (imsize..., inchannels); padbatch = true)[3]
310-
classifier = create_classifier(nfeaturemaps, nclasses; dropout_rate, pool_layer,
311-
use_conv)
327+
nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3]
328+
classifier = classifier_fn(nfeaturemaps)
312329
return Chain(backbone, classifier)
313330
end
314-
function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...)
315-
return resnet(block_fn, block_repeats, shortcut_dict[downsample_opt], kwargs...)
316-
end
317331

318332
# block-layer configurations for ResNet-like models
319-
const resnet_configs = Dict(18 => (basicblock, [2, 2, 2, 2]),
320-
34 => (basicblock, [3, 4, 6, 3]),
321-
50 => (bottleneck, [3, 4, 6, 3]),
322-
101 => (bottleneck, [3, 4, 23, 3]),
323-
152 => (bottleneck, [3, 8, 36, 3]))
333+
const resnet_configs = Dict(18 => (:basicblock, [2, 2, 2, 2]),
334+
34 => (:basicblock, [3, 4, 6, 3]),
335+
50 => (:bottleneck, [3, 4, 6, 3]),
336+
101 => (:bottleneck, [3, 4, 23, 3]),
337+
152 => (:bottleneck, [3, 8, 36, 3]))

src/layers/Layers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ include("normalise.jl")
2626
export prenorm, ChannelLayerNorm
2727

2828
include("conv.jl")
29-
export conv_norm, depthwise_sep_conv_bn, invertedresidual, skip_identity, skip_projection
29+
export conv_norm, depthwise_sep_conv_bn, invertedresidual
3030

3131
include("drop.jl")
3232
export DropBlock, DropPath, droppath_rates

0 commit comments

Comments
 (0)