1
1
"""
2
- basicblock(inplanes, planes; stride = 1, downsample = identity,
3
- reduction_factor = 1, dilation = 1, first_dilation = dilation,
4
- activation = relu, connection = addact\$ activation,
5
- norm_layer = BatchNorm, drop_block = identity, drop_path = identity,
6
- attn_fn = planes -> identity)
2
+ basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activation = relu,
3
+ norm_layer = BatchNorm, prenorm = false,
4
+ drop_block = identity, drop_path = identity,
5
+ attn_fn = planes -> identity)
7
6
8
7
Creates a basic ResNet block.
9
8
@@ -12,24 +11,19 @@ Creates a basic ResNet block.
12
11
- `inplanes`: number of input feature maps
13
12
- `planes`: number of feature maps for the block
14
13
- `stride`: the stride of the block
15
- - `downsample`: the downsampling function to use
16
14
- `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first
17
15
convolution.
18
- - `dilation`: the dilation of the second convolution.
19
- - `first_dilation`: the dilation of the first convolution.
20
16
- `activation`: the activation function to use.
21
- - `connection`: the function applied to the output of residual and skip paths in
22
- a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses
23
- PartialFunctions.jl to pass in the activation function with the notation `addact\$ activation`.
24
17
- `norm_layer`: the normalization layer to use.
25
18
- `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks`
26
19
function and passed in.
27
20
- `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks`
28
21
function and passed in.
29
22
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
30
23
"""
31
- function basicblock (inplanes, planes; stride = 1 , reduction_factor = 1 , activation = relu,
32
- norm_layer = BatchNorm, prenorm = false ,
24
+ function basicblock (inplanes:: Integer , planes:: Integer ; stride:: Integer = 1 ,
25
+ reduction_factor:: Integer = 1 , activation = relu,
26
+ norm_layer = BatchNorm, prenorm:: Bool = false ,
33
27
drop_block = identity, drop_path = identity,
34
28
attn_fn = planes -> identity)
35
29
first_planes = planes ÷ reduction_factor
45
39
expansion_factor (:: typeof (basicblock)) = 1
46
40
47
41
"""
48
- bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1 ,
49
- base_width = 64, reduction_factor = 1, first_dilation = 1 ,
50
- activation = relu, connection = addact \$ activation ,
51
- norm_layer = BatchNorm, drop_block = identity, drop_path = identity,
52
- attn_fn = planes -> identity)
42
+ bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = 64 ,
43
+ reduction_factor = 1, activation = relu ,
44
+ norm_layer = BatchNorm, prenorm = false ,
45
+ drop_block = identity, drop_path = identity,
46
+ attn_fn = planes -> identity)
53
47
54
48
Creates a bottleneck ResNet block.
55
49
@@ -58,26 +52,22 @@ Creates a bottleneck ResNet block.
58
52
- `inplanes`: number of input feature maps
59
53
- `planes`: number of feature maps for the block
60
54
- `stride`: the stride of the block
61
- - `downsample`: the downsampling function to use
62
55
- `cardinality`: the number of groups in the convolution.
63
56
- `base_width`: the number of output feature maps for each convolutional group.
64
57
- `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first
65
58
convolution.
66
- - `first_dilation`: the dilation of the 3x3 convolution.
67
59
- `activation`: the activation function to use.
68
- - `connection`: the function applied to the output of residual and skip paths in
69
- a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses
70
- PartialFunctions.jl to pass in the activation function with the notation `addact\$ activation`.
71
60
- `norm_layer`: the normalization layer to use.
72
61
- `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks`
73
62
function and passed in.
74
63
- `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks`
75
64
function and passed in.
76
65
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
77
66
"""
78
- function bottleneck (inplanes, planes; stride = 1 , cardinality = 1 , base_width = 64 ,
79
- reduction_factor = 1 , activation = relu,
80
- norm_layer = BatchNorm, prenorm = false ,
67
+ function bottleneck (inplanes:: Integer , planes:: Integer ; stride:: Integer ,
68
+ cardinality:: Integer = 1 , base_width:: Integer = 64 ,
69
+ reduction_factor:: Integer = 1 , activation = relu,
70
+ norm_layer = BatchNorm, prenorm:: Bool = false ,
81
71
drop_block = identity, drop_path = identity,
82
72
attn_fn = planes -> identity)
83
73
width = floor (Int, planes * (base_width / 64 )) * cardinality
113
103
114
104
# Downsample layer which is an identity projection. Uses max pooling
115
105
# when the output size is more than the input size.
106
+ # TODO - figure out how to make this work when outplanes < inplanes
116
107
function downsample_identity (inplanes:: Integer , outplanes:: Integer ; kwargs... )
117
108
if outplanes > inplanes
118
109
return Chain (MaxPool ((1 , 1 ); stride = 2 ),
@@ -174,8 +165,8 @@ on how to use this function.
174
165
- `activation`: The activation function used in the stem.
175
166
"""
176
167
function resnet_stem (stem_type:: Symbol = :default ; inchannels:: Integer = 3 ,
177
- replace_pool:: Bool = false , norm_layer = BatchNorm, prenorm = false ,
178
- activation = relu )
168
+ replace_pool:: Bool = false , activation = relu ,
169
+ norm_layer = BatchNorm, prenorm :: Bool = false )
179
170
@assert stem_type in [:default , :deep , :deep_tiered ]
180
171
" Stem type must be one of [:default, :deep, :deep_tiered]"
181
172
# Main stem
@@ -203,65 +194,70 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3,
203
194
prenorm,
204
195
stride = 2 , pad = 1 , bias = false )... ) :
205
196
MaxPool ((3 , 3 ); stride = 2 , pad = 1 )
206
- return Chain (conv1, bn1, stempool), inplanes
207
- end
208
-
209
- # Templating builders for the blocks and the downsampling layers
210
- function template_builder (block_fn; kwargs... )
211
- function (inplanes, planes; _kwargs... )
212
- return block_fn (inplanes, planes; kwargs... , _kwargs... )
213
- end
214
- end
215
-
216
- function template_builder (:: typeof (basicblock); reduction_factor:: Integer = 1 ,
217
- activation = relu, norm_layer = BatchNorm, prenorm:: Bool = false ,
218
- attn_fn = planes -> identity, kargs... )
219
- return (args... ; kwargs... ) -> basicblock (args... ; kwargs... , reduction_factor,
220
- activation, norm_layer, prenorm, attn_fn)
197
+ return Chain (conv1, bn1, stempool)
221
198
end
222
199
223
- function template_builder (:: typeof (bottleneck); cardinality:: Integer = 1 ,
224
- base_width:: Integer = 64 ,
225
- reduction_factor:: Integer = 1 , activation = relu,
226
- norm_layer = BatchNorm, prenorm:: Bool = false ,
227
- attn_fn = planes -> identity, kargs... )
228
- return (args... ; kwargs... ) -> bottleneck (args... ; kwargs... , cardinality, base_width,
229
- reduction_factor, activation,
230
- norm_layer, prenorm, attn_fn)
231
- end
200
+ resnet_planes (stage_idx:: Integer ) = 64 * 2 ^ (stage_idx - 1 )
232
201
233
- function template_builder (downsample_fn:: Union {typeof (downsample_conv),
234
- typeof (downsample_pool),
235
- typeof (downsample_identity)};
236
- norm_layer = BatchNorm, prenorm = false )
237
- return (args... ; kwargs... ) -> downsample_fn (args... ; kwargs... , norm_layer, prenorm)
202
+ function basicblock_builder (block_repeats:: Vector{<:Integer} ; inplanes:: Integer = 64 ,
203
+ reduction_factor:: Integer = 1 , expansion:: Integer = 1 ,
204
+ norm_layer = BatchNorm, prenorm:: Bool = false ,
205
+ activation = relu, attn_fn = planes -> identity,
206
+ drop_block_rate = 0.0 , drop_path_rate = 0.0 ,
207
+ stride_fn = get_stride, planes_fn = resnet_planes,
208
+ downsample_tuple = (downsample_conv, downsample_identity))
209
+ pathschedule = linear_scheduler (drop_path_rate; depth = sum (block_repeats))
210
+ blockschedule = linear_scheduler (drop_block_rate; depth = sum (block_repeats))
211
+ # closure over `idxs`
212
+ function get_layers (stage_idx:: Integer , block_idx:: Integer )
213
+ planes = planes_fn (stage_idx)
214
+ # `get_stride` is a callback that the user can tweak to change the stride of the
215
+ # blocks. It defaults to the standard behaviour as in the paper
216
+ stride = stride_fn (stage_idx, block_idx)
217
+ downsample_fn = (stride != 1 || inplanes != planes * expansion) ?
218
+ downsample_tuple[1 ] : downsample_tuple[2 ]
219
+ # DropBlock, DropPath both take in rates based on a linear scaling schedule
220
+ schedule_idx = sum (block_repeats[1 : (stage_idx - 1 )]) + block_idx
221
+ drop_path = DropPath (pathschedule[schedule_idx])
222
+ drop_block = DropBlock (blockschedule[schedule_idx])
223
+ block = basicblock (inplanes, planes; stride, reduction_factor, activation,
224
+ norm_layer, prenorm, attn_fn, drop_path, drop_block)
225
+ downsample = downsample_fn (inplanes, planes * expansion; stride)
226
+ # inplanes increases by expansion after each block
227
+ inplanes = planes * expansion
228
+ return block, downsample
229
+ end
230
+ return get_layers
238
231
end
239
232
240
- resnet_planes (stage_idx:: Integer ) = 64 * 2 ^ (stage_idx - 1 )
241
-
242
- function configure_resnet_block (block_template, expansion, block_repeats:: Vector{<:Integer} ;
243
- stride_fn = get_stride, plane_fn = resnet_planes,
244
- downsample_templates:: NTuple{2, Any} ,
245
- inplanes:: Integer = 64 ,
246
- drop_path_rate = 0.0 , drop_block_rate = 0.0 , kwargs... )
233
+ function bottleneck_builder (block_repeats:: Vector{<:Integer} ; inplanes:: Integer = 64 ,
234
+ cardinality:: Integer = 1 , base_width:: Integer = 64 ,
235
+ reduction_factor:: Integer = 1 , expansion:: Integer = 4 ,
236
+ norm_layer = BatchNorm, prenorm:: Bool = false ,
237
+ activation = relu, attn_fn = planes -> identity,
238
+ drop_block_rate = 0.0 , drop_path_rate = 0.0 ,
239
+ stride_fn = get_stride, planes_fn = resnet_planes,
240
+ downsample_tuple = (downsample_conv, downsample_identity))
247
241
pathschedule = linear_scheduler (drop_path_rate; depth = sum (block_repeats))
248
242
blockschedule = linear_scheduler (drop_block_rate; depth = sum (block_repeats))
249
243
# closure over `idxs`
250
244
function get_layers (stage_idx:: Integer , block_idx:: Integer )
251
- planes = plane_fn (stage_idx)
245
+ planes = planes_fn (stage_idx)
252
246
# `get_stride` is a callback that the user can tweak to change the stride of the
253
247
# blocks. It defaults to the standard behaviour as in the paper
254
248
stride = stride_fn (stage_idx, block_idx)
255
- downsample_template = (stride != 1 || inplanes != planes * expansion) ?
256
- downsample_templates [1 ] : downsample_templates [2 ]
249
+ downsample_fn = (stride != 1 || inplanes != planes * expansion) ?
250
+ downsample_tuple [1 ] : downsample_tuple [2 ]
257
251
# DropBlock, DropPath both take in rates based on a linear scaling schedule
258
252
schedule_idx = sum (block_repeats[1 : (stage_idx - 1 )]) + block_idx
259
253
drop_path = DropPath (pathschedule[schedule_idx])
260
254
drop_block = DropBlock (blockschedule[schedule_idx])
261
- block = block_template (inplanes, planes; stride, drop_path, drop_block)
262
- downsample = downsample_template (inplanes, planes * expansion; stride)
255
+ block = bottleneck (inplanes, planes; stride, cardinality, base_width,
256
+ reduction_factor, activation, norm_layer, prenorm,
257
+ attn_fn, drop_path, drop_block)
258
+ downsample = downsample_fn (inplanes, planes * expansion; stride)
263
259
# inplanes increases by expansion after each block
264
- inplanes = ( planes * expansion)
260
+ inplanes = planes * expansion
265
261
return block, downsample
266
262
end
267
263
return get_layers
@@ -283,41 +279,59 @@ function resnet_stages(get_layers, block_repeats::Vector{<:Integer}, connection)
283
279
return Chain (stages... )
284
280
end
285
281
286
- function resnet (connection, get_layers, block_repeats:: Vector{<:Integer} , stem, classifier)
287
- stage_blocks = resnet_stages (get_layers, block_repeats, connection)
288
- return Chain (Chain (stem, stage_blocks), classifier)
282
+ function resnet (block_type:: Symbol , block_repeats:: Vector{<:Integer} ;
283
+ downsample_opt:: NTuple{2, Any} = (downsample_conv, downsample_identity),
284
+ cardinality:: Integer = 1 , base_width:: Integer = 64 , inplanes:: Integer = 64 ,
285
+ reduction_factor:: Integer = 1 , imsize:: Dims{2} = (256 , 256 ),
286
+ inchannels:: Integer = 3 , stem_fn = resnet_stem,
287
+ connection = addact, activation = relu, norm_layer = BatchNorm,
288
+ prenorm:: Bool = false , attn_fn = planes -> identity,
289
+ pool_layer = AdaptiveMeanPool ((1 , 1 )), use_conv:: Bool = false ,
290
+ drop_block_rate = 0.0 , drop_path_rate = 0.0 , dropout_rate = 0.0 ,
291
+ nclasses:: Integer = 1000 )
292
+ # Build stem
293
+ stem = stem_fn (; inchannels)
294
+ # Block builder
295
+ if block_type == :basicblock
296
+ @assert cardinality== 1 " Cardinality must be 1 for `basicblock`"
297
+ @assert base_width== 64 " Base width must be 64 for `basicblock`"
298
+ get_layers = basicblock_builder (block_repeats; inplanes, reduction_factor,
299
+ activation, norm_layer, prenorm, attn_fn,
300
+ drop_block_rate, drop_path_rate,
301
+ stride_fn = get_stride, planes_fn = resnet_planes,
302
+ downsample_tuple = downsample_opt)
303
+ elseif block_type == :bottleneck
304
+ get_layers = bottleneck_builder (block_repeats; inplanes, cardinality, base_width,
305
+ reduction_factor, activation, norm_layer,
306
+ prenorm, attn_fn, drop_block_rate, drop_path_rate,
307
+ stride_fn = get_stride, planes_fn = resnet_planes,
308
+ downsample_tuple = downsample_opt)
309
+ else
310
+ throw (ArgumentError (" Unknown block type $block_type " ))
311
+ end
312
+ classifier_fn = nfeatures -> create_classifier (nfeatures, nclasses; dropout_rate,
313
+ pool_layer, use_conv)
314
+ return resnet ((imsize... , inchannels), stem, connection$ activation, get_layers,
315
+ block_repeats, classifier_fn)
316
+ end
317
+ function resnet (block_fn, block_repeats, downsample_opt:: Symbol = :B ; kwargs... )
318
+ return resnet (block_fn, block_repeats, shortcut_dict[downsample_opt]; kwargs... )
289
319
end
290
320
291
- function resnet (block_fn, block_repeats:: Vector{<:Integer} ,
292
- downsample_opt:: NTuple{2, Any} = (downsample_conv, downsample_identity);
293
- imsize:: Dims{2} = (256 , 256 ), inchannels:: Integer = 3 ,
294
- stem = first (resnet_stem (; inchannels)), inplanes:: Integer = 64 ,
295
- connection = addact, activation = relu,
296
- pool_layer = AdaptiveMeanPool ((1 , 1 )), use_conv:: Bool = false ,
297
- dropout_rate = 0.0 , nclasses:: Integer = 1000 , kwargs... )
298
- # Configure downsample templates
299
- downsample_templates = map (template_builder, downsample_opt)
300
- # Configure block templates
301
- block_template = template_builder (block_fn; kwargs... )
302
- get_layers = configure_resnet_block (block_template, expansion_factor (block_fn),
303
- block_repeats; inplanes, downsample_templates,
304
- kwargs... )
321
+ function resnet (img_dims, stem, connection, get_layers, block_repeats:: Vector{<:Integer} ,
322
+ classifier_fn)
305
323
# Build stages of the ResNet
306
- stage_blocks = resnet_stages (get_layers, block_repeats, connection$ activation )
324
+ stage_blocks = resnet_stages (get_layers, block_repeats, connection)
307
325
backbone = Chain (stem, stage_blocks)
308
326
# Build the classifier head
309
- nfeaturemaps = Flux. outputsize (backbone, (imsize... , inchannels); padbatch = true )[3 ]
310
- classifier = create_classifier (nfeaturemaps, nclasses; dropout_rate, pool_layer,
311
- use_conv)
327
+ nfeaturemaps = Flux. outputsize (backbone, img_dims; padbatch = true )[3 ]
328
+ classifier = classifier_fn (nfeaturemaps)
312
329
return Chain (backbone, classifier)
313
330
end
314
- function resnet (block_fn, block_repeats, downsample_opt:: Symbol = :B ; kwargs... )
315
- return resnet (block_fn, block_repeats, shortcut_dict[downsample_opt], kwargs... )
316
- end
317
331
318
332
# block-layer configurations for ResNet-like models
319
- const resnet_configs = Dict (18 => (basicblock, [2 , 2 , 2 , 2 ]),
320
- 34 => (basicblock, [3 , 4 , 6 , 3 ]),
321
- 50 => (bottleneck, [3 , 4 , 6 , 3 ]),
322
- 101 => (bottleneck, [3 , 4 , 23 , 3 ]),
323
- 152 => (bottleneck, [3 , 8 , 36 , 3 ]))
333
+ const resnet_configs = Dict (18 => (: basicblock , [2 , 2 , 2 , 2 ]),
334
+ 34 => (: basicblock , [3 , 4 , 6 , 3 ]),
335
+ 50 => (: bottleneck , [3 , 4 , 6 , 3 ]),
336
+ 101 => (: bottleneck , [3 , 4 , 23 , 3 ]),
337
+ 152 => (: bottleneck , [3 , 8 , 36 , 3 ]))
0 commit comments