Skip to content

Commit 2c3e257

Browse files
remove params entirely
1 parent 17db916 commit 2c3e257

File tree

8 files changed

+43
-95
lines changed

8 files changed

+43
-95
lines changed

src/layers/show.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,21 +90,21 @@ function _layer_show(io::IO, layer, indent::Int=0, name=nothing)
9090
_str = isnothing(name) ? "" : "$name = "
9191
str = _str * sprint(show, layer, context=io)
9292
print(io, " "^indent, str, indent==0 ? "" : ",")
93-
if !isempty(params(layer))
93+
if !isempty(trainables(layer))
9494
print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str)))
95-
printstyled(io, "# ", underscorise(sum(length, params(layer); init=0)), " parameters";
95+
printstyled(io, "# ", underscorise(sum(length, trainables(layer); init=0)), " parameters";
9696
color=:light_black)
97-
nonparam = _childarray_sum(length, layer) - sum(length, params(layer), init=0)
97+
nonparam = _childarray_sum(length, layer) - sum(length, trainables(layer), init=0)
9898
if nonparam > 0
9999
printstyled(io, ", plus ", underscorise(nonparam), indent==0 ? " non-trainable" : ""; color=:light_black)
100100
end
101-
_nan_show(io, params(layer))
101+
_nan_show(io, trainables(layer))
102102
end
103103
indent==0 || println(io)
104104
end
105105

106106
function _big_finale(io::IO, m)
107-
ps = params(m)
107+
ps = trainables(m)
108108
if length(ps) > 2
109109
pars = underscorise(sum(length, ps; init=0))
110110
bytes = Base.format_bytes(Base.summarysize(m))

src/outputsize.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,6 @@ function ChainRulesCore.rrule(::typeof(striplazy), m)
302302
striplazy(m), _ -> error("striplazy should never be used within a gradient")
303303
end
304304

305-
params!(p::Params, x::LazyLayer, seen = IdSet()) = error("LazyLayer should never be used within params(m). Call striplazy(m) first.")
306-
307305
Functors.functor(::Type{<:LazyLayer}, x) = error("LazyLayer should not be walked with Functors.jl, as the arrays which Flux.gpu wants to move may not exist yet.")
308306

309307
function Base.show(io::IO, l::LazyLayer)

test/data.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ using Random
8282
X = zeros(2, 10)
8383
loss(x) = sum((x .- θ).^2)
8484
d = DataLoader(X)
85-
Flux.train!(loss, Params([θ]), ncycle(d, 10), Descent(0.1))
85+
opt_state = Flux.setup(Descent(0.1), θ)
86+
Flux.train!(loss, θ, ncycle(d, 10), opt_state)
8687
@test norm(θ) < 1e-4
8788

8889
# test interaction with `train!`
@@ -91,7 +92,8 @@ using Random
9192
Y = fill(2, 10)
9293
loss(x, y) = sum((y - x'*θ).^2)
9394
d = DataLoader((X, Y))
94-
Flux.train!(loss, Params([θ]), ncycle(d, 10), Descent(0.1))
95+
opt_state = Flux.setup(Descent(0.1), θ)
96+
Flux.train!(loss, θ, ncycle(d, 10), opt_state)
9597
@test norm.- 1) < 1e-10
9698

9799
# specify the rng

test/ext_cuda/curnn.jl

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,8 @@
1-
using Flux, CUDA, Test
2-
3-
@testset for R in [RNN, GRU, LSTM, GRUv3]
4-
m = R(10, 5) |> gpu
5-
x = gpu(rand(10))
6-
(m̄,) = gradient(m -> sum(m(x)), m)
7-
Flux.reset!(m)
8-
θ = gradient(() -> sum(m(x)), params(m))
9-
@test x isa CuArray
10-
@test θ[m.cell.Wi] isa CuArray
11-
@test collect(m̄.cell.Wi) == collect(θ[m.cell.Wi])
12-
end
131

142
@testset "RNN" begin
153
@testset for R in [RNN, GRU, LSTM, GRUv3], batch_size in (1, 5)
164
rnn = R(10, 5)
17-
curnn = fmap(gpu, rnn)
5+
curnn = rnn |> gpu
186

197
Flux.reset!(rnn)
208
Flux.reset!(curnn)

test/ext_cuda/layers.jl

Lines changed: 20 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -110,17 +110,17 @@ end
110110
l = cl((2,2), 1=>3, bias = false) |> gpu
111111
ip = zeros(Float32, 28,28,1,1) |> gpu
112112
@test sum(l(ip)) 0.f0
113-
gs = gradient(() -> sum(l(ip)), Flux.params(l))
114-
@test l.bias gs.params
113+
gs = gradient(l -> sum(l(ip)), l)[1]
114+
@test gs.bias === nothing
115115
end
116116

117117
@testset "Dense without bias" begin
118118
l = Dense(ones(Float32, 4, 3), false) |> gpu
119119
ip = zeros(Float32, 3, 7) |> gpu
120120

121121
@test sum(l(ip)) 0.f0
122-
gs = gradient(() -> sum(l(ip)), Flux.params(l))
123-
@test l.bias gs.params
122+
gs = gradient(l -> sum(l(ip)), l)[1]
123+
@test gs.bias === nothing
124124
end
125125

126126
@testset "Extended BatchNorm" begin
@@ -133,13 +133,13 @@ end
133133
μ_cpu = copy(m_cpu.μ)
134134
m_cpu(x_cpu)
135135
@test m_cpu.μ μ_cpu
136-
gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu))
136+
gradient(m_cpu -> sum(m_cpu(x_cpu)), m_cpu)
137137
@test !(m_cpu.μ μ_cpu)
138138

139139
μ_gpu = copy(m_gpu.μ)
140140
m_gpu(x_gpu)
141141
@test m_gpu.μ μ_gpu
142-
gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu))
142+
gradient(m_gpu -> sum(m_gpu(x_gpu)), m_gpu)
143143
@test !(m_gpu.μ μ_gpu)
144144

145145
@test Array(m_gpu.μ) m_cpu.μ
@@ -149,14 +149,14 @@ end
149149
μ_cpu = copy(m_cpu.μ)
150150
m_cpu(x_cpu)
151151
@test m_cpu.μ μ_cpu
152-
gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu))
152+
gradient(m_cpu -> sum(m_cpu(x_cpu)), m_cpu)
153153
@test m_cpu.μ μ_cpu
154154

155155
testmode!(m_gpu)
156156
μ_gpu = copy(m_gpu.μ)
157157
m_gpu(x_gpu)
158158
@test m_gpu.μ μ_gpu
159-
gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu))
159+
gradient(m_gpu -> sum(m_gpu(x_gpu)), m_gpu)
160160
@test m_gpu.μ μ_gpu
161161

162162
## In trainmode, always track statistics
@@ -165,52 +165,36 @@ end
165165
m_cpu(x_cpu)
166166
@test !(m_cpu.μ μ_cpu)
167167
μ_cpu = copy(m_cpu.μ)
168-
gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu))
168+
gradient(m_cpu -> sum(m_cpu(x_cpu)), m_cpu)
169169
@test !(m_cpu.μ μ_cpu)
170170

171171
trainmode!(m_gpu)
172172
μ_gpu = copy(m_gpu.μ)
173173
m_gpu(x_gpu)
174174
@test !(m_gpu.μ μ_gpu)
175175
μ_gpu = copy(m_gpu.μ)
176-
gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu))
176+
gradient(m_gpu -> sum(m_gpu(x_gpu)), m_gpu)
177177
@test !(m_gpu.μ μ_gpu)
178-
179-
## No errors if input type mistmatch
180-
# x_cpu = rand(Float64, 3, 2, 2)
181-
# x_gpu = x_cpu |> gpu
182-
# m_cpu(x_cpu)
183-
# gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu))
184-
# m_gpu(x_gpu)
185-
# gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu))
186178
end
187179

188180
@testset "Two-streams Bilinear" begin
189181
x = zeros(Float32,10,9) |> gpu
190182
y = zeros(Float32,2,9) |> gpu
191183
b = Flux.Bilinear(10, 2, 3) |> gpu
192-
@test size(b(x,y)) == (3,9)
193-
@test sum(abs2, b(x,y)) 0f0
194-
gs_gpu = gradient(() -> sum(abs2.(b(x, y))), params(b))
195-
b_cpu, x_cpu, y_cpu = b |> cpu, x |> cpu, y |> cpu
196-
gs_cpu = gradient(() -> sum(abs2.(b_cpu(x_cpu, y_cpu))), params(b_cpu))
197-
for (pgpu, pcpu) in zip(params(b), params(b_cpu))
198-
@test gs_cpu[pcpu] Array(gs_gpu[pgpu])
199-
end
184+
@test size(b(x, y)) == (3,9)
185+
@test sum(abs2, b(x, y)) 0f0
186+
test_gradients(b |> cpu, x |> cpu, y |> cpu,
187+
test_gpu=true, compare_finite_diff=false, loss=o -> mean(abs2, o))
200188
end
201189

202190
@testset "Two-streams Bilinear" begin
203191
x = zeros(Float32,10,9) |> gpu
204192
y = zeros(Float32,2,9) |> gpu
205193
b = Flux.Bilinear(10, 2, 3) |> gpu
206-
@test size(b(x,y)) == (3,9)
207-
@test sum(abs2, b(x,y)) 0f0
208-
gs_gpu = gradient(() -> sum(abs2.(b(x, y))), params(b))
209-
b_cpu, x_cpu, y_cpu = b |> cpu, x |> cpu, y |> cpu
210-
gs_cpu = gradient(() -> sum(abs2.(b_cpu(x_cpu, y_cpu))), params(b_cpu))
211-
for (pgpu, pcpu) in zip(params(b), params(b_cpu))
212-
@test gs_cpu[pcpu] Array(gs_gpu[pgpu])
213-
end
194+
@test size(b(x, y)) == (3,9)
195+
@test sum(abs2, b(x, y)) 0f0
196+
test_gradients(b |> cpu, x |> cpu, y |> cpu,
197+
test_gpu=true, compare_finite_diff=false, loss=o -> mean(abs2, o))
214198
end
215199

216200
@testset "Parallel" begin
@@ -228,15 +212,9 @@ end
228212
end
229213

230214
@testset "gradient" begin
231-
input_cpu = randn(10, 10, 10, 10)
232-
input_gpu = input_cpu |> gpu
233215
layer_cpu = Parallel(+, x -> zero(x), identity)
234-
layer_gpu = layer_cpu |> gpu
235-
gs_cpu = gradient(() -> sum(abs2.(layer_cpu(input_cpu))), params(layer_cpu))
236-
gs_gpu = gradient(() -> sum(abs2.(layer_gpu(input_gpu))), params(layer_gpu))
237-
for (pgpu, pcpu) in zip(params(layer_cpu), params(layer_gpu))
238-
@test gs_cpu[pcpu] gs_gpu[pgpu]
239-
end
216+
test_gradients(layer_cpu, randn(5, 5, 5, 5),
217+
test_gpu=true, compare_finite_diff=false, loss=o -> mean(abs2, o))
240218
end
241219
end
242220

test/layers/basic.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ using Flux: activations
196196
x = randn(Float32,11,7)
197197
b = Flux.Bilinear(11, 11, 3)
198198
@test size(b(x)) == (3,7)
199-
@test_nowarn gs = gradient(() -> sum(abs2.(b(x))), params(b))
199+
test_gradients(b, x)
200200
end
201201

202202
@testset "constructors" begin
@@ -436,16 +436,15 @@ end
436436
@testset "gradients of Chain{Vector}" begin
437437
m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2))
438438
m1v = Chain([m1[1], m1[2]])
439-
@test sum(length, params(m1)) == sum(length, params(m1v))
439+
@test sum(length, trainables(m1)) == sum(length, trainables(m1v))
440440

441441
x1 = randn(Float32,3,5)
442442
@test m1(x1) m1v(x1)
443443

444444
y1 = rand(Bool,2,5)
445-
g1 = gradient(() -> Flux.Losses.logitcrossentropy(m1(x1), y1), params(m1))
446-
g1v = gradient(() -> Flux.Losses.logitcrossentropy(m1v(x1), y1), params(m1v))
447-
@test g1[m1[1].weight] g1v[m1v[1].weight]
448-
@test g1[m1[2].bias] g1v[m1v[2].bias]
445+
g1 = gradient(m1 -> Flux.logitcrossentropy(m1(x1), y1), m1)[1]
446+
g1v = gradient(m1v -> Flux.logitcrossentropy(m1v(x1), y1), m1v)[1]
447+
check_equal_leaves(g1, g1v)
449448

450449
@test Flux.destructure(m1)[1] Flux.destructure(m1v)[1]
451450
z1 = rand(22);

test/runtests.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using Flux
22
using Flux: OneHotArray, OneHotMatrix, OneHotVector
3-
using Flux: params
43
using Test
54
using Random, Statistics, LinearAlgebra
65
using IterTools: ncycle
@@ -11,9 +10,9 @@ using Functors: fmapstructure_with_path
1110

1211
## Uncomment below to change the default test settings
1312
# ENV["FLUX_TEST_AMDGPU"] = "true"
14-
ENV["FLUX_TEST_CUDA"] = "true"
13+
# ENV["FLUX_TEST_CUDA"] = "true"
1514
# ENV["FLUX_TEST_METAL"] = "true"
16-
ENV["FLUX_TEST_CPU"] = "false"
15+
# ENV["FLUX_TEST_CPU"] = "false"
1716
# ENV["FLUX_TEST_DISTRIBUTED_MPI"] = "true"
1817
# ENV["FLUX_TEST_DISTRIBUTED_NCCL"] = "true"
1918
ENV["FLUX_TEST_ENZYME"] = "false" # We temporarily disable Enzyme tests since they are failing

test/utils.jl

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using Flux
22
using Flux: throttle, nfan, glorot_uniform, glorot_normal,
33
kaiming_normal, kaiming_uniform, orthogonal, truncated_normal,
44
sparse_init, identity_init, unstack, batch, unbatch,
5-
unsqueeze, params, loadmodel!
5+
unsqueeze, loadmodel!
66
using MLUtils
77
using Statistics, LinearAlgebra
88
using Random
@@ -334,28 +334,12 @@ end
334334
o = ones(s)
335335
z = zeros(s)
336336

337-
@testset "Explicit" begin
338-
gfun(args...) = gradient((x, y) -> sum(op.(x,y)), args...)
339-
g = gfun(o, z)
340-
@test gfun(o, false) == (g[1], nothing)
337+
gfun(args...) = gradient((x, y) -> sum(op.(x,y)), args...)
338+
g = gfun(o, z)
339+
@test gfun(o, false) == (g[1], nothing)
341340

342-
g = gfun(z, o)
343-
@test gfun(false, o) == (nothing, g[2])
344-
end
345-
346-
@testset "Implicit" begin
347-
gfun(args...) = gradient(() -> sum(op.(args...)), params(collect(args)))
348-
g = gfun(o, z)
349-
350-
gres = gfun(o, false)
351-
@test gres[o] == g[o]
352-
@test false gres.params
353-
354-
g = gfun(z, o)
355-
gres = gfun(false, o)
356-
@test gres[o] == g[o]
357-
@test false gres.params
358-
end
341+
g = gfun(z, o)
342+
@test gfun(false, o) == (nothing, g[2])
359343
end
360344
end
361345

0 commit comments

Comments
 (0)