Skip to content

experiments with precompile and invalidations #274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -40,6 +41,7 @@ NNlib = "0.8"
NNlibCUDA = "0.2"
NearestNeighbors = "0.4"
Reexport = "1"
SnoopPrecompile = "1"
StatsBase = "0.33"
julia = "1.7"

Expand Down
55 changes: 55 additions & 0 deletions invalidations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using SnoopCompileCore

invalidations = @snoopr begin
using GraphNeuralNetworks
using Flux
# using CUDA
# using Graphs
# using Random, Statistics, LinearAlgebra
end

function workload()
num_graphs = 3
gs = [rand_graph(5, 10) for _ in 1:num_graphs]
g = Flux.batch(gs)
x = rand(Float32, 4, g.num_nodes)
model = GNNChain(GCNConv(4 => 4, relu),
GCNConv(4 => 4),
GlobalPool(max),
Dense(4, 1))
y = model(g, x)
# @assert size(y) == (1, num_graphs)
end

tinf = @snoopi_deep begin
workload()
end

using SnoopCompile
trees = invalidation_trees(invalidations)
staletrees = precompile_blockers(trees, tinf)

@show length(uinvalidated(invalidations)) # show total invalidations

show(trees[end]) # show the most invalidating method

# Count number of children (number of invalidations per invalidated method)
n_invalidations = map(SnoopCompile.countchildren, trees)

# (optional) plot the number of children per method invalidations
import Plots
Plots.plot(
1:length(trees),
n_invalidations;
markershape=:circle,
xlabel="i-th method invalidation",
label="Number of children per method invalidations"
)

# (optional) report invalidations summary
using PrettyTables # needed for `report_invalidations` to be defined
SnoopCompile.report_invalidations(;
invalidations,
process_filename = x -> last(split(x, ".julia/packages/")),
n_rows = 0, # no-limit (show all invalidations)
)
8 changes: 4 additions & 4 deletions src/GNNGraphs/gatherscatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ _gather(x::Tuple, i) = map(x -> _gather(x, i), x)
_gather(x::AbstractArray, i) = NNlib.gather(x, i)
_gather(x::Nothing, i) = nothing

_scatter(aggr, src::Nothing, idx, n) = nothing
_scatter(aggr, src::NamedTuple, idx, n) = map(s -> _scatter(aggr, s, idx, n), src)
_scatter(aggr, src::Tuple, idx, n) = map(s -> _scatter(aggr, s, idx, n), src)
_scatter(aggr, src::Dict, idx, n) = Dict(k => _scatter(aggr, v, idx, n) for (k, v) in src)
_scatter(aggr::A, src::Nothing, idx, n) where A = nothing
_scatter(aggr::A, src::NamedTuple, idx, n) where A = map(s -> _scatter(aggr, s, idx, n), src)
_scatter(aggr::A, src::Tuple, idx, n) where A = map(s -> _scatter(aggr, s, idx, n), src)
_scatter(aggr::A, src::Dict, idx, n) where A = Dict(k => _scatter(aggr, v, idx, n) for (k, v) in src)

function _scatter(aggr,
src::AbstractArray,
Expand Down
7 changes: 6 additions & 1 deletion src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module GraphNeuralNetworks

using Statistics: mean
using LinearAlgebra, Random
using Base: tail
using Base: tail, Fix1, Fix2
using CUDA
using Flux
using Flux: glorot_uniform, leakyrelu, GRUCell, @functor, batch
Expand All @@ -11,6 +11,7 @@ using NNlib, NNlibCUDA
using NNlib: scatter, gather
using ChainRulesCore
using Reexport
using SnoopPrecompile
using SparseArrays, Graphs # not needed but if removed Documenter will complain

include("GNNGraphs/GNNGraphs.jl")
Expand Down Expand Up @@ -83,4 +84,8 @@ include("msgpass.jl")
include("mldatasets.jl")
include("deprecations.jl")

@precompile_all_calls begin
include("precompile.jl")
end

end
30 changes: 15 additions & 15 deletions src/msgpass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ See also [`apply_edges`](@ref) and [`aggregate_neighbors`](@ref).
"""
function propagate end

function propagate(f, g::GNNGraph, aggr; xi = nothing, xj = nothing, e = nothing)
propagate(f, g, aggr, xi, xj, e)
function propagate(f::F, g::GNNGraph, aggr; xi = nothing, xj = nothing, e = nothing) where F
propagate(f, g, aggr, xi, xj, e)
end

function propagate(f, g::GNNGraph, aggr, xi, xj, e = nothing)
function propagate(f::F, g::GNNGraph, aggr, xi, xj, e = nothing) where F
m = apply_edges(f, g, xi, xj, e)
m̄ = aggregate_neighbors(g, aggr, m)
return m̄
Expand All @@ -87,12 +87,12 @@ end
# https://github.com/JuliaLang/julia/issues/15276
## and zygote issues
# https://github.com/FluxML/Zygote.jl/issues/1317
function propagate(f, g::GNNGraph, aggr, l::GNNLayer; xi = nothing, xj = nothing,
e = nothing)
propagate((xi, xj, e) -> f(l, xi, xj, e), g, aggr, xi, xj, e)
function propagate(f::F, g::GNNGraph, aggr, l::GNNLayer; xi = nothing, xj = nothing,
e = nothing) where F
propagate(Fix1(f, l), g, aggr, xi, xj, e)
end
function propagate(f, g::GNNGraph, aggr, l::GNNLayer, xi, xj, e = nothing)
propagate((xi, xj, e) -> f(l, xi, xj, e), g, aggr, xi, xj, e)
function propagate(f::F, g::GNNGraph, aggr, l::GNNLayer, xi, xj, e = nothing) where F
propagate(Fix1(f, l), g, aggr, xi, xj, e)
end

## APPLY EDGES
Expand Down Expand Up @@ -135,11 +135,11 @@ See also [`propagate`](@ref) and [`aggregate_neighbors`](@ref).
"""
function apply_edges end

function apply_edges(f, g::GNNGraph; xi = nothing, xj = nothing, e = nothing)
function apply_edges(f::F, g::GNNGraph; xi = nothing, xj = nothing, e = nothing) where F
apply_edges(f, g, xi, xj, e)
end

function apply_edges(f, g::GNNGraph, xi, xj, e = nothing)
function apply_edges(f::F, g::GNNGraph, xi, xj, e = nothing) where F
check_num_nodes(g, xi)
check_num_nodes(g, xj)
check_num_edges(g, e)
Expand All @@ -154,12 +154,12 @@ end
# https://github.com/JuliaLang/julia/issues/15276
## and zygote issues
# https://github.com/FluxML/Zygote.jl/issues/1317
function apply_edges(f, g::GNNGraph, l::GNNLayer; xi = nothing, xj = nothing, e = nothing)
apply_edges((xi, xj, e) -> f(l, xi, xj, e), g, xi, xj, e)
function apply_edges(f::F, g::GNNGraph, l::GNNLayer; xi = nothing, xj = nothing, e = nothing) where F
apply_edges(Fix1(f, l), g, xi, xj, e)
end

function apply_edges(f, g::GNNGraph, l::GNNLayer, xi, xj, e = nothing)
apply_edges((xi, xj, e) -> f(l, xi, xj, e), g, xi, xj, e)
function apply_edges(f::F, g::GNNGraph, l::GNNLayer, xi, xj, e = nothing) where F
apply_edges(Fix1(f, l), g, xi, xj, e)
end

## AGGREGATE NEIGHBORS
Expand All @@ -176,7 +176,7 @@ features
Neighborhood aggregation is the second step of [`propagate`](@ref),
where it comes after [`apply_edges`](@ref).
"""
function aggregate_neighbors(g::GNNGraph, aggr, m)
function aggregate_neighbors(g::GNNGraph, aggr::A, m) where {A}
check_num_edges(g, m)
s, t = edge_index(g)
return GNNGraphs._scatter(aggr, m, t, g.num_nodes)
Expand Down
16 changes: 16 additions & 0 deletions src/precompile.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

function workflow1()
nnodes, d = 10, 6
ngraphs = 5
g = Flux.batch([rand_graph(nnodes, 3*nnodes) for i in 1:ngraphs])
x = rand(Float32, d, g.num_nodes)
model = GNNChain(GCNConv(d => d, relu),
GraphConv(d => d, tanh),
GATv2Conv(d => d ÷ 2, relu, heads=2),
GlobalPool(max),
Dense(d, 1))
y = model(g, x)
grad = gradient(m -> sum(m(g, x)), model)[1]
end

workflow1()