1
1
module LuxCore
2
2
3
- using DocStringExtensions
4
3
using Functors, Random, Setfield
5
4
6
5
function _default_rng ()
7
- @static if VERSION >= v " 1.7"
8
- return Xoshiro (1234 )
9
- else
10
- return MersenneTwister (1234 )
11
- end
6
+ rng = Random. default_rng ()
7
+ Random. seed! (rng, 1234 )
8
+ return rng
12
9
end
13
10
14
11
"""
15
- $(TYPEDEF)
12
+ abstract type AbstractExplicitLayer
16
13
17
14
Abstract Type for all Lux Layers
18
15
@@ -36,7 +33,7 @@ See also [`AbstractExplicitContainerLayer`](@ref)
36
33
abstract type AbstractExplicitLayer end
37
34
38
35
"""
39
- $(TYPEDSIGNATURES )
36
+ initialparameters(rng::AbstractRNG, layer )
40
37
41
38
Generate the initial parameters of the layer `l`.
42
39
"""
@@ -45,18 +42,36 @@ function initialparameters(rng::AbstractRNG, l::NamedTuple)
45
42
return map (Base. Fix1 (initialparameters, rng), l)
46
43
end
47
44
initialparameters (:: AbstractRNG , :: Nothing ) = NamedTuple ()
45
+ function initialparameters (rng:: AbstractRNG , l:: Union{Tuple, AbstractArray} )
46
+ any (Base. Fix2 (isa, AbstractExplicitLayer), l) &&
47
+ return map (Base. Fix1 (initialparameters, rng), l)
48
+ throw (MethodError (initialparameters, (rng, l)))
49
+ end
50
+ function initialparameters (rng:: AbstractRNG , l)
51
+ contains_lux_layer (l) && return fmap (Base. Fix1 (initialparameters, rng), l)
52
+ throw (MethodError (initialparameters, (rng, l)))
53
+ end
48
54
49
55
"""
50
- $(TYPEDSIGNATURES )
56
+ initialstates(rng::AbstractRNG, layer )
51
57
52
58
Generate the initial states of the layer `l`.
53
59
"""
54
60
initialstates (:: AbstractRNG , :: AbstractExplicitLayer ) = NamedTuple ()
55
61
initialstates (rng:: AbstractRNG , l:: NamedTuple ) = map (Base. Fix1 (initialstates, rng), l)
56
62
initialstates (:: AbstractRNG , :: Nothing ) = NamedTuple ()
63
+ function initialstates (rng:: AbstractRNG , l:: Union{Tuple, AbstractArray} )
64
+ any (Base. Fix2 (isa, AbstractExplicitLayer), l) &&
65
+ return map (Base. Fix1 (initialstates, rng), l)
66
+ throw (MethodError (initialstates, (rng, l)))
67
+ end
68
+ function initialstates (rng:: AbstractRNG , l)
69
+ contains_lux_layer (l) && return fmap (Base. Fix1 (initialstates, rng), l)
70
+ throw (MethodError (initialstates, (rng, l)))
71
+ end
57
72
58
73
"""
59
- $(TYPEDSIGNATURES )
74
+ parameterlength(layer )
60
75
61
76
Return the total number of parameters of the layer `l`.
62
77
"""
69
84
parameterlength (a:: AbstractArray ) = length (a)
70
85
71
86
"""
72
- $(TYPEDSIGNATURES )
87
+ statelength(layer )
73
88
74
89
Return the total number of states of the layer `l`.
75
90
"""
76
91
statelength (l:: AbstractExplicitLayer ) = statelength (initialstates (_default_rng (), l))
77
92
statelength (nt:: Union{NamedTuple, Tuple} ) = length (nt) == 0 ? 0 : sum (statelength, nt)
78
93
statelength (a:: AbstractArray ) = length (a)
79
- statelength (x :: Union{Number, Symbol, Val, <:AbstractRNG} ) = 1
94
+ statelength (:: Any ) = 1
80
95
81
96
"""
82
- $(TYPEDSIGNATURES )
97
+ setup(rng::AbstractRNG, layer )
83
98
84
99
Shorthand for getting the parameters and states of the layer `l`. Is equivalent to
85
100
`(initialparameters(rng, l), initialstates(rng, l))`.
@@ -90,18 +105,14 @@ This function is not pure, it mutates `rng`.
90
105
91
106
:::
92
107
"""
93
- function setup (rng:: AbstractRNG , l:: AbstractExplicitLayer )
94
- return (initialparameters (rng, l), initialstates (rng, l))
95
- end
108
+ setup (rng:: AbstractRNG , l) = (initialparameters (rng, l), initialstates (rng, l))
96
109
97
110
"""
98
- $(TYPEDSIGNATURES )
111
+ apply(model, x, ps, st )
99
112
100
113
Simply calls `model(x, ps, st)`
101
114
"""
102
- function apply (model:: AbstractExplicitLayer , x, ps, st:: NamedTuple )
103
- return model (x, ps, st)
104
- end
115
+ apply (model:: AbstractExplicitLayer , x, ps, st) = model (x, ps, st)
105
116
106
117
"""
107
118
display_name(layer::AbstractExplicitLayer)
@@ -120,7 +131,7 @@ Base.show(io::IO, x::AbstractExplicitLayer) = print(io, "$(display_name(x))()")
120
131
121
132
# Abstract Container Layers
122
133
"""
123
- $(TYPEDEF)
134
+ abstract type AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayer
124
135
125
136
Abstract Container Type for certain Lux Layers. `layers` is a tuple containing fieldnames
126
137
for the layer, and constructs the parameters and states using those.
@@ -171,21 +182,22 @@ end
171
182
172
183
# Test Mode
173
184
"""
174
- $(TYPEDSIGNATURES )
185
+ testmode(st::NamedTuple )
175
186
176
187
Make all occurances of `training` in state `st` -- `Val(false)`.
177
188
"""
178
189
testmode (st:: NamedTuple ) = update_state (st, :training , Val (false ))
179
190
180
191
"""
181
- $(TYPEDSIGNATURES )
192
+ trainmode(st::NamedTuple )
182
193
183
194
Make all occurances of `training` in state `st` -- `Val(true)`.
184
195
"""
185
196
trainmode (st:: NamedTuple ) = update_state (st, :training , Val (true ))
186
197
187
198
"""
188
- $(TYPEDSIGNATURES)
199
+ update_state(st::NamedTuple, key::Symbol, value;
200
+ layer_check=_default_layer_check(key))
189
201
190
202
Recursively update all occurances of the `key` in the state `st` with the `value`.
191
203
"""
@@ -202,4 +214,42 @@ function _default_layer_check(key)
202
214
return _default_layer_check_closure
203
215
end
204
216
217
+ """
218
+ contains_lux_layer(l) -> Bool
219
+
220
+ Check if the structure `l` is a Lux AbstractExplicitLayer or a container of such a layer.
221
+ """
222
+ function contains_lux_layer (l)
223
+ return check_fmap_condition (Base. Fix2 (isa, AbstractExplicitLayer),
224
+ AbstractExplicitLayer, l)
225
+ end
226
+
227
+ """
228
+ check_fmap_condition(cond, tmatch, x) -> Bool
229
+
230
+ `fmap`s into the structure `x` and see if `cond` is statisfied for any of the leaf
231
+ elements.
232
+
233
+ ## Arguments
234
+
235
+ * `cond` - A function that takes a single argument and returns a `Bool`.
236
+ * `tmatch` - A shortcut to check if `x` is of type `tmatch`. Can be disabled by passing
237
+ `nothing`.
238
+ * `x` - The structure to check.
239
+
240
+ ## Returns
241
+
242
+ A Boolean Value
243
+ """
244
+ function check_fmap_condition (cond, tmatch, x)
245
+ tmatch != = nothing && x isa tmatch && return true
246
+ matched = Ref (false )
247
+ function __check (l)
248
+ cond (l) && (matched[] = true )
249
+ return l
250
+ end
251
+ fmap (__check, x)
252
+ return matched[]
253
+ end
254
+
205
255
end
0 commit comments