From c118028ad4f2c81907d3a904923dba2316168ce1 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Sat, 26 Nov 2022 21:55:14 -0800 Subject: [PATCH 1/6] Add norm functions These roughly correspond to Flux's `*Norm` layers. --- src/NNlib.jl | 3 +- src/normalization.jl | 306 ++++++++++++++++++++++++++++++++++++++++++ src/utils.jl | 12 ++ test/normalization.jl | 273 +++++++++++++++++++++++++++++++++++++ test/runtests.jl | 4 + 5 files changed, 596 insertions(+), 2 deletions(-) create mode 100644 test/normalization.jl diff --git a/src/NNlib.jl b/src/NNlib.jl index 8450a0261..04be89387 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -1,7 +1,7 @@ module NNlib import Atomix -import ChainRulesCore: rrule +import ChainRulesCore: rrule, @ignore_derivatives using Base.Broadcast: broadcasted using Base.Threads @@ -16,7 +16,6 @@ using Pkg using Random using Requires using Statistics -using Statistics: mean const libblas = Base.libblas_name diff --git a/src/normalization.jl b/src/normalization.jl index c06843d38..593dd14b8 100644 --- a/src/normalization.jl +++ b/src/normalization.jl @@ -12,3 +12,309 @@ function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, runnin end y, batchnorm_pullback end + +""" + norm_stats(x, dims) + +Calculates sample mean and (uncorrected) variance of `x` along `dims`. + + - `dims=(1,...,N-2,N)` for BatchNorm + - `dims=(1,...,N-2)` for InstanceNorm and GroupNorm + - `dims=(1,...,S)` where S < N for LayerNorm/Flux.jl/stable/ + +This is more efficient than calling `mean(x; dims)` and `var(x; dims)` separately, +because it can share some computation across both. +Implementors may want to overload this function to use custom kernels and more. +""" +function norm_stats(x, dims) + μ = mean(x; dims) + σ² = var(x; dims, mean = μ, corrected = false) + return μ, σ² +end + +function rrule(::typeof(norm_stats), x, dims) + μ, mean_pullback = rrule(mean, x; dims) + σ², var_pullback = rrule(var, x; dims, mean = μ, corrected = false) + function norm_stats_pullback(dargs) + dμ, dσ² = unthunk(dargs) + dx = ChainRulesCore.add!!(var_pullback(dμ)[2], mean_pullback(dσ²)[2]) + return (NoTangent(), dx, NoTangent()) + end + return (μ, σ²), norm_stats_pullback +end + +_maybe_reshape(::Nothing, _) = nothing +_maybe_reshape(x, dims) = reshape(x, dims) +_apply_scale_bias(x, ::Nothing, ::Nothing) = x +_apply_scale_bias(x, scale, bias) = x .* scale .+ bias + +""" + norm_helper(x, μ, σ², scale::Union{AbstractArray, Nothing}, + bias::Union{AbstractArray, Nothing}, ϵ::Real, affine_size = size(μ)) + +Shared code path for all built-in norm functions. + +`μ` and `σ²` should be calculated on the fly using [`norm_stats`](@ref), +or extracted from an existing collection such as [`RunningStats`](@ref). +`bias` and `scale` are consistent with cuDNN and Flux.Scale. +We opt for `scale` over `weight` to avoid confusion with dense layers. +If the size of the statistics and affine parameters differ, +use `affine_size` to add padding dimensions as required to match the input. +""" +function norm_helper(x, μ, σ², scale::Union{AbstractArray, Nothing}, + bias::Union{AbstractArray, Nothing}, ϵ::Real, affine_size = size(μ)) + @ignore_derivatives if isnothing(scale) != isnothing(bias) + error("both scale and bias must be provided or left as nothing") + end + scale′, bias′ = _maybe_reshape(scale, affine_size), _maybe_reshape(bias, affine_size) + return _apply_scale_bias((x .- μ) ./ sqrt.(σ² .+ ϵ), scale′, bias′) +end + +""" + RunningStats(mean, variance, momentum) + +Contains running mean and variance estimates for stateful norm functions. +`momentum` controls the strength of the moving average update. + +If the parameters are mutable, they will be updated in-place. +Otherwise, they will be replaced wholesale. + +See also [`update_running_stats!`](@ref). +""" +mutable struct RunningStats{M <: AbstractArray, V <: AbstractArray, MT <: Real} + mean::M + variance::V + momentum::MT +end + +# Conditionally pulls running stats or calculates them on the fly. +# Part of the reason this is a dedicated function is to have a more type stable pullback. +function maybe_norm_stats(stats::Union{RunningStats, Nothing}, x, dims, + use_running_stats::Bool) + if stats !== nothing && use_running_stats + # Maintains consistency with mean/var + sz = Base.setindex(Base.reduced_indices(x, dims) |> Base.to_shape, :, ndims(x) - 1) + return reshape(stats.mean, sz), reshape(stats.variance, sz) + end + # No running stats exist or are disabled in inference mode + return norm_stats(x, dims) +end + +# Kludge so we can close over a Union inner pullback type +struct MaybeNormStatsPullback{B, P <: ProjectTo{AbstractArray}} + back::B + projector::P +end +function (pb::MaybeNormStatsPullback)(dargs) + _, dx = unthunk(pb.back(dargs)) + return (NoTangent(), NoTangent(), pb.projector(dx), NoTangent(), NoTangent()) +end +function rrule(::typeof(maybe_norm_stats), stats::Union{RunningStats, Nothing}, x, dims, + use_running_stats::Bool) + project = ProjectTo(x) + noop_back(_) = (NoTangent(), NoTangent()) + if stats === nothing || !use_running_stats + (μ, σ²), back = rrule(norm_stats, x, dims) + else + # The default is to track, so this only happens when a layer is frozen + sz = Base.setindex(Base.reduced_indices(x, dims) |> Base.to_shape, :, ndims(x) - 1) + μ, σ², back = reshape(stats.mean, sz), reshape(stats.variance, sz), noop_back + end + back_type = Union{typeof(noop_back), _rrule_pullback_rt(norm_stats, x, dims)} + return (μ, σ²), MaybeNormStatsPullback{back_type, typeof(project)}(back, project) +end + +""" + update_running_stats!(stats::RunningStats, x::AbstractArray{<:Any, N}, μ, σ², + reduce_dims) where {N} + +Performs a moving average update for layers with tracked statistics. +`μ` and `σ²` are the sample mean and variance, most likely from [`norm_stats`](@ref). +`reduce_dims` should also match the `dims` argument of [`norm_stats`](@ref). + +See also [`RunningStats`](@ref). +""" +function update_running_stats!(stats::RunningStats, x, μ, σ², reduce_dims::Dims) + V = eltype(σ²) + momentum = stats.momentum + res_mtm = one(V) - momentum + m = prod(size(x, i) for i in reduce_dims) + correction = m / (m - one(V)) + + running_mean, running_var = stats.mean, stats.variance + if ChainRulesCore.is_inplaceable_destination(running_mean) + stats.mean .= res_mtm .* running_mean .+ momentum .* vec(μ) + else + stats.mean = res_mtm .* running_mean .+ momentum .* vec(μ) + end + if ChainRulesCore.is_inplaceable_destination(running_var) + stats.variance .= res_mtm .* running_var .+ momentum .* correction .* vec(σ²) + else + stats.variance = res_mtm .* running_var .+ momentum .* correction .* vec(σ²) + end +end + +# Convenience functions +# We follow roughly the same arg order as torch.nn.functional.*_norm: +# input, unique args for this particular norm type, bias + scale, eps; kwargs... + +""" + layernorm(x::AbstractArray{<:Any,N}, ::Val{S}, scale = nothing, bias = nothing, + ϵ=ofeltype(x, 1e-5)) where {N, S} + +Functional [Layer Normalization](https://arxiv.org/abs/1607.06450) operation. + +Normalizes `x` along the first `S` dimensions. + +For an additional learned affine transform, provide a `S`-dimensional `scale` and `bias`. + +See also [`batchnorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref). + +# Examples + +```jldoctest +julia> using Statistics + +julia> xs = rand(3, 3, 3, 2); # a batch of 2 images, each having 3 channels + +julia> y = NNlib.layernorm(xs, Val(3)); + +julia> isapprox(std(y; dims = 1:3), ones(1, 1, 1, 2); atol = 0.1) && + std(y; dims = 1:3) != std(xs; dims = 1:3) +true +``` +""" +function layernorm(x::AbstractArray{<:Any, N}, ::Val{S}, scale = nothing, bias = nothing, + ϵ = ofeltype(x, 1e-5)) where {N, S} + @ignore_derivatives if S > N + throw(DimensionMismatch("got $S reduction dims for $N-dimensional array")) + end + μ, σ² = norm_stats(x, ntuple(identity, S)) + return norm_helper(x, μ, σ², scale, bias, ϵ, size(x)[1:S]) +end + +""" + batchnorm(x::AbstractArray{<:Any, N}, + running_stats::Union{RunningStats, Nothing} = nothing, + scale::Union{AbstractVector, Nothing} = nothing, + bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5); + training::Bool = within_grad()) where {N} + +Functional [Batch Normalization](https://arxiv.org/abs/1502.03167) operation. + +Normalizes `x` along each ``D_1×...×D_{N-2}×1×D_N`` input slice, +where `N-1` is the "channel" (or "feature", for 2D inputs) dimension. + +Provide a [`RunningStats`](@ref) to fix a estimated mean and variance. +`batchnorm` will renormalize the input using these statistics during inference, +and update them using batch-level statistics when training. +To override this behaviour, manually set a value for `training`. + +If specified, `scale` and `bias` will be applied as an additional learned affine transform. + +See also [`layernorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref). +""" +function batchnorm(x::AbstractArray{<:Any, N}, + running_stats::Union{RunningStats, Nothing} = nothing, + scale::Union{AbstractVector, Nothing} = nothing, + bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5); + training::Bool = within_grad()) where {N} + reduce_dims = ((1:(N - 2))..., N) + μ, σ² = maybe_norm_stats(running_stats, x, reduce_dims, !training) + # Because μ and σ² could be updated in-place, we compute the output first + y = norm_helper(x, μ, σ², scale, bias, ϵ) + @ignore_derivatives if running_stats !== nothing && training + update_running_stats!(running_stats, x, μ, σ², reduce_dims) + end + return y +end + +""" + instancenorm(x::AbstractArray{<:Any, N}, + running_stats::Union{RunningStats, Nothing} = nothing, + scale::Union{AbstractVector, Nothing} = nothing, + bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5); + training::Bool = within_grad()) where {N} + +Functional [Instance Normalization](https://arxiv.org/abs/1607.08022) operation. + +Normalizes `x` along each ``D_1×...×D_{N-2}×1×1`` input slice, + +Provide a [`RunningStats`](@ref) to fix a estimated mean and variance. +`instancenorm` will renormalize the input using these statistics during inference, +and update them using batch-level statistics when training. +To override this behaviour, manually set a value for `training`. + +If specified, `scale` and `bias` will be applied as an additional learned affine transform. + +See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`groupnorm`](@ref). +""" +function instancenorm(x::AbstractArray{<:Any, N}, + running_stats::Union{RunningStats, Nothing} = nothing, + scale::Union{AbstractVector, Nothing} = nothing, + bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5); + training::Bool = within_grad()) where {N} + affine_size = (ntuple(_ -> 1, N - 2)..., size(x, N - 1), :) + reduce_dims = ((1:(N - 2))...,) + μ, σ² = maybe_norm_stats(running_stats, x, reduce_dims, !training) + # Because μ and σ² could be updated in-place, we compute the output first + y = norm_helper(x, μ, σ², scale, bias, ϵ, affine_size) + ChainRulesCore.@ignore_derivatives if running_stats !== nothing && training + μ′, σ²′ = mean(μ; dims = N), mean(σ²; dims = N) # Need to sum (C, N) -> (C,) + update_running_stats!(running_stats, x, μ′, σ²′, reduce_dims) + end + return y +end + +""" + groupnorm(x::AbstractArray{<:Any, N}, groups::Integer, + scale::Union{AbstractVector, Nothing} = nothing, + bias::Union{AbstractVector, Nothing} = nothing, + ϵ = ofeltype(x, 1e-5)) where {N} + +Functional [Group Normalization](https://arxiv.org/abs/1803.08494) operation. + +Normalizes `x` along the first `N - 2` (spatial) dimensions, +where `N-1` is the "channel" (or "feature", for 2D inputs) dimension, +and the channel dimension is divided into `groups` groups along which statistics are computed. +The number of channels must be an integer multiple of the number of groups. + +If specified, `scale` and `bias` will be applied as an additional learned affine transform. + +See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`instancenorm`](@ref). + +# Examples + +```jldoctest +julia> using Statistics + +julia> xs = rand(3, 3, 4, 2); # a batch of 2 images, each having 4 channels + +julia> y = NNlib.groupnorm(xs, 4); + +julia> isapprox(std(y[:, :, 1:2, 1]), 1; atol = 0.1) && + std(xs[:, :, 1:2, 1]) != std(y[:, :, 1:2, 1]) +true + +julia> isapprox(std(y[:, :, 3:4, 2]), 1; atol = 0.1) && + std(xs[:, :, 3:4, 2]) != std(y[:, :, 3:4, 2]) +true +``` +""" +function groupnorm(x::AbstractArray{<:Any, N}, groups::Integer, + scale::Union{AbstractVector, Nothing} = nothing, + bias::Union{AbstractVector, Nothing} = nothing, + ϵ = ofeltype(x, 1e-5)) where {N} + sz = size(x) + channels = @ignore_derivatives begin + ch = sz[max(1, N - 1)] + newch, remainder = divrem(ch, groups) + remainder == 0 ? newch : + throw(ArgumentError("channels $ch should be multiple of groups $groups")) + end + affine_size = (ntuple(_ -> 1, N - 2)..., channels, groups, :) + grouped_size = (sz[1:(N - 2)]..., channels, groups, :) + x′ = reshape(x, grouped_size) + μ, σ² = norm_stats(x′, ((1:(N - 2))...,)) + return reshape(norm_helper(x′, μ, σ², scale, bias, ϵ, affine_size), sz) +end diff --git a/src/utils.jl b/src/utils.jl index 3d23e7383..9edfb2112 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -162,3 +162,15 @@ if VERSION < v"1.7.0-DEV.793" end end + +# This is a terrible hack to prevent the spread of type instabilities +# when the pullback type changes depending on runtime information, +# e.g. when a normalization layer is "active" vs "inactive". +function _rrule_pullback_rt(@nospecialize(fn), args...) + rt = Base.promote_op(rrule, typeof(fn), map(typeof, args)...) + rt <: Tuple{<:Any,<:Any} && return rt.parameters[2] + return rt +end + +# Extracted from Flux. Should this have a docstring and/or be in the docs? +ofeltype(x, y) = convert(float(eltype(x)), y) \ No newline at end of file diff --git a/test/normalization.jl b/test/normalization.jl new file mode 100644 index 000000000..128610297 --- /dev/null +++ b/test/normalization.jl @@ -0,0 +1,273 @@ +FiniteDifferences.to_vec(stats::NNlib.RunningStats) = [], _ -> stats + +randn_sample(shape, μ, σ) = randn(rng, shape) .* σ .+ μ +f32_arange(shape...) = Float32.(reshape(1:prod(shape), shape)) + +function make_bn(ch; training = true) + stats, bias, scale = NNlib.RunningStats(zeros(ch), ones(ch), 0.1), zeros(ch), ones(ch) + return x -> NNlib.batchnorm(x, stats, scale, bias; training) +end +function make_in(ch; training = true) + stats, bias, scale = NNlib.RunningStats(zeros(ch), ones(ch), 0.1), zeros(ch), ones(ch) + return x -> NNlib.instancenorm(x, stats, scale, bias; training) +end +function make_gn(ch, groups) + bias, scale = zeros(ch), ones(ch) + return x -> NNlib.groupnorm(x, groups, scale, bias) +end + +@testset "Helpers" begin + # BatchNorm dimensions + let W = 128, C = 4, N = 64 + x = cat([randn_sample((W, W, 1, N), i, i) for i in 1:C]...; dims = 3) + μ, σ² = NNlib.norm_stats(x, (1, 2, 4)) + @test vec(μ)≈1:C rtol=0.05 + @test vec(σ²)≈abs2.(1:C) rtol=0.05 + end + + # LayerNorm dimensions + let W = 128, C = 64, N = 4 + x = cat([randn_sample((W, W, C, 1), i, i) for i in 1:N]...; dims = 4) + μ, σ² = NNlib.norm_stats(x, (1, 2, 3)) + @test vec(μ)≈1:N rtol=0.05 + @test vec(σ²)≈abs2.(1:N) rtol=0.05 + end + + # Group/InstanceNorm dimensions + let W = 128, C = 2, N = 2, shape = (W, W, 1, 1) + x = [randn_sample(shape, 1, 1);;; randn_sample(shape, 2, 2);;;; + randn_sample(shape, 3, 3);;; randn_sample(shape, 4, 4)] + μ, σ² = NNlib.norm_stats(x, (1, 2)) + @test vec(μ)≈1:(C * N) rtol=0.05 + @test vec(σ²)≈abs2.(1:(C * N)) rtol=0.05 + end + + x = rand(rng, 16, 16, 3, 4) + @testset "dims = $dims" for (dims, tsize) in [ + (1, 2, 4) => (1, 1, size(x, 3), 1), + (1, 2, 3) => (1, 1, 1, size(x, 4)), + (1, 2) => (1, 1, size(x, 3), size(x, 4)), + ] + meanvar = (ones(tsize), ones(tsize)) + test_rrule(NNlib.norm_stats, x, dims ⊢ NoTangent(); output_tangent = meanvar) + + running_stats = NNlib.RunningStats(meanvar..., 0.1) + y_ns, back_ns = rrule(NNlib.norm_stats, x, dims) + dx_ns = back_ns(meanvar)[2] + for (stats, training, y, y_ad, dx) in [ + (nothing, true, y_ns, y_ns, dx_ns), + (nothing, false, y_ns, y_ns, dx_ns), + (running_stats, true, y_ns, y_ns, dx_ns), + (running_stats, false, meanvar, meanvar, NoTangent()), + ] + @test NNlib.maybe_norm_stats(stats, x, dims, !training) == y + ŷ, back = rrule(NNlib.maybe_norm_stats, stats, x, dims, !training) + @test ŷ == y_ad + @test back(meanvar) == (NoTangent(), NoTangent(), dx, NoTangent(), NoTangent()) + + test_rrule(NNlib.maybe_norm_stats, stats ⊢ NoTangent(), x, dims ⊢ NoTangent(), + !training; output_tangent = meanvar, check_inferred = false) + end + + ps = ntuple(_ -> rand(rng, tsize...), 4) + gradtest((args...) -> NNlib.norm_helper(args..., size(ps[1])), x, ps..., 1e-5) + end + + p = ones(16, 16) + @test_throws ErrorException NNlib.norm_helper(x, p, p, nothing, p, 1e-5) + @test_throws ErrorException NNlib.norm_helper(x, p, p, p, nothing, 1e-5) +end + +@testset "Layer Norm" begin + full_size = (2, 3, 4, 5) + @testset for xdims in 2:4, kdims in 1:(xdims - 1) + x = rand(rng, full_size[1:xdims]...) + bias, scale = ntuple(_ -> rand(rng, full_size[1:kdims]...), 2) + dims = Val(kdims) + + y = @inferred NNlib.layernorm(x, dims) + @test size(y) == size(x) + y = @inferred NNlib.layernorm(x, dims, scale, bias) + @test size(y) == size(x) + + # FiniteDifferences gives incorrect results on some but not all args, why? + gradtest(x -> NNlib.layernorm(x, dims), x; broken = true) + gradtest((x, s, b) -> NNlib.layernorm(x, dims, s, b), x, scale, bias; skip = true) + end +end + +@testset "Batch Norm" begin + let x = [1.0 3.0 5.0; 2.0 4.0 6.0], bias = zeros(2), scale = ones(2) + @testset for use_stats in (true, false) + stats = use_stats ? NNlib.RunningStats(zeros(2), ones(2), 0.1) : nothing + y, back = Zygote.pullback(NNlib.batchnorm, x, stats, scale, bias, 1e-5) + @test y≈[-1.22474 0 1.22474; -1.22474 0 1.22474] atol=1e-5 + + expected_mean, expected_var = [0.3, 0.4], [1.3, 1.3] + if use_stats + # μ of batch will be + # (1. + 3. + 5.) / 3 = 3 + # (2. + 4. + 6.) / 3 = 4 + # + # ∴ update rule with momentum: + # .1 * 3 + 0 = .3 + # .1 * 4 + 0 = .4 + @test stats.mean ≈ expected_mean + # σ² of batch will be + # sum(abs2, [1., 3., 5.] .- 3) / 2 = 2.6... + # sum(abs2, [2., 4., 6.] .- 4) / 2 = 2.6... + # + # ∴ update rule with momentum: + # .1 * (3 / (3 - 1)) * 2.6 + (1 - .1) * 1 = 1.3 + @test stats.variance ≈ expected_var + end + + dx, dstats, dscale, dbias, _ = back(fill!(similar(y), 1)) + @test dx≈[3.06186 0.612371 -1.83711; 3.06186 0.612371 -1.83711] atol=1e-5 + @test dscale == zeros(2) + @test dbias == fill(3.0, 2) + @test dstats === nothing + + if use_stats + tmp_mean, tmp_var = copy(stats.mean), copy(stats.variance) + x′ = @inferred NNlib.batchnorm(x, stats, scale, bias, 1e-5) + @test x′[1]≈((1 - expected_mean[1]) / sqrt(expected_var[1])) atol=1e-5 + # Stats should be unchanged + @test stats.mean == tmp_mean + @test stats.variance == tmp_var + end + end + end + + let x = f32_arange(3, 2, 1), m = make_bn(2) + y = reshape(permutedims(x, [2, 1, 3]), 2, :) + y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3]) + @test m(x) == y + @inferred m(x) + end + + let x = f32_arange(2, 3, 2, 1), m = make_bn(2) + y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :) + y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4]) + @test m(x) == y + @inferred m(x) + end + + let x = f32_arange(2, 2, 3, 2, 1), m = make_bn(2) + y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :) + y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5]) + @test m(x) == y + @inferred m(x) + end + + let x = randn(Float32, 416, 416, 32, 1), m = make_bn(32; training = false) + @test (@allocated m(x)) < 100_000_000 + end +end + +@testset "Instance Norm" begin + let x = reshape(1.0:12.0, 3, 2, 2), bias = zeros(2), scale = ones(2) + @testset for use_stats in (true, false) + stats = use_stats ? NNlib.RunningStats(zeros(2), ones(2), 0.1) : nothing + y, back = Zygote.pullback(NNlib.instancenorm, x, stats, scale, bias, 1e-5) + @test y≈[-1.22474 -1.22474; 0.0 0.0; 1.22474 1.22474;;; + -1.22474 -1.22474; 0.0 0.0; 1.22474 1.22474] rtol=1e-5 + + expected_mean, expected_var = [0.5, 0.8], [1.0, 1.0] + if use_stats + # μ will be + # (1. + 2. + 3.) / 3 = 2. + # (4. + 5. + 6.) / 3 = 5. + # (7. + 8. + 9.) / 3 = 8. + # (10. + 11. + 12.) / 3 = 11. + # + # ∴ update rule with momentum: + # .1 * (2. + 8.) / 2 + 0 = .5 + # .1 * (5. + 11.) / 2 + 0 = .8 + @test stats.mean ≈ expected_mean + # σ² will be + # sum(abs2, [1. + 2. + 3.] .- 2) / 3 = 2.6... + # sum(abs2, [4. + 5. + 6.] .- 5) / 3 = 2.6... + # sum(abs2, [7. + 8. + 9.] .- 8) / 3 = 2.6... + # sum(abs2, [10. + 11. + 12.] .- 11) / 3 = 2.6... + # + # ∴ update rule with momentum: + # .1 * (3 / (3 - 1)) * 2.6... + (1 - .1) * 1 = 1. + @test stats.variance ≈ expected_var + end + + dx, dstats, dscale, dbias, _ = back(fill!(similar(y), 1)) + @test dx≈[3.6742 3.6742; 1.22474 1.22474; -1.22474 -1.22474;;; + 3.6742 3.6742; 1.22474 1.22474; -1.22474 -1.22474] rtol=1e-5 + @test dscale == zeros(2) + @test dbias == fill(6.0, 2) + @test dstats === nothing + + if use_stats + tmp_mean, tmp_var = copy(stats.mean), copy(stats.variance) + x′ = @inferred NNlib.instancenorm(x, stats, scale, bias, 1e-5) + @test x′[1]≈((1 - expected_mean[1]) / sqrt(expected_var[1])) atol=1e-5 + # Stats should be unchanged + @test stats.mean == tmp_mean + @test stats.variance == tmp_var + end + end + end + + let m = make_in(2), shape = (2, 4, 1, 2, 3), x = f32_arange(shape...) + y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) + y = reshape(m(y), shape...) + @test m(x) == y + @inferred m(x) + end + + # Instance norm == batch norm when channel and batch dims are squashed + let m_inorm = make_in(2; training = true), m_bnorm = make_bn(12; training = true), + shape = (5, 5, 3, 4, 2, 6), x = f32_arange(shape...) + + x′ = reshape(x, (shape[1:(end - 2)]..., :, 1)) + @test m_inorm(x) == reshape(m_bnorm(x′), shape) + end + + let m = make_in(32), x = randn(Float32, 416, 416, 32, 1) + @test (@allocated m(x)) < 100_000_000 + end +end + +@testset "Group Norm" begin + full_size = (2, 3, 6, 5) + @testset for xdims in 1:3, groups in (1, 2, 3, 6) + x = rand(rng, full_size[(end - xdims):end]...) + bias, scale = ntuple(_ -> rand(rng, full_size[end - 1]), 2) + + y = @inferred NNlib.groupnorm(x, groups) + @test size(y) == size(x) + y = @inferred NNlib.groupnorm(x, groups, scale, bias) + @test size(y) == size(x) + + # FiniteDifferences gives incorrect results on some but not all args, why? + gradtest(x -> NNlib.groupnorm(x, groups), x; broken = true) + gradtest((x, s, b) -> NNlib.groupnorm(x, groups, s, b), x, scale, bias; skip = true) + end + + let m = make_gn(4, 2), shape = (5, 5, 3, 4, 4, 6) + y = Zygote.pullback(m, f32_arange(shape...))[1] + @test size(y) == shape + end + + let m = make_gn(2, 2), shape = (2, 4, 1, 2, 3), x = f32_arange(shape...) + y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) + y = reshape(m(y), shape...) + @test m(x) == y + end + + # Group norm == instance norm when the group size == number of channels + let m_inorm = make_in(4), m_gnorm = make_gn(4, 4), x = f32_arange(2, 2, 3, 4, 5) + @test m_inorm(x) ≈ m_gnorm(x) + end + + # Group norm == batch norm for a group of size 1 and batch of size 1 + let m_bnorm = make_bn(4), m_gnorm = make_gn(4, 4), x = f32_arange(2, 2, 3, 4, 1) + @test m_bnorm(x) ≈ m_gnorm(x) + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 8b359ad87..4d0fc8bba 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -159,4 +159,8 @@ end else @info "Skipping AMDGPU tests, set NNLIB_TEST_AMDGPU=true to run them." end + + @testset "Normalization" begin + include("normalization.jl") + end end From e0e61dd264d76036b9083cbf603cb5e6e7fa1da8 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Mon, 2 Jan 2023 16:35:48 -0800 Subject: [PATCH 2/6] include in docs --- docs/src/reference.md | 16 ++++++++++++++++ src/normalization.jl | 28 ++++++++++++++-------------- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/docs/src/reference.md b/docs/src/reference.md index c01db6b24..ba9cc751c 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -95,6 +95,22 @@ NNlib.unfold NNlib.fold ``` +## Normalization + +These roughly correspond to Flux's `*Norm` layers. + + +```@docs +NNlib.layernorm +NNlib.batchnorm +NNlib.instancenorm +NNlib.groupnorm +NNlib.norm_stats +NNlib.norm_helper +NNlib.RunningStats +NNlib.update_running_stats! +``` + ## Upsampling `Flux`'s `Upsample` layer uses `NNlib.upsample_nearest`, `NNlib.upsample_bilinear`, and `NNlib.upsample_trilinear` as its backend. Additionally, `Flux`'s `PixelShuffle` layer uses `NNlib.pixel_shuffle` as its backend. diff --git a/src/normalization.jl b/src/normalization.jl index 593dd14b8..4c1f10bd7 100644 --- a/src/normalization.jl +++ b/src/normalization.jl @@ -18,9 +18,9 @@ end Calculates sample mean and (uncorrected) variance of `x` along `dims`. - - `dims=(1,...,N-2,N)` for BatchNorm - - `dims=(1,...,N-2)` for InstanceNorm and GroupNorm - - `dims=(1,...,S)` where S < N for LayerNorm/Flux.jl/stable/ + - `dims=(1,...,N-2,N)` for batchnorm + - `dims=(1,...,N-2)` for instancenorm and groupnorm + - `dims=(1,...,S)` where S < N for layernorm This is more efficient than calling `mean(x; dims)` and `var(x; dims)` separately, because it can share some computation across both. @@ -54,8 +54,8 @@ _apply_scale_bias(x, scale, bias) = x .* scale .+ bias Shared code path for all built-in norm functions. -`μ` and `σ²` should be calculated on the fly using [`norm_stats`](@ref), -or extracted from an existing collection such as [`RunningStats`](@ref). +`μ` and `σ²` should be calculated on the fly using [`NNlib.norm_stats`](@ref), +or extracted from an existing collection such as [`NNlib.RunningStats`](@ref). `bias` and `scale` are consistent with cuDNN and Flux.Scale. We opt for `scale` over `weight` to avoid confusion with dense layers. If the size of the statistics and affine parameters differ, @@ -79,7 +79,7 @@ Contains running mean and variance estimates for stateful norm functions. If the parameters are mutable, they will be updated in-place. Otherwise, they will be replaced wholesale. -See also [`update_running_stats!`](@ref). +See also [`NNlib.update_running_stats!`](@ref). """ mutable struct RunningStats{M <: AbstractArray, V <: AbstractArray, MT <: Real} mean::M @@ -129,10 +129,10 @@ end reduce_dims) where {N} Performs a moving average update for layers with tracked statistics. -`μ` and `σ²` are the sample mean and variance, most likely from [`norm_stats`](@ref). -`reduce_dims` should also match the `dims` argument of [`norm_stats`](@ref). +`μ` and `σ²` are the sample mean and variance, most likely from [`NNlib.norm_stats`](@ref). +`reduce_dims` should also match the `dims` argument of [`NNlib.norm_stats`](@ref). -See also [`RunningStats`](@ref). +See also [`NNlib.RunningStats`](@ref). """ function update_running_stats!(stats::RunningStats, x, μ, σ², reduce_dims::Dims) V = eltype(σ²) @@ -168,7 +168,7 @@ Normalizes `x` along the first `S` dimensions. For an additional learned affine transform, provide a `S`-dimensional `scale` and `bias`. -See also [`batchnorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref). +See also [`NNlib.batchnorm`](@ref), [`NNlib.instancenorm`](@ref), and [`NNlib.groupnorm`](@ref). # Examples @@ -205,14 +205,14 @@ Functional [Batch Normalization](https://arxiv.org/abs/1502.03167) operation. Normalizes `x` along each ``D_1×...×D_{N-2}×1×D_N`` input slice, where `N-1` is the "channel" (or "feature", for 2D inputs) dimension. -Provide a [`RunningStats`](@ref) to fix a estimated mean and variance. +Provide a [`NNlib.RunningStats`](@ref) to fix a estimated mean and variance. `batchnorm` will renormalize the input using these statistics during inference, and update them using batch-level statistics when training. To override this behaviour, manually set a value for `training`. If specified, `scale` and `bias` will be applied as an additional learned affine transform. -See also [`layernorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref). +See also [`NNlib.layernorm`](@ref), [`NNlib.instancenorm`](@ref), and [`NNlib.groupnorm`](@ref). """ function batchnorm(x::AbstractArray{<:Any, N}, running_stats::Union{RunningStats, Nothing} = nothing, @@ -247,7 +247,7 @@ To override this behaviour, manually set a value for `training`. If specified, `scale` and `bias` will be applied as an additional learned affine transform. -See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`groupnorm`](@ref). +See also [`NNlib.layernorm`](@ref), [`NNlib.batchnorm`](@ref), and [`NNlib.groupnorm`](@ref). """ function instancenorm(x::AbstractArray{<:Any, N}, running_stats::Union{RunningStats, Nothing} = nothing, @@ -281,7 +281,7 @@ The number of channels must be an integer multiple of the number of groups. If specified, `scale` and `bias` will be applied as an additional learned affine transform. -See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`instancenorm`](@ref). +See also [`NNlib.layernorm`](@ref), [`NNlib.batchnorm`](@ref), and [`NNlib.instancenorm`](@ref). # Examples From a9dc138df89b204f584b017207fd1355ad74d887 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Mon, 2 Jan 2023 19:06:52 -0800 Subject: [PATCH 3/6] fix CI on 1.6 and MacOS --- src/normalization.jl | 22 +++++++++++----------- test/normalization.jl | 16 +++++++++------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/normalization.jl b/src/normalization.jl index 4c1f10bd7..5eec7f603 100644 --- a/src/normalization.jl +++ b/src/normalization.jl @@ -54,8 +54,8 @@ _apply_scale_bias(x, scale, bias) = x .* scale .+ bias Shared code path for all built-in norm functions. -`μ` and `σ²` should be calculated on the fly using [`NNlib.norm_stats`](@ref), -or extracted from an existing collection such as [`NNlib.RunningStats`](@ref). +`μ` and `σ²` should be calculated on the fly using [`norm_stats`](@ref), +or extracted from an existing collection such as [`RunningStats`](@ref). `bias` and `scale` are consistent with cuDNN and Flux.Scale. We opt for `scale` over `weight` to avoid confusion with dense layers. If the size of the statistics and affine parameters differ, @@ -79,7 +79,7 @@ Contains running mean and variance estimates for stateful norm functions. If the parameters are mutable, they will be updated in-place. Otherwise, they will be replaced wholesale. -See also [`NNlib.update_running_stats!`](@ref). +See also [`update_running_stats!`](@ref). """ mutable struct RunningStats{M <: AbstractArray, V <: AbstractArray, MT <: Real} mean::M @@ -129,10 +129,10 @@ end reduce_dims) where {N} Performs a moving average update for layers with tracked statistics. -`μ` and `σ²` are the sample mean and variance, most likely from [`NNlib.norm_stats`](@ref). -`reduce_dims` should also match the `dims` argument of [`NNlib.norm_stats`](@ref). +`μ` and `σ²` are the sample mean and variance, most likely from [`norm_stats`](@ref). +`reduce_dims` should also match the `dims` argument of [`norm_stats`](@ref). -See also [`NNlib.RunningStats`](@ref). +See also [`RunningStats`](@ref). """ function update_running_stats!(stats::RunningStats, x, μ, σ², reduce_dims::Dims) V = eltype(σ²) @@ -168,7 +168,7 @@ Normalizes `x` along the first `S` dimensions. For an additional learned affine transform, provide a `S`-dimensional `scale` and `bias`. -See also [`NNlib.batchnorm`](@ref), [`NNlib.instancenorm`](@ref), and [`NNlib.groupnorm`](@ref). +See also [`batchnorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref). # Examples @@ -205,14 +205,14 @@ Functional [Batch Normalization](https://arxiv.org/abs/1502.03167) operation. Normalizes `x` along each ``D_1×...×D_{N-2}×1×D_N`` input slice, where `N-1` is the "channel" (or "feature", for 2D inputs) dimension. -Provide a [`NNlib.RunningStats`](@ref) to fix a estimated mean and variance. +Provide a [`RunningStats`](@ref) to fix a estimated mean and variance. `batchnorm` will renormalize the input using these statistics during inference, and update them using batch-level statistics when training. To override this behaviour, manually set a value for `training`. If specified, `scale` and `bias` will be applied as an additional learned affine transform. -See also [`NNlib.layernorm`](@ref), [`NNlib.instancenorm`](@ref), and [`NNlib.groupnorm`](@ref). +See also [`layernorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref). """ function batchnorm(x::AbstractArray{<:Any, N}, running_stats::Union{RunningStats, Nothing} = nothing, @@ -247,7 +247,7 @@ To override this behaviour, manually set a value for `training`. If specified, `scale` and `bias` will be applied as an additional learned affine transform. -See also [`NNlib.layernorm`](@ref), [`NNlib.batchnorm`](@ref), and [`NNlib.groupnorm`](@ref). +See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`groupnorm`](@ref). """ function instancenorm(x::AbstractArray{<:Any, N}, running_stats::Union{RunningStats, Nothing} = nothing, @@ -281,7 +281,7 @@ The number of channels must be an integer multiple of the number of groups. If specified, `scale` and `bias` will be applied as an additional learned affine transform. -See also [`NNlib.layernorm`](@ref), [`NNlib.batchnorm`](@ref), and [`NNlib.instancenorm`](@ref). +See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`instancenorm`](@ref). # Examples diff --git a/test/normalization.jl b/test/normalization.jl index 128610297..b38b1ca71 100644 --- a/test/normalization.jl +++ b/test/normalization.jl @@ -35,8 +35,10 @@ end # Group/InstanceNorm dimensions let W = 128, C = 2, N = 2, shape = (W, W, 1, 1) - x = [randn_sample(shape, 1, 1);;; randn_sample(shape, 2, 2);;;; - randn_sample(shape, 3, 3);;; randn_sample(shape, 4, 4)] + # Tile to W x W x 2 x 2 + x = cat(cat(randn_sample(shape, 1, 1), randn_sample(shape, 2, 2); dims = 3), + cat(randn_sample(shape, 3, 3), randn_sample(shape, 4, 4); dims = 3); + dims = 4) μ, σ² = NNlib.norm_stats(x, (1, 2)) @test vec(μ)≈1:(C * N) rtol=0.05 @test vec(σ²)≈abs2.(1:(C * N)) rtol=0.05 @@ -60,7 +62,9 @@ end (running_stats, true, y_ns, y_ns, dx_ns), (running_stats, false, meanvar, meanvar, NoTangent()), ] - @test NNlib.maybe_norm_stats(stats, x, dims, !training) == y + ŷ = NNlib.maybe_norm_stats(stats, x, dims, !training) + @test ŷ[1]≈y[1] rtol=1e-5 + @test ŷ[2]≈y[2] rtol=1e-5 ŷ, back = rrule(NNlib.maybe_norm_stats, stats, x, dims, !training) @test ŷ == y_ad @test back(meanvar) == (NoTangent(), NoTangent(), dx, NoTangent(), NoTangent()) @@ -170,8 +174,7 @@ end @testset for use_stats in (true, false) stats = use_stats ? NNlib.RunningStats(zeros(2), ones(2), 0.1) : nothing y, back = Zygote.pullback(NNlib.instancenorm, x, stats, scale, bias, 1e-5) - @test y≈[-1.22474 -1.22474; 0.0 0.0; 1.22474 1.22474;;; - -1.22474 -1.22474; 0.0 0.0; 1.22474 1.22474] rtol=1e-5 + @test y≈repeat([-1.22474, 0.0, 1.22474], 1, 2, 2) rtol=1e-5 expected_mean, expected_var = [0.5, 0.8], [1.0, 1.0] if use_stats @@ -197,8 +200,7 @@ end end dx, dstats, dscale, dbias, _ = back(fill!(similar(y), 1)) - @test dx≈[3.6742 3.6742; 1.22474 1.22474; -1.22474 -1.22474;;; - 3.6742 3.6742; 1.22474 1.22474; -1.22474 -1.22474] rtol=1e-5 + @test dx≈repeat([3.6742, 1.22474, -1.22474], 1, 2, 2) rtol=1e-5 @test dscale == zeros(2) @test dbias == fill(6.0, 2) @test dstats === nothing From 5b49a64397426db99b8faadba007785f96697ece Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Mon, 2 Jan 2023 21:08:22 -0800 Subject: [PATCH 4/6] Simplify RunningStats, faster var calculation and try fixing 1.6 inference --- src/normalization.jl | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/src/normalization.jl b/src/normalization.jl index 5eec7f603..f49c161c7 100644 --- a/src/normalization.jl +++ b/src/normalization.jl @@ -67,7 +67,8 @@ function norm_helper(x, μ, σ², scale::Union{AbstractArray, Nothing}, error("both scale and bias must be provided or left as nothing") end scale′, bias′ = _maybe_reshape(scale, affine_size), _maybe_reshape(bias, affine_size) - return _apply_scale_bias((x .- μ) ./ sqrt.(σ² .+ ϵ), scale′, bias′) + denom = inv.(sqrt.(σ² .+ ϵ)) + return _apply_scale_bias((x .- μ) .* denom, scale′, bias′) end """ @@ -76,12 +77,11 @@ end Contains running mean and variance estimates for stateful norm functions. `momentum` controls the strength of the moving average update. -If the parameters are mutable, they will be updated in-place. -Otherwise, they will be replaced wholesale. +Parameters should be mutable and will be updated in-place. See also [`update_running_stats!`](@ref). """ -mutable struct RunningStats{M <: AbstractArray, V <: AbstractArray, MT <: Real} +struct RunningStats{M <: AbstractArray, V <: AbstractArray, MT <: Real} mean::M variance::V momentum::MT @@ -142,16 +142,9 @@ function update_running_stats!(stats::RunningStats, x, μ, σ², reduce_dims::Di correction = m / (m - one(V)) running_mean, running_var = stats.mean, stats.variance - if ChainRulesCore.is_inplaceable_destination(running_mean) - stats.mean .= res_mtm .* running_mean .+ momentum .* vec(μ) - else - stats.mean = res_mtm .* running_mean .+ momentum .* vec(μ) - end - if ChainRulesCore.is_inplaceable_destination(running_var) - stats.variance .= res_mtm .* running_var .+ momentum .* correction .* vec(σ²) - else - stats.variance = res_mtm .* running_var .+ momentum .* correction .* vec(σ²) - end + stats.mean .= res_mtm .* running_mean .+ momentum .* vec(μ) + stats.variance .= res_mtm .* running_var .+ momentum .* correction .* vec(σ²) + return end # Convenience functions @@ -190,7 +183,7 @@ function layernorm(x::AbstractArray{<:Any, N}, ::Val{S}, scale = nothing, bias = throw(DimensionMismatch("got $S reduction dims for $N-dimensional array")) end μ, σ² = norm_stats(x, ntuple(identity, S)) - return norm_helper(x, μ, σ², scale, bias, ϵ, size(x)[1:S]) + return norm_helper(x, μ, σ², scale, bias, ϵ, size(x)[1:S]::Dims{S}) end """ From 2ce9c42c320c1e7e3d604b2eaf8cf9d1baad133c Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Mon, 2 Jan 2023 21:23:34 -0800 Subject: [PATCH 5/6] set init=0 in stats update --- src/normalization.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/normalization.jl b/src/normalization.jl index f49c161c7..1041d5c0d 100644 --- a/src/normalization.jl +++ b/src/normalization.jl @@ -138,7 +138,7 @@ function update_running_stats!(stats::RunningStats, x, μ, σ², reduce_dims::Di V = eltype(σ²) momentum = stats.momentum res_mtm = one(V) - momentum - m = prod(size(x, i) for i in reduce_dims) + m = prod(size(x, i) for i in reduce_dims; init = 1) correction = m / (m - one(V)) running_mean, running_var = stats.mean, stats.variance From 664c97c0049f65b9e595bedb8f94ab3661c9a684 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Sun, 1 Oct 2023 20:48:17 -0700 Subject: [PATCH 6/6] fix tests first --- src/normalization.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/normalization.jl b/src/normalization.jl index 1041d5c0d..2b6e19727 100644 --- a/src/normalization.jl +++ b/src/normalization.jl @@ -191,7 +191,7 @@ end running_stats::Union{RunningStats, Nothing} = nothing, scale::Union{AbstractVector, Nothing} = nothing, bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5); - training::Bool = within_grad()) where {N} + training::Bool) where {N} Functional [Batch Normalization](https://arxiv.org/abs/1502.03167) operation. @@ -211,7 +211,7 @@ function batchnorm(x::AbstractArray{<:Any, N}, running_stats::Union{RunningStats, Nothing} = nothing, scale::Union{AbstractVector, Nothing} = nothing, bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5); - training::Bool = within_grad()) where {N} + training::Bool = within_gradient(x)) where {N} reduce_dims = ((1:(N - 2))..., N) μ, σ² = maybe_norm_stats(running_stats, x, reduce_dims, !training) # Because μ and σ² could be updated in-place, we compute the output first @@ -227,7 +227,7 @@ end running_stats::Union{RunningStats, Nothing} = nothing, scale::Union{AbstractVector, Nothing} = nothing, bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5); - training::Bool = within_grad()) where {N} + training::Bool)) where {N} Functional [Instance Normalization](https://arxiv.org/abs/1607.08022) operation. @@ -246,7 +246,7 @@ function instancenorm(x::AbstractArray{<:Any, N}, running_stats::Union{RunningStats, Nothing} = nothing, scale::Union{AbstractVector, Nothing} = nothing, bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5); - training::Bool = within_grad()) where {N} + training::Bool = within_gradient(x)) where {N} affine_size = (ntuple(_ -> 1, N - 2)..., size(x, N - 1), :) reduce_dims = ((1:(N - 2))...,) μ, σ² = maybe_norm_stats(running_stats, x, reduce_dims, !training)