56
56
57
57
Uses a `loss` function and training `data` to improve the `model`'s parameters
58
58
according to a particular optimisation rule `opt`. Iterates through `data` once,
59
- evaluating `loss(model, d...)` for each `d` in data.
59
+ evaluating for each `d in data` either `loss(model, d...)` if `d isa Tuple`,
60
+ or else `loss(model, d)` for other `d`.
60
61
61
62
For example, with these definitions...
62
63
```
63
- data = [(x1, y1), (x2, y2), (x3, y3)]; # each element must be a tuple
64
+ data = [(x1, y1), (x2, y2), (x3, y3)]
64
65
65
66
loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument
66
67
76
77
```
77
78
You can also write this loop yourself, if you need more flexibility.
78
79
For this reason `train!` is not highly extensible.
79
- It adds only a few featurs to the loop above:
80
+ It adds only a few features to the loop above:
80
81
81
82
* Stop with a `DomainError` if the loss is infinite or `NaN` at any point.
82
83
@@ -88,9 +89,6 @@ It adds only a few featurs to the loop above:
88
89
(This is to move away from Zygote's "implicit" parameter handling, with `Grads`.)
89
90
* Instead of `loss` being a function which accepts only the data,
90
91
now it must also accept the `model` itself, as the first argument.
91
- * `data` must iterate tuples, otherwise you get an error.
92
- (Previously non-tuple types were not splatted into the loss.
93
- Pass in `((d,) for d in data)` to simulate this.)
94
92
* `opt` should be the result of [`Flux.setup`](@ref). Using an optimiser
95
93
such as `Adam()` without this step should give you a warning.
96
94
* Callback functions are not supported.
@@ -100,9 +98,8 @@ function train!(loss, model, data, opt; cb = nothing)
100
98
isnothing (cb) || error (""" train! does not support callback functions.
101
99
For more control use a loop with `gradient` and `update!`.""" )
102
100
@withprogress for (i,d) in enumerate (data)
103
- d isa Tuple || error (""" train! expects as data an iterator producing tuples, but got $(typeof (d)) .
104
- Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""" )
105
- l, gs = Zygote. withgradient (m -> loss (m, d... ), model)
101
+ d_splat = d isa Tuple ? d : (d,)
102
+ l, gs = Zygote. withgradient (m -> loss (m, d_splat... ), model)
106
103
if ! isfinite (l)
107
104
throw (DomainError (" Loss is $l on data item $i , stopping training" ))
108
105
end
@@ -112,8 +109,8 @@ function train!(loss, model, data, opt; cb = nothing)
112
109
end
113
110
114
111
# This method let you use Optimisers.Descent() without setup, when there is no state
115
- function train! (loss, model, data, rule:: Optimisers.AbstractRule )
116
- train! (loss, model, data, _rule_to_state (model, rule))
112
+ function train! (loss, model, data, rule:: Optimisers.AbstractRule ; cb = nothing )
113
+ train! (loss, model, data, _rule_to_state (model, rule); cb )
117
114
end
118
115
119
116
function _rule_to_state (model, rule:: Optimisers.AbstractRule )
0 commit comments