Skip to content

Commit 85961da

Browse files
update
1 parent 6797f95 commit 85961da

File tree

4 files changed

+132
-115
lines changed

4 files changed

+132
-115
lines changed

docs/src/destructure.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ Another kind of flat view of a nested model is provided by the `modules` command
7474
Flux.modules
7575
```
7676

77-
### Saving and Loading
77+
### Save and Load
7878

7979
```@docs
80-
Flux.loadmodel!
8180
Flux.state
82-
```
81+
Flux.loadmodel!
82+
```

docs/src/saving.md

Lines changed: 80 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,53 @@
11
# Saving and Loading Models
22

33
You may wish to save models so that they can be loaded and run in a later
4-
session. The easiest way to do this is via
4+
session. Flux provides a number of ways to do this.
5+
The recommended way, which is the most robust one for long term storage,
6+
is to use [`Flux.state`](@ref) in combination with a serialization format like
7+
[JLD2.jl](https://juliaio.github.io/JLD2.jl/dev/) or
58
[BSON.jl](https://github.com/JuliaIO/BSON.jl).
69

710
Save a model:
811

912
```jldoctest saving
1013
julia> using Flux
1114
12-
julia> model = Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax)
13-
Chain(
14-
Dense(10 => 5, relu), # 55 parameters
15-
Dense(5 => 2), # 12 parameters
16-
NNlib.softmax,
17-
) # Total: 4 arrays, 67 parameters, 524 bytes.
15+
julia> struct MyModel
16+
net
17+
end
1818
19-
julia> using BSON: @save
19+
julia> Flux.@functor MyModel
2020
21-
julia> @save "mymodel.bson" model
21+
julia> MyModel() = MyModel(Chain(Dense(10, 5, relu), Dense(5, 2)))
22+
23+
julia> model = MyModel()
24+
MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2)))
25+
26+
julia> model_state = Flux.state(model);
27+
28+
julia> using JLD2
29+
30+
julia> jldsave("mymodel.jld2"; model_state)
2231
```
2332

24-
Load it again:
33+
Load it again in a new session using [`Flux.loadmodel!`](@ref):
2534

2635
```jldoctest saving
27-
julia> using Flux # Flux must be loaded before calling @load
36+
julia> using Flux, JLD2
2837
29-
julia> using BSON: @load
38+
julia> model_state = JLD2.load("mymodel.jld2", "model_state")
3039
31-
julia> @load "mymodel.bson" model
40+
julia> model = MyModel(); # MyModel definition must be available
3241
33-
julia> model
34-
Chain(
35-
Dense(10 => 5, relu), # 55 parameters
36-
Dense(5 => 2), # 12 parameters
37-
NNlib.softmax,
38-
) # Total: 4 arrays, 67 parameters, 524 bytes.
42+
julia> Flux.loadmodel!(model, model_state);
3943
```
4044

41-
Models are just normal Julia structs, so it's fine to use any Julia storage
42-
format for this purpose. BSON.jl is particularly well supported and most likely
43-
to be forwards compatible (that is, models saved now will load in future
44-
versions of Flux).
45-
4645
!!! note
4746

4847
If a saved model's parameters are stored on the GPU, the model will not load
4948
later on if there is no GPU support available. It's best to [move your model
5049
to the CPU](gpu.md) with `cpu(model)` before saving it.
5150

52-
!!! warning
53-
54-
Previous versions of Flux suggested saving only the model weights using
55-
`@save "mymodel.bson" params(model)`.
56-
This is no longer recommended and even strongly discouraged.
57-
Saving models this way will only store the trainable parameters which
58-
will result in incorrect behavior for layers like `BatchNorm`.
59-
60-
```julia
61-
julia> using Flux
62-
63-
julia> model = Chain(Dense(10 => 5,relu),Dense(5 => 2),softmax)
64-
Chain(
65-
Dense(10 => 5, relu), # 55 parameters
66-
Dense(5 => 2), # 12 parameters
67-
NNlib.softmax,
68-
) # Total: 4 arrays, 67 parameters, 524 bytes.
69-
70-
julia> weights = Flux.params(model);
71-
```
72-
73-
Loading the model as shown above will return a new model with the stored parameters.
74-
But sometimes you already have a model, and you want to load stored parameters into it.
75-
This can be done as
76-
77-
```julia
78-
using Flux: loadmodel!
79-
using BSON
80-
81-
# some predefined model
82-
model = Chain(Dense(10 => 5, relu), Dense(5 => 2), softmax)
83-
84-
# load one model into another
85-
model = loadmodel!(model, BSON.load("mymodel.bson")[:model])
86-
```
87-
88-
This ensures that the model loaded from `"mymodel.bson"` matches the structure of `model`. [`Flux.loadmodel!`](@ref) is also convenient for copying parameters between models in memory.
89-
90-
```@docs
91-
Flux.loadmodel!
92-
```
9351

9452
## Checkpointing
9553

@@ -98,50 +56,91 @@ In longer training runs it's a good idea to periodically save your model, so tha
9856
```jldoctest saving
9957
julia> using Flux: throttle
10058
101-
julia> using BSON: @save
59+
julia> using JLD2
10260
103-
julia> m = Chain(Dense(10 => 5, relu), Dense(5 => 2), softmax)
61+
julia> m = Chain(Dense(10 => 5, relu), Dense(5 => 2))
10462
Chain(
10563
Dense(10 => 5, relu), # 55 parameters
10664
Dense(5 => 2), # 12 parameters
107-
NNlib.softmax,
10865
) # Total: 4 arrays, 67 parameters, 524 bytes.
10966
11067
julia> evalcb = throttle(30) do
111-
# Show loss
112-
@save "model-checkpoint.bson" model
68+
jldsave("model-checkpoint.jld2", model_state = Flux.state(m))
11369
end;
11470
```
11571

116-
This will update the `"model-checkpoint.bson"` file every thirty seconds.
72+
This will update the `"model-checkpoint.jld2"` file every thirty seconds.
11773

11874
You can get more advanced by saving a series of models throughout training, for example
11975

12076
```julia
121-
@save "model-$(now()).bson" model
77+
jldsave("model-$(now()).jld2", model_state = Flux.state(m))
12278
```
12379

124-
will produce a series of models like `"model-2018-03-06T02:57:10.41.bson"`. You
80+
will produce a series of models like `"model-2018-03-06T02:57:10.41.jld2"`. You
12581
could also store the current test set loss, so that it's easy to (for example)
12682
revert to an older copy of the model if it starts to overfit.
12783

12884
```julia
129-
@save "model-$(now()).bson" model loss = testloss()
85+
jldsave("model-$(now()).jld2", model_state = Flux.state(m), loss = testloss())
13086
```
13187

132-
Note that to resume a model's training, you might need to restore other stateful parts of your training loop. Possible examples are stateful optimisers (which usually utilize an `IdDict` to store their state), and the randomness used to partition the original data into the training and validation sets.
88+
Note that to resume a model's training, you might need to restore other stateful parts of your training loop. Possible examples are the optimiser state and the randomness used to partition the original data into the training and validation sets.
13389

13490
You can store the optimiser state alongside the model, to resume training
135-
exactly where you left off. BSON is smart enough to [cache values](https://github.com/JuliaIO/BSON.jl/blob/v0.3.4/src/write.jl#L71) and insert links when saving, but only if it knows everything to be saved up front. Thus models and optimisers must be saved together to have the latter work after restoring.
91+
exactly where you left off:
13692

13793
```julia
138-
opt = Adam()
139-
@save "model-$(now()).bson" model opt
94+
model = MyModel()
95+
opt_state = Flux.setup(AdamW(), model)
96+
97+
# ... train model ...
98+
99+
model_state = Flux.state(model)
100+
jldsave("checkpoint_epoch=42.jld2"; model_state, opt_state)
140101
```
141102

142-
## Saving the state only
103+
# Saving Models as Julia Structs
143104

144-
An alternative ... TODO
105+
Models are just normal Julia structs, so it's fine to use any Julia storage
106+
format to save the struct as it is instead of saving the state returned by [`Flux.state`](@ref).
107+
[BSON.jl](https://github.com/JuliaIO/BSON.jl) is particularly convenient for this,
108+
since it can also save anynomous functions, which are sometimes part of a model definition.
145109

146-
```julia
110+
Save a model:
111+
112+
```jldoctest saving
113+
julia> using Flux
114+
115+
julia> model = Chain(Dense(10, 5, NNlib.relu), Dense(5, 2));
116+
117+
julia> using BSON: @save
118+
119+
julia> @save "mymodel.bson" model
120+
```
147121

122+
Load it again in a new session:
123+
124+
```jldoctest saving
125+
julia> using Flux, BSON
126+
127+
julia> BSON.@load "mymodel.bson" model
128+
129+
julia> model
130+
Chain(
131+
Dense(10 => 5, relu), # 55 parameters
132+
Dense(5 => 2), # 12 parameters
133+
) # Total: 4 arrays, 67 parameters, 524 bytes.
134+
```
135+
!!! warning
136+
Saving models this way could lead to compatibility issues across julia versions
137+
and across Flux versions if some of the Flux layers' internals are changed.
138+
It is therefore not recommended for long term storage, use [`Flux.state`](@ref) instead.
139+
140+
!!! warning
141+
142+
Previous versions of Flux suggested saving only the model weights using
143+
`@save "mymodel.bson" params(model)`.
144+
This is no longer recommended and even strongly discouraged.
145+
Saving models this way will only store the trainable parameters which
146+
will result in incorrect behavior for layers like `BatchNorm`.

src/loading.jl

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ Non-array elements (such as activation functions) are not copied and need not ma
4747
Zero bias vectors and `bias=false` are considered equivalent
4848
(see extended help for more details).
4949
50+
See also [`Flux.state`](@ref).
51+
5052
# Examples
5153
```julia
5254
julia> dst = Chain(Dense(Flux.ones32(2, 5), Flux.ones32(2), tanh), Dense(2 => 1; bias = [1f0]))
@@ -106,23 +108,28 @@ function loadmodel!(dst, src; filter = _ -> true, cache = Base.IdSet())
106108
end
107109

108110
"""
109-
state(x; keep = leaf -> !(leaf isa Function))
111+
state(x)
112+
113+
Return an object with the same nested structure as `x` according to `Functors.children`,
114+
but made only of basic containers (e.g. named tuples, tuples, arrays, and dictionaries).
110115
111-
Return an object with the same nested structure as `x`
112-
according to `Functors.children`, but made only of
113-
basic containers (e.g. named tuples, tuples, arrays, and dictionaries).
116+
Besides trainable and non-trainable arrays, the state will contain leaf nodes that are not arrays,
117+
such as numbers, symbols, strings, and nothing values. The leaf types that end up in the state
118+
could increase in the future.
114119
115120
This method is particularly useful for saving and loading models,
116-
since it doesn't require the user to specify the model type.
117-
The state can be passed to `loadmodel!` to restore the model.
121+
since the state contain only simple data types that can be easily serialized.
118122
119-
The `keep` function is applied on the leaves of `x`.
120-
If `keep(leaf)` is `false` , the leaf is replaced by `nothing`,
121-
otherwise it is left as is. By default, all functions are excluded.
123+
The state can be passed to [`loadmodel!`](@ref) to restore the model.
122124
123125
# Examples
124126
127+
## Copy the state into another model
128+
125129
```julia-repl
130+
julia> s = Flux.state(Dense(1, 2, tanh))
131+
(weight = Float32[0.5058468; 1.2398405;;], bias = Float32[0.0, 0.0], σ = missing)
132+
126133
julia> m1 = Chain(Dense(1, 2, tanh), Dense(2, 1));
127134
128135
julia> m2 = Chain(Dense(1, 2, tanh), Dense(2, 1));
@@ -135,18 +142,29 @@ julia> Flux.loadmodel!(m2, s);
135142
julia> m2[1].weight == m1[1].weight
136143
true
137144
```
145+
146+
## Save and load with BSON
147+
```julia-repl
148+
julia> using BSON
149+
150+
julia> BSON.@save "checkpoint.bson" model_state = s
151+
152+
julia> Flux.loadmodel!(m2, BSON.load("checkpoint.bson")[:model_state])
153+
```
154+
155+
## Save and load with JLD2
156+
157+
```julia-repl
158+
julia> using JLD2
159+
160+
julia> JLD2.jldsave("checkpoint.jld2", model_state = s)
161+
162+
julia> Flux.loadmodel!(m2, JLD2.load("checkpoint.jld2", "model_state"))
163+
```
138164
"""
139-
function state(x; keep = _state_keep)
140-
if Functors.isleaf(x)
141-
return keep(x) ? x : nothing
142-
else
143-
return _valuemap(c -> state(c; keep), Functors.children(x))
144-
end
145-
end
165+
state(x) = Functors.fmapstructure(x -> _state_keep(x) ? x : missing, x)
146166

147-
_state_keep(x::Function) = false
148-
_state_keep(x) = true
167+
const STATE_TYPES = Union{AbstractArray, Number, Nothing, AbstractString, Symbol}
149168

150-
# map for tuples, namedtuples, and dicts
151-
_valuemap(f, x) = map(f, x)
152-
_valuemap(f, x::Dict) = Dict(k => f(v) for (k, v) in x)
169+
_state_keep(x::STATE_TYPES) = true
170+
_state_keep(x) = false

test/loading.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ end
199199
@test s.layers isa Tuple
200200
@test length(s.layers) == 2
201201
@test s.layers[1].weight === m1[1].weight
202-
@test s.layers[1].σ === nothing
202+
@test s.layers[1].σ === missing
203203
@test s.layers[2].layers[1].weight === m1[2].layers[1].weight
204204

205205
Flux.loadmodel!(m2, s)
@@ -212,16 +212,16 @@ end
212212
s = Flux.state(m3)
213213
@test s.layers[2].active == true
214214
@test s.layers[2].p == 0.2
215-
@test s.layers[4] === nothing, β = Float32[0.0, 0.0], γ = Float32[1.0, 1.0],
216-
μ = Float32[0.0, 0.0], σ² = Float32[1.0, 1.0], ϵ = 1.0f-5, momentum = 0.1f0, affine = true,
217-
track_stats = true, active = true, chs = 2)
215+
@test s.layers[4].λ === missing
216+
for k in (, , , :σ², , :momentum, :affine, :track_stats, :active, :chs)
217+
@test s.layers[4][k] === getfield(m3[4], k)
218+
end
218219
end
219220

220-
@testset "keep" begin
221-
s = Flux.state(m1, keep = x -> x isa AbstractArray)
222-
@test s.layers[1].weight isa AbstractArray
223-
@test s.layers[1].σ === nothing
224-
@test s.layers[2].connection === nothing
225-
@test s.layers[2].layers[1].bias === nothing
221+
@testset "saved types" begin
222+
m = (num = 1, cnum = Complex(1.2, 2), str = "hello", arr = [1, 2, 3],
223+
dict = Dict(:a => 1, :b => 2), tup = (1, 2, 3), sym = :a, nth = nothing)
224+
s = Flux.state(m)
225+
@test s == m
226226
end
227227
end

0 commit comments

Comments
 (0)