diff --git a/NEWS.md b/NEWS.md index e7fad6ccf0..611fd53868 100644 --- a/NEWS.md +++ b/NEWS.md @@ -7,12 +7,13 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl Thus `LayerNorm(3; ϵ=1e-4)` (not `ε`!) should become `LayerNorm(3; eps=1e-4)`. * `DataLoader(...) |> gpu` will now produce a special iterator, moving each batch as needed, instead of giving an error. +* Added `Flux.state` returning the internal state of the model for serialization. ## v0.13.15 * Added [MultiHeadAttention](https://github.com/FluxML/Flux.jl/pull/2146) layer. * `f16, f32, f64` now specifically target floating point arrays (i.e. integers arrays and other types are preserved). * `f16, f32, f64` can now handle `Complex{<:AbstractFloat}` arrays. -* Added `EmbeddingBag` layer +* Added `EmbeddingBag` layer. ## v0.13.14 * Fixed various deprecation warnings, from `Zygone.@nograd` and `Vararg`. diff --git a/Project.toml b/Project.toml index 4e2dba884e..709e99a1d2 100644 --- a/Project.toml +++ b/Project.toml @@ -24,6 +24,12 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +[weakdeps] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + +[extensions] +AMDGPUExt = "AMDGPU" + [compat] AMDGPU = "0.4.13" Adapt = "3.0" @@ -44,9 +50,6 @@ Zygote = "0.6.49" cuDNN = "1" julia = "1.6" -[extensions] -AMDGPUExt = "AMDGPU" - [extras] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" @@ -59,6 +62,3 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "BSON"] - -[weakdeps] -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" diff --git a/docs/Project.toml b/docs/Project.toml index 4af31f2254..a4d907e63e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -5,6 +5,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" diff --git a/docs/make.jl b/docs/make.jl index a536014d99..a1b588d618 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,4 +1,6 @@ -using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers, OneHotArrays, Zygote, ChainRulesCore, Plots, MLDatasets, Statistics, DataFrames +using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers, + OneHotArrays, Zygote, ChainRulesCore, Plots, MLDatasets, Statistics, + DataFrames, JLD2 DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true) diff --git a/docs/src/destructure.md b/docs/src/destructure.md index 6e9eac191e..1cdcad5ce7 100644 --- a/docs/src/destructure.md +++ b/docs/src/destructure.md @@ -72,4 +72,11 @@ Another kind of flat view of a nested model is provided by the `modules` command ```@docs Flux.modules +``` + +### Save and Load + +```@docs +Flux.state +Flux.loadmodel! ``` \ No newline at end of file diff --git a/docs/src/saving.md b/docs/src/saving.md index 853f4b0d9c..16f944ef08 100644 --- a/docs/src/saving.md +++ b/docs/src/saving.md @@ -1,7 +1,10 @@ # Saving and Loading Models You may wish to save models so that they can be loaded and run in a later -session. The easiest way to do this is via +session. Flux provides a number of ways to do this. +The recommended way, which is the most robust one for long term storage, +is to use [`Flux.state`](@ref) in combination with a serialization format like +[JLD2.jl](https://juliaio.github.io/JLD2.jl/dev/) or [BSON.jl](https://github.com/JuliaIO/BSON.jl). Save a model: @@ -9,132 +12,136 @@ Save a model: ```jldoctest saving julia> using Flux -julia> model = Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax) -Chain( - Dense(10 => 5, relu), # 55 parameters - Dense(5 => 2), # 12 parameters - NNlib.softmax, -) # Total: 4 arrays, 67 parameters, 524 bytes. +julia> struct MyModel + net + end -julia> using BSON: @save +julia> Flux.@functor MyModel -julia> @save "mymodel.bson" model +julia> MyModel() = MyModel(Chain(Dense(10, 5, relu), Dense(5, 2))); + +julia> model = MyModel() +MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2))) + +julia> model_state = Flux.state(model); + +julia> using JLD2 + +julia> jldsave("mymodel.jld2"; model_state) ``` -Load it again: +Load it again in a new session using [`Flux.loadmodel!`](@ref): ```jldoctest saving -julia> using Flux # Flux must be loaded before calling @load +julia> using Flux, JLD2 -julia> using BSON: @load +julia> model_state = JLD2.load("mymodel.jld2", "model_state"); -julia> @load "mymodel.bson" model +julia> model = MyModel(); # MyModel definition must be available -julia> model -Chain( - Dense(10 => 5, relu), # 55 parameters - Dense(5 => 2), # 12 parameters - NNlib.softmax, -) # Total: 4 arrays, 67 parameters, 524 bytes. +julia> Flux.loadmodel!(model, model_state); ``` -Models are just normal Julia structs, so it's fine to use any Julia storage -format for this purpose. BSON.jl is particularly well supported and most likely -to be forwards compatible (that is, models saved now will load in future -versions of Flux). - !!! note If a saved model's parameters are stored on the GPU, the model will not load later on if there is no GPU support available. It's best to [move your model to the CPU](gpu.md) with `cpu(model)` before saving it. -!!! warning - - Previous versions of Flux suggested saving only the model weights using - `@save "mymodel.bson" params(model)`. - This is no longer recommended and even strongly discouraged. - Saving models this way will only store the trainable parameters which - will result in incorrect behavior for layers like `BatchNorm`. - -```julia -julia> using Flux - -julia> model = Chain(Dense(10 => 5,relu),Dense(5 => 2),softmax) -Chain( - Dense(10 => 5, relu), # 55 parameters - Dense(5 => 2), # 12 parameters - NNlib.softmax, -) # Total: 4 arrays, 67 parameters, 524 bytes. - -julia> weights = Flux.params(model); -``` - -Loading the model as shown above will return a new model with the stored parameters. -But sometimes you already have a model, and you want to load stored parameters into it. -This can be done as - -```julia -using Flux: loadmodel! -using BSON - -# some predefined model -model = Chain(Dense(10 => 5, relu), Dense(5 => 2), softmax) - -# load one model into another -model = loadmodel!(model, BSON.load("mymodel.bson")[:model]) -``` - -This ensures that the model loaded from `"mymodel.bson"` matches the structure of `model`. [`Flux.loadmodel!`](@ref) is also convenient for copying parameters between models in memory. - -```@docs -Flux.loadmodel! -``` ## Checkpointing -In longer training runs it's a good idea to periodically save your model, so that you can resume if training is interrupted (for example, if there's a power cut). You can do this by saving the model in the [callback provided to `train!`](training/training.md). +In longer training runs it's a good idea to periodically save your model, so that you can resume if training is interrupted (for example, if there's a power cut). ```jldoctest saving julia> using Flux: throttle -julia> using BSON: @save +julia> using JLD2 -julia> m = Chain(Dense(10 => 5, relu), Dense(5 => 2), softmax) +julia> m = Chain(Dense(10 => 5, relu), Dense(5 => 2)) Chain( Dense(10 => 5, relu), # 55 parameters Dense(5 => 2), # 12 parameters - NNlib.softmax, ) # Total: 4 arrays, 67 parameters, 524 bytes. -julia> evalcb = throttle(30) do - # Show loss - @save "model-checkpoint.bson" model +julia> for epoch in 1:10 + # ... train model ... + jldsave("model-checkpoint.jld2", model_state = Flux.state(m)) end; ``` -This will update the `"model-checkpoint.bson"` file every thirty seconds. +This will update the `"model-checkpoint.jld2"` every epoch. You can get more advanced by saving a series of models throughout training, for example ```julia -@save "model-$(now()).bson" model +jldsave("model-$(now()).jld2", model_state = Flux.state(m)) ``` -will produce a series of models like `"model-2018-03-06T02:57:10.41.bson"`. You +will produce a series of models like `"model-2018-03-06T02:57:10.41.jld2"`. You could also store the current test set loss, so that it's easy to (for example) revert to an older copy of the model if it starts to overfit. ```julia -@save "model-$(now()).bson" model loss = testloss() +jldsave("model-$(now()).jld2", model_state = Flux.state(m), loss = testloss()) ``` -Note that to resume a model's training, you might need to restore other stateful parts of your training loop. Possible examples are stateful optimisers (which usually utilize an `IdDict` to store their state), and the randomness used to partition the original data into the training and validation sets. +Note that to resume a model's training, you might need to restore other stateful parts of your training loop. Possible examples are the optimiser state and the randomness used to partition the original data into the training and validation sets. You can store the optimiser state alongside the model, to resume training -exactly where you left off. BSON is smart enough to [cache values](https://github.com/JuliaIO/BSON.jl/blob/v0.3.4/src/write.jl#L71) and insert links when saving, but only if it knows everything to be saved up front. Thus models and optimisers must be saved together to have the latter work after restoring. +exactly where you left off: ```julia -opt = Adam() -@save "model-$(now()).bson" model opt +model = MyModel() +opt_state = Flux.setup(AdamW(), model) + +# ... train model ... + +model_state = Flux.state(model) +jldsave("checkpoint_epoch=42.jld2"; model_state, opt_state) +``` + +# Saving Models as Julia Structs + +Models are just normal Julia structs, so it's fine to use any Julia storage +format to save the struct as it is instead of saving the state returned by [`Flux.state`](@ref). +[BSON.jl](https://github.com/JuliaIO/BSON.jl) is particularly convenient for this, +since it can also save anynomous functions, which are sometimes part of a model definition. + +Save a model: + +```jldoctest saving +julia> using Flux + +julia> model = Chain(Dense(10, 5, NNlib.relu), Dense(5, 2)); + +julia> using BSON: @save + +julia> @save "mymodel.bson" model ``` + +Load it again in a new session: + +```jldoctest saving +julia> using Flux, BSON + +julia> BSON.@load "mymodel.bson" model + +julia> model +Chain( + Dense(10 => 5, relu), # 55 parameters + Dense(5 => 2), # 12 parameters +) # Total: 4 arrays, 67 parameters, 524 bytes. +``` +!!! warning + Saving models this way could lead to compatibility issues across julia versions + and across Flux versions if some of the Flux layers' internals are changed. + It is therefore not recommended for long term storage, use [`Flux.state`](@ref) instead. + +!!! warning + + Previous versions of Flux suggested saving only the model weights using + `@save "mymodel.bson" params(model)`. + This is no longer recommended and even strongly discouraged. + Saving models this way will only store the trainable parameters which + will result in incorrect behavior for layers like `BatchNorm`. diff --git a/src/loading.jl b/src/loading.jl index 5cdd129936..8238a19cde 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -47,6 +47,8 @@ Non-array elements (such as activation functions) are not copied and need not ma Zero bias vectors and `bias=false` are considered equivalent (see extended help for more details). +See also [`Flux.state`](@ref). + # Examples ```julia julia> dst = Chain(Dense(Flux.ones32(2, 5), Flux.ones32(2), tanh), Dense(2 => 1; bias = [1f0])) @@ -88,12 +90,14 @@ but copying a `src` value of `true` will error. function loadmodel!(dst, src; filter = _ -> true, cache = Base.IdSet()) ldsts = _filter_children(filter, Functors.children(dst)) lsrcs = _filter_children(filter, Functors.children(src)) - (keys(ldsts) == keys(lsrcs)) || - throw(ArgumentError("Tried to load $(keys(lsrcs)) into $(keys(ldsts)) but the structures do not match.")) - - foreach(ldsts, lsrcs) do ldst, lsrc + keys_ldsts = keys(ldsts) + keys_lsrcs = keys(lsrcs) + collect(keys_ldsts) == collect(keys_lsrcs) || throw(ArgumentError("Tried to load $(keys_lsrcs) into $(keys_ldsts) but the structures do not match.")) + + for k in keys_lsrcs + lsrc, ldst = lsrcs[k], ldsts[k] if ldst in cache # we already loaded this parameter before - _tie_check(ldst, lsrc) && return ldst + _tie_check(ldst, lsrc) elseif Functors.isleaf(ldst) # our first time loading this leaf push!(cache, ldst) loadleaf!(ldst, lsrc) @@ -104,3 +108,71 @@ function loadmodel!(dst, src; filter = _ -> true, cache = Base.IdSet()) return dst end + +""" + state(x) + +Return an object with the same nested structure as `x` according to `Functors.children`, +but made only of basic containers (e.g. named tuples, tuples, arrays, and dictionaries). + +Besides trainable and non-trainable arrays, the state will contain leaf nodes that are not arrays, +such as numbers, symbols, strings, and nothing values. The leaf types that end up in the state +could increase in the future. + +This method is particularly useful for saving and loading models, +since the state contain only simple data types that can be easily serialized. + +The state can be passed to [`loadmodel!`](@ref) to restore the model. + +# Examples + +## Copy the state into another model + +```jldoctest +julia> m1 = Chain(Dense(1, 2, tanh; init=ones), Dense(2, 1; init=ones)); + +julia> s = Flux.state(m1) +(layers = ((weight = [1.0; 1.0;;], bias = [0.0, 0.0], σ = ()), (weight = [1.0 1.0], bias = [0.0], σ = ())),) + +julia> m2 = Chain(Dense(1, 2, tanh), Dense(2, 1; bias=false)); # weights are random numbers + +julia> Flux.loadmodel!(m2, s); + +julia> m2[1].weight # now the weights of m2 are the same as m1 +2×1 Matrix{Float32}: + 1.0 + 1.0 + +julia> Flux.state(trainmode!(Dropout(0.2))) # contains p & activity, but not RNG state +(p = 0.2, dims = (), active = true, rng = ()) + +julia> Flux.state(BatchNorm(1)) # contains non-trainable arrays μ, σ² +(λ = (), β = Float32[0.0], γ = Float32[1.0], μ = Float32[0.0], σ² = Float32[1.0], ϵ = 1.0f-5, momentum = 0.1f0, affine = true, track_stats = true, active = nothing, chs = 1) +``` + +## Save and load with BSON + +```julia-repl +julia> using BSON + +julia> BSON.@save "checkpoint.bson" model_state = s + +julia> Flux.loadmodel!(m2, BSON.load("checkpoint.bson")[:model_state]) +``` + +## Save and load with JLD2 + +```julia-repl +julia> using JLD2 + +julia> JLD2.jldsave("checkpoint.jld2", model_state = s) + +julia> Flux.loadmodel!(m2, JLD2.load("checkpoint.jld2", "model_state")) +``` +""" +state(x) = Functors.fmapstructure(_state, x) + +const STATE_TYPES = Union{AbstractArray, Number, Nothing, AbstractString, Symbol} + +_state(x::STATE_TYPES) = x +_state(x) = () diff --git a/test/loading.jl b/test/loading.jl new file mode 100644 index 0000000000..06bc412d31 --- /dev/null +++ b/test/loading.jl @@ -0,0 +1,239 @@ + +ls(dims...) = reshape(collect(Float32, 1:prod(dims)), dims...) # accepts dims in reverse order to Dense +dl(nin, nout, bias) = Dense(ls(nout, nin), bias(nout)) +dm(bias) = Chain( + dl(3, 5, bias), + dl(5, 4, bias), + dl(4, 3, bias) +) + +nobias(n) = false +testdense(m, bt) = @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt))) + @test l1.weight == l2.weight + @test l1.bias == l2.bias + @test_skip typeof(l1.bias) === typeof(l2.bias) +end + + +@testset "loadmodel!(dst, src)" begin + m1 = Chain(Dense(10, 5), Dense(5, 2, relu)) + m2 = Chain(Dense(10, 5), Dense(5, 2)) + m3 = Chain(Conv((3, 3), 3 => 16), Dense(5, 2)) + m4 = Chain(Dense(10, 6), Dense(6, 2)) + m5 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(2, 5), false), Dense(5, 2))) + m6 = Chain(Dense(10, 5), Parallel(+, Dense(5, 2), Dense(5, 2))) + + loadmodel!(m1, m2) + # trainable parameters copy over + @test m1[1].weight == m2[1].weight + @test m1[1].bias == m2[1].bias + # non-array leaves are untouched + @test m1[2].σ == relu + + loadmodel!(m5, m6) + # more complex nested structures also work + @test m5[1].weight == m6[1].weight + @test m5[2][1].weight == m6[2][1].weight + # false bias is not overwritten + @test m5[2][1].bias == false + + # mismatched nodes throw an error + @test_throws ArgumentError loadmodel!(m1, m3) + @test_throws ArgumentError loadmodel!(m1, m5) + # size mismatches throw an error + @test_throws DimensionMismatch loadmodel!(m1, m4) + + # tests for BatchNorm and Dropout + m1 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), Flux.flatten, Dropout(0.2)) + m2 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), x -> reshape(x, :, size(x)[end]), Dropout(0.1)) + m2[2].μ .= rand(Float32, size(m2[2].μ)...) + loadmodel!(m1, m2) + # non-trainable parameters are copied as well + @test m1[2].μ == m2[2].μ + # functions are not copied + @test m1[3] == Flux.flatten + # dropout rate is not copied + @test m1[4].p == 0.2 + + # from LegolasFlux (https://github.com/beacon-biosignals/LegolasFlux.jl/blob/80569ab63a8248a8a063c76e0bbf701f4ada9bd4/examples/digits.jl#L33) + # tests Chain(...) vs Chain([...]) + # tests MaxPool + # tests testmode!/trainmode! is not copied + # tests Dense, Conv, BatchNorm, Dropout (like above) but in a bigger model + chain1 = Chain(Dropout(0.2), + Conv((3, 3), 1 => 32, relu), + BatchNorm(32, relu), + MaxPool((2, 2)), + Dropout(0.2), + Conv((3, 3), 32 => 16, relu), + Dropout(0.2), + MaxPool((2, 2)), + Dropout(0.2), + Conv((3, 3), 16 => 10, relu), + Dropout(0.2), + x -> reshape(x, :, size(x, 4)), + Dropout(0.2), + Dense(90, 10), + softmax) + chain2 = Chain([Dropout(0.1), + Conv((3, 3), 1 => 32, relu), + BatchNorm(32, relu), + MaxPool((3, 3)), + Dropout(0.1), + Conv((3, 3), 32 => 16, relu), + Dropout(0.1), + MaxPool((3, 3)), + Dropout(0.1), + Conv((3, 3), 16 => 10, relu), + Dropout(0.1), + x -> reshape(x, :, size(x, 4)), + Dropout(0.1), + Dense(90, 10), + softmax]) + chain2[3].μ .= 5f0 + chain2[3].σ² .= 2f0 + testmode!(chain2) + loadmodel!(chain1, chain2) + for (dst, src) in zip(chain1, chain2) + if dst isa Dropout + @test dst.p == 0.2 + elseif dst isa Union{Conv, Dense} + @test dst.weight == src.weight + @test dst.bias == src.bias + elseif dst isa MaxPool + @test dst.k == (2, 2) + elseif dst isa BatchNorm + @test dst.μ == src.μ + @test dst.σ² == src.σ² + @test isnothing(dst.active) + end + end + + # copy only a subset of the model + chain1[end - 1].weight .= 1f0 + chain1[3].μ .= 3f0 + chain1[2].bias .= 5f0 + loadmodel!(chain2[end - 1], chain1[end - 1]) + loadmodel!(chain2[3], chain1[3]) + @test chain2[end - 1].weight == chain1[end - 1].weight + @test chain2[3].μ == chain1[3].μ + @test chain2[2].bias != chain1[2].bias + + # test shared weights + shared_dst = Dense(10 => 10) + shared_src = Dense(10 => 10) + # matched weights are okay + m1 = Chain(shared_dst, Dense(shared_dst.weight)) + m2 = Chain(shared_src, Dense(shared_src.weight)) + loadmodel!(m1, m2) + @test m1[1].weight === m1[2].weight + @test m1[1].weight == m2[2].weight + # mismatched weights are an error + m2 = Chain(Dense(10 => 10), Dense(10 => 10)) + @test_throws ErrorException loadmodel!(m1, m2) + # loading into tied weights with absent parameter is okay when the dst == zero + b = Flux.zeros32(5) + m1 = Chain(Dense(10 => 5; bias = b), Dense(5 => 5; bias = b)) + m2 = Chain(Dense(10 => 5; bias = Flux.zeros32(5)), Dense(5 => 5; bias = false)) + loadmodel!(m1, m2) + @test m1[1].bias === m1[2].bias + @test iszero(m1[1].bias) + # loading into tied weights with absent parameter is bad when the dst != zero + m2[1].bias .= 1 + @test_throws ErrorException loadmodel!(m1, m2) + + @testset "loadmodel! & filter" begin + m1 = Chain(Dense(10, 5), Dense(5, 2, relu)) + m2 = Chain(Dense(10, 5), Dropout(0.2), Dense(5, 2)) + m3 = Chain(Dense(10, 5), Dense(5, 2, relu)) + + # this will not error cause Dropout is skipped + loadmodel!(m1, m2; filter = x -> !(x isa Dropout)) + @test m1[1].weight == m2[1].weight + @test m1[2].weight == m2[3].weight + + # this will not error cause Dropout is skipped + loadmodel!(m2, m3; filter = x -> !(x isa Dropout)) + @test m3[1].weight == m2[1].weight + @test m3[2].weight == m2[3].weight + end + + @testset "loadmodel! & absent bias" begin + m0 = Chain(Dense(2 => 3; bias=false, init = Flux.ones32), Dense(3 => 1)) + m1 = Chain(Dense(2 => 3; bias = Flux.randn32(3)), Dense(3 => 1)) + m2 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1)) + + Flux.loadmodel!(m1, m2) + @test m1[1].bias == 7:9 + @test sum(m1[1].weight) == 21 + + # load from a model without bias -- should ideally recognise the `false` but `Params` doesn't store it + m1 = Flux.loadmodel!(m1, m0) + @test iszero(m1[1].bias) + @test sum(m1[1].weight) == 6 # written before error + + # load into a model without bias -- should it ignore the parameter which has no home, or error? + m0 = Flux.loadmodel!(m0, m2) + @test iszero(m0[1].bias) # obviously unchanged + @test sum(m0[1].weight) == 21 + end +end + +@testset "loadmodel!(dst, src) with BSON" begin + m1 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1)) + m2 = Chain(Dense(Float32[0 0; 0 0; 0 0], Float32[0, 0, 0]), Dense(3 => 1)) + @test m1[1].weight != m2[1].weight + mktempdir() do dir + BSON.@save joinpath(dir, "test.bson") m1 + m2 = Flux.loadmodel!(m2, BSON.load(joinpath(dir, "test.bson"))[:m1]) + @test m1[1].weight == m2[1].weight + end +end + +@testset "state" begin + m1 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(2, 5), false), Dense(5 => 2))) + m2 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.zeros32(2, 5), Flux.ones32(2)), Dense(5 => 2))) + s = Flux.state(m1) + @test s isa NamedTuple + @test fieldnames(typeof(s)) == (:layers,) + @test s.layers isa Tuple + @test length(s.layers) == 2 + @test s.layers[1].weight === m1[1].weight + @test s.layers[1].σ === () + @test s.layers[2].layers[1].weight === m1[2].layers[1].weight + + Flux.loadmodel!(m2, s) + @test m2[1].weight == m1[1].weight + @test all(m2[2].layers[1].bias .== m1[2].layers[1].bias) + + @testset "non-state elements are replaced with empty tuple" begin + @test Flux.state((1, tanh)) == (1, ()) + @test Flux.state((a=1, b=tanh)) == (; a=1, b=()) + @test Flux.state(Dict(:a=>1, :b=>tanh)) == Dict(:a=>1, :b=>()) + X, Y = Flux.ones32(3, 2), Flux.zeros32(2, 2) + tree = Dict(:a=>1, :b=>(; c=X, d=(Y, 1, (tanh,)), e=sin)) + state_tree = Dict(:a=>1, :b=>(; c=X, d=(Y, 1, ((),)), e=())) + @test Flux.state(tree) == state_tree + end + + @testset "track active state and batch norm params" begin + m3 = Chain(Dense(10, 5), Dropout(0.2), Dense(5, 2), BatchNorm(2)) + trainmode!(m3) + s = Flux.state(m3) + @test s.layers[2].active == true + @test s.layers[2].p == 0.2 + @test s.layers[4].λ === () + for k in (:β, :γ, :μ, :σ², :ϵ, :momentum, :affine, :track_stats, :active, :chs) + @test s.layers[4][k] === getfield(m3[4], k) + end + end + + @testset "preservation of saved types" begin + m = (num = 1, cnum = Complex(1.2, 2), str = "hello", arr = [1, 2, 3], + bool = true, dict = Dict(:a => 1, :b => 2), tup = (1, 2, 3), + sym = :a, nth = nothing) + + s = Flux.state(m) + @test s == m + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 09a65bb046..8285a712d2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,6 +17,10 @@ Random.seed!(0) include("utils.jl") end + @testset "Loading" begin + include("loading.jl") + end + @testset "Optimise / Train" begin include("optimise.jl") include("train.jl") diff --git a/test/utils.jl b/test/utils.jl index bda738f6a0..bac8deefa6 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -419,182 +419,6 @@ end @test_skip typeof(l1.bias) === typeof(l2.bias) end - - @testset "loadmodel!(dst, src)" begin - m1 = Chain(Dense(10, 5), Dense(5, 2, relu)) - m2 = Chain(Dense(10, 5), Dense(5, 2)) - m3 = Chain(Conv((3, 3), 3 => 16), Dense(5, 2)) - m4 = Chain(Dense(10, 6), Dense(6, 2)) - m5 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(2, 5), false), Dense(5, 2))) - m6 = Chain(Dense(10, 5), Parallel(+, Dense(5, 2), Dense(5, 2))) - - loadmodel!(m1, m2) - # trainable parameters copy over - @test m1[1].weight == m2[1].weight - @test m1[1].bias == m2[1].bias - # non-array leaves are untouched - @test m1[2].σ == relu - - loadmodel!(m5, m6) - # more complex nested structures also work - @test m5[1].weight == m6[1].weight - @test m5[2][1].weight == m6[2][1].weight - # false bias is not overwritten - @test m5[2][1].bias == false - - # mismatched nodes throw an error - @test_throws ArgumentError loadmodel!(m1, m3) - @test_throws ArgumentError loadmodel!(m1, m5) - # size mismatches throw an error - @test_throws DimensionMismatch loadmodel!(m1, m4) - - # tests for BatchNorm and Dropout - m1 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), Flux.flatten, Dropout(0.2)) - m2 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), x -> reshape(x, :, size(x)[end]), Dropout(0.1)) - m2[2].μ .= rand(Float32, size(m2[2].μ)...) - loadmodel!(m1, m2) - # non-trainable parameters are copied as well - @test m1[2].μ == m2[2].μ - # functions are not copied - @test m1[3] == Flux.flatten - # dropout rate is not copied - @test m1[4].p == 0.2 - - # from LegolasFlux (https://github.com/beacon-biosignals/LegolasFlux.jl/blob/80569ab63a8248a8a063c76e0bbf701f4ada9bd4/examples/digits.jl#L33) - # tests Chain(...) vs Chain([...]) - # tests MaxPool - # tests testmode!/trainmode! is not copied - # tests Dense, Conv, BatchNorm, Dropout (like above) but in a bigger model - chain1 = Chain(Dropout(0.2), - Conv((3, 3), 1 => 32, relu), - BatchNorm(32, relu), - MaxPool((2, 2)), - Dropout(0.2), - Conv((3, 3), 32 => 16, relu), - Dropout(0.2), - MaxPool((2, 2)), - Dropout(0.2), - Conv((3, 3), 16 => 10, relu), - Dropout(0.2), - x -> reshape(x, :, size(x, 4)), - Dropout(0.2), - Dense(90, 10), - softmax) - chain2 = Chain([Dropout(0.1), - Conv((3, 3), 1 => 32, relu), - BatchNorm(32, relu), - MaxPool((3, 3)), - Dropout(0.1), - Conv((3, 3), 32 => 16, relu), - Dropout(0.1), - MaxPool((3, 3)), - Dropout(0.1), - Conv((3, 3), 16 => 10, relu), - Dropout(0.1), - x -> reshape(x, :, size(x, 4)), - Dropout(0.1), - Dense(90, 10), - softmax]) - chain2[3].μ .= 5f0 - chain2[3].σ² .= 2f0 - testmode!(chain2) - loadmodel!(chain1, chain2) - for (dst, src) in zip(chain1, chain2) - if dst isa Dropout - @test dst.p == 0.2 - elseif dst isa Union{Conv, Dense} - @test dst.weight == src.weight - @test dst.bias == src.bias - elseif dst isa MaxPool - @test dst.k == (2, 2) - elseif dst isa BatchNorm - @test dst.μ == src.μ - @test dst.σ² == src.σ² - @test isnothing(dst.active) - end - end - - # copy only a subset of the model - chain1[end - 1].weight .= 1f0 - chain1[3].μ .= 3f0 - chain1[2].bias .= 5f0 - loadmodel!(chain2[end - 1], chain1[end - 1]) - loadmodel!(chain2[3], chain1[3]) - @test chain2[end - 1].weight == chain1[end - 1].weight - @test chain2[3].μ == chain1[3].μ - @test chain2[2].bias != chain1[2].bias - - # test shared weights - shared_dst = Dense(10 => 10) - shared_src = Dense(10 => 10) - # matched weights are okay - m1 = Chain(shared_dst, Dense(shared_dst.weight)) - m2 = Chain(shared_src, Dense(shared_src.weight)) - loadmodel!(m1, m2) - @test m1[1].weight === m1[2].weight - @test m1[1].weight == m2[2].weight - # mismatched weights are an error - m2 = Chain(Dense(10 => 10), Dense(10 => 10)) - @test_throws ErrorException loadmodel!(m1, m2) - # loading into tied weights with absent parameter is okay when the dst == zero - b = Flux.zeros32(5) - m1 = Chain(Dense(10 => 5; bias = b), Dense(5 => 5; bias = b)) - m2 = Chain(Dense(10 => 5; bias = Flux.zeros32(5)), Dense(5 => 5; bias = false)) - loadmodel!(m1, m2) - @test m1[1].bias === m1[2].bias - @test iszero(m1[1].bias) - # loading into tied weights with absent parameter is bad when the dst != zero - m2[1].bias .= 1 - @test_throws ErrorException loadmodel!(m1, m2) - - @testset "loadmodel! & filter" begin - m1 = Chain(Dense(10, 5), Dense(5, 2, relu)) - m2 = Chain(Dense(10, 5), Dropout(0.2), Dense(5, 2)) - m3 = Chain(Dense(10, 5), Dense(5, 2, relu)) - - # this will not error cause Dropout is skipped - loadmodel!(m1, m2; filter = x -> !(x isa Dropout)) - @test m1[1].weight == m2[1].weight - @test m1[2].weight == m2[3].weight - - # this will not error cause Dropout is skipped - loadmodel!(m2, m3; filter = x -> !(x isa Dropout)) - @test m3[1].weight == m2[1].weight - @test m3[2].weight == m2[3].weight - end - - @testset "loadmodel! & absent bias" begin - m0 = Chain(Dense(2 => 3; bias=false, init = Flux.ones32), Dense(3 => 1)) - m1 = Chain(Dense(2 => 3; bias = Flux.randn32(3)), Dense(3 => 1)) - m2 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1)) - - Flux.loadmodel!(m1, m2) - @test m1[1].bias == 7:9 - @test sum(m1[1].weight) == 21 - - # load from a model without bias -- should ideally recognise the `false` but `Params` doesn't store it - m1 = Flux.loadmodel!(m1, m0) - @test iszero(m1[1].bias) - @test sum(m1[1].weight) == 6 # written before error - - # load into a model without bias -- should it ignore the parameter which has no home, or error? - m0 = Flux.loadmodel!(m0, m2) - @test iszero(m0[1].bias) # obviously unchanged - @test sum(m0[1].weight) == 21 - end - end - - @testset "loadmodel!(dst, src) with BSON" begin - m1 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1)) - m2 = Chain(Dense(Float32[0 0; 0 0; 0 0], Float32[0, 0, 0]), Dense(3 => 1)) - @test m1[1].weight != m2[1].weight - mktempdir() do dir - BSON.@save joinpath(dir, "test.bson") m1 - m2 = Flux.loadmodel!(m2, BSON.load(joinpath(dir, "test.bson"))[:m1]) - @test m1[1].weight == m2[1].weight - end - end - @testset "destructure" begin import Flux: destructure @testset "Bias type $bt" for bt in (zeros, nobias)