|
| 1 | + |
| 2 | +ls(dims...) = reshape(collect(Float32, 1:prod(dims)), dims...) # accepts dims in reverse order to Dense |
| 3 | +dl(nin, nout, bias) = Dense(ls(nout, nin), bias(nout)) |
| 4 | +dm(bias) = Chain( |
| 5 | + dl(3, 5, bias), |
| 6 | + dl(5, 4, bias), |
| 7 | + dl(4, 3, bias) |
| 8 | +) |
| 9 | + |
| 10 | +nobias(n) = false |
| 11 | +testdense(m, bt) = @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt))) |
| 12 | + @test l1.weight == l2.weight |
| 13 | + @test l1.bias == l2.bias |
| 14 | + @test_skip typeof(l1.bias) === typeof(l2.bias) |
| 15 | +end |
| 16 | + |
| 17 | + |
| 18 | +@testset "loadmodel!(dst, src)" begin |
| 19 | + m1 = Chain(Dense(10, 5), Dense(5, 2, relu)) |
| 20 | + m2 = Chain(Dense(10, 5), Dense(5, 2)) |
| 21 | + m3 = Chain(Conv((3, 3), 3 => 16), Dense(5, 2)) |
| 22 | + m4 = Chain(Dense(10, 6), Dense(6, 2)) |
| 23 | + m5 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(2, 5), false), Dense(5, 2))) |
| 24 | + m6 = Chain(Dense(10, 5), Parallel(+, Dense(5, 2), Dense(5, 2))) |
| 25 | + |
| 26 | + loadmodel!(m1, m2) |
| 27 | + # trainable parameters copy over |
| 28 | + @test m1[1].weight == m2[1].weight |
| 29 | + @test m1[1].bias == m2[1].bias |
| 30 | + # non-array leaves are untouched |
| 31 | + @test m1[2].σ == relu |
| 32 | + |
| 33 | + loadmodel!(m5, m6) |
| 34 | + # more complex nested structures also work |
| 35 | + @test m5[1].weight == m6[1].weight |
| 36 | + @test m5[2][1].weight == m6[2][1].weight |
| 37 | + # false bias is not overwritten |
| 38 | + @test m5[2][1].bias == false |
| 39 | + |
| 40 | + # mismatched nodes throw an error |
| 41 | + @test_throws ArgumentError loadmodel!(m1, m3) |
| 42 | + @test_throws ArgumentError loadmodel!(m1, m5) |
| 43 | + # size mismatches throw an error |
| 44 | + @test_throws DimensionMismatch loadmodel!(m1, m4) |
| 45 | + |
| 46 | + # tests for BatchNorm and Dropout |
| 47 | + m1 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), Flux.flatten, Dropout(0.2)) |
| 48 | + m2 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), x -> reshape(x, :, size(x)[end]), Dropout(0.1)) |
| 49 | + m2[2].μ .= rand(Float32, size(m2[2].μ)...) |
| 50 | + loadmodel!(m1, m2) |
| 51 | + # non-trainable parameters are copied as well |
| 52 | + @test m1[2].μ == m2[2].μ |
| 53 | + # functions are not copied |
| 54 | + @test m1[3] == Flux.flatten |
| 55 | + # dropout rate is not copied |
| 56 | + @test m1[4].p == 0.2 |
| 57 | + |
| 58 | + # from LegolasFlux (https://github.com/beacon-biosignals/LegolasFlux.jl/blob/80569ab63a8248a8a063c76e0bbf701f4ada9bd4/examples/digits.jl#L33) |
| 59 | + # tests Chain(...) vs Chain([...]) |
| 60 | + # tests MaxPool |
| 61 | + # tests testmode!/trainmode! is not copied |
| 62 | + # tests Dense, Conv, BatchNorm, Dropout (like above) but in a bigger model |
| 63 | + chain1 = Chain(Dropout(0.2), |
| 64 | + Conv((3, 3), 1 => 32, relu), |
| 65 | + BatchNorm(32, relu), |
| 66 | + MaxPool((2, 2)), |
| 67 | + Dropout(0.2), |
| 68 | + Conv((3, 3), 32 => 16, relu), |
| 69 | + Dropout(0.2), |
| 70 | + MaxPool((2, 2)), |
| 71 | + Dropout(0.2), |
| 72 | + Conv((3, 3), 16 => 10, relu), |
| 73 | + Dropout(0.2), |
| 74 | + x -> reshape(x, :, size(x, 4)), |
| 75 | + Dropout(0.2), |
| 76 | + Dense(90, 10), |
| 77 | + softmax) |
| 78 | + chain2 = Chain([Dropout(0.1), |
| 79 | + Conv((3, 3), 1 => 32, relu), |
| 80 | + BatchNorm(32, relu), |
| 81 | + MaxPool((3, 3)), |
| 82 | + Dropout(0.1), |
| 83 | + Conv((3, 3), 32 => 16, relu), |
| 84 | + Dropout(0.1), |
| 85 | + MaxPool((3, 3)), |
| 86 | + Dropout(0.1), |
| 87 | + Conv((3, 3), 16 => 10, relu), |
| 88 | + Dropout(0.1), |
| 89 | + x -> reshape(x, :, size(x, 4)), |
| 90 | + Dropout(0.1), |
| 91 | + Dense(90, 10), |
| 92 | + softmax]) |
| 93 | + chain2[3].μ .= 5f0 |
| 94 | + chain2[3].σ² .= 2f0 |
| 95 | + testmode!(chain2) |
| 96 | + loadmodel!(chain1, chain2) |
| 97 | + for (dst, src) in zip(chain1, chain2) |
| 98 | + if dst isa Dropout |
| 99 | + @test dst.p == 0.2 |
| 100 | + elseif dst isa Union{Conv, Dense} |
| 101 | + @test dst.weight == src.weight |
| 102 | + @test dst.bias == src.bias |
| 103 | + elseif dst isa MaxPool |
| 104 | + @test dst.k == (2, 2) |
| 105 | + elseif dst isa BatchNorm |
| 106 | + @test dst.μ == src.μ |
| 107 | + @test dst.σ² == src.σ² |
| 108 | + @test isnothing(dst.active) |
| 109 | + end |
| 110 | + end |
| 111 | + |
| 112 | + # copy only a subset of the model |
| 113 | + chain1[end - 1].weight .= 1f0 |
| 114 | + chain1[3].μ .= 3f0 |
| 115 | + chain1[2].bias .= 5f0 |
| 116 | + loadmodel!(chain2[end - 1], chain1[end - 1]) |
| 117 | + loadmodel!(chain2[3], chain1[3]) |
| 118 | + @test chain2[end - 1].weight == chain1[end - 1].weight |
| 119 | + @test chain2[3].μ == chain1[3].μ |
| 120 | + @test chain2[2].bias != chain1[2].bias |
| 121 | + |
| 122 | + # test shared weights |
| 123 | + shared_dst = Dense(10 => 10) |
| 124 | + shared_src = Dense(10 => 10) |
| 125 | + # matched weights are okay |
| 126 | + m1 = Chain(shared_dst, Dense(shared_dst.weight)) |
| 127 | + m2 = Chain(shared_src, Dense(shared_src.weight)) |
| 128 | + loadmodel!(m1, m2) |
| 129 | + @test m1[1].weight === m1[2].weight |
| 130 | + @test m1[1].weight == m2[2].weight |
| 131 | + # mismatched weights are an error |
| 132 | + m2 = Chain(Dense(10 => 10), Dense(10 => 10)) |
| 133 | + @test_throws ErrorException loadmodel!(m1, m2) |
| 134 | + # loading into tied weights with absent parameter is okay when the dst == zero |
| 135 | + b = Flux.zeros32(5) |
| 136 | + m1 = Chain(Dense(10 => 5; bias = b), Dense(5 => 5; bias = b)) |
| 137 | + m2 = Chain(Dense(10 => 5; bias = Flux.zeros32(5)), Dense(5 => 5; bias = false)) |
| 138 | + loadmodel!(m1, m2) |
| 139 | + @test m1[1].bias === m1[2].bias |
| 140 | + @test iszero(m1[1].bias) |
| 141 | + # loading into tied weights with absent parameter is bad when the dst != zero |
| 142 | + m2[1].bias .= 1 |
| 143 | + @test_throws ErrorException loadmodel!(m1, m2) |
| 144 | + |
| 145 | + @testset "loadmodel! & filter" begin |
| 146 | + m1 = Chain(Dense(10, 5), Dense(5, 2, relu)) |
| 147 | + m2 = Chain(Dense(10, 5), Dropout(0.2), Dense(5, 2)) |
| 148 | + m3 = Chain(Dense(10, 5), Dense(5, 2, relu)) |
| 149 | + |
| 150 | + # this will not error cause Dropout is skipped |
| 151 | + loadmodel!(m1, m2; filter = x -> !(x isa Dropout)) |
| 152 | + @test m1[1].weight == m2[1].weight |
| 153 | + @test m1[2].weight == m2[3].weight |
| 154 | + |
| 155 | + # this will not error cause Dropout is skipped |
| 156 | + loadmodel!(m2, m3; filter = x -> !(x isa Dropout)) |
| 157 | + @test m3[1].weight == m2[1].weight |
| 158 | + @test m3[2].weight == m2[3].weight |
| 159 | + end |
| 160 | + |
| 161 | + @testset "loadmodel! & absent bias" begin |
| 162 | + m0 = Chain(Dense(2 => 3; bias=false, init = Flux.ones32), Dense(3 => 1)) |
| 163 | + m1 = Chain(Dense(2 => 3; bias = Flux.randn32(3)), Dense(3 => 1)) |
| 164 | + m2 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1)) |
| 165 | + |
| 166 | + Flux.loadmodel!(m1, m2) |
| 167 | + @test m1[1].bias == 7:9 |
| 168 | + @test sum(m1[1].weight) == 21 |
| 169 | + |
| 170 | + # load from a model without bias -- should ideally recognise the `false` but `Params` doesn't store it |
| 171 | + m1 = Flux.loadmodel!(m1, m0) |
| 172 | + @test iszero(m1[1].bias) |
| 173 | + @test sum(m1[1].weight) == 6 # written before error |
| 174 | + |
| 175 | + # load into a model without bias -- should it ignore the parameter which has no home, or error? |
| 176 | + m0 = Flux.loadmodel!(m0, m2) |
| 177 | + @test iszero(m0[1].bias) # obviously unchanged |
| 178 | + @test sum(m0[1].weight) == 21 |
| 179 | + end |
| 180 | +end |
| 181 | + |
| 182 | +@testset "loadmodel!(dst, src) with BSON" begin |
| 183 | + m1 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1)) |
| 184 | + m2 = Chain(Dense(Float32[0 0; 0 0; 0 0], Float32[0, 0, 0]), Dense(3 => 1)) |
| 185 | + @test m1[1].weight != m2[1].weight |
| 186 | + mktempdir() do dir |
| 187 | + BSON.@save joinpath(dir, "test.bson") m1 |
| 188 | + m2 = Flux.loadmodel!(m2, BSON.load(joinpath(dir, "test.bson"))[:m1]) |
| 189 | + @test m1[1].weight == m2[1].weight |
| 190 | + end |
| 191 | +end |
| 192 | + |
| 193 | +@testset "state" begin |
| 194 | + m1 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(2, 5), false), Dense(5 => 2))) |
| 195 | + m2 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.zeros32(2, 5), Flux.ones32(2)), Dense(5 => 2))) |
| 196 | + s = Flux.state(m1) |
| 197 | + @test s isa NamedTuple |
| 198 | + @test fieldnames(typeof(s)) == (:layers,) |
| 199 | + @test s.layers isa Tuple |
| 200 | + @test length(s.layers) == 2 |
| 201 | + @test s.layers[1].weight === m1[1].weight |
| 202 | + @test s.layers[1].σ === nothing |
| 203 | + @test s.layers[2].layers[1].weight === m1[2].layers[1].weight |
| 204 | + |
| 205 | + Flux.loadmodel!(m2, s) |
| 206 | + @test m2[1].weight == m1[1].weight |
| 207 | + @test all(m2[2].layers[1].bias .== m1[2].layers[1].bias) |
| 208 | + |
| 209 | + @testset "track active state and batch norm params" begin |
| 210 | + m3 = Chain(Dense(10, 5), Dropout(0.2), Dense(5, 2), BatchNorm(2)) |
| 211 | + trainmode!(m3) |
| 212 | + s = Flux.state(m3) |
| 213 | + @test s.layers[2].active == true |
| 214 | + @test s.layers[2].p == 0.2 |
| 215 | + @test s.layers[4] == (λ = nothing, β = Float32[0.0, 0.0], γ = Float32[1.0, 1.0], |
| 216 | + μ = Float32[0.0, 0.0], σ² = Float32[1.0, 1.0], ϵ = 1.0f-5, momentum = 0.1f0, affine = true, |
| 217 | + track_stats = true, active = true, chs = 2) |
| 218 | + end |
| 219 | + |
| 220 | + @testset "keep" begin |
| 221 | + s = Flux.state(m1, keep = x -> x isa AbstractArray) |
| 222 | + @test s.layers[1].weight isa AbstractArray |
| 223 | + @test s.layers[1].σ === nothing |
| 224 | + @test s.layers[2].connection === nothing |
| 225 | + @test s.layers[2].layers[1].bias === nothing |
| 226 | + end |
| 227 | +end |
0 commit comments