@@ -9,6 +9,8 @@ Copies all [`trainable`](@ref), [`isnumeric`](@ref) parameters in the model
9
9
to a vector, and returns also a function which reverses this transformation.
10
10
Differentiable.
11
11
12
+ See also [`destructure!`](@ref).
13
+
12
14
# Example
13
15
```jldoctest
14
16
julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0 + 4.0im])))
@@ -31,6 +33,36 @@ function destructure(x)
31
33
flat, Restructure (x, off, len)
32
34
end
33
35
36
+ """
37
+ destructure!(model) -> vector, reconstructor
38
+
39
+ This is a variant of [`destructure`](@ref), whose reconstruction function mutates the model.
40
+ Requires that all trainable parameters in the model be mutable arrays!
41
+
42
+ # Example
43
+ ```jldoctest
44
+ julia> m = (x=[1.0, 2.0], y=(sin, Float32[3.0 4.0], cos))
45
+
46
+ julia> v, re! = destructure!(m)
47
+ ([1.0, 2.0, 3.0, 4.0], Restructure!(NamedTuple, ..., 4))
48
+
49
+ julia> m === re!([3, 5, 7, 9]) # mutates the original m, and returns it
50
+ true
51
+
52
+ julia> m
53
+ (x = [3.0, 5.0], y = (sin, Float32[7.0 9.0], cos))
54
+ ```
55
+ """
56
+ function destructure! (x)
57
+ flat, off, len = _flatten (x)
58
+ flat, Restructure! (x, off, len)
59
+ end
60
+
61
+ # function destructure!(flat::AbstractVector, x)
62
+ # flat, off, len = _flatten!(flat, x)
63
+ # flat, Restructure!(x, off, len)
64
+ # end
65
+
34
66
"""
35
67
Restructure(Model, ..., length)
36
68
@@ -55,12 +87,20 @@ struct Restructure{T,S}
55
87
model:: T
56
88
offsets:: S
57
89
length:: Int
90
+ mutate:: Bool
58
91
end
59
- (re:: Restructure )(flat:: AbstractVector ) = _rebuild (re. model, re. offsets, flat, re. length)
92
+ Restructure (model, offsets, length) = Restructure (model, offsets, length, false )
93
+ Restructure! (model, offsets, length) = Restructure (model, offsets, length, true )
94
+
95
+ (re:: Restructure )(flat:: AbstractVector ) = re. mutate ? _rebuild! (re. model, re. offsets, flat, re. length) : _rebuild (re. model, re. offsets, flat, re. length)
60
96
(re:: Restructure )(x, flat:: AbstractVector ) = re (flat)(x)
61
- Base. show (io:: IO , re:: Restructure{T} ) where T = print (io, " Restructure(" , T. name. name, " , ..., " , re. length, " )" )
62
97
Base. length (re:: Restructure ) = re. length
63
98
99
+ function Base. show (io:: IO , re:: Restructure{T} ) where T
100
+ print (io, " Restructure" , re. mutate ? " !" : " " )
101
+ print (io, " (" , T. name. name, " , ..., " , re. length, " )" )
102
+ end
103
+
64
104
# This flattens a model, and returns a web of offsets for later use:
65
105
function _flatten (x)
66
106
isnumeric (x) && return vcat (_vec (x)), 0 , length (x) # trivial case
@@ -75,6 +115,17 @@ function _flatten(x)
75
115
isempty (arrays) && return Bool[], off, 0
76
116
reduce (vcat, arrays), off, len[]
77
117
end
118
+ # function _flatten!(flat, x)
119
+ # isnumeric(x) && return copyto!(flat, _vec(x)) # trivial case
120
+ # len = Ref(0)
121
+ # off = fmap(x; exclude = isnumeric, walk = _TrainableStructWalk()) do y
122
+ # o = len[]
123
+ # copyto!(flat, o, _vec(y))
124
+ # len[] = o + length(y)
125
+ # o
126
+ # end
127
+ # flat, off, len[]
128
+ # end
78
129
79
130
struct _TrainableStructWalk <: AbstractWalk end
80
131
@@ -97,10 +148,18 @@ function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _Trai
97
148
_getat (y, o, flat)
98
149
end
99
150
end
151
+ # (mutating version, same arguments & same return)
152
+ function _rebuild! (x, off, flat:: AbstractVector , len = length (flat); walk = _Trainable_biwalk (), kw... )
153
+ len == length (flat) || throw (DimensionMismatch (" Rebuild expected a vector of length $len , got $(length (flat)) " ))
154
+ fmap (x, off; exclude = isnumeric, walk, kw... ) do y, o
155
+ copyto! (y, _getat (y, o, flat, view))
156
+ end
157
+ x
158
+ end
100
159
101
- _getat (y:: Number , o:: Int , flat:: AbstractVector ) = ProjectTo (y)(flat[o + 1 ])
102
- _getat (y:: AbstractArray , o:: Int , flat:: AbstractVector ) =
103
- ProjectTo (y)(reshape (flat[ o .+ (1 : length (y))] , axes (y))) # ProjectTo is just correcting eltypes
160
+ _getat (y:: Number , o:: Int , flat:: AbstractVector , _ ... ) = ProjectTo (y)(flat[o + 1 ])
161
+ _getat (y:: AbstractArray , o:: Int , flat:: AbstractVector , get = getindex ) =
162
+ ProjectTo (y)(reshape (get ( flat, o .+ (1 : length (y))) , axes (y))) # ProjectTo is just correcting eltypes
104
163
105
164
struct _Trainable_biwalk <: AbstractWalk end
106
165
@@ -135,6 +194,10 @@ function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat, len; kw...)
135
194
_rebuild_back (dx) = (NoT, NoT, NoT, _grad! (x, unthunk (dx), off, _zero (flat)), NoT)
136
195
_rebuild (x, off, flat, len; kw... ), _rebuild_back
137
196
end
197
+ function ChainRulesCore. rrule (:: typeof (_rebuild!), x, off, flat, len; kw... )
198
+ _rebuild!_back (dx) = (NoT, NoT, NoT, _grad! (x, unthunk (dx), off, _zero (flat)), NoT)
199
+ _rebuild! (x, off, flat, len; kw... ), _rebuild!_back
200
+ end
138
201
139
202
_zero (x) = map! (zero, similar (x, float (eltype (x))), x) # mutable zero array for _grad!
140
203
ChainRulesCore. @non_differentiable _zero (x)
0 commit comments