Skip to content

Commit 5fe8ada

Browse files
add Flux.state(x) (#2239)
* Flux.state * some docs * some docs * add keep keyword * update * apply suggestions * cleanup * remove callback from docs * fix * new proposal pruning missings * sentinel is empty tuple * rewording * fix loadmodel! * don't drop fields * require keys to be equal in loadmodel!
1 parent 1e1da28 commit 5fe8ada

File tree

10 files changed

+426
-269
lines changed

10 files changed

+426
-269
lines changed

NEWS.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl
77
Thus `LayerNorm(3; ϵ=1e-4)` (not `ε`!) should become `LayerNorm(3; eps=1e-4)`.
88
* `DataLoader(...) |> gpu` will now produce a special iterator, moving each batch as needed,
99
instead of giving an error.
10+
* Added `Flux.state` returning the internal state of the model for serialization.
1011

1112
## v0.13.15
1213
* Added [MultiHeadAttention](https://github.com/FluxML/Flux.jl/pull/2146) layer.
1314
* `f16, f32, f64` now specifically target floating point arrays (i.e. integers arrays and other types are preserved).
1415
* `f16, f32, f64` can now handle `Complex{<:AbstractFloat}` arrays.
15-
* Added `EmbeddingBag` layer
16+
* Added `EmbeddingBag` layer.
1617

1718
## v0.13.14
1819
* Fixed various deprecation warnings, from `Zygone.@nograd` and `Vararg`.

Project.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2424
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2525
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
2626

27+
[weakdeps]
28+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
29+
30+
[extensions]
31+
AMDGPUExt = "AMDGPU"
32+
2733
[compat]
2834
AMDGPU = "0.4.13"
2935
Adapt = "3.0"
@@ -44,9 +50,6 @@ Zygote = "0.6.49"
4450
cuDNN = "1"
4551
julia = "1.6"
4652

47-
[extensions]
48-
AMDGPUExt = "AMDGPU"
49-
5053
[extras]
5154
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
5255
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
@@ -59,6 +62,3 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5962

6063
[targets]
6164
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "BSON"]
62-
63-
[weakdeps]
64-
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
55
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
66
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
77
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
8+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
89
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
910
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1011
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

docs/make.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers, OneHotArrays, Zygote, ChainRulesCore, Plots, MLDatasets, Statistics, DataFrames
1+
using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers,
2+
OneHotArrays, Zygote, ChainRulesCore, Plots, MLDatasets, Statistics,
3+
DataFrames, JLD2
24

35
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true)
46

docs/src/destructure.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,11 @@ Another kind of flat view of a nested model is provided by the `modules` command
7272

7373
```@docs
7474
Flux.modules
75+
```
76+
77+
### Save and Load
78+
79+
```@docs
80+
Flux.state
81+
Flux.loadmodel!
7582
```

docs/src/saving.md

Lines changed: 87 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,140 +1,147 @@
11
# Saving and Loading Models
22

33
You may wish to save models so that they can be loaded and run in a later
4-
session. The easiest way to do this is via
4+
session. Flux provides a number of ways to do this.
5+
The recommended way, which is the most robust one for long term storage,
6+
is to use [`Flux.state`](@ref) in combination with a serialization format like
7+
[JLD2.jl](https://juliaio.github.io/JLD2.jl/dev/) or
58
[BSON.jl](https://github.com/JuliaIO/BSON.jl).
69

710
Save a model:
811

912
```jldoctest saving
1013
julia> using Flux
1114
12-
julia> model = Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax)
13-
Chain(
14-
Dense(10 => 5, relu), # 55 parameters
15-
Dense(5 => 2), # 12 parameters
16-
NNlib.softmax,
17-
) # Total: 4 arrays, 67 parameters, 524 bytes.
15+
julia> struct MyModel
16+
net
17+
end
1818
19-
julia> using BSON: @save
19+
julia> Flux.@functor MyModel
2020
21-
julia> @save "mymodel.bson" model
21+
julia> MyModel() = MyModel(Chain(Dense(10, 5, relu), Dense(5, 2)));
22+
23+
julia> model = MyModel()
24+
MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2)))
25+
26+
julia> model_state = Flux.state(model);
27+
28+
julia> using JLD2
29+
30+
julia> jldsave("mymodel.jld2"; model_state)
2231
```
2332

24-
Load it again:
33+
Load it again in a new session using [`Flux.loadmodel!`](@ref):
2534

2635
```jldoctest saving
27-
julia> using Flux # Flux must be loaded before calling @load
36+
julia> using Flux, JLD2
2837
29-
julia> using BSON: @load
38+
julia> model_state = JLD2.load("mymodel.jld2", "model_state");
3039
31-
julia> @load "mymodel.bson" model
40+
julia> model = MyModel(); # MyModel definition must be available
3241
33-
julia> model
34-
Chain(
35-
Dense(10 => 5, relu), # 55 parameters
36-
Dense(5 => 2), # 12 parameters
37-
NNlib.softmax,
38-
) # Total: 4 arrays, 67 parameters, 524 bytes.
42+
julia> Flux.loadmodel!(model, model_state);
3943
```
4044

41-
Models are just normal Julia structs, so it's fine to use any Julia storage
42-
format for this purpose. BSON.jl is particularly well supported and most likely
43-
to be forwards compatible (that is, models saved now will load in future
44-
versions of Flux).
45-
4645
!!! note
4746

4847
If a saved model's parameters are stored on the GPU, the model will not load
4948
later on if there is no GPU support available. It's best to [move your model
5049
to the CPU](gpu.md) with `cpu(model)` before saving it.
5150

52-
!!! warning
53-
54-
Previous versions of Flux suggested saving only the model weights using
55-
`@save "mymodel.bson" params(model)`.
56-
This is no longer recommended and even strongly discouraged.
57-
Saving models this way will only store the trainable parameters which
58-
will result in incorrect behavior for layers like `BatchNorm`.
59-
60-
```julia
61-
julia> using Flux
62-
63-
julia> model = Chain(Dense(10 => 5,relu),Dense(5 => 2),softmax)
64-
Chain(
65-
Dense(10 => 5, relu), # 55 parameters
66-
Dense(5 => 2), # 12 parameters
67-
NNlib.softmax,
68-
) # Total: 4 arrays, 67 parameters, 524 bytes.
69-
70-
julia> weights = Flux.params(model);
71-
```
72-
73-
Loading the model as shown above will return a new model with the stored parameters.
74-
But sometimes you already have a model, and you want to load stored parameters into it.
75-
This can be done as
76-
77-
```julia
78-
using Flux: loadmodel!
79-
using BSON
80-
81-
# some predefined model
82-
model = Chain(Dense(10 => 5, relu), Dense(5 => 2), softmax)
83-
84-
# load one model into another
85-
model = loadmodel!(model, BSON.load("mymodel.bson")[:model])
86-
```
87-
88-
This ensures that the model loaded from `"mymodel.bson"` matches the structure of `model`. [`Flux.loadmodel!`](@ref) is also convenient for copying parameters between models in memory.
89-
90-
```@docs
91-
Flux.loadmodel!
92-
```
9351

9452
## Checkpointing
9553

96-
In longer training runs it's a good idea to periodically save your model, so that you can resume if training is interrupted (for example, if there's a power cut). You can do this by saving the model in the [callback provided to `train!`](training/training.md).
54+
In longer training runs it's a good idea to periodically save your model, so that you can resume if training is interrupted (for example, if there's a power cut).
9755

9856
```jldoctest saving
9957
julia> using Flux: throttle
10058
101-
julia> using BSON: @save
59+
julia> using JLD2
10260
103-
julia> m = Chain(Dense(10 => 5, relu), Dense(5 => 2), softmax)
61+
julia> m = Chain(Dense(10 => 5, relu), Dense(5 => 2))
10462
Chain(
10563
Dense(10 => 5, relu), # 55 parameters
10664
Dense(5 => 2), # 12 parameters
107-
NNlib.softmax,
10865
) # Total: 4 arrays, 67 parameters, 524 bytes.
10966
110-
julia> evalcb = throttle(30) do
111-
# Show loss
112-
@save "model-checkpoint.bson" model
67+
julia> for epoch in 1:10
68+
# ... train model ...
69+
jldsave("model-checkpoint.jld2", model_state = Flux.state(m))
11370
end;
11471
```
11572

116-
This will update the `"model-checkpoint.bson"` file every thirty seconds.
73+
This will update the `"model-checkpoint.jld2"` every epoch.
11774

11875
You can get more advanced by saving a series of models throughout training, for example
11976

12077
```julia
121-
@save "model-$(now()).bson" model
78+
jldsave("model-$(now()).jld2", model_state = Flux.state(m))
12279
```
12380

124-
will produce a series of models like `"model-2018-03-06T02:57:10.41.bson"`. You
81+
will produce a series of models like `"model-2018-03-06T02:57:10.41.jld2"`. You
12582
could also store the current test set loss, so that it's easy to (for example)
12683
revert to an older copy of the model if it starts to overfit.
12784

12885
```julia
129-
@save "model-$(now()).bson" model loss = testloss()
86+
jldsave("model-$(now()).jld2", model_state = Flux.state(m), loss = testloss())
13087
```
13188

132-
Note that to resume a model's training, you might need to restore other stateful parts of your training loop. Possible examples are stateful optimisers (which usually utilize an `IdDict` to store their state), and the randomness used to partition the original data into the training and validation sets.
89+
Note that to resume a model's training, you might need to restore other stateful parts of your training loop. Possible examples are the optimiser state and the randomness used to partition the original data into the training and validation sets.
13390

13491
You can store the optimiser state alongside the model, to resume training
135-
exactly where you left off. BSON is smart enough to [cache values](https://github.com/JuliaIO/BSON.jl/blob/v0.3.4/src/write.jl#L71) and insert links when saving, but only if it knows everything to be saved up front. Thus models and optimisers must be saved together to have the latter work after restoring.
92+
exactly where you left off:
13693

13794
```julia
138-
opt = Adam()
139-
@save "model-$(now()).bson" model opt
95+
model = MyModel()
96+
opt_state = Flux.setup(AdamW(), model)
97+
98+
# ... train model ...
99+
100+
model_state = Flux.state(model)
101+
jldsave("checkpoint_epoch=42.jld2"; model_state, opt_state)
102+
```
103+
104+
# Saving Models as Julia Structs
105+
106+
Models are just normal Julia structs, so it's fine to use any Julia storage
107+
format to save the struct as it is instead of saving the state returned by [`Flux.state`](@ref).
108+
[BSON.jl](https://github.com/JuliaIO/BSON.jl) is particularly convenient for this,
109+
since it can also save anynomous functions, which are sometimes part of a model definition.
110+
111+
Save a model:
112+
113+
```jldoctest saving
114+
julia> using Flux
115+
116+
julia> model = Chain(Dense(10, 5, NNlib.relu), Dense(5, 2));
117+
118+
julia> using BSON: @save
119+
120+
julia> @save "mymodel.bson" model
140121
```
122+
123+
Load it again in a new session:
124+
125+
```jldoctest saving
126+
julia> using Flux, BSON
127+
128+
julia> BSON.@load "mymodel.bson" model
129+
130+
julia> model
131+
Chain(
132+
Dense(10 => 5, relu), # 55 parameters
133+
Dense(5 => 2), # 12 parameters
134+
) # Total: 4 arrays, 67 parameters, 524 bytes.
135+
```
136+
!!! warning
137+
Saving models this way could lead to compatibility issues across julia versions
138+
and across Flux versions if some of the Flux layers' internals are changed.
139+
It is therefore not recommended for long term storage, use [`Flux.state`](@ref) instead.
140+
141+
!!! warning
142+
143+
Previous versions of Flux suggested saving only the model weights using
144+
`@save "mymodel.bson" params(model)`.
145+
This is no longer recommended and even strongly discouraged.
146+
Saving models this way will only store the trainable parameters which
147+
will result in incorrect behavior for layers like `BatchNorm`.

0 commit comments

Comments
 (0)