-
-
Notifications
You must be signed in to change notification settings - Fork 615
add Flux.state(x) #2239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
add Flux.state(x) #2239
Changes from 9 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
ec8699d
Flux.state
CarloLucibello db80171
some docs
CarloLucibello f97602b
some docs
CarloLucibello e4e3921
add keep keyword
CarloLucibello a4f9607
update
CarloLucibello e3db0e1
apply suggestions
CarloLucibello a9c59b1
cleanup
CarloLucibello 30cf25a
remove callback from docs
CarloLucibello 532fe94
fix
CarloLucibello 7149292
new proposal pruning missings
CarloLucibello 6e1a5e1
sentinel is empty tuple
CarloLucibello 6895b26
rewording
CarloLucibello b50d1a9
fix loadmodel!
CarloLucibello 1eb89a2
don't drop fields
CarloLucibello 8d45e77
require keys to be equal in loadmodel!
CarloLucibello File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,140 +1,147 @@ | ||
| # Saving and Loading Models | ||
|
|
||
| You may wish to save models so that they can be loaded and run in a later | ||
| session. The easiest way to do this is via | ||
| session. Flux provides a number of ways to do this. | ||
| The recommended way, which is the most robust one for long term storage, | ||
| is to use [`Flux.state`](@ref) in combination with a serialization format like | ||
| [JLD2.jl](https://juliaio.github.io/JLD2.jl/dev/) or | ||
| [BSON.jl](https://github.com/JuliaIO/BSON.jl). | ||
|
|
||
| Save a model: | ||
|
|
||
| ```jldoctest saving | ||
| julia> using Flux | ||
|
|
||
| julia> model = Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax) | ||
| Chain( | ||
| Dense(10 => 5, relu), # 55 parameters | ||
| Dense(5 => 2), # 12 parameters | ||
| NNlib.softmax, | ||
| ) # Total: 4 arrays, 67 parameters, 524 bytes. | ||
| julia> struct MyModel | ||
| net | ||
| end | ||
|
|
||
| julia> using BSON: @save | ||
| julia> Flux.@functor MyModel | ||
|
|
||
| julia> @save "mymodel.bson" model | ||
| julia> MyModel() = MyModel(Chain(Dense(10, 5, relu), Dense(5, 2))) | ||
|
|
||
| julia> model = MyModel() | ||
| MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2))) | ||
|
|
||
| julia> model_state = Flux.state(model); | ||
|
|
||
| julia> using JLD2 | ||
|
|
||
| julia> jldsave("mymodel.jld2"; model_state) | ||
| ``` | ||
|
|
||
| Load it again: | ||
| Load it again in a new session using [`Flux.loadmodel!`](@ref): | ||
|
|
||
| ```jldoctest saving | ||
| julia> using Flux # Flux must be loaded before calling @load | ||
| julia> using Flux, JLD2 | ||
|
|
||
| julia> using BSON: @load | ||
| julia> model_state = JLD2.load("mymodel.jld2", "model_state") | ||
|
|
||
| julia> @load "mymodel.bson" model | ||
| julia> model = MyModel(); # MyModel definition must be available | ||
|
|
||
| julia> model | ||
| Chain( | ||
| Dense(10 => 5, relu), # 55 parameters | ||
| Dense(5 => 2), # 12 parameters | ||
| NNlib.softmax, | ||
| ) # Total: 4 arrays, 67 parameters, 524 bytes. | ||
| julia> Flux.loadmodel!(model, model_state); | ||
| ``` | ||
|
|
||
| Models are just normal Julia structs, so it's fine to use any Julia storage | ||
| format for this purpose. BSON.jl is particularly well supported and most likely | ||
| to be forwards compatible (that is, models saved now will load in future | ||
| versions of Flux). | ||
|
|
||
| !!! note | ||
|
|
||
| If a saved model's parameters are stored on the GPU, the model will not load | ||
| later on if there is no GPU support available. It's best to [move your model | ||
| to the CPU](gpu.md) with `cpu(model)` before saving it. | ||
|
|
||
| !!! warning | ||
|
|
||
| Previous versions of Flux suggested saving only the model weights using | ||
| `@save "mymodel.bson" params(model)`. | ||
| This is no longer recommended and even strongly discouraged. | ||
| Saving models this way will only store the trainable parameters which | ||
| will result in incorrect behavior for layers like `BatchNorm`. | ||
|
|
||
| ```julia | ||
| julia> using Flux | ||
|
|
||
| julia> model = Chain(Dense(10 => 5,relu),Dense(5 => 2),softmax) | ||
| Chain( | ||
| Dense(10 => 5, relu), # 55 parameters | ||
| Dense(5 => 2), # 12 parameters | ||
| NNlib.softmax, | ||
| ) # Total: 4 arrays, 67 parameters, 524 bytes. | ||
|
|
||
| julia> weights = Flux.params(model); | ||
| ``` | ||
|
|
||
| Loading the model as shown above will return a new model with the stored parameters. | ||
| But sometimes you already have a model, and you want to load stored parameters into it. | ||
| This can be done as | ||
|
|
||
| ```julia | ||
| using Flux: loadmodel! | ||
| using BSON | ||
|
|
||
| # some predefined model | ||
| model = Chain(Dense(10 => 5, relu), Dense(5 => 2), softmax) | ||
|
|
||
| # load one model into another | ||
| model = loadmodel!(model, BSON.load("mymodel.bson")[:model]) | ||
| ``` | ||
|
|
||
| This ensures that the model loaded from `"mymodel.bson"` matches the structure of `model`. [`Flux.loadmodel!`](@ref) is also convenient for copying parameters between models in memory. | ||
|
|
||
| ```@docs | ||
| Flux.loadmodel! | ||
| ``` | ||
|
|
||
| ## Checkpointing | ||
|
|
||
| In longer training runs it's a good idea to periodically save your model, so that you can resume if training is interrupted (for example, if there's a power cut). You can do this by saving the model in the [callback provided to `train!`](training/training.md). | ||
| In longer training runs it's a good idea to periodically save your model, so that you can resume if training is interrupted (for example, if there's a power cut). | ||
|
|
||
| ```jldoctest saving | ||
| julia> using Flux: throttle | ||
|
|
||
| julia> using BSON: @save | ||
| julia> using JLD2 | ||
|
|
||
| julia> m = Chain(Dense(10 => 5, relu), Dense(5 => 2), softmax) | ||
| julia> m = Chain(Dense(10 => 5, relu), Dense(5 => 2)) | ||
| Chain( | ||
| Dense(10 => 5, relu), # 55 parameters | ||
| Dense(5 => 2), # 12 parameters | ||
| NNlib.softmax, | ||
| ) # Total: 4 arrays, 67 parameters, 524 bytes. | ||
|
|
||
| julia> evalcb = throttle(30) do | ||
| # Show loss | ||
| @save "model-checkpoint.bson" model | ||
| julia> for epoch in 1:10 | ||
| # ... train model ... | ||
| jldsave("model-checkpoint.jld2", model_state = Flux.state(m)) | ||
CarloLucibello marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| end; | ||
| ``` | ||
|
|
||
| This will update the `"model-checkpoint.bson"` file every thirty seconds. | ||
| This will update the `"model-checkpoint.jld2"` every epoch. | ||
|
|
||
| You can get more advanced by saving a series of models throughout training, for example | ||
|
|
||
| ```julia | ||
| @save "model-$(now()).bson" model | ||
| jldsave("model-$(now()).jld2", model_state = Flux.state(m)) | ||
| ``` | ||
|
|
||
| will produce a series of models like `"model-2018-03-06T02:57:10.41.bson"`. You | ||
| will produce a series of models like `"model-2018-03-06T02:57:10.41.jld2"`. You | ||
| could also store the current test set loss, so that it's easy to (for example) | ||
| revert to an older copy of the model if it starts to overfit. | ||
|
|
||
| ```julia | ||
| @save "model-$(now()).bson" model loss = testloss() | ||
| jldsave("model-$(now()).jld2", model_state = Flux.state(m), loss = testloss()) | ||
| ``` | ||
|
|
||
| Note that to resume a model's training, you might need to restore other stateful parts of your training loop. Possible examples are stateful optimisers (which usually utilize an `IdDict` to store their state), and the randomness used to partition the original data into the training and validation sets. | ||
| Note that to resume a model's training, you might need to restore other stateful parts of your training loop. Possible examples are the optimiser state and the randomness used to partition the original data into the training and validation sets. | ||
|
|
||
| You can store the optimiser state alongside the model, to resume training | ||
| exactly where you left off. BSON is smart enough to [cache values](https://github.com/JuliaIO/BSON.jl/blob/v0.3.4/src/write.jl#L71) and insert links when saving, but only if it knows everything to be saved up front. Thus models and optimisers must be saved together to have the latter work after restoring. | ||
| exactly where you left off: | ||
|
|
||
| ```julia | ||
| opt = Adam() | ||
| @save "model-$(now()).bson" model opt | ||
| model = MyModel() | ||
| opt_state = Flux.setup(AdamW(), model) | ||
|
|
||
| # ... train model ... | ||
|
|
||
| model_state = Flux.state(model) | ||
| jldsave("checkpoint_epoch=42.jld2"; model_state, opt_state) | ||
| ``` | ||
|
|
||
| # Saving Models as Julia Structs | ||
|
|
||
| Models are just normal Julia structs, so it's fine to use any Julia storage | ||
| format to save the struct as it is instead of saving the state returned by [`Flux.state`](@ref). | ||
| [BSON.jl](https://github.com/JuliaIO/BSON.jl) is particularly convenient for this, | ||
| since it can also save anynomous functions, which are sometimes part of a model definition. | ||
|
|
||
| Save a model: | ||
|
|
||
| ```jldoctest saving | ||
| julia> using Flux | ||
|
|
||
| julia> model = Chain(Dense(10, 5, NNlib.relu), Dense(5, 2)); | ||
|
|
||
| julia> using BSON: @save | ||
|
|
||
| julia> @save "mymodel.bson" model | ||
| ``` | ||
|
|
||
| Load it again in a new session: | ||
|
|
||
| ```jldoctest saving | ||
| julia> using Flux, BSON | ||
|
|
||
| julia> BSON.@load "mymodel.bson" model | ||
|
|
||
| julia> model | ||
| Chain( | ||
| Dense(10 => 5, relu), # 55 parameters | ||
| Dense(5 => 2), # 12 parameters | ||
| ) # Total: 4 arrays, 67 parameters, 524 bytes. | ||
| ``` | ||
| !!! warning | ||
| Saving models this way could lead to compatibility issues across julia versions | ||
| and across Flux versions if some of the Flux layers' internals are changed. | ||
| It is therefore not recommended for long term storage, use [`Flux.state`](@ref) instead. | ||
|
|
||
| !!! warning | ||
|
|
||
| Previous versions of Flux suggested saving only the model weights using | ||
| `@save "mymodel.bson" params(model)`. | ||
| This is no longer recommended and even strongly discouraged. | ||
| Saving models this way will only store the trainable parameters which | ||
| will result in incorrect behavior for layers like `BatchNorm`. | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.