-
Notifications
You must be signed in to change notification settings - Fork 238
/
Copy pathdevice.jl
110 lines (88 loc) · 3.45 KB
/
device.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# on-device sparse array functionality
using SparseArrays
# NOTE: this functionality is currently very bare-bones, only defining the array types
# without any device-compatible sparse array functionality
# core types
export CuSparseDeviceVector, CuSparseDeviceMatrixCSC, CuSparseDeviceMatrixCSR,
CuSparseDeviceMatrixBSR, CuSparseDeviceMatrixCOO
struct CuSparseDeviceVector{Tv,Ti,A} <: AbstractSparseVector{Tv,Ti}
iPtr::CuDeviceVector{Ti,A,Ti}
nzVal::CuDeviceVector{Tv,A,Ti}
len::Int
nnz::Ti
end
Base.length(g::CuSparseDeviceVector) = prod(g.dims)
Base.size(g::CuSparseDeviceVector) = (g.len,)
SparseArrays.nnz(g::CuSparseDeviceVector) = g.nnz
struct CuSparseDeviceMatrixCSC{Tv,Ti,A} <: AbstractSparseMatrix{Tv,Ti}
colPtr::CuDeviceVector{Ti,A,Ti}
rowVal::CuDeviceVector{Ti,A,Ti}
nzVal::CuDeviceVector{Tv,A,Ti}
dims::NTuple{2,Int}
nnz::Ti
end
Base.length(g::CuSparseDeviceMatrixCSC) = prod(g.dims)
Base.size(g::CuSparseDeviceMatrixCSC) = g.dims
SparseArrays.nnz(g::CuSparseDeviceMatrixCSC) = g.nnz
struct CuSparseDeviceMatrixCSR{Tv,Ti,A} <: AbstractSparseMatrix{Tv,Ti}
rowPtr::CuDeviceVector{Ti,A,Ti}
colVal::CuDeviceVector{Ti,A,Ti}
nzVal::CuDeviceVector{Tv,A,Ti}
dims::NTuple{2,Int}
nnz::Ti
end
Base.length(g::CuSparseDeviceMatrixCSR) = prod(g.dims)
Base.size(g::CuSparseDeviceMatrixCSR) = g.dims
SparseArrays.nnz(g::CuSparseDeviceMatrixCSR) = g.nnz
struct CuSparseDeviceMatrixBSR{Tv,Ti,A} <: AbstractSparseMatrix{Tv,Ti}
rowPtr::CuDeviceVector{Ti,A,Ti}
colVal::CuDeviceVector{Ti,A,Ti}
nzVal::CuDeviceVector{Tv,A,Ti}
dims::NTuple{2,Int}
blockDim::Ti
dir::Char
nnz::Ti
end
Base.length(g::CuSparseDeviceMatrixBSR) = prod(g.dims)
Base.size(g::CuSparseDeviceMatrixBSR) = g.dims
SparseArrays.nnz(g::CuSparseDeviceMatrixBSR) = g.nnz
struct CuSparseDeviceMatrixCOO{Tv,Ti,A} <: AbstractSparseMatrix{Tv,Ti}
rowInd::CuDeviceVector{Ti,A,Ti}
colInd::CuDeviceVector{Ti,A,Ti}
nzVal::CuDeviceVector{Tv,A,Ti}
dims::NTuple{2,Int}
nnz::Ti
end
Base.length(g::CuSparseDeviceMatrixCOO) = prod(g.dims)
Base.size(g::CuSparseDeviceMatrixCOO) = g.dims
SparseArrays.nnz(g::CuSparseDeviceMatrixCOO) = g.nnz
# input/output
function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceVector)
println(io, "$(length(A))-element device sparse vector at:")
println(io, " iPtr: $(A.iPtr)")
print(io, " nzVal: $(A.nzVal)")
end
function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceMatrixCSR)
println(io, "$(length(A))-element device sparse matrix CSR at:")
println(io, " rowPtr: $(A.rowPtr)")
println(io, " colVal: $(A.colVal)")
print(io, " nzVal: $(A.nzVal)")
end
function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceMatrixCSC)
println(io, "$(length(A))-element device sparse matrix CSC at:")
println(io, " colPtr: $(A.colPtr)")
println(io, " rowVal: $(A.rowVal)")
print(io, " nzVal: $(A.nzVal)")
end
function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceMatrixBSR)
println(io, "$(length(A))-element device sparse matrix BSR at:")
println(io, " rowPtr: $(A.rowPtr)")
println(io, " colVal: $(A.colVal)")
print(io, " nzVal: $(A.nzVal)")
end
function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceMatrixCOO)
println(io, "$(length(A))-element device sparse matrix COO at:")
println(io, " rowPtr: $(A.rowPtr)")
println(io, " colInd: $(A.colInd)")
print(io, " nzVal: $(A.nzVal)")
end