|
| 1 | +""" |
| 2 | + efficientnet(scalings, block_config; |
| 3 | + inchannels = 3, nclasses = 1000, max_width = 1280) |
| 4 | +
|
| 5 | +Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). |
| 6 | +
|
| 7 | +# Arguments |
| 8 | +
|
| 9 | +- `scalings`: global width and depth scaling (given as a tuple) |
| 10 | +- `block_config`: configuration for each inverted residual block, |
| 11 | + given as a vector of tuples with elements: |
| 12 | + - `n`: number of block repetitions (will be scaled by global depth scaling) |
| 13 | + - `k`: kernel size |
| 14 | + - `s`: kernel stride |
| 15 | + - `e`: expansion ratio |
| 16 | + - `i`: block input channels (will be scaled by global width scaling) |
| 17 | + - `o`: block output channels (will be scaled by global width scaling) |
| 18 | +- `inchannels`: number of input channels |
| 19 | +- `nclasses`: number of output classes |
| 20 | +- `max_width`: maximum number of output channels before the fully connected |
| 21 | + classification blocks |
| 22 | +""" |
| 23 | +function efficientnet(scalings, block_config; |
| 24 | + inchannels = 3, nclasses = 1000, max_width = 1280) |
| 25 | + wscale, dscale = scalings |
| 26 | + scalew(w) = wscale ≈ 1 ? w : ceil(Int64, wscale * w) |
| 27 | + scaled(d) = dscale ≈ 1 ? d : ceil(Int64, dscale * d) |
| 28 | + |
| 29 | + out_channels = _round_channels(scalew(32), 8) |
| 30 | + stem = conv_bn((3, 3), inchannels, out_channels, swish; |
| 31 | + bias = false, stride = 2, pad = SamePad()) |
| 32 | + |
| 33 | + blocks = [] |
| 34 | + for (n, k, s, e, i, o) in block_config |
| 35 | + in_channels = _round_channels(scalew(i), 8) |
| 36 | + out_channels = _round_channels(scalew(o), 8) |
| 37 | + repeats = scaled(n) |
| 38 | + |
| 39 | + push!(blocks, |
| 40 | + invertedresidual(k, in_channels, in_channels * e, out_channels, swish; |
| 41 | + stride = s, reduction = 4)) |
| 42 | + for _ in 1:(repeats - 1) |
| 43 | + push!(blocks, |
| 44 | + invertedresidual(k, out_channels, out_channels * e, out_channels, swish; |
| 45 | + stride = 1, reduction = 4)) |
| 46 | + end |
| 47 | + end |
| 48 | + blocks = Chain(blocks...) |
| 49 | + |
| 50 | + head_out_channels = _round_channels(max_width, 8) |
| 51 | + head = conv_bn((1, 1), out_channels, head_out_channels, swish; |
| 52 | + bias = false, pad = SamePad()) |
| 53 | + |
| 54 | + top = Dense(head_out_channels, nclasses) |
| 55 | + |
| 56 | + return Chain(Chain([stem..., blocks, head...]), |
| 57 | + Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, top)) |
| 58 | +end |
| 59 | + |
| 60 | +# n: # of block repetitions |
| 61 | +# k: kernel size k x k |
| 62 | +# s: stride |
| 63 | +# e: expantion ratio |
| 64 | +# i: block input channels |
| 65 | +# o: block output channels |
| 66 | +const efficientnet_block_configs = [ |
| 67 | +# (n, k, s, e, i, o) |
| 68 | + (1, 3, 1, 1, 32, 16), |
| 69 | + (2, 3, 2, 6, 16, 24), |
| 70 | + (2, 5, 2, 6, 24, 40), |
| 71 | + (3, 3, 2, 6, 40, 80), |
| 72 | + (3, 5, 1, 6, 80, 112), |
| 73 | + (4, 5, 2, 6, 112, 192), |
| 74 | + (1, 3, 1, 6, 192, 320) |
| 75 | +] |
| 76 | + |
| 77 | +# w: width scaling |
| 78 | +# d: depth scaling |
| 79 | +# r: image resolution |
| 80 | +const efficientnet_global_configs = Dict( |
| 81 | +# ( r, ( w, d)) |
| 82 | + :b0 => (224, (1.0, 1.0)), |
| 83 | + :b1 => (240, (1.0, 1.1)), |
| 84 | + :b2 => (260, (1.1, 1.2)), |
| 85 | + :b3 => (300, (1.2, 1.4)), |
| 86 | + :b4 => (380, (1.4, 1.8)), |
| 87 | + :b5 => (456, (1.6, 2.2)), |
| 88 | + :b6 => (528, (1.8, 2.6)), |
| 89 | + :b7 => (600, (2.0, 3.1)), |
| 90 | + :b8 => (672, (2.2, 3.6)) |
| 91 | +) |
| 92 | + |
| 93 | +struct EfficientNet |
| 94 | + layers::Any |
| 95 | +end |
| 96 | + |
| 97 | +""" |
| 98 | + EfficientNet(scalings, block_config; |
| 99 | + inchannels = 3, nclasses = 1000, max_width = 1280) |
| 100 | +
|
| 101 | +Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). |
| 102 | +See also [`efficientnet`](#). |
| 103 | +
|
| 104 | +# Arguments |
| 105 | +
|
| 106 | +- `scalings`: global width and depth scaling (given as a tuple) |
| 107 | +- `block_config`: configuration for each inverted residual block, |
| 108 | + given as a vector of tuples with elements: |
| 109 | + - `n`: number of block repetitions (will be scaled by global depth scaling) |
| 110 | + - `k`: kernel size |
| 111 | + - `s`: kernel stride |
| 112 | + - `e`: expansion ratio |
| 113 | + - `i`: block input channels (will be scaled by global width scaling) |
| 114 | + - `o`: block output channels (will be scaled by global width scaling) |
| 115 | +- `inchannels`: number of input channels |
| 116 | +- `nclasses`: number of output classes |
| 117 | +- `max_width`: maximum number of output channels before the fully connected |
| 118 | + classification blocks |
| 119 | +""" |
| 120 | +function EfficientNet(scalings, block_config; |
| 121 | + inchannels = 3, nclasses = 1000, max_width = 1280) |
| 122 | + layers = efficientnet(scalings, block_config; |
| 123 | + inchannels = inchannels, |
| 124 | + nclasses = nclasses, |
| 125 | + max_width = max_width) |
| 126 | + return EfficientNet(layers) |
| 127 | +end |
| 128 | + |
| 129 | +@functor EfficientNet |
| 130 | + |
| 131 | +(m::EfficientNet)(x) = m.layers(x) |
| 132 | + |
| 133 | +backbone(m::EfficientNet) = m.layers[1] |
| 134 | +classifier(m::EfficientNet) = m.layers[2] |
| 135 | + |
| 136 | +""" |
| 137 | + EfficientNet(name::Symbol; pretrain = false) |
| 138 | +
|
| 139 | +Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). |
| 140 | +See also [`efficientnet`](#). |
| 141 | +
|
| 142 | +# Arguments |
| 143 | +
|
| 144 | +- `name`: name of default configuration |
| 145 | + (can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`) |
| 146 | +- `pretrain`: set to `true` to load the pre-trained weights for ImageNet |
| 147 | +""" |
| 148 | +function EfficientNet(name::Symbol; pretrain = false) |
| 149 | + @assert name in keys(efficientnet_global_configs) |
| 150 | + "`name` must be one of $(sort(collect(keys(efficientnet_global_configs))))" |
| 151 | + |
| 152 | + model = EfficientNet(efficientnet_global_configs[name][2], efficientnet_block_configs) |
| 153 | + pretrain && loadpretrain!(model, string("efficientnet-", name)) |
| 154 | + |
| 155 | + return model |
| 156 | +end |
0 commit comments