Skip to content

Commit a5e5546

Browse files
allow non-tuple data in the new train! (#2119)
* allow non-tuple data * cl/batchme * add tests * test multiple callback * cleanup notes * cleanup * cleanup * remove callbacks * cleanup * Update src/train.jl Co-authored-by: Kyle Daruwalla <[email protected]> Co-authored-by: Kyle Daruwalla <[email protected]>
1 parent da8ce81 commit a5e5546

File tree

3 files changed

+16
-15
lines changed

3 files changed

+16
-15
lines changed

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2121
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2222
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2323
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
24-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2524
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2625

2726
[compat]

src/train.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,12 @@ end
5656
5757
Uses a `loss` function and training `data` to improve the `model`'s parameters
5858
according to a particular optimisation rule `opt`. Iterates through `data` once,
59-
evaluating `loss(model, d...)` for each `d` in data.
59+
evaluating for each `d in data` either `loss(model, d...)` if `d isa Tuple`,
60+
or else `loss(model, d)` for other `d`.
6061
6162
For example, with these definitions...
6263
```
63-
data = [(x1, y1), (x2, y2), (x3, y3)]; # each element must be a tuple
64+
data = [(x1, y1), (x2, y2), (x3, y3)]
6465
6566
loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument
6667
@@ -76,7 +77,7 @@ end
7677
```
7778
You can also write this loop yourself, if you need more flexibility.
7879
For this reason `train!` is not highly extensible.
79-
It adds only a few featurs to the loop above:
80+
It adds only a few features to the loop above:
8081
8182
* Stop with a `DomainError` if the loss is infinite or `NaN` at any point.
8283
@@ -88,9 +89,6 @@ It adds only a few featurs to the loop above:
8889
(This is to move away from Zygote's "implicit" parameter handling, with `Grads`.)
8990
* Instead of `loss` being a function which accepts only the data,
9091
now it must also accept the `model` itself, as the first argument.
91-
* `data` must iterate tuples, otherwise you get an error.
92-
(Previously non-tuple types were not splatted into the loss.
93-
Pass in `((d,) for d in data)` to simulate this.)
9492
* `opt` should be the result of [`Flux.setup`](@ref). Using an optimiser
9593
such as `Adam()` without this step should give you a warning.
9694
* Callback functions are not supported.
@@ -100,9 +98,8 @@ function train!(loss, model, data, opt; cb = nothing)
10098
isnothing(cb) || error("""train! does not support callback functions.
10199
For more control use a loop with `gradient` and `update!`.""")
102100
@withprogress for (i,d) in enumerate(data)
103-
d isa Tuple || error("""train! expects as data an iterator producing tuples, but got $(typeof(d)).
104-
Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""")
105-
l, gs = Zygote.withgradient(m -> loss(m, d...), model)
101+
d_splat = d isa Tuple ? d : (d,)
102+
l, gs = Zygote.withgradient(m -> loss(m, d_splat...), model)
106103
if !isfinite(l)
107104
throw(DomainError("Loss is $l on data item $i, stopping training"))
108105
end
@@ -112,8 +109,8 @@ function train!(loss, model, data, opt; cb = nothing)
112109
end
113110

114111
# This method let you use Optimisers.Descent() without setup, when there is no state
115-
function train!(loss, model, data, rule::Optimisers.AbstractRule)
116-
train!(loss, model, data, _rule_to_state(model, rule))
112+
function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing)
113+
train!(loss, model, data, _rule_to_state(model, rule); cb)
117114
end
118115

119116
function _rule_to_state(model, rule::Optimisers.AbstractRule)

test/train.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,15 @@ end
4444
@test CNT == 51 # stopped early
4545
@test m1.weight[1] -5 # did not corrupt weights
4646
end
47-
@testset "data must give tuples" begin
48-
m1 = Dense(1 => 1)
49-
@test_throws ErrorException Flux.train!((args...,) -> 1, m1, [(x=1, y=2) for _ in 1:3], Descent(0.1))
47+
48+
@testset "non-tuple data" begin
49+
loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
50+
model = (weight=copy(w2), bias=zeros(10))
51+
opt = Flux.setup(AdamW(), model)
52+
Flux.train!(loss, model, (rand(10) for _ in 1: 10^5), opt)
53+
@test loss(model, rand(10, 10)) < 0.01
5054
end
55+
5156
@testset "callbacks give helpful error" begin
5257
m1 = Dense(1 => 1)
5358
cb = () -> println("this should not be printed")

0 commit comments

Comments
 (0)