Skip to content

Commit 04f727e

Browse files
use GBMatrix
1 parent cba6565 commit 04f727e

File tree

4 files changed

+8
-3
lines changed

4 files changed

+8
-3
lines changed

Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
1717
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1818
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1919
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
20+
SuiteSparseGraphBLAS = "c2e53296-7b14-11e9-1210-bddfa8111e1d"
2021

2122
[compat]
2223
CUDA = "3.3"

src/GraphNeuralNetworks.jl

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ using Flux: glorot_uniform, leakyrelu, GRUCell, @functor
1414
using MacroTools: @forward
1515
using NNlib, NNlibCUDA
1616
using ChainRulesCore
17+
using SuiteSparseGraphBLAS: GBMatrix
1718
import LightGraphs
1819
using LightGraphs: AbstractGraph, outneighbors, inneighbors, is_directed, ne, nv,
1920
adjacency_matrix, degree

src/graph_conversions.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ function to_sparse(A::ADJMAT_T, T::DataType=eltype(A); dir=:out, num_nodes=nothi
111111
if T != eltype(A)
112112
A = T.(A)
113113
end
114-
return sparse(A), num_nodes, num_edges
114+
return GBMatrix(sparse(A)), num_nodes, num_edges
115115
end
116116

117117
function to_sparse(adj_list::ADJLIST_T, T::DataType=Int; dir=:out, num_nodes=nothing)
@@ -125,7 +125,8 @@ function to_sparse(coo::COO_T, T::DataType=Int; dir=:out, num_nodes=nothing)
125125
num_nodes = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
126126
A = sparse(s, t, eweight, num_nodes, num_nodes)
127127
num_edges = length(s)
128-
A, num_nodes, num_edges
128+
129+
GBMatrix(A), num_nodes, num_edges
129130
end
130131

131132
@non_differentiable to_coo(x...)

test/runtests.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ tests = [
2323
!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")
2424

2525
# Testing all graph types. :sparse is a bit broken at the moment
26-
@testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:coo, :sparse, :dense)
26+
# @testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:coo, :sparse, :dense)
27+
@testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:sparse,)
28+
2729
global GRAPH_T = graph_type
2830
for t in tests
2931
include("$t.jl")

0 commit comments

Comments
 (0)