-
Notifications
You must be signed in to change notification settings - Fork 69
Expand file tree
/
Copy pathnormalise.jl
More file actions
78 lines (59 loc) · 1.91 KB
/
normalise.jl
File metadata and controls
78 lines (59 loc) · 1.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Implementation of normalization layers for GraphNeuralNetworks
@doc raw"""
PairNorm(scale_value; [scale_individually])
PairNorm layer from paper [PairNorm: Tackling Oversmoothing in GNNs](https://arxiv.org/abs/1909.12223)
Performs the operation(normalization)
```math
\mathbf{x}_i^c &= \mathbf{x}_i - \frac{1}{n}
\sum_{i=1}^n \mathbf{x}_i \\
\mathbf{x}_i^{\prime} &= s \cdot
\frac{\mathbf{x}_i^c}{\sqrt{\frac{1}{n} \sum_{i=1}^n
{\| \mathbf{x}_i^c \|}^2_2}}
```
The input to this layer is the output from GNN layers
# Arguments
- `scale_value`: Scaling factor `s` used in normalisation. Default `1.0`
- `scale_individually`: If set to `true`, will compute the scaling step as
```math
\mathbf{x}^{\prime}_i = s \cdot
\frac{\mathbf{x}_i^c}{{\| \mathbf{x}_i^c \|}_2}
```
Default `false`
- `ϵ` : Small value added in the denominator for numerical stability. Default `1f-5`
# Examples
```julia
# create data
s = [1,1,2,3]
t = [2,3,1,1]
g = GNNGraph(s, t)
x = randn(Float32, 3, g.num_nodes)
scale_value = 1.0
# create layer
l = GCNConv(3 => 5)
pn = PairNorm(scale_value)
# forward pass of GCN
y = l(g, x) # size: 5 × num_nodes
# forward pass of PairNorm
ȳ = pn(y)
```
"""
struct PairNorm{V, N}
scale_value::V
ϵ::N
scale_individually::Bool
end
@functor PairNorm
function PairNorm(scale_value::Real=1.0f0; scale_individually::Bool=false, eps::Real=1f-5, ϵ=nothing)
ε = _greek_ascii_depwarn(ϵ => eps, :BatchNorm, "ϵ" => "eps")
return PairNorm(scale_value, ε, scale_individually)
end
function (PN::PairNorm)(x::AbstractArray)
xm = mean(x, dims=1)
x = x .- xm
if PN.scale_individually
return (PN.scale_value .* x) ./ (PN.ϵ .+ [norm(x[i,:]) for i in axes(x,1)])
else
return (PN.scale_value .* x) ./ (PN.ϵ + √(mean(sum(x.^2, dims=2))))
end
end
Base.show(io::IO, pn::PairNorm) = print(io, "PairNorm(", pn.scale_value, ")")