From 7f42668c65fec757545ce450d59cbb7b175ee7dc Mon Sep 17 00:00:00 2001 From: Casper Date: Fri, 15 Apr 2022 16:54:57 +0200 Subject: [PATCH 1/9] Initial commit --- src/layers/pool.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 88ed3ddd0..2f9e4ca6e 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -44,6 +44,25 @@ end (l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata=l(g, node_features(g))) +@doc raw""" + GlobalConcatPool(aggr) + +```math +\mathbf{x}_i' = [\mathbf{x}_i; \square_{i \in V} \mathbf{x}_i] + +""" +struct GlobalConcatPool{F} <: GNNLayer + aggr::F +end + +function (l::GlobalConcatPool)(g::GNNGraph, x::AbstractArray) + g_feat = reduce_nodes(l.aggr, g, x) + feat_arr = gather(g_feat, graph_indicator(g)) + return vcat(x, feat_arr) +end + +(l::GlobalConcatPool)(g::GNNGraph) = GNNGraph(g, gdata=l(g, node_features(g))) + @doc raw""" GlobalAttentionPool(fgate, ffeat=identity) From 4612532c87190611c76bb0376675e44d562a417b Mon Sep 17 00:00:00 2001 From: Casper Date: Fri, 15 Apr 2022 17:21:02 +0200 Subject: [PATCH 2/9] use broadcast nodes --- src/layers/pool.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 2f9e4ca6e..2fecf3fff 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -57,7 +57,7 @@ end function (l::GlobalConcatPool)(g::GNNGraph, x::AbstractArray) g_feat = reduce_nodes(l.aggr, g, x) - feat_arr = gather(g_feat, graph_indicator(g)) + feat_arr = broadcast_nodes(g, g_feat) return vcat(x, feat_arr) end From 871275f3079aac8da309e2092706279aad9a9be9 Mon Sep 17 00:00:00 2001 From: casper2002casper Date: Sat, 16 Apr 2022 10:35:58 +0200 Subject: [PATCH 3/9] Update src/layers/pool.jl Co-authored-by: Carlo Lucibello --- src/layers/pool.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 2fecf3fff..67f4e58f7 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -48,8 +48,7 @@ end GlobalConcatPool(aggr) ```math -\mathbf{x}_i' = [\mathbf{x}_i; \square_{i \in V} \mathbf{x}_i] - +\mathbf{x}_i' = [\mathbf{x}_i; \square_{j \in V} \mathbf{x}_j] """ struct GlobalConcatPool{F} <: GNNLayer aggr::F From 3d211c16331eac8cbacece5c87c82bd415f6f892 Mon Sep 17 00:00:00 2001 From: Casper Date: Sat, 16 Apr 2022 12:52:16 +0200 Subject: [PATCH 4/9] Add to exports --- src/GraphNeuralNetworks.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index 86bab1896..011763160 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -64,6 +64,7 @@ export SGConv, # layers/pool + ConcatPool, GlobalPool, GlobalAttentionPool, TopKPool, From 12fbc3ff2c98b7432a069307aa574fcb69804583 Mon Sep 17 00:00:00 2001 From: Casper Date: Sat, 16 Apr 2022 12:52:56 +0200 Subject: [PATCH 5/9] Generalize --- src/layers/pool.jl | 54 ++++++++++++++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 67f4e58f7..c9c6e36c8 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -1,5 +1,40 @@ using DataStructures: nlargest +@doc raw""" + ConcatPool(pooling_layer) + +```math +\mathbf{x}_i' = [\mathbf{x}_i; \mathbf{u}_V] +``` + +# Arguments + +- `pooling_layer`: + +# Examples + +```julia +using Flux, GraphNeuralNetworks, Graphs + +add_pool = ConcatPool(GlobalPool(mean)) + +g = GNNGraph(rand_graph(10, 4)) +X = rand(32, 10) +pool(g, X) # => 64x10 matrix +``` +""" +struct ConcatPool <: GNNLayer + pool::GNNLayer +end + +function (l::ConcatPool)(g::GNNGraph, x::AbstractArray) + g_feat = applylayer(l.pool, g, x) + feat_arr = broadcast_nodes(g, g_feat) + return vcat(x, feat_arr) +end + +(l::ConcatPool)(g::GNNGraph) = GNNGraph(g, gdata=l(g, node_features(g))) + @doc raw""" GlobalPool(aggr) @@ -44,25 +79,6 @@ end (l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata=l(g, node_features(g))) -@doc raw""" - GlobalConcatPool(aggr) - -```math -\mathbf{x}_i' = [\mathbf{x}_i; \square_{j \in V} \mathbf{x}_j] -""" -struct GlobalConcatPool{F} <: GNNLayer - aggr::F -end - -function (l::GlobalConcatPool)(g::GNNGraph, x::AbstractArray) - g_feat = reduce_nodes(l.aggr, g, x) - feat_arr = broadcast_nodes(g, g_feat) - return vcat(x, feat_arr) -end - -(l::GlobalConcatPool)(g::GNNGraph) = GNNGraph(g, gdata=l(g, node_features(g))) - - @doc raw""" GlobalAttentionPool(fgate, ffeat=identity) From 3b8e42439e1a8e87cfffee2f54de6018f738f05c Mon Sep 17 00:00:00 2001 From: Casper Date: Sat, 16 Apr 2022 12:53:06 +0200 Subject: [PATCH 6/9] Add tests --- test/layers/pool.jl | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/test/layers/pool.jl b/test/layers/pool.jl index f7bb74a83..7b8eeb015 100644 --- a/test/layers/pool.jl +++ b/test/layers/pool.jl @@ -21,6 +21,35 @@ test_layer(p, g, rtol=1e-5, exclude_grad_fields = [:aggr], outtype=:graph) end + @testset "ConcatPool" begin + p = GlobalPool(+) + l = ConcatPool(p) + n = 10 + chin = 6 + X = rand(Float32, chin, n) + g = GNNGraph(random_regular_graph(n, 4), ndata=X, graph_type=GRAPH_T) + y = p(g, X) + u = l(g, X) + + @test size(u) == (chin*2, n) + @test u[1:chin,:] ≈ X + @test u[chin+1:end,:] ≈ repeat(p, 1, n) + + n = [1, 2, 3] + ng = length(n) + g = Flux.batch([GNNGraph(random_regular_graph(n, 4), + ndata=rand(Float32, chin, n[i]), + graph_type=GRAPH_T) + for i=1:ng]) + y = p(g, g.ndata.x) + u = l(g, g.ndata.x) + @test size(u) == (chin*2, sum(n)) + @test u[1:chin,:] ≈ g.ndata.x + @test u[chin+1:end,:] ≈ hcat([repeat(y[:,i], 1, n[i]) for i=1:ng]...) + + test_layer(p, g, rtol=1e-5, exclude_grad_fields = [:aggr], outtype=:graph) + end + @testset "GlobalAttentionPool" begin n = 10 chin = 6 From 0e520e4400b0f225f977d8b1920d3ae88f100bc8 Mon Sep 17 00:00:00 2001 From: Casper Date: Sat, 16 Apr 2022 12:54:10 +0200 Subject: [PATCH 7/9] Fix test --- test/layers/pool.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/layers/pool.jl b/test/layers/pool.jl index 7b8eeb015..9bf18c896 100644 --- a/test/layers/pool.jl +++ b/test/layers/pool.jl @@ -37,7 +37,7 @@ n = [1, 2, 3] ng = length(n) - g = Flux.batch([GNNGraph(random_regular_graph(n, 4), + g = Flux.batch([GNNGraph(random_regular_graph(n[i], 4), ndata=rand(Float32, chin, n[i]), graph_type=GRAPH_T) for i=1:ng]) From 16b03c96c94e1941e50a867b7ccfb167a1e73e76 Mon Sep 17 00:00:00 2001 From: Casper Date: Sat, 16 Apr 2022 13:20:29 +0200 Subject: [PATCH 8/9] Use correct graph size test --- test/layers/pool.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/layers/pool.jl b/test/layers/pool.jl index 9bf18c896..175cfbfe4 100644 --- a/test/layers/pool.jl +++ b/test/layers/pool.jl @@ -35,7 +35,7 @@ @test u[1:chin,:] ≈ X @test u[chin+1:end,:] ≈ repeat(p, 1, n) - n = [1, 2, 3] + n = [5, 6, 7] ng = length(n) g = Flux.batch([GNNGraph(random_regular_graph(n[i], 4), ndata=rand(Float32, chin, n[i]), From 866e41616e273f34a5b08fc7f3b968b8c134716d Mon Sep 17 00:00:00 2001 From: Casper Date: Sat, 16 Apr 2022 13:53:42 +0200 Subject: [PATCH 9/9] last fix --- test/layers/pool.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/layers/pool.jl b/test/layers/pool.jl index 175cfbfe4..6d84b01d4 100644 --- a/test/layers/pool.jl +++ b/test/layers/pool.jl @@ -33,7 +33,7 @@ @test size(u) == (chin*2, n) @test u[1:chin,:] ≈ X - @test u[chin+1:end,:] ≈ repeat(p, 1, n) + @test u[chin+1:end,:] ≈ repeat(y, 1, n) n = [5, 6, 7] ng = length(n)