Skip to content

Commit 5bdf443

Browse files
don't test Optimise module
remove Flux.params from tests broken deprecation in __old_to_new pen2 remove params entirely export Optimisers cleanup fix ambiguity comment
1 parent 7525499 commit 5bdf443

18 files changed

+211
-420
lines changed

src/Flux.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ using MLUtils
1212
const stack = MLUtils.stack # now exported by Base
1313
import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions
1414
using Optimisers: freeze!, thaw!, adjust!, trainables
15+
@reexport using Optimisers
16+
1517
using Random: default_rng
1618
using Zygote, ChainRulesCore
1719
using Zygote: Params, @adjoint, gradient, pullback
@@ -56,13 +58,8 @@ export Chain, Dense, Embedding, EmbeddingBag,
5658
))
5759

5860
include("optimise/Optimise.jl")
59-
using .Optimise
60-
export Descent, Adam, Momentum, Nesterov, RMSProp,
61-
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, OAdam,
62-
AdamW, RAdam, AdaBelief, InvDecay, ExpDecay,
63-
WeightDecay, SignDecay, ClipValue, ClipNorm
64-
65-
export ClipGrad, OptimiserChain # these are const defined in deprecations, for ClipValue, Optimiser
61+
using .Optimise: Optimise
62+
export ClipValue # this is const defined in deprecations, for ClipGrad
6663

6764
include("train.jl")
6865
using .Train

src/deprecations.jl

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,31 +41,40 @@ train!(loss, ps::Params, data, opt::Optimisers.AbstractRule; cb=nothing) = error
4141
""")
4242

4343
train!(loss, model, data, opt::Optimise.AbstractOptimiser; cb=nothing) =
44-
train!(loss, model, data, _old_to_new(opt); cb)
44+
train!(loss, model, data, __old_to_new(opt); cb)
4545

4646
# Next, to use the new `setup` with the still-exported old-style `Adam` etc:
4747
import .Train: setup
48-
setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model)
48+
setup(rule::Optimise.AbstractOptimiser, model) = setup(__old_to_new(rule), model)
4949
# ... and allow accidental use of `Optimisers.setup` to do the same:
50-
Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model)
50+
Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = setup(__old_to_new(rule), model)
51+
52+
53+
function __old_to_new(rule)
54+
Base.depwarn("""Optimisers from Flux.Optimise module are deprecated.
55+
Use optimisers from Optimisers.jl instead.""", :__old_to_new)
56+
return _old_to_new(rule)
57+
end
5158

5259
for T in [:Descent, :Adam, :Momentum, :Nesterov,
5360
:AdaGrad, :AdaMax, :AdaDelta, :AMSGrad, :NAdam, :RAdam, :OAdam, :AdaBelief,
5461
# :InvDecay, :ExpDecay,
5562
:SignDecay,
5663
]
57-
@eval function _old_to_new(rule::$T)
64+
@eval function _old_to_new(rule::Optimise.$T)
5865
args = map(f -> getfield(rule, f), fieldnames(Optimisers.$T))
5966
Optimisers.$T(args...)
6067
end
6168
end
62-
_old_to_new(rule::Optimiser) = Optimisers.OptimiserChain(map(_old_to_new, rule.os)...)
63-
const OptimiserChain = Optimise.Optimiser # lets you use new name with implicit params too.
64-
_old_to_new(rule::WeightDecay) = Optimisers.WeightDecay(rule.wd) # called lambda now
65-
_old_to_new(rule::ClipNorm) = Optimisers.ClipNorm(rule.thresh) # called omega, and there are more fields
66-
_old_to_new(rule::ClipValue) = Optimisers.ClipGrad(rule.thresh) # called delta now, and struct name differs
67-
const ClipGrad = Optimise.ClipValue
68-
_old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon) # RMSProp has no field centred
69+
_old_to_new(rule::Optimise.Optimiser) = Optimisers.OptimiserChain(map(_old_to_new, rule.os)...)
70+
# const OptimiserChain = Optimise.Optimiser # lets you use new name with implicit params too.
71+
const Optimiser = Optimisers.OptimiserChain
72+
_old_to_new(rule::Optimise.WeightDecay) = Optimisers.WeightDecay(rule.wd) # called lambda now
73+
_old_to_new(rule::Optimise.ClipNorm) = Optimisers.ClipNorm(rule.thresh) # called omega, and there are more fields
74+
_old_to_new(rule::Optimise.ClipValue) = Optimisers.ClipGrad(rule.thresh) # called delta now, and struct name differs
75+
# const ClipGrad = Optimise.ClipValue
76+
const ClipValue = Optimisers.ClipGrad
77+
_old_to_new(rule::Optimise.RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon) # RMSProp has no field centred
6978

7079
_old_to_new(rule) = error("Flux.setup does not know how to translate this old-style implicit rule to a new-style Optimisers.jl explicit rule")
7180

@@ -83,8 +92,21 @@ function update!(opt::Optimise.AbstractOptimiser, model, grad)
8392
# to accept only arrays. Remove if this causes problems!
8493
# update!(opt::Flux.Optimise.AbstractOptimiser, x::AbstractArray, x̄)
8594
error("""Invalid input to `update!`.
86-
* For the implicit style, this needs `update(::AbstractOptimiser, ::Params, ::Grads)`
87-
* For the explicit style, `update(state, model, grad)` needs `state = Flux.setup(opt, model)`.
95+
* For the implicit style, this needs `update!(::AbstractOptimiser, ::Params, ::Grads)`
96+
* For the explicit style, `update!(state, model, grad)` needs `state = Flux.setup(opt, model)`.
97+
""")
98+
end
99+
100+
# TODO this friendly error should go in Optimisers.jl.
101+
# remove after https://github.com/FluxML/Optimisers.jl/pull/181
102+
function update!(opt::Optimisers.AbstractRule, model, grad)
103+
error("""Invalid input to `update!`.
104+
`update!(state, model, grad)` needs `state = Flux.setup(opt, model)`.
105+
""")
106+
end
107+
function update!(opt::Optimisers.AbstractRule, model::Chain, grad::Tuple)
108+
error("""Invalid input to `update!`.
109+
`update!(state, model, grad)` needs `state = Flux.setup(opt, model)`.
88110
""")
89111
end
90112

src/layers/conv.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ Conv((3,), 4 => 5, σ) # 65 parameters
145145
julia> layer(randn(100, 4, 64)) |> size
146146
(98, 5, 64)
147147
148-
julia> Flux.params(layer) |> length
148+
julia> Flux.trainables(layer) |> length
149149
2
150150
```
151151
"""
@@ -294,7 +294,7 @@ ConvTranspose((3,), 5 => 4, σ) # 64 parameters
294294
julia> layer(randn(100, 5, 64)) |> size # transposed convolution will increase the dimension size (upsampling)
295295
(102, 4, 64)
296296
297-
julia> Flux.params(layer) |> length
297+
julia> Flux.trainables(layer) |> length
298298
2
299299
```
300300
"""

src/layers/show.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,15 @@ function _layer_show(io::IO, layer, indent::Int=0, name=nothing)
104104
_str = isnothing(name) ? "" : "$name = "
105105
str = _str * _layer_string(io, layer)
106106
print(io, " "^indent, str, indent==0 ? "" : ",")
107-
if !isempty(params(layer))
107+
if !isempty(trainables(layer))
108108
print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str)))
109-
printstyled(io, "# ", underscorise(sum(length, params(layer); init=0)), " parameters";
109+
printstyled(io, "# ", underscorise(sum(length, trainables(layer); init=0)), " parameters";
110110
color=:light_black)
111-
nonparam = _childarray_sum(length, layer) - sum(length, params(layer), init=0)
111+
nonparam = _childarray_sum(length, layer) - sum(length, trainables(layer), init=0)
112112
if nonparam > 0
113113
printstyled(io, ", plus ", underscorise(nonparam), indent==0 ? " non-trainable" : ""; color=:light_black)
114114
end
115-
_nan_show(io, params(layer))
115+
_nan_show(io, trainables(layer))
116116
end
117117
indent==0 || println(io)
118118
end
@@ -127,7 +127,7 @@ function _layer_string(::IO, a::AbstractArray)
127127
end
128128

129129
function _big_finale(io::IO, m)
130-
ps = params(m)
130+
ps = trainables(m)
131131
if length(ps) > 2
132132
pars = underscorise(sum(length, ps; init=0))
133133
bytes = Base.format_bytes(Base.summarysize(m))

src/outputsize.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,6 @@ function ChainRulesCore.rrule(::typeof(striplazy), m)
302302
striplazy(m), _ -> error("striplazy should never be used within a gradient")
303303
end
304304

305-
params!(p::Params, x::LazyLayer, seen = IdSet()) = error("LazyLayer should never be used within params(m). Call striplazy(m) first.")
306-
307305
Functors.functor(::Type{<:LazyLayer}, x) = error("LazyLayer should not be walked with Functors.jl, as the arrays which Flux.gpu wants to move may not exist yet.")
308306

309307
function Base.show(io::IO, l::LazyLayer)

test/data.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,18 +80,20 @@ using Random
8080
# test interaction with `train!`
8181
θ = ones(2)
8282
X = zeros(2, 10)
83-
loss(x) = sum((x .- θ).^2)
83+
loss(θ, x) = sum((x .- θ).^2)
8484
d = DataLoader(X)
85-
Flux.train!(loss, Params([θ]), ncycle(d, 10), Descent(0.1))
85+
opt_state = Flux.setup(Descent(0.1), θ)
86+
Flux.train!(loss, θ, ncycle(d, 10), opt_state)
8687
@test norm(θ) < 1e-4
8788

8889
# test interaction with `train!`
8990
θ = zeros(2)
9091
X = ones(2, 10)
9192
Y = fill(2, 10)
92-
loss(x, y) = sum((y - x'*θ).^2)
93+
loss(θ, x, y) = sum((y - x'*θ).^2)
9394
d = DataLoader((X, Y))
94-
Flux.train!(loss, Params([θ]), ncycle(d, 10), Descent(0.1))
95+
opt_state = Flux.setup(Descent(0.1), θ)
96+
Flux.train!(loss, θ, ncycle(d, 10), opt_state)
9597
@test norm.- 1) < 1e-10
9698

9799
# specify the rng

test/ext_cuda/cuda.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ CUDA.allowscalar(false)
2121
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
2222
cm = gpu(m)
2323

24-
@test all(p isa CuArray for p in Flux.params(cm))
24+
@test all(p isa CuArray for p in Flux.trainables(cm))
2525
@test cm(gpu(rand(10, 10))) isa CuArray{Float32,2}
2626

2727
xs = rand(5, 5)

test/ext_cuda/curnn.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
2+
@testset "RNN" begin
3+
@testset for R in [RNN, GRU, LSTM, GRUv3], batch_size in (1, 5)
4+
rnn = R(10, 5)
5+
curnn = rnn |> gpu
6+
7+
Flux.reset!(rnn)
8+
Flux.reset!(curnn)
9+
x = batch_size == 1 ?
10+
rand(Float32, 10) :
11+
rand(Float32, 10, batch_size)
12+
cux = gpu(x)
13+
14+
y, back = pullback((r, x) -> r(x), rnn, x)
15+
cuy, cuback = pullback((r, x) -> r(x), curnn, cux)
16+
17+
@test y collect(cuy)
18+
19+
= randn(size(y))
20+
m̄, x̄ = back(ȳ)
21+
cum̄, cux̄ = cuback(gpu(ȳ))
22+
23+
@test collect(cux̄)
24+
@test m̄[].cell.Wi collect(cum̄[].cell.Wi)
25+
@test m̄[].cell.Wh collect(cum̄[].cell.Wh)
26+
@test m̄[].cell.b collect(cum̄[].cell.b)
27+
if m̄[].state isa Tuple
28+
for (x, cx) in zip(m̄[].state, cum̄[].state)
29+
@test x collect(cx)
30+
end
31+
else
32+
@test m̄[].state collect(cum̄[].state)
33+
end
34+
35+
Flux.reset!(rnn)
36+
Flux.reset!(curnn)
37+
ohx = batch_size == 1 ?
38+
Flux.onehot(rand(1:10), 1:10) :
39+
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
40+
cuohx = gpu(ohx)
41+
y = (rnn(ohx); rnn(ohx))
42+
43+
cuy = (curnn(cuohx); curnn(cuohx))
44+
@test y collect(cuy)
45+
46+
Flux.reset!(rnn)
47+
Flux.reset!(curnn)
48+
fx = rand(Float32, 10, batch_size, 3)
49+
cufx = gpu(fx)
50+
fy = (rnn(fx); rnn(fx))
51+
52+
cufy = (curnn(cufx); curnn(cufx))
53+
@test fy collect(cufy)
54+
end
55+
end

test/ext_cuda/layers.jl

Lines changed: 20 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -110,17 +110,17 @@ end
110110
l = cl((2,2), 1=>3, bias = false) |> gpu
111111
ip = zeros(Float32, 28,28,1,1) |> gpu
112112
@test sum(l(ip)) 0.f0
113-
gs = gradient(() -> sum(l(ip)), Flux.params(l))
114-
@test l.bias gs.params
113+
gs = gradient(l -> sum(l(ip)), l)[1]
114+
@test gs.bias === nothing
115115
end
116116

117117
@testset "Dense without bias" begin
118118
l = Dense(ones(Float32, 4, 3), false) |> gpu
119119
ip = zeros(Float32, 3, 7) |> gpu
120120

121121
@test sum(l(ip)) 0.f0
122-
gs = gradient(() -> sum(l(ip)), Flux.params(l))
123-
@test l.bias gs.params
122+
gs = gradient(l -> sum(l(ip)), l)[1]
123+
@test gs.bias === nothing
124124
end
125125

126126
@testset "Extended BatchNorm" begin
@@ -133,13 +133,13 @@ end
133133
μ_cpu = copy(m_cpu.μ)
134134
m_cpu(x_cpu)
135135
@test m_cpu.μ μ_cpu
136-
gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu))
136+
gradient(m_cpu -> sum(m_cpu(x_cpu)), m_cpu)
137137
@test !(m_cpu.μ μ_cpu)
138138

139139
μ_gpu = copy(m_gpu.μ)
140140
m_gpu(x_gpu)
141141
@test m_gpu.μ μ_gpu
142-
gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu))
142+
gradient(m_gpu -> sum(m_gpu(x_gpu)), m_gpu)
143143
@test !(m_gpu.μ μ_gpu)
144144

145145
@test Array(m_gpu.μ) m_cpu.μ
@@ -149,14 +149,14 @@ end
149149
μ_cpu = copy(m_cpu.μ)
150150
m_cpu(x_cpu)
151151
@test m_cpu.μ μ_cpu
152-
gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu))
152+
gradient(m_cpu -> sum(m_cpu(x_cpu)), m_cpu)
153153
@test m_cpu.μ μ_cpu
154154

155155
testmode!(m_gpu)
156156
μ_gpu = copy(m_gpu.μ)
157157
m_gpu(x_gpu)
158158
@test m_gpu.μ μ_gpu
159-
gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu))
159+
gradient(m_gpu -> sum(m_gpu(x_gpu)), m_gpu)
160160
@test m_gpu.μ μ_gpu
161161

162162
## In trainmode, always track statistics
@@ -165,52 +165,36 @@ end
165165
m_cpu(x_cpu)
166166
@test !(m_cpu.μ μ_cpu)
167167
μ_cpu = copy(m_cpu.μ)
168-
gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu))
168+
gradient(m_cpu -> sum(m_cpu(x_cpu)), m_cpu)
169169
@test !(m_cpu.μ μ_cpu)
170170

171171
trainmode!(m_gpu)
172172
μ_gpu = copy(m_gpu.μ)
173173
m_gpu(x_gpu)
174174
@test !(m_gpu.μ μ_gpu)
175175
μ_gpu = copy(m_gpu.μ)
176-
gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu))
176+
gradient(m_gpu -> sum(m_gpu(x_gpu)), m_gpu)
177177
@test !(m_gpu.μ μ_gpu)
178-
179-
## No errors if input type mistmatch
180-
# x_cpu = rand(Float64, 3, 2, 2)
181-
# x_gpu = x_cpu |> gpu
182-
# m_cpu(x_cpu)
183-
# gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu))
184-
# m_gpu(x_gpu)
185-
# gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu))
186178
end
187179

188180
@testset "Two-streams Bilinear" begin
189181
x = zeros(Float32,10,9) |> gpu
190182
y = zeros(Float32,2,9) |> gpu
191183
b = Flux.Bilinear(10, 2, 3) |> gpu
192-
@test size(b(x,y)) == (3,9)
193-
@test sum(abs2, b(x,y)) 0f0
194-
gs_gpu = gradient(() -> sum(abs2.(b(x, y))), params(b))
195-
b_cpu, x_cpu, y_cpu = b |> cpu, x |> cpu, y |> cpu
196-
gs_cpu = gradient(() -> sum(abs2.(b_cpu(x_cpu, y_cpu))), params(b_cpu))
197-
for (pgpu, pcpu) in zip(params(b), params(b_cpu))
198-
@test gs_cpu[pcpu] Array(gs_gpu[pgpu])
199-
end
184+
@test size(b(x, y)) == (3,9)
185+
@test sum(abs2, b(x, y)) 0f0
186+
test_gradients(b |> cpu, x |> cpu, y |> cpu,
187+
test_gpu=true, compare_finite_diff=false, loss=o -> mean(abs2, o))
200188
end
201189

202190
@testset "Two-streams Bilinear" begin
203191
x = zeros(Float32,10,9) |> gpu
204192
y = zeros(Float32,2,9) |> gpu
205193
b = Flux.Bilinear(10, 2, 3) |> gpu
206-
@test size(b(x,y)) == (3,9)
207-
@test sum(abs2, b(x,y)) 0f0
208-
gs_gpu = gradient(() -> sum(abs2.(b(x, y))), params(b))
209-
b_cpu, x_cpu, y_cpu = b |> cpu, x |> cpu, y |> cpu
210-
gs_cpu = gradient(() -> sum(abs2.(b_cpu(x_cpu, y_cpu))), params(b_cpu))
211-
for (pgpu, pcpu) in zip(params(b), params(b_cpu))
212-
@test gs_cpu[pcpu] Array(gs_gpu[pgpu])
213-
end
194+
@test size(b(x, y)) == (3,9)
195+
@test sum(abs2, b(x, y)) 0f0
196+
test_gradients(b |> cpu, x |> cpu, y |> cpu,
197+
test_gpu=true, compare_finite_diff=false, loss=o -> mean(abs2, o))
214198
end
215199

216200
@testset "Parallel" begin
@@ -228,15 +212,9 @@ end
228212
end
229213

230214
@testset "gradient" begin
231-
input_cpu = randn(10, 10, 10, 10)
232-
input_gpu = input_cpu |> gpu
233215
layer_cpu = Parallel(+, x -> zero(x), identity)
234-
layer_gpu = layer_cpu |> gpu
235-
gs_cpu = gradient(() -> sum(abs2.(layer_cpu(input_cpu))), params(layer_cpu))
236-
gs_gpu = gradient(() -> sum(abs2.(layer_gpu(input_gpu))), params(layer_gpu))
237-
for (pgpu, pcpu) in zip(params(layer_cpu), params(layer_gpu))
238-
@test gs_cpu[pcpu] gs_gpu[pgpu]
239-
end
216+
test_gradients(layer_cpu, randn(5, 5, 5, 5),
217+
test_gpu=true, compare_finite_diff=false, loss=o -> mean(abs2, o))
240218
end
241219
end
242220

0 commit comments

Comments
 (0)