Skip to content

Commit e488f51

Browse files
amontoisonAlexis Montoison
authored andcommitted
[CUSPARSE] Interface generic mv! for SparseMatrixBSR
1 parent 64df292 commit e488f51

File tree

4 files changed

+83
-84
lines changed

4 files changed

+83
-84
lines changed

lib/cusparse/generic.jl

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -152,22 +152,16 @@ function vv!(transx::SparseChar, X::CuSparseVector{T}, Y::DenseCuVector{T}, inde
152152
return result[]
153153
end
154154

155-
function mv!(transa::SparseChar, alpha::Number, A::Union{CuSparseMatrixCSC{TA},CuSparseMatrixCSR{TA},CuSparseMatrixCOO{TA}}, X::DenseCuVector{T},
155+
function mv!(transa::SparseChar, alpha::Number, A::CuSparseMatrix{TA}, X::DenseCuVector{T},
156156
beta::Number, Y::DenseCuVector{T}, index::SparseChar, algo::cusparseSpMVAlg_t=CUSPARSE_SPMV_ALG_DEFAULT) where {TA, T}
157157

158+
(A isa CuSparseMatrixBSR) && (CUSPARSE.version() < v"12.6.3") && throw(ErrorException("This operation is not supported by the current CUDA version."))
159+
158160
# Support transa = 'C' for real matrices
159161
transa = T <: Real && transa == 'C' ? 'T' : transa
160162

161-
if isa(A, CuSparseMatrixCSC)
162-
# cusparseSpMV completely supports CSC matrices with CUSPARSE.version() ≥ v"12.0".
163-
# We use Aᵀ to model them as CSR matrices for older versions of CUSPARSE.
164-
descA = CuSparseMatrixDescriptor(A, index, transposed=true)
165-
n,m = size(A)
166-
transa = transa == 'N' ? 'T' : 'N'
167-
else
168-
descA = CuSparseMatrixDescriptor(A, index)
169-
m,n = size(A)
170-
end
163+
descA = CuSparseMatrixDescriptor(A, index)
164+
m,n = size(A)
171165

172166
if transa == 'N'
173167
chkmvdims(X,n,Y,m)
@@ -318,12 +312,12 @@ function bmm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparse
318312
return out[]
319313
end
320314
with_workspace(bufferSize) do buffer
321-
# We should find a way to reuse the buffer (issue #1362)
322-
if !(A isa CuSparseMatrixCOO)
323-
cusparseSpMM_preprocess(
324-
handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
325-
descC, T, algo, buffer)
326-
end
315+
# Uncomment if we find a way to reuse the buffer (issue #1362)
316+
# if !(A isa CuSparseMatrixCOO)
317+
# cusparseSpMM_preprocess(
318+
# handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
319+
# descC, T, algo, buffer)
320+
# end
327321
cusparseSpMM(
328322
handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
329323
descC, T, algo, buffer)
@@ -372,12 +366,12 @@ function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::DenseCuMa
372366
return out[]
373367
end
374368
with_workspace(bufferSize) do buffer
375-
# We should find a way to reuse the buffer (issue #1362)
376-
if !(B isa CuSparseMatrixCOO)
377-
cusparseSpMM_preprocess(
378-
handle(), transb, transa, Ref{T}(alpha), descB, descA, Ref{T}(beta),
379-
descC, T, algo, buffer)
380-
end
369+
# Uncomment if we find a way to reuse the buffer (issue #1362)
370+
# if !(B isa CuSparseMatrixCOO)
371+
# cusparseSpMM_preprocess(
372+
# handle(), transb, transa, Ref{T}(alpha), descB, descA, Ref{T}(beta),
373+
# descC, T, algo, buffer)
374+
# end
381375
cusparseSpMM(
382376
handle(), transb, transa, Ref{T}(alpha), descB, descA, Ref{T}(beta),
383377
descC, T, algo, buffer)

lib/cusparse/level2.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
# sparse linear algebra functions that perform operations between sparse matrices and dense
22
# vectors
33

4-
export sv2!, sv2, gemvi!
4+
export sv2!, sv2, mv2!, gemvi!
55

66
for (fname,elty) in ((:cusparseSbsrmv, :Float32),
77
(:cusparseDbsrmv, :Float64),
88
(:cusparseCbsrmv, :ComplexF32),
99
(:cusparseZbsrmv, :ComplexF64))
1010
@eval begin
11-
function mv!(transa::SparseChar,
12-
alpha::Number,
13-
A::CuSparseMatrixBSR{$elty},
14-
X::CuVector{$elty},
15-
beta::Number,
16-
Y::CuVector{$elty},
17-
index::SparseChar)
11+
function mv2!(transa::SparseChar,
12+
alpha::Number,
13+
A::CuSparseMatrixBSR{$elty},
14+
X::CuVector{$elty},
15+
beta::Number,
16+
Y::CuVector{$elty},
17+
index::SparseChar)
1818

1919
# Support transa = 'C' for real matrices
2020
transa = $elty <: Real && transa == 'C' ? 'T' : transa

test/libraries/cusparse.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -756,8 +756,7 @@ end
756756
alpha = rand(elty)
757757
beta = rand(elty)
758758
@testset "$(typeof(d_A))" for d_A in [CuSparseMatrixCSR(A),
759-
CuSparseMatrixCSC(A),
760-
CuSparseMatrixBSR(A, blockdim)]
759+
CuSparseMatrixCSC(A)]
761760
d_x = CuArray(x)
762761
d_y = CuArray(y)
763762
@test_throws DimensionMismatch CUSPARSE.mv!('T',alpha,d_A,d_x,beta,d_y,'O')
@@ -766,9 +765,16 @@ end
766765
h_z = collect(d_y)
767766
z = alpha * A * x + beta * y
768767
@test z h_z
769-
#if d_A isa CuSparseMatrixCSR
770-
# @test d_y' * (d_A * d_x) ≈ (d_y' * d_A) * d_x
771-
#end
768+
end
769+
@testset "$(typeof(d_A))" for d_A in [CuSparseMatrixBSR(A, blockdim)]
770+
d_x = CuArray(x)
771+
d_y = CuArray(y)
772+
@test_throws DimensionMismatch CUSPARSE.mv2!('T',alpha,d_A,d_x,beta,d_y,'O')
773+
@test_throws DimensionMismatch CUSPARSE.mv2!('N',alpha,d_A,d_y,beta,d_x,'O')
774+
CUSPARSE.mv2!('N',alpha,d_A,d_x,beta,d_y,'O')
775+
h_z = collect(d_y)
776+
z = alpha * A * x + beta * y
777+
@test z h_z
772778
end
773779
end
774780

test/libraries/cusparse/generic.jl

Lines changed: 47 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using CUDA.CUSPARSE
22
using SparseArrays
33
using LinearAlgebra
44

5-
@testset "generic mv!" for T in [Float32, Float64]
5+
@testset "generic mv! -- $T" for T in [Float32, Float64]
66
m = 10
77
A = sprand(T, m, m, 0.1)
88
x = rand(Complex{T}, m)
@@ -17,7 +17,7 @@ using LinearAlgebra
1717
dA = CuSparseMatrixCSR(dA)
1818
mv!('N', one(T), dA, dx, zero(T), dy, 'O')
1919
@test Array(dy) A * x
20-
20+
2121
A_bad = sprand(T, m+1, m, 0.1)
2222
dA_bad = adapt(CuArray, A_bad)
2323
@test_throws DimensionMismatch("Y must have length $(m+1), but has length $m") mv!('N', one(T), dA_bad, dx, zero(T), dy, 'O')
@@ -32,9 +32,7 @@ SPMV_ALGOS = Dict(CuSparseMatrixCSC => [CUSPARSE.CUSPARSE_SPMV_ALG_DEFAULT],
3232
CUSPARSE.CUSPARSE_SPMV_CSR_ALG1,
3333
CUSPARSE.CUSPARSE_SPMV_CSR_ALG2],
3434
CuSparseMatrixCOO => [CUSPARSE.CUSPARSE_SPMV_ALG_DEFAULT,
35-
CUSPARSE.CUSPARSE_SPMV_COO_ALG1,
36-
],
37-
)
35+
CUSPARSE.CUSPARSE_SPMV_COO_ALG1])
3836

3937
SPMM_ALGOS = Dict(CuSparseMatrixCSC => [CUSPARSE.CUSPARSE_SPMM_ALG_DEFAULT],
4038
CuSparseMatrixCSR => [CUSPARSE.CUSPARSE_SPMM_ALG_DEFAULT,
@@ -43,10 +41,9 @@ SPMM_ALGOS = Dict(CuSparseMatrixCSC => [CUSPARSE.CUSPARSE_SPMM_ALG_DEFAULT],
4341
CUSPARSE.CUSPARSE_SPMM_CSR_ALG3],
4442
CuSparseMatrixCOO => [CUSPARSE.CUSPARSE_SPMM_ALG_DEFAULT,
4543
CUSPARSE.CUSPARSE_SPMM_COO_ALG1,
44+
CUSPARSE.CUSPARSE_SPMM_COO_ALG2,
4645
CUSPARSE.CUSPARSE_SPMM_COO_ALG3,
47-
CUSPARSE.CUSPARSE_SPMM_COO_ALG4]
48-
)
49-
46+
CUSPARSE.CUSPARSE_SPMM_COO_ALG4])
5047

5148
if CUSPARSE.version() >= v"12.1.3"
5249
push!(SPMV_ALGOS[CuSparseMatrixCOO], CUSPARSE.CUSPARSE_SPMV_COO_ALG2)
@@ -57,15 +54,20 @@ if CUSPARSE.version() >= v"12.5.1"
5754
CUSPARSE.CUSPARSE_SPMM_BSR_ALG1]
5855
end
5956

57+
if CUSPARSE.version() >= v"12.6.3"
58+
SPMV_ALGOS[CuSparseMatrixBSR] = [CUSPARSE.CUSPARSE_SPMV_ALG_DEFAULT,
59+
CUSPARSE.CUSPARSE_SPMV_BSR_ALG1]
60+
end
61+
6062
for SparseMatrixType in keys(SPMV_ALGOS)
6163
@testset "$SparseMatrixType -- mv! algo=$algo" for algo in SPMV_ALGOS[SparseMatrixType]
6264
@testset "mv! $T" for T in [Float32, Float64, ComplexF32, ComplexF64]
6365
@testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)]
64-
SparseMatrixType == CuSparseMatrixCSC && T <: Complex && transa == 'C' && continue
66+
(SparseMatrixType == CuSparseMatrixBSR) && (transa != 'N') && continue
6567
A = sprand(T, 20, 10, 0.1)
6668
B = transa == 'N' ? rand(T, 10) : rand(T, 20)
6769
C = transa == 'N' ? rand(T, 20) : rand(T, 10)
68-
dA = SparseMatrixType(A)
70+
dA = SparseMatrixType == CuSparseMatrixBSR ? SparseMatrixType(A,1) : SparseMatrixType(A)
6971
dB = CuArray(B)
7072
dC = CuArray(C)
7173

@@ -83,7 +85,6 @@ for SparseMatrixType in keys(SPMM_ALGOS)
8385
@testset "mm! $T" for T in [Float32, Float64, ComplexF32, ComplexF64]
8486
@testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)]
8587
@testset "transb = $transb" for (transb, opb) in [('N', identity), ('T', transpose), ('C', adjoint)]
86-
CUSPARSE.version() < v"12.0" && SparseMatrixType == CuSparseMatrixCSC && T <: Complex && transa == 'C' && continue
8788
algo == CUSPARSE.CUSPARSE_SPMM_CSR_ALG3 && (transa != 'N' || transb != 'N') && continue
8889
(SparseMatrixType == CuSparseMatrixBSR) && (transa != 'N') && continue
8990
A = sprand(T, 10, 10, 0.1)
@@ -122,7 +123,6 @@ for SparseMatrixType in keys(SPMM_ALGOS)
122123
@testset "$T" for T in [Float32, Float64, ComplexF32, ComplexF64]
123124
@testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)]
124125
@testset "transb = $transb" for (transb, opb) in [('N', identity), ('T', transpose), ('C', adjoint)]
125-
CUSPARSE.version() < v"12.0" && SparseMatrixType == CuSparseMatrixCSR && T <: Complex && transb == 'C' && continue
126126
algo == CUSPARSE.CUSPARSE_SPMM_CSR_ALG3 && (transa != 'N' || transb != 'N') && continue
127127
A = rand(T, 10, 10)
128128
B = transb == 'N' ? sprand(T, 10, 5, 0.5) : sprand(T, 5, 10, 0.5)
@@ -313,16 +313,14 @@ end
313313
@test Z collect(dY)
314314
end
315315

316-
SPGEMM_ALGOS = Dict(CuSparseMatrixCSR => [CUSPARSE.CUSPARSE_SPGEMM_DEFAULT],
317-
CuSparseMatrixCSC => [CUSPARSE.CUSPARSE_SPGEMM_DEFAULT])
318-
if CUSPARSE.version() >= v"12.0"
319-
append!(SPGEMM_ALGOS[CuSparseMatrixCSR], (CUSPARSE.CUSPARSE_SPGEMM_ALG1,
320-
CUSPARSE.CUSPARSE_SPGEMM_ALG2,
321-
CUSPARSE.CUSPARSE_SPGEMM_ALG3))
322-
append!(SPGEMM_ALGOS[CuSparseMatrixCSC], (CUSPARSE.CUSPARSE_SPGEMM_ALG1,
323-
CUSPARSE.CUSPARSE_SPGEMM_ALG2,
324-
CUSPARSE.CUSPARSE_SPGEMM_ALG3))
325-
end
316+
SPGEMM_ALGOS = Dict(CuSparseMatrixCSR => [CUSPARSE.CUSPARSE_SPGEMM_DEFAULT,
317+
CUSPARSE.CUSPARSE_SPGEMM_ALG1,
318+
CUSPARSE.CUSPARSE_SPGEMM_ALG2,
319+
CUSPARSE.CUSPARSE_SPGEMM_ALG3],
320+
CuSparseMatrixCSC => [CUSPARSE.CUSPARSE_SPGEMM_DEFAULT,
321+
CUSPARSE.CUSPARSE_SPGEMM_ALG1,
322+
CUSPARSE.CUSPARSE_SPGEMM_ALG2,
323+
CUSPARSE.CUSPARSE_SPGEMM_ALG3])
326324
# Algorithms CUSPARSE.CUSPARSE_SPGEMM_CSR_ALG_DETERMINITIC and
327325
# CUSPARSE.CUSPARSE_SPGEMM_CSR_ALG_NONDETERMINITIC are dedicated to the cusparseSpGEMMreuse routine.
328326

@@ -391,39 +389,40 @@ for SparseMatrixType in keys(SPGEMM_ALGOS)
391389
end
392390
end
393391

394-
if CUSPARSE.version() >= v"11.4.1"
392+
SDDMM_ALGOS = Dict(CuSparseMatrixCSR => [CUSPARSE.CUSPARSE_SDDMM_ALG_DEFAULT])
395393

396-
SDDMM_ALGOS = Dict(CuSparseMatrixCSR => [CUSPARSE.CUSPARSE_SDDMM_ALG_DEFAULT])
394+
# if CUSPARSE.version() >= v"12.1.0"
395+
# SDDMM_ALGOS[CuSparseMatrixBSR] = [CUSPARSE_SDDMM_ALG_DEFAULT]
396+
# end
397397

398-
for SparseMatrixType in keys(SDDMM_ALGOS)
399-
@testset "$SparseMatrixType -- sddmm! algo=$algo" for algo in SDDMM_ALGOS[SparseMatrixType]
400-
@testset "sddmm! $T" for T in [Float32, Float64, ComplexF32, ComplexF64]
401-
@testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)]
402-
@testset "transb = $transb" for (transb, opb) in [('N', identity), ('T', transpose), ('C', adjoint)]
403-
T <: Complex && (transa == 'C' || transb == 'C') && continue
404-
mA = transa == 'N' ? 25 : 10
405-
nA = transa == 'N' ? 10 : 25
406-
mB = transb == 'N' ? 10 : 35
407-
nB = transb == 'N' ? 35 : 10
398+
for SparseMatrixType in keys(SDDMM_ALGOS)
399+
@testset "$SparseMatrixType -- sddmm! algo=$algo" for algo in SDDMM_ALGOS[SparseMatrixType]
400+
@testset "sddmm! $T" for T in [Float32, Float64, ComplexF32, ComplexF64]
401+
@testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)]
402+
@testset "transb = $transb" for (transb, opb) in [('N', identity), ('T', transpose), ('C', adjoint)]
403+
T <: Complex && (transa == 'C' || transb == 'C') && continue
404+
mA = transa == 'N' ? 25 : 10
405+
nA = transa == 'N' ? 10 : 25
406+
mB = transb == 'N' ? 10 : 35
407+
nB = transb == 'N' ? 35 : 10
408408

409-
A = rand(T,mA,nA)
410-
B = rand(T,mB,nB)
411-
C = sprand(T,25,35,0.3)
409+
A = rand(T,mA,nA)
410+
B = rand(T,mB,nB)
411+
C = sprand(T,25,35,0.3)
412412

413-
spyC = copy(C)
414-
spyC.nzval .= one(T)
413+
spyC = copy(C)
414+
spyC.nzval .= one(T)
415415

416-
dA = CuArray(A)
417-
dB = CuArray(B)
418-
dC = SparseMatrixType(C)
416+
dA = CuArray(A)
417+
dB = CuArray(B)
418+
dC = SparseMatrixType(C)
419419

420-
alpha = rand(T)
421-
beta = rand(T)
420+
alpha = rand(T)
421+
beta = rand(T)
422422

423-
D = alpha * (opa(A) * opb(B)) .* spyC + beta * C
424-
sddmm!(transa, transb, alpha, dA, dB, beta, dC, 'O', algo)
425-
@test collect(dC) D
426-
end
423+
D = alpha * (opa(A) * opb(B)) .* spyC + beta * C
424+
sddmm!(transa, transb, alpha, dA, dB, beta, dC, 'O', algo)
425+
@test collect(dC) D
427426
end
428427
end
429428
end

0 commit comments

Comments
 (0)