Skip to content

Commit be419fb

Browse files
add keep keyword
1 parent f7d5743 commit be419fb

File tree

7 files changed

+272
-217
lines changed

7 files changed

+272
-217
lines changed

docs/make.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ makedocs(
3535
"Gradients -- Zygote.jl" => "training/zygote.md",
3636
"Batching Data -- MLUtils.jl" => "data/mlutils.md",
3737
"OneHotArrays.jl" => "data/onehot.md",
38-
"Saving and Loading" => "models/saving.md"
3938
"Low-level Operations -- NNlib.jl" => "models/nnlib.md",
4039
"Nested Structures -- Functors.jl" => "models/functors.md",
4140
],

docs/src/destructure.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,11 @@ Another kind of flat view of a nested model is provided by the `modules` command
7272

7373
```@docs
7474
Flux.modules
75-
```
75+
```
76+
77+
### Saving and Loading
78+
79+
```@docs
80+
Flux.loadmodel!
81+
Flux.state
82+
```

docs/src/models/saving.md

Lines changed: 0 additions & 6 deletions
This file was deleted.

src/loading.jl

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -103,30 +103,47 @@ function loadmodel!(dst, src; filter = _ -> true, cache = Base.IdSet())
103103
end
104104

105105
"""
106-
state(x; full=false)
106+
state(x; keep = leaf -> !(leaf isa Function))
107107
108108
Return an object with the same nested structure as `x`
109-
according to `Functors.children()`, but made only of
109+
according to `Functors.children`, but made only of
110110
basic containers (e.g. named tuples, tuples, arrays, and dictionaries).
111111
112-
If `full` is `false` (default), then only arrays and scalar original leaves are used as leaf values in the return,
113-
with the other leaves being replaced by `nothing`.
112+
This method is particularly useful for saving and loading models,
113+
since it doesn't require the user to specify the model type.
114+
The state can be passed to `loadmodel!` to restore the model.
114115
115-
This method is particularly useful for saving and loading models, since it doesn't
116-
require the user to specify the model type.
117-
The returned state, can be passed to `loadmodel!` to restore the model.
116+
The `keep` function is applied on the leaves of `x`.
117+
If `keep(leaf)` is `false` , the leaf is replaced by `nothing`,
118+
otherwise it is left as is. By default, all functions are excluded.
119+
120+
# Examples
121+
122+
```julia-repl
123+
julia> m1 = Chain(Dense(1, 2, tanh), Dense(2, 1));
124+
125+
julia> m2 = Chain(Dense(1, 2, tanh), Dense(2, 1));
126+
127+
julia> s = Flux.state(m1)
128+
layers = ((weight = Float32[-0.56867087; 1.229064;;], bias = Float32[0.0, 0.0], σ = nothing), (weight = Float32[0.23323897 -0.5561147], bias = Float32[0.0], σ = nothing)),)
129+
130+
julia> Flux.loadmodel!(m2, s);
131+
132+
julia> m2[1].weight == m1[1].weight
133+
true
134+
```
118135
"""
119-
function state(x; full=false)
136+
function state(x; keep = _state_keep)
120137
if Functors.isleaf(x)
121-
if full
122-
return x
123-
else
124-
return x isa Union{Number, AbstractArray} ? x : nothing
125-
end
138+
return keep(x) ? x : nothing
126139
else
127-
return valuemap(c -> state(c; full), Functors.children(x))
140+
return _valuemap(c -> state(c; keep), Functors.children(x))
128141
end
129142
end
130143

131-
valuemap(f, x) = map(f, x)
132-
valuemap(f, x::Dict) = Dict(k => f(v) for (k, v) in x)
144+
_state_keep(x::Function) = false
145+
_state_keep(x) = true
146+
147+
# map for tuples, namedtuples, and dicts
148+
_valuemap(f, x) = map(f, x)
149+
_valuemap(f, x::Dict) = Dict(k => f(v) for (k, v) in x)

test/loading.jl

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
2+
ls(dims...) = reshape(collect(Float32, 1:prod(dims)), dims...) # accepts dims in reverse order to Dense
3+
dl(nin, nout, bias) = Dense(ls(nout, nin), bias(nout))
4+
dm(bias) = Chain(
5+
dl(3, 5, bias),
6+
dl(5, 4, bias),
7+
dl(4, 3, bias)
8+
)
9+
10+
nobias(n) = false
11+
testdense(m, bt) = @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt)))
12+
@test l1.weight == l2.weight
13+
@test l1.bias == l2.bias
14+
@test_skip typeof(l1.bias) === typeof(l2.bias)
15+
end
16+
17+
18+
@testset "loadmodel!(dst, src)" begin
19+
m1 = Chain(Dense(10, 5), Dense(5, 2, relu))
20+
m2 = Chain(Dense(10, 5), Dense(5, 2))
21+
m3 = Chain(Conv((3, 3), 3 => 16), Dense(5, 2))
22+
m4 = Chain(Dense(10, 6), Dense(6, 2))
23+
m5 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(2, 5), false), Dense(5, 2)))
24+
m6 = Chain(Dense(10, 5), Parallel(+, Dense(5, 2), Dense(5, 2)))
25+
26+
loadmodel!(m1, m2)
27+
# trainable parameters copy over
28+
@test m1[1].weight == m2[1].weight
29+
@test m1[1].bias == m2[1].bias
30+
# non-array leaves are untouched
31+
@test m1[2].σ == relu
32+
33+
loadmodel!(m5, m6)
34+
# more complex nested structures also work
35+
@test m5[1].weight == m6[1].weight
36+
@test m5[2][1].weight == m6[2][1].weight
37+
# false bias is not overwritten
38+
@test m5[2][1].bias == false
39+
40+
# mismatched nodes throw an error
41+
@test_throws ArgumentError loadmodel!(m1, m3)
42+
@test_throws ArgumentError loadmodel!(m1, m5)
43+
# size mismatches throw an error
44+
@test_throws DimensionMismatch loadmodel!(m1, m4)
45+
46+
# tests for BatchNorm and Dropout
47+
m1 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), Flux.flatten, Dropout(0.2))
48+
m2 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), x -> reshape(x, :, size(x)[end]), Dropout(0.1))
49+
m2[2].μ .= rand(Float32, size(m2[2].μ)...)
50+
loadmodel!(m1, m2)
51+
# non-trainable parameters are copied as well
52+
@test m1[2].μ == m2[2].μ
53+
# functions are not copied
54+
@test m1[3] == Flux.flatten
55+
# dropout rate is not copied
56+
@test m1[4].p == 0.2
57+
58+
# from LegolasFlux (https://github.com/beacon-biosignals/LegolasFlux.jl/blob/80569ab63a8248a8a063c76e0bbf701f4ada9bd4/examples/digits.jl#L33)
59+
# tests Chain(...) vs Chain([...])
60+
# tests MaxPool
61+
# tests testmode!/trainmode! is not copied
62+
# tests Dense, Conv, BatchNorm, Dropout (like above) but in a bigger model
63+
chain1 = Chain(Dropout(0.2),
64+
Conv((3, 3), 1 => 32, relu),
65+
BatchNorm(32, relu),
66+
MaxPool((2, 2)),
67+
Dropout(0.2),
68+
Conv((3, 3), 32 => 16, relu),
69+
Dropout(0.2),
70+
MaxPool((2, 2)),
71+
Dropout(0.2),
72+
Conv((3, 3), 16 => 10, relu),
73+
Dropout(0.2),
74+
x -> reshape(x, :, size(x, 4)),
75+
Dropout(0.2),
76+
Dense(90, 10),
77+
softmax)
78+
chain2 = Chain([Dropout(0.1),
79+
Conv((3, 3), 1 => 32, relu),
80+
BatchNorm(32, relu),
81+
MaxPool((3, 3)),
82+
Dropout(0.1),
83+
Conv((3, 3), 32 => 16, relu),
84+
Dropout(0.1),
85+
MaxPool((3, 3)),
86+
Dropout(0.1),
87+
Conv((3, 3), 16 => 10, relu),
88+
Dropout(0.1),
89+
x -> reshape(x, :, size(x, 4)),
90+
Dropout(0.1),
91+
Dense(90, 10),
92+
softmax])
93+
chain2[3].μ .= 5f0
94+
chain2[3].σ² .= 2f0
95+
testmode!(chain2)
96+
loadmodel!(chain1, chain2)
97+
for (dst, src) in zip(chain1, chain2)
98+
if dst isa Dropout
99+
@test dst.p == 0.2
100+
elseif dst isa Union{Conv, Dense}
101+
@test dst.weight == src.weight
102+
@test dst.bias == src.bias
103+
elseif dst isa MaxPool
104+
@test dst.k == (2, 2)
105+
elseif dst isa BatchNorm
106+
@test dst.μ == src.μ
107+
@test dst.σ² == src.σ²
108+
@test isnothing(dst.active)
109+
end
110+
end
111+
112+
# copy only a subset of the model
113+
chain1[end - 1].weight .= 1f0
114+
chain1[3].μ .= 3f0
115+
chain1[2].bias .= 5f0
116+
loadmodel!(chain2[end - 1], chain1[end - 1])
117+
loadmodel!(chain2[3], chain1[3])
118+
@test chain2[end - 1].weight == chain1[end - 1].weight
119+
@test chain2[3].μ == chain1[3].μ
120+
@test chain2[2].bias != chain1[2].bias
121+
122+
# test shared weights
123+
shared_dst = Dense(10 => 10)
124+
shared_src = Dense(10 => 10)
125+
# matched weights are okay
126+
m1 = Chain(shared_dst, Dense(shared_dst.weight))
127+
m2 = Chain(shared_src, Dense(shared_src.weight))
128+
loadmodel!(m1, m2)
129+
@test m1[1].weight === m1[2].weight
130+
@test m1[1].weight == m2[2].weight
131+
# mismatched weights are an error
132+
m2 = Chain(Dense(10 => 10), Dense(10 => 10))
133+
@test_throws ErrorException loadmodel!(m1, m2)
134+
# loading into tied weights with absent parameter is okay when the dst == zero
135+
b = Flux.zeros32(5)
136+
m1 = Chain(Dense(10 => 5; bias = b), Dense(5 => 5; bias = b))
137+
m2 = Chain(Dense(10 => 5; bias = Flux.zeros32(5)), Dense(5 => 5; bias = false))
138+
loadmodel!(m1, m2)
139+
@test m1[1].bias === m1[2].bias
140+
@test iszero(m1[1].bias)
141+
# loading into tied weights with absent parameter is bad when the dst != zero
142+
m2[1].bias .= 1
143+
@test_throws ErrorException loadmodel!(m1, m2)
144+
145+
@testset "loadmodel! & filter" begin
146+
m1 = Chain(Dense(10, 5), Dense(5, 2, relu))
147+
m2 = Chain(Dense(10, 5), Dropout(0.2), Dense(5, 2))
148+
m3 = Chain(Dense(10, 5), Dense(5, 2, relu))
149+
150+
# this will not error cause Dropout is skipped
151+
loadmodel!(m1, m2; filter = x -> !(x isa Dropout))
152+
@test m1[1].weight == m2[1].weight
153+
@test m1[2].weight == m2[3].weight
154+
155+
# this will not error cause Dropout is skipped
156+
loadmodel!(m2, m3; filter = x -> !(x isa Dropout))
157+
@test m3[1].weight == m2[1].weight
158+
@test m3[2].weight == m2[3].weight
159+
end
160+
161+
@testset "loadmodel! & absent bias" begin
162+
m0 = Chain(Dense(2 => 3; bias=false, init = Flux.ones32), Dense(3 => 1))
163+
m1 = Chain(Dense(2 => 3; bias = Flux.randn32(3)), Dense(3 => 1))
164+
m2 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1))
165+
166+
Flux.loadmodel!(m1, m2)
167+
@test m1[1].bias == 7:9
168+
@test sum(m1[1].weight) == 21
169+
170+
# load from a model without bias -- should ideally recognise the `false` but `Params` doesn't store it
171+
m1 = Flux.loadmodel!(m1, m0)
172+
@test iszero(m1[1].bias)
173+
@test sum(m1[1].weight) == 6 # written before error
174+
175+
# load into a model without bias -- should it ignore the parameter which has no home, or error?
176+
m0 = Flux.loadmodel!(m0, m2)
177+
@test iszero(m0[1].bias) # obviously unchanged
178+
@test sum(m0[1].weight) == 21
179+
end
180+
end
181+
182+
@testset "loadmodel!(dst, src) with BSON" begin
183+
m1 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1))
184+
m2 = Chain(Dense(Float32[0 0; 0 0; 0 0], Float32[0, 0, 0]), Dense(3 => 1))
185+
@test m1[1].weight != m2[1].weight
186+
mktempdir() do dir
187+
BSON.@save joinpath(dir, "test.bson") m1
188+
m2 = Flux.loadmodel!(m2, BSON.load(joinpath(dir, "test.bson"))[:m1])
189+
@test m1[1].weight == m2[1].weight
190+
end
191+
end
192+
193+
@testset "state" begin
194+
m1 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(2, 5), false), Dense(5 => 2)))
195+
m2 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.zeros32(2, 5), Flux.ones32(2)), Dense(5 => 2)))
196+
s = Flux.state(m1)
197+
@test s isa NamedTuple
198+
@test fieldnames(typeof(s)) == (:layers,)
199+
@test s.layers isa Tuple
200+
@test length(s.layers) == 2
201+
@test s.layers[1].weight === m1[1].weight
202+
@test s.layers[1].σ === nothing
203+
@test s.layers[2].layers[1].weight === m1[2].layers[1].weight
204+
205+
Flux.loadmodel!(m2, s)
206+
@test m2[1].weight == m1[1].weight
207+
@test all(m2[2].layers[1].bias .== m1[2].layers[1].bias)
208+
209+
@testset "track active state and batch norm params" begin
210+
m3 = Chain(Dense(10, 5), Dropout(0.2), Dense(5, 2), BatchNorm(2))
211+
trainmode!(m3)
212+
s = Flux.state(m3)
213+
@test s.layers[2].active == true
214+
@test s.layers[2].p == 0.2
215+
@test s.layers[4] === nothing, β = Float32[0.0, 0.0], γ = Float32[1.0, 1.0],
216+
μ = Float32[0.0, 0.0], σ² = Float32[1.0, 1.0], ϵ = 1.0f-5, momentum = 0.1f0, affine = true,
217+
track_stats = true, active = true, chs = 2)
218+
end
219+
220+
@testset "keep" begin
221+
s = Flux.state(m1, keep = x -> x isa AbstractArray)
222+
@test s.layers[1].weight isa AbstractArray
223+
@test s.layers[1].σ === nothing
224+
@test s.layers[2].connection === nothing
225+
@test s.layers[2].layers[1].bias === nothing
226+
end
227+
end

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ Random.seed!(0)
1717
include("utils.jl")
1818
end
1919

20+
@testset "Loading" begin
21+
include("loading.jl")
22+
end
23+
2024
@testset "Optimise / Train" begin
2125
include("optimise.jl")
2226
include("train.jl")

0 commit comments

Comments
 (0)