Skip to content
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
# Flux Release Notes

<<<<<<< HEAD
See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release.

## v0.13.16
* Most greek-letter keyword arguments are deprecated in favour of ascii.
Thus `LayerNorm(3; ϵ=1e-4)` (not `ε`!) should become `LayerNorm(3; eps=1e-4)`.
* `DataLoader(...) |> gpu` will now produce a special iterator, moving each batch as needed,
instead of giving an error.
* Added `Flux.state` returning the internal state of the model for serialization.

## v0.13.15
* Added [MultiHeadAttention](https://github.com/FluxML/Flux.jl/pull/2146) layer.
* `f16, f32, f64` now specifically target floating point arrays (i.e. integers arrays and other types are preserved).
* `f16, f32, f64` can now handle `Complex{<:AbstractFloat}` arrays.
* Added `EmbeddingBag` layer
* Added `EmbeddingBag` layer.

## v0.13.14
* Fixed various deprecation warnings, from `Zygone.@nograd` and `Vararg`.
Expand Down
12 changes: 6 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"

[extensions]
AMDGPUExt = "AMDGPU"

[compat]
AMDGPU = "0.4.13"
Adapt = "3.0"
Expand All @@ -44,9 +50,6 @@ Zygote = "0.6.49"
cuDNN = "1"
julia = "1.6"

[extensions]
AMDGPUExt = "AMDGPU"

[extras]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Expand All @@ -59,6 +62,3 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "BSON"]

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Expand Down
4 changes: 3 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers, OneHotArrays, Zygote, ChainRulesCore, Plots, MLDatasets, Statistics, DataFrames
using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers,
OneHotArrays, Zygote, ChainRulesCore, Plots, MLDatasets, Statistics,
DataFrames, JLD2

DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true)

Expand Down
7 changes: 7 additions & 0 deletions docs/src/destructure.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,11 @@ Another kind of flat view of a nested model is provided by the `modules` command

```@docs
Flux.modules
```

### Save and Load

```@docs
Flux.state
Flux.loadmodel!
```
167 changes: 87 additions & 80 deletions docs/src/saving.md
Original file line number Diff line number Diff line change
@@ -1,140 +1,147 @@
# Saving and Loading Models

You may wish to save models so that they can be loaded and run in a later
session. The easiest way to do this is via
session. Flux provides a number of ways to do this.
The recommended way, which is the most robust one for long term storage,
is to use [`Flux.state`](@ref) in combination with a serialization format like
[JLD2.jl](https://juliaio.github.io/JLD2.jl/dev/) or
[BSON.jl](https://github.com/JuliaIO/BSON.jl).

Save a model:

```jldoctest saving
julia> using Flux

julia> model = Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax)
Chain(
Dense(10 => 5, relu), # 55 parameters
Dense(5 => 2), # 12 parameters
NNlib.softmax,
) # Total: 4 arrays, 67 parameters, 524 bytes.
julia> struct MyModel
net
end

julia> using BSON: @save
julia> Flux.@functor MyModel

julia> @save "mymodel.bson" model
julia> MyModel() = MyModel(Chain(Dense(10, 5, relu), Dense(5, 2)))

julia> model = MyModel()
MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2)))

julia> model_state = Flux.state(model);

julia> using JLD2

julia> jldsave("mymodel.jld2"; model_state)
```

Load it again:
Load it again in a new session using [`Flux.loadmodel!`](@ref):

```jldoctest saving
julia> using Flux # Flux must be loaded before calling @load
julia> using Flux, JLD2

julia> using BSON: @load
julia> model_state = JLD2.load("mymodel.jld2", "model_state")

julia> @load "mymodel.bson" model
julia> model = MyModel(); # MyModel definition must be available

julia> model
Chain(
Dense(10 => 5, relu), # 55 parameters
Dense(5 => 2), # 12 parameters
NNlib.softmax,
) # Total: 4 arrays, 67 parameters, 524 bytes.
julia> Flux.loadmodel!(model, model_state);
```

Models are just normal Julia structs, so it's fine to use any Julia storage
format for this purpose. BSON.jl is particularly well supported and most likely
to be forwards compatible (that is, models saved now will load in future
versions of Flux).

!!! note

If a saved model's parameters are stored on the GPU, the model will not load
later on if there is no GPU support available. It's best to [move your model
to the CPU](gpu.md) with `cpu(model)` before saving it.

!!! warning

Previous versions of Flux suggested saving only the model weights using
`@save "mymodel.bson" params(model)`.
This is no longer recommended and even strongly discouraged.
Saving models this way will only store the trainable parameters which
will result in incorrect behavior for layers like `BatchNorm`.

```julia
julia> using Flux

julia> model = Chain(Dense(10 => 5,relu),Dense(5 => 2),softmax)
Chain(
Dense(10 => 5, relu), # 55 parameters
Dense(5 => 2), # 12 parameters
NNlib.softmax,
) # Total: 4 arrays, 67 parameters, 524 bytes.

julia> weights = Flux.params(model);
```

Loading the model as shown above will return a new model with the stored parameters.
But sometimes you already have a model, and you want to load stored parameters into it.
This can be done as

```julia
using Flux: loadmodel!
using BSON

# some predefined model
model = Chain(Dense(10 => 5, relu), Dense(5 => 2), softmax)

# load one model into another
model = loadmodel!(model, BSON.load("mymodel.bson")[:model])
```

This ensures that the model loaded from `"mymodel.bson"` matches the structure of `model`. [`Flux.loadmodel!`](@ref) is also convenient for copying parameters between models in memory.

```@docs
Flux.loadmodel!
```

## Checkpointing

In longer training runs it's a good idea to periodically save your model, so that you can resume if training is interrupted (for example, if there's a power cut). You can do this by saving the model in the [callback provided to `train!`](training/training.md).
In longer training runs it's a good idea to periodically save your model, so that you can resume if training is interrupted (for example, if there's a power cut).

```jldoctest saving
julia> using Flux: throttle

julia> using BSON: @save
julia> using JLD2

julia> m = Chain(Dense(10 => 5, relu), Dense(5 => 2), softmax)
julia> m = Chain(Dense(10 => 5, relu), Dense(5 => 2))
Chain(
Dense(10 => 5, relu), # 55 parameters
Dense(5 => 2), # 12 parameters
NNlib.softmax,
) # Total: 4 arrays, 67 parameters, 524 bytes.

julia> evalcb = throttle(30) do
# Show loss
@save "model-checkpoint.bson" model
julia> for epoch in 1:10
# ... train model ...
jldsave("model-checkpoint.jld2", model_state = Flux.state(m))
end;
```

This will update the `"model-checkpoint.bson"` file every thirty seconds.
This will update the `"model-checkpoint.jld2"` every epoch.

You can get more advanced by saving a series of models throughout training, for example

```julia
@save "model-$(now()).bson" model
jldsave("model-$(now()).jld2", model_state = Flux.state(m))
```

will produce a series of models like `"model-2018-03-06T02:57:10.41.bson"`. You
will produce a series of models like `"model-2018-03-06T02:57:10.41.jld2"`. You
could also store the current test set loss, so that it's easy to (for example)
revert to an older copy of the model if it starts to overfit.

```julia
@save "model-$(now()).bson" model loss = testloss()
jldsave("model-$(now()).jld2", model_state = Flux.state(m), loss = testloss())
```

Note that to resume a model's training, you might need to restore other stateful parts of your training loop. Possible examples are stateful optimisers (which usually utilize an `IdDict` to store their state), and the randomness used to partition the original data into the training and validation sets.
Note that to resume a model's training, you might need to restore other stateful parts of your training loop. Possible examples are the optimiser state and the randomness used to partition the original data into the training and validation sets.

You can store the optimiser state alongside the model, to resume training
exactly where you left off. BSON is smart enough to [cache values](https://github.com/JuliaIO/BSON.jl/blob/v0.3.4/src/write.jl#L71) and insert links when saving, but only if it knows everything to be saved up front. Thus models and optimisers must be saved together to have the latter work after restoring.
exactly where you left off:

```julia
opt = Adam()
@save "model-$(now()).bson" model opt
model = MyModel()
opt_state = Flux.setup(AdamW(), model)

# ... train model ...

model_state = Flux.state(model)
jldsave("checkpoint_epoch=42.jld2"; model_state, opt_state)
```

# Saving Models as Julia Structs

Models are just normal Julia structs, so it's fine to use any Julia storage
format to save the struct as it is instead of saving the state returned by [`Flux.state`](@ref).
[BSON.jl](https://github.com/JuliaIO/BSON.jl) is particularly convenient for this,
since it can also save anynomous functions, which are sometimes part of a model definition.

Save a model:

```jldoctest saving
julia> using Flux

julia> model = Chain(Dense(10, 5, NNlib.relu), Dense(5, 2));

julia> using BSON: @save

julia> @save "mymodel.bson" model
```

Load it again in a new session:

```jldoctest saving
julia> using Flux, BSON

julia> BSON.@load "mymodel.bson" model

julia> model
Chain(
Dense(10 => 5, relu), # 55 parameters
Dense(5 => 2), # 12 parameters
) # Total: 4 arrays, 67 parameters, 524 bytes.
```
!!! warning
Saving models this way could lead to compatibility issues across julia versions
and across Flux versions if some of the Flux layers' internals are changed.
It is therefore not recommended for long term storage, use [`Flux.state`](@ref) instead.

!!! warning

Previous versions of Flux suggested saving only the model weights using
`@save "mymodel.bson" params(model)`.
This is no longer recommended and even strongly discouraged.
Saving models this way will only store the trainable parameters which
will result in incorrect behavior for layers like `BatchNorm`.
70 changes: 70 additions & 0 deletions src/loading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ Non-array elements (such as activation functions) are not copied and need not ma
Zero bias vectors and `bias=false` are considered equivalent
(see extended help for more details).

See also [`Flux.state`](@ref).

# Examples
```julia
julia> dst = Chain(Dense(Flux.ones32(2, 5), Flux.ones32(2), tanh), Dense(2 => 1; bias = [1f0]))
Expand Down Expand Up @@ -104,3 +106,71 @@ function loadmodel!(dst, src; filter = _ -> true, cache = Base.IdSet())

return dst
end

"""
state(x)

Return an object with the same nested structure as `x` according to `Functors.children`,
but made only of basic containers (e.g. named tuples, tuples, arrays, and dictionaries).

Besides trainable and non-trainable arrays, the state will contain leaf nodes that are not arrays,
such as numbers, symbols, strings, and nothing values. The leaf types that end up in the state
could increase in the future.

This method is particularly useful for saving and loading models,
since the state contain only simple data types that can be easily serialized.

The state can be passed to [`loadmodel!`](@ref) to restore the model.

# Examples

## Copy the state into another model

```jldoctest
julia> m1 = Chain(Dense(1, 2, tanh; init=ones), Dense(2, 1; init=ones));

julia> s = Flux.state(m1)
(layers = ((weight = [1.0; 1.0;;], bias = [0.0, 0.0], σ = missing), (weight = [1.0 1.0], bias = [0.0], σ = missing)),)

julia> m2 = Chain(Dense(1, 2, tanh), Dense(2, 1; bias=false)); # weights are random numbers

julia> Flux.loadmodel!(m2, s);

julia> m2[1].weight # now the weights of m2 are the same as m1
2×1 Matrix{Float32}:
1.0
1.0

julia> Flux.state(trainmode!(Dropout(0.2))) # contains p & activity, but not RNG state
(p = 0.2, dims = missing, active = true, rng = missing)

julia> Flux.state(BatchNorm(1)) # contains non-trainable arrays μ, σ²
(λ = missing, β = Float32[0.0], γ = Float32[1.0], μ = Float32[0.0], σ² = Float32[1.0], ϵ = 1.0f-5, momentum = 0.1f0, affine = true, track_stats = true, active = nothing, chs = 1)
```

## Save and load with BSON

```julia-repl
julia> using BSON

julia> BSON.@save "checkpoint.bson" model_state = s

julia> Flux.loadmodel!(m2, BSON.load("checkpoint.bson")[:model_state])
```

## Save and load with JLD2

```julia-repl
julia> using JLD2

julia> JLD2.jldsave("checkpoint.jld2", model_state = s)

julia> Flux.loadmodel!(m2, JLD2.load("checkpoint.jld2", "model_state"))
```
"""
state(x) = Functors.fmapstructure(_state, x)

const STATE_TYPES = Union{AbstractArray, Number, Nothing, AbstractString, Symbol}

_state(x::STATE_TYPES) = x
_state(x) = missing
Loading