Skip to content

Commit 4c9e98b

Browse files
authored
Add support for Optimisers.jl (#114)
* Add Optimisers.jl dep * Add Optimisers.jl support * Add optimiser dispatch for `gradient`
1 parent 6780869 commit 4c9e98b

File tree

8 files changed

+48
-13
lines changed

8 files changed

+48
-13
lines changed

CHANGELOG.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11

22
# News
33

4-
## [0.3.0] - 04.04.2022
4+
## Unreleased
5+
6+
### Added
57

8+
- Support for [Optimisers.jl](https://github.com/FluxML/Optimisers.jl) https://github.com/FluxML/FluxTraining.jl/pull/114.
9+
10+
## [0.3.0] - 04.04.2022
611

712
### Added
813

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1313
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
1414
InlineTest = "bd334432-b1e7-49c7-a2dc-dd9149e4ebd6"
1515
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
16+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1617
ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e"
1718
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
1819
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
@@ -34,8 +35,8 @@ Graphs = "1"
3435
ImageCore = "0.8, 0.9"
3536
InlineTest = "0.2"
3637
OnlineStats = "1.5"
37-
Parameters = "0.12"
3838
ParameterSchedulers = "0.3.1"
39+
Parameters = "0.12"
3940
PrettyTables = "1, 1.1, 1.2"
4041
ProgressMeter = "1.4"
4142
Reexport = "1.0"

src/FluxTraining.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ module ES
1515
end
1616
import OnlineStats
1717
using OnlineStats: EqualWeight, Mean, OnlineStat
18+
import Optimisers
1819
using Parameters
1920
using ProgressMeter: Progress, next!
2021
using Statistics: mean

src/learner.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ mutable struct Learner
1616
data::PropDict
1717
optimizer
1818
lossfn
19+
# this used to store `Flux.Params` but now stores the optimiser state
20+
# if an optim from Optimisers.jl is used
1921
params
2022
step::PropDict
2123
callbacks::Callbacks
@@ -96,7 +98,7 @@ function Learner(
9698
_dataiters(data),
9799
optimizer,
98100
lossfn,
99-
paramsrec(model),
101+
setupoptimstate(model, optimizer),
100102
PropDict(),
101103
cbs,
102104
PropDict())
@@ -129,9 +131,15 @@ phasedataiter(::AbstractValidationPhase) = :validation
129131

130132
function model!(learner, model)
131133
learner.model = model
132-
learner.params = paramsrec(model)
134+
learner.params = setupoptimstate(model, learner.optimizer)
133135
end
134136

137+
# Flux.jl optimisers store `params`, while Optimisers.jl store the result of `setup`
138+
setupoptimstate(model, ::Flux.Optimise.AbstractOptimiser) = Flux.params(model)
139+
# Optimisers.jl has no abstract supertype so we assume non-Flux optimisers
140+
# conform to the Optimisers.jl interface.
141+
setupoptimstate(model, optim) = Optimisers.setup(optim, model)
142+
135143

136144
_dataiters(d::PropDict) = d
137145
_dataiters(t::NamedTuple) = PropDict(pairs(t))
@@ -146,9 +154,3 @@ function _dataiters(t::Tuple)
146154
error("Please pass a `NamedTuple` or `PropDict` as `data`.")
147155
end
148156
end
149-
150-
151-
paramsrec(m) = Flux.params(m)
152-
paramsrec(t::Union{Tuple,NamedTuple}) = map(paramsrec, t)
153-
154-
# Callback utilities

src/training.jl

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,36 @@ function step! end
4949
function step!(learner, phase::TrainingPhase, batch)
5050
xs, ys = batch
5151
runstep(learner, phase, (; xs=xs, ys=ys)) do handle, state
52-
state.grads = gradient(learner.params) do
53-
state.ŷs = learner.model(state.xs)
52+
53+
state.grads = _gradient(learner.optimizer, learner.model, learner.params) do model
54+
state.ŷs = model(state.xs)
5455
handle(LossBegin())
5556
state.loss = learner.lossfn(state.ŷs, state.ys)
5657
handle(BackwardBegin())
5758
return state.loss
5859
end
5960
handle(BackwardEnd())
60-
update!(learner.optimizer, learner.params, state.grads)
61+
learner.params, learner.model = _update!(
62+
learner.optimizer, learner.params, learner.model, state.grads)
6163
end
6264
end
6365

6466

67+
# Handle both old Flux.jl and new Optimisers.jl optimisers
68+
69+
_gradient(f, _, m, _) = gradient(f, m)[1]
70+
_gradient(f, ::Flux.Optimise.AbstractOptimiser, m, ps::Params) = gradient(() -> f(m), ps)
71+
72+
function _update!(optimizer::Flux.Optimise.AbstractOptimiser, params, model, grads)
73+
update!(optimizer, params, grads)
74+
return params, model
75+
end
76+
function _update!(_, st, model, grads)
77+
st, model = Optimisers.update!(st, model, grads)
78+
return st, model
79+
end
80+
81+
6582
function step!(learner, phase::ValidationPhase, batch)
6683
xs, ys = batch
6784
runstep(learner, phase, (;xs=xs, ys=ys)) do _, state

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.1.0"
66
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
77
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
88
ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19"
9+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
910
ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e"
1011
ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89"
1112
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"

test/imports.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using ReTest
2+
import Optimisers
23
using FluxTraining
34
using ParameterSchedulers
45
using Colors

test/training.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,10 @@ end
4747
fit!(learner, 5)
4848
@test learner.model.coeff[1] 3 atol = 0.1
4949
end
50+
51+
52+
@testset "Optimisers.jl compatibility" begin
53+
learner = testlearner(coeff = 3, opt=Optimisers.Descent(0.001))
54+
fit!(learner, 5)
55+
@test learner.model.coeff[1] 3 atol = 0.1
56+
end

0 commit comments

Comments
 (0)