Skip to content

Commit 1e1da28

Browse files
authored
Speed-up normalization layers (#2220)
* Speedup normalization layers * Revert 'normalise' change
1 parent 7088682 commit 1e1da28

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/layers/normalise.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,14 @@ function _norm_layer_forward(
238238
end
239239

240240
eps = convert(float(T), l.ϵ)
241-
o = _norm_layer_forward(x, μ, σ², eps)
242-
hasaffine(l) || return l.λ.(o)
241+
hasaffine(l) || return l.λ.(_norm_layer_forward(x, μ, σ², eps))
243242

244243
γ = reshape(l.γ, affine_shape)
245244
β = reshape(l.β, affine_shape)
246-
return l.λ.(γ .* o .+ β)
245+
246+
scale = γ ./ sqrt.(σ² .+ eps)
247+
bias = -scale .* μ .+ β
248+
l.λ.(scale .* x .+ bias)
247249
end
248250

249251
@inline _norm_layer_forward(x, μ, σ², ϵ) = (x .- μ) ./ sqrt.(σ² .+ ϵ)

0 commit comments

Comments
 (0)