@@ -69,7 +69,7 @@ extraChain(::Tuple{}, x) = ()
6969
7070
7171"""
72- Dense(in, out, σ = identity; bias = true, init = glorot_uniform)
72+ Dense(in, out, σ= identity; bias= true, init= glorot_uniform)
7373 Dense(W::AbstractMatrix, [bias, σ])
7474
7575Create a traditional `Dense` layer, whose forward pass is given by:
@@ -81,7 +81,7 @@ as an `in × N` matrix, or any array with `size(x,1) == in`.
8181The out `y` will be a vector of length `out`, or a batch with
8282`size(y) == (out, size(x)[2:end]...)`
8383
84- Keyword `bias = false` will switch off trainable bias for the layer.
84+ Keyword `bias= false` will switch off trainable bias for the layer.
8585The initialisation of the weight matrix is `W = init(out, in)`, calling the function
8686given to keyword `init`, with default [`glorot_uniform`](@doc Flux.glorot_uniform).
8787The weight matrix and/or the bias vector (of length `out`) may also be provided explicitly.
@@ -109,41 +109,45 @@ julia> Flux.params(d1) # no trainable bias
109109Params([[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]])
110110```
111111"""
112- struct Dense{F,S <: AbstractArray ,T }
113- weight:: S
114- bias:: T
112+ struct Dense{F, M <: AbstractMatrix , B }
113+ weight:: M
114+ bias:: B
115115 σ:: F
116+ function Dense (W:: M , bias = true , σ:: F = identity) where {M<: AbstractMatrix , F}
117+ b = create_bias (W, bias, size (W,1 ))
118+ new {F,M,typeof(b)} (W, b, σ)
119+ end
116120end
117121
118- Dense (W, b) = Dense (W, b, identity)
122+ function Dense (in:: Integer , out:: Integer , σ = identity;
123+ initW = nothing , initb = nothing ,
124+ init = glorot_uniform, bias= true )
119125
120- Dense (W:: AbstractArray , b:: Bool = true , σ = identity) =
121- Dense (W, create_bias (W, b, size (W,1 )), σ)
122-
123- function Dense (in:: Integer , out:: Integer , σ = identity; initW = nothing ,
124- init = glorot_uniform, initb = nothing , bias:: Bool = true )
125- if initW != = nothing
126- Base. depwarn (" initW is deprecated, please use the `init` keyword instead" , :Dense )
127- init = initW
126+ W = if initW != = nothing
127+ Base. depwarn (" keyword initW is deprecated, please use init (which similarly accepts a funtion like randn)" , :Dense )
128+ initW (out, in)
129+ else
130+ init (out, in)
128131 end
129132
130- if initb != = nothing
131- Base. depwarn (" initb is deprecated, please use the array based constructors instead " , :Dense )
132- initb = initb
133+ b = if bias === true && initb != = nothing
134+ Base. depwarn (" keyword initb is deprecated, please simply supply the bias vector, bias=initb(out) " , :Dense )
135+ initb (out)
133136 else
134- initb = zeros
137+ bias
135138 end
136- Dense (init (out, in), bias ? initb (out) : Zeros (), σ)
139+
140+ return Dense (W, b, σ)
137141end
138142
139143@functor Dense
140144
141145function (a:: Dense )(x:: AbstractVecOrMat )
142146 W, b, σ = a. weight, a. bias, a. σ
143- σ .(W * x .+ b)
147+ return σ .(W* x .+ b)
144148end
145149
146- (a:: Dense )(x) =
150+ (a:: Dense )(x:: AbstractArray ) =
147151 reshape (a (reshape (x, size (x,1 ), :)), :, size (x)[2 : end ]. .. )
148152
149153function Base. show (io:: IO , l:: Dense )
@@ -292,6 +296,7 @@ If `x` and `y` are matrices, then each column of the output `z = B(x, y)` is of
292296with `B` a Bilinear layer.
293297
294298If `y` is not given, it is taken to be equal to `x`, i.e. `B(x) == B(x, x)`
299+
295300The two inputs may also be provided as a tuple, `B((x, y)) == B(x, y)`,
296301which is accepted as the input to a `Chain`.
297302
@@ -300,7 +305,6 @@ By default the bias vector is `zeros(Float32, out)`, option `bias=false` will sw
300305trainable bias. Either of these may be provided explicitly.
301306
302307# Examples
303-
304308```jldoctest
305309julia> x, y = randn(Float32, 5, 32), randn(Float32, 5, 32);
306310
@@ -417,4 +421,4 @@ function Base.show(io::IO, m::Parallel)
417421 print (io, " Parallel(" , m. connection, " , " )
418422 join (io, m. layers, " , " )
419423 print (io, " )" )
420- end
424+ end
0 commit comments