Skip to content

Commit 89c8d43

Browse files
committed
in-place destructure
1 parent 1cd1e87 commit 89c8d43

File tree

3 files changed

+111
-6
lines changed

3 files changed

+111
-6
lines changed

src/Optimisers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ export AbstractRule
99
include("adjust.jl")
1010

1111
include("destructure.jl")
12-
export destructure
12+
export destructure, destructure!
1313

1414
include("rules.jl")
1515
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,

src/destructure.jl

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ Copies all [`trainable`](@ref), [`isnumeric`](@ref) parameters in the model
99
to a vector, and returns also a function which reverses this transformation.
1010
Differentiable.
1111
12+
See also [`destructure!`](@ref).
13+
1214
# Example
1315
```jldoctest
1416
julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0 + 4.0im])))
@@ -31,6 +33,36 @@ function destructure(x)
3133
flat, Restructure(x, off, len)
3234
end
3335

36+
"""
37+
destructure!(model) -> vector, reconstructor
38+
39+
This is a variant of [`destructure`](@ref), whose reconstruction function mutates the model.
40+
Requires that all trainable parameters in the model be mutable arrays!
41+
42+
# Example
43+
```jldoctest
44+
julia> m = (x=[1.0, 2.0], y=(sin, Float32[3.0 4.0], cos))
45+
46+
julia> v, re! = destructure!(m)
47+
([1.0, 2.0, 3.0, 4.0], Restructure!(NamedTuple, ..., 4))
48+
49+
julia> m === re!([3, 5, 7, 9]) # mutates the original m, and returns it
50+
true
51+
52+
julia> m
53+
(x = [3.0, 5.0], y = (sin, Float32[7.0 9.0], cos))
54+
```
55+
"""
56+
function destructure!(x)
57+
flat, off, len = _flatten(x)
58+
flat, Restructure!(x, off, len)
59+
end
60+
61+
# function destructure!(flat::AbstractVector, x)
62+
# flat, off, len = _flatten!(flat, x)
63+
# flat, Restructure!(x, off, len)
64+
# end
65+
3466
"""
3567
Restructure(Model, ..., length)
3668
@@ -55,12 +87,20 @@ struct Restructure{T,S}
5587
model::T
5688
offsets::S
5789
length::Int
90+
mutate::Bool
5891
end
59-
(re::Restructure)(flat::AbstractVector) = _rebuild(re.model, re.offsets, flat, re.length)
92+
Restructure(model, offsets, length) = Restructure(model, offsets, length, false)
93+
Restructure!(model, offsets, length) = Restructure(model, offsets, length, true)
94+
95+
(re::Restructure)(flat::AbstractVector) = re.mutate ? _rebuild!(re.model, re.offsets, flat, re.length) : _rebuild(re.model, re.offsets, flat, re.length)
6096
(re::Restructure)(x, flat::AbstractVector) = re(flat)(x)
61-
Base.show(io::IO, re::Restructure{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")")
6297
Base.length(re::Restructure) = re.length
6398

99+
function Base.show(io::IO, re::Restructure{T}) where T
100+
print(io, "Restructure", re.mutate ? "!" : "")
101+
print(io, "(", T.name.name, ", ..., ", re.length, ")")
102+
end
103+
64104
# This flattens a model, and returns a web of offsets for later use:
65105
function _flatten(x)
66106
isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case
@@ -75,6 +115,17 @@ function _flatten(x)
75115
isempty(arrays) && return Bool[], off, 0
76116
reduce(vcat, arrays), off, len[]
77117
end
118+
# function _flatten!(flat, x)
119+
# isnumeric(x) && return copyto!(flat, _vec(x)) # trivial case
120+
# len = Ref(0)
121+
# off = fmap(x; exclude = isnumeric, walk = _TrainableStructWalk()) do y
122+
# o = len[]
123+
# copyto!(flat, o, _vec(y))
124+
# len[] = o + length(y)
125+
# o
126+
# end
127+
# flat, off, len[]
128+
# end
78129

79130
struct _TrainableStructWalk <: AbstractWalk end
80131

@@ -97,10 +148,18 @@ function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _Trai
97148
_getat(y, o, flat)
98149
end
99150
end
151+
# (mutating version, same arguments & same return)
152+
function _rebuild!(x, off, flat::AbstractVector, len = length(flat); walk = _Trainable_biwalk(), kw...)
153+
len == length(flat) || throw(DimensionMismatch("Rebuild expected a vector of length $len, got $(length(flat))"))
154+
fmap(x, off; exclude = isnumeric, walk, kw...) do y, o
155+
copyto!(y, _getat(y, o, flat, view))
156+
end
157+
x
158+
end
100159

101-
_getat(y::Number, o::Int, flat::AbstractVector) = ProjectTo(y)(flat[o + 1])
102-
_getat(y::AbstractArray, o::Int, flat::AbstractVector) =
103-
ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes
160+
_getat(y::Number, o::Int, flat::AbstractVector, _...) = ProjectTo(y)(flat[o + 1])
161+
_getat(y::AbstractArray, o::Int, flat::AbstractVector, get=getindex) =
162+
ProjectTo(y)(reshape(get(flat, o .+ (1:length(y))), axes(y))) # ProjectTo is just correcting eltypes
104163

105164
struct _Trainable_biwalk <: AbstractWalk end
106165

@@ -135,6 +194,10 @@ function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat, len; kw...)
135194
_rebuild_back(dx) = (NoT, NoT, NoT, _grad!(x, unthunk(dx), off, _zero(flat)), NoT)
136195
_rebuild(x, off, flat, len; kw...), _rebuild_back
137196
end
197+
function ChainRulesCore.rrule(::typeof(_rebuild!), x, off, flat, len; kw...)
198+
_rebuild!_back(dx) = (NoT, NoT, NoT, _grad!(x, unthunk(dx), off, _zero(flat)), NoT)
199+
_rebuild!(x, off, flat, len; kw...), _rebuild!_back
200+
end
138201

139202
_zero(x) = map!(zero, similar(x, float(eltype(x))), x) # mutable zero array for _grad!
140203
ChainRulesCore.@non_differentiable _zero(x)

test/destructure.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@ m9 = (a = m1, b = mat, c = [mat, m1])
2424
@test destructure(m9)[1] == 1:7
2525

2626
@test destructure(m1)[2](7:9) == [7,8,9]
27+
@test m1 == 1:3 # not mutated
2728
@test destructure(m2)[2](4:9) == ([4,5,6], [7,8,9])
2829
@test destructure(m3)[2](4:9) == (x = [4,5,6], y = sin, z = [7,8,9])
30+
@test m3.z == 4:6 # not mutated
2931
m4′ = destructure(m4)[2](4:9)
3032
@test m4′ == (x = [4,5,6], y = [4,5,6], z = [7,8,9])
3133
@test m4′.x === m4′.y
@@ -60,11 +62,31 @@ m9 = (a = m1, b = mat, c = [mat, m1])
6062
@test_throws Exception destructure(m7)[2]([10,20,30,40])
6163
end
6264

65+
@testset "destructure!" begin
66+
m3′ = deepcopy(m3)
67+
@test destructure!(m3′)[1] == 1:6
68+
@test destructure!(m3′)[2](4:9) == (x = [4,5,6], y = sin, z = [7,8,9])
69+
@test m3′ == (x = [4,5,6], y = sin, z = [7,8,9])
70+
71+
m7′ = deepcopy(m7)
72+
@test destructure!(m7′)[1] == 1:3
73+
destructure!(m7′)[2]([10,20,30])
74+
@test m7′.a == (sin, [10,20,30])
75+
@test m7′.b == (cos, [4,5,6])
76+
@test m7′.c == (tan, [7,8,9])
77+
78+
# errors
79+
@test_throws Exception destructure!(m7)[2]([10,20])
80+
@test_throws Exception destructure!(m7)[2]([10,20,30,40])
81+
end
82+
6383
@testset "gradient of flatten" begin
6484
@test gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0]
85+
@test gradient(m -> destructure!(m)[1][1], m1)[1] == [1,0,0]
6586
@test gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0])
6687
@test gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing)
6788
@test gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0])
89+
@test gradient(m -> destructure!(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0])
6890
@test gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0])
6991

7092
g5 = gradient(m -> destructure(m)[1][3], m5)[1]
@@ -206,6 +228,26 @@ end
206228
end
207229
end
208230

231+
@testset "gradient of rebuild!" begin
232+
re1 = destructure!(deepcopy(m1))[2]
233+
@test gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0]
234+
235+
re2 = destructure!(deepcopy(m2))[2]
236+
@test gradient(x -> re2(x)[1][2], rand(6))[1] == [0,1,0,0,0,0]
237+
238+
re3 = destructure!(deepcopy(m3))[2]
239+
@test gradient(x -> re3(x).x[3], rand(6))[1] == [0,0,1,0,0,0]
240+
@test gradient(x -> re3(x).z[1], rand(6))[1] == [0,0,0,1,0,0]
241+
242+
re4 = destructure!(deepcopy(m4))[2]
243+
@test gradient(x -> re4(x).x[1], rand(6))[1] == [1,0,0,0,0,0]
244+
@test gradient(x -> re4(x).y[2], rand(6))[1] == [0,1,0,0,0,0]
245+
@test gradient(rand(6)) do x
246+
m = re4(x)
247+
m.x[1] + 2*m.y[2] + 3*m.z[3]
248+
end[1] == [1,2,0, 0,0,3]
249+
end
250+
209251
@testset "Flux issue 1826" begin
210252
v, re = destructure((x=[1,2.0], y=[3,4,5.0]))
211253
@test gradient(zero(v)) do w

0 commit comments

Comments
 (0)