Skip to content

Commit 75ff840

Browse files
authoredSep 3, 2023
Merge pull request #8 from LuxDL/ap/better_defaults
Add better defaults to initialparams/states
2 parents 03520dc + 4d8c7a5 commit 75ff840

File tree

3 files changed

+111
-27
lines changed

3 files changed

+111
-27
lines changed
 

‎lib/LuxCore/Project.toml

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
name = "LuxCore"
22
uuid = "bb33d45b-7691-41d6-9220-0943567d0623"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "0.1.5"
4+
version = "0.1.6"
55

66
[deps]
7-
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
87
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
98
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
109
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1110

1211
[compat]
13-
DocStringExtensions = "0.9"
1412
Functors = "0.2, 0.3, 0.4"
1513
Setfield = "0.8, 1"
1614
julia = "1.6"

‎lib/LuxCore/src/LuxCore.jl

+74-24
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
module LuxCore
22

3-
using DocStringExtensions
43
using Functors, Random, Setfield
54

65
function _default_rng()
7-
@static if VERSION >= v"1.7"
8-
return Xoshiro(1234)
9-
else
10-
return MersenneTwister(1234)
11-
end
6+
rng = Random.default_rng()
7+
Random.seed!(rng, 1234)
8+
return rng
129
end
1310

1411
"""
15-
$(TYPEDEF)
12+
abstract type AbstractExplicitLayer
1613
1714
Abstract Type for all Lux Layers
1815
@@ -36,7 +33,7 @@ See also [`AbstractExplicitContainerLayer`](@ref)
3633
abstract type AbstractExplicitLayer end
3734

3835
"""
39-
$(TYPEDSIGNATURES)
36+
initialparameters(rng::AbstractRNG, layer)
4037
4138
Generate the initial parameters of the layer `l`.
4239
"""
@@ -45,18 +42,36 @@ function initialparameters(rng::AbstractRNG, l::NamedTuple)
4542
return map(Base.Fix1(initialparameters, rng), l)
4643
end
4744
initialparameters(::AbstractRNG, ::Nothing) = NamedTuple()
45+
function initialparameters(rng::AbstractRNG, l::Union{Tuple, AbstractArray})
46+
any(Base.Fix2(isa, AbstractExplicitLayer), l) &&
47+
return map(Base.Fix1(initialparameters, rng), l)
48+
throw(MethodError(initialparameters, (rng, l)))
49+
end
50+
function initialparameters(rng::AbstractRNG, l)
51+
contains_lux_layer(l) && return fmap(Base.Fix1(initialparameters, rng), l)
52+
throw(MethodError(initialparameters, (rng, l)))
53+
end
4854

4955
"""
50-
$(TYPEDSIGNATURES)
56+
initialstates(rng::AbstractRNG, layer)
5157
5258
Generate the initial states of the layer `l`.
5359
"""
5460
initialstates(::AbstractRNG, ::AbstractExplicitLayer) = NamedTuple()
5561
initialstates(rng::AbstractRNG, l::NamedTuple) = map(Base.Fix1(initialstates, rng), l)
5662
initialstates(::AbstractRNG, ::Nothing) = NamedTuple()
63+
function initialstates(rng::AbstractRNG, l::Union{Tuple, AbstractArray})
64+
any(Base.Fix2(isa, AbstractExplicitLayer), l) &&
65+
return map(Base.Fix1(initialstates, rng), l)
66+
throw(MethodError(initialstates, (rng, l)))
67+
end
68+
function initialstates(rng::AbstractRNG, l)
69+
contains_lux_layer(l) && return fmap(Base.Fix1(initialstates, rng), l)
70+
throw(MethodError(initialstates, (rng, l)))
71+
end
5772

5873
"""
59-
$(TYPEDSIGNATURES)
74+
parameterlength(layer)
6075
6176
Return the total number of parameters of the layer `l`.
6277
"""
@@ -69,17 +84,17 @@ end
6984
parameterlength(a::AbstractArray) = length(a)
7085

7186
"""
72-
$(TYPEDSIGNATURES)
87+
statelength(layer)
7388
7489
Return the total number of states of the layer `l`.
7590
"""
7691
statelength(l::AbstractExplicitLayer) = statelength(initialstates(_default_rng(), l))
7792
statelength(nt::Union{NamedTuple, Tuple}) = length(nt) == 0 ? 0 : sum(statelength, nt)
7893
statelength(a::AbstractArray) = length(a)
79-
statelength(x::Union{Number, Symbol, Val, <:AbstractRNG}) = 1
94+
statelength(::Any) = 1
8095

8196
"""
82-
$(TYPEDSIGNATURES)
97+
setup(rng::AbstractRNG, layer)
8398
8499
Shorthand for getting the parameters and states of the layer `l`. Is equivalent to
85100
`(initialparameters(rng, l), initialstates(rng, l))`.
@@ -90,18 +105,14 @@ This function is not pure, it mutates `rng`.
90105
91106
:::
92107
"""
93-
function setup(rng::AbstractRNG, l::AbstractExplicitLayer)
94-
return (initialparameters(rng, l), initialstates(rng, l))
95-
end
108+
setup(rng::AbstractRNG, l) = (initialparameters(rng, l), initialstates(rng, l))
96109

97110
"""
98-
$(TYPEDSIGNATURES)
111+
apply(model, x, ps, st)
99112
100113
Simply calls `model(x, ps, st)`
101114
"""
102-
function apply(model::AbstractExplicitLayer, x, ps, st::NamedTuple)
103-
return model(x, ps, st)
104-
end
115+
apply(model::AbstractExplicitLayer, x, ps, st) = model(x, ps, st)
105116

106117
"""
107118
display_name(layer::AbstractExplicitLayer)
@@ -120,7 +131,7 @@ Base.show(io::IO, x::AbstractExplicitLayer) = print(io, "$(display_name(x))()")
120131

121132
# Abstract Container Layers
122133
"""
123-
$(TYPEDEF)
134+
abstract type AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayer
124135
125136
Abstract Container Type for certain Lux Layers. `layers` is a tuple containing fieldnames
126137
for the layer, and constructs the parameters and states using those.
@@ -171,21 +182,22 @@ end
171182

172183
# Test Mode
173184
"""
174-
$(TYPEDSIGNATURES)
185+
testmode(st::NamedTuple)
175186
176187
Make all occurances of `training` in state `st` -- `Val(false)`.
177188
"""
178189
testmode(st::NamedTuple) = update_state(st, :training, Val(false))
179190

180191
"""
181-
$(TYPEDSIGNATURES)
192+
trainmode(st::NamedTuple)
182193
183194
Make all occurances of `training` in state `st` -- `Val(true)`.
184195
"""
185196
trainmode(st::NamedTuple) = update_state(st, :training, Val(true))
186197

187198
"""
188-
$(TYPEDSIGNATURES)
199+
update_state(st::NamedTuple, key::Symbol, value;
200+
layer_check=_default_layer_check(key))
189201
190202
Recursively update all occurances of the `key` in the state `st` with the `value`.
191203
"""
@@ -202,4 +214,42 @@ function _default_layer_check(key)
202214
return _default_layer_check_closure
203215
end
204216

217+
"""
218+
contains_lux_layer(l) -> Bool
219+
220+
Check if the structure `l` is a Lux AbstractExplicitLayer or a container of such a layer.
221+
"""
222+
function contains_lux_layer(l)
223+
return check_fmap_condition(Base.Fix2(isa, AbstractExplicitLayer),
224+
AbstractExplicitLayer, l)
225+
end
226+
227+
"""
228+
check_fmap_condition(cond, tmatch, x) -> Bool
229+
230+
`fmap`s into the structure `x` and see if `cond` is statisfied for any of the leaf
231+
elements.
232+
233+
## Arguments
234+
235+
* `cond` - A function that takes a single argument and returns a `Bool`.
236+
* `tmatch` - A shortcut to check if `x` is of type `tmatch`. Can be disabled by passing
237+
`nothing`.
238+
* `x` - The structure to check.
239+
240+
## Returns
241+
242+
A Boolean Value
243+
"""
244+
function check_fmap_condition(cond, tmatch, x)
245+
tmatch !== nothing && x isa tmatch && return true
246+
matched = Ref(false)
247+
function __check(l)
248+
cond(l) && (matched[] = true)
249+
return l
250+
end
251+
fmap(__check, x)
252+
return matched[]
253+
end
254+
205255
end

‎lib/LuxCore/test/runtests.jl

+36
Original file line numberDiff line numberDiff line change
@@ -194,4 +194,40 @@ end
194194

195195
@test LuxCore.display_name(model) == "StructWithName"
196196
end
197+
198+
@testset "initialparameter/initialstate for Default Containers" begin
199+
models1 = [Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))),
200+
Chain2(Dense(5, 10), Dense(10, 5)), [Dense(5, 10), Dense(10, 5)]]
201+
models2 = [Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))),
202+
Chain2(Dense(5, 10), Dense(10, 5)), (Dense(5, 10), Dense(10, 5))]
203+
204+
for models in (models1, models2)
205+
ps, st = LuxCore.setup(rng, models)
206+
@test length(ps) == length(models)
207+
@test length(st) == length(models)
208+
@test typeof(ps[1]) == typeof(LuxCore.initialparameters(rng, models[1]))
209+
@test typeof(ps[2]) == typeof(LuxCore.initialparameters(rng, models[2]))
210+
@test typeof(ps[3][1]) == typeof(LuxCore.initialparameters(rng, models[3][1]))
211+
@test typeof(ps[3][2]) == typeof(LuxCore.initialparameters(rng, models[3][2]))
212+
@test typeof(st[1]) == typeof(LuxCore.initialstates(rng, models[1]))
213+
@test typeof(st[2]) == typeof(LuxCore.initialstates(rng, models[2]))
214+
@test typeof(st[3][1]) == typeof(LuxCore.initialstates(rng, models[3][1]))
215+
@test typeof(st[3][2]) == typeof(LuxCore.initialstates(rng, models[3][2]))
216+
end
217+
end
218+
219+
@testset "Convenience Checks" begin
220+
models1 = [Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))),
221+
Chain2(Dense(5, 10), Dense(10, 5)), [Dense(5, 10), Dense(10, 5)]]
222+
223+
@test LuxCore.contains_lux_layer(models1)
224+
225+
models2 = [1, 2, 3, 4]
226+
227+
@test !LuxCore.contains_lux_layer(models2)
228+
229+
models3 = [1, 2, 3, (; a=Dense(5, 10), b=Dense(10, 5))]
230+
231+
@test LuxCore.contains_lux_layer(models3)
232+
end
197233
end

0 commit comments

Comments
 (0)
Please sign in to comment.