@@ -2,7 +2,7 @@ using CUDA.CUSPARSE
22using SparseArrays
33using LinearAlgebra
44
5- @testset " generic mv!" for T in [Float32, Float64]
5+ @testset " generic mv! -- $T " for T in [Float32, Float64]
66 m = 10
77 A = sprand (T, m, m, 0.1 )
88 x = rand (Complex{T}, m)
@@ -17,7 +17,7 @@ using LinearAlgebra
1717 dA = CuSparseMatrixCSR (dA)
1818 mv! (' N' , one (T), dA, dx, zero (T), dy, ' O' )
1919 @test Array (dy) ≈ A * x
20-
20+
2121 A_bad = sprand (T, m+ 1 , m, 0.1 )
2222 dA_bad = adapt (CuArray, A_bad)
2323 @test_throws DimensionMismatch (" Y must have length $(m+ 1 ) , but has length $m " ) mv! (' N' , one (T), dA_bad, dx, zero (T), dy, ' O' )
@@ -32,9 +32,7 @@ SPMV_ALGOS = Dict(CuSparseMatrixCSC => [CUSPARSE.CUSPARSE_SPMV_ALG_DEFAULT],
3232 CUSPARSE. CUSPARSE_SPMV_CSR_ALG1,
3333 CUSPARSE. CUSPARSE_SPMV_CSR_ALG2],
3434 CuSparseMatrixCOO => [CUSPARSE. CUSPARSE_SPMV_ALG_DEFAULT,
35- CUSPARSE. CUSPARSE_SPMV_COO_ALG1,
36- ],
37- )
35+ CUSPARSE. CUSPARSE_SPMV_COO_ALG1])
3836
3937SPMM_ALGOS = Dict (CuSparseMatrixCSC => [CUSPARSE. CUSPARSE_SPMM_ALG_DEFAULT],
4038 CuSparseMatrixCSR => [CUSPARSE. CUSPARSE_SPMM_ALG_DEFAULT,
@@ -43,10 +41,9 @@ SPMM_ALGOS = Dict(CuSparseMatrixCSC => [CUSPARSE.CUSPARSE_SPMM_ALG_DEFAULT],
4341 CUSPARSE. CUSPARSE_SPMM_CSR_ALG3],
4442 CuSparseMatrixCOO => [CUSPARSE. CUSPARSE_SPMM_ALG_DEFAULT,
4543 CUSPARSE. CUSPARSE_SPMM_COO_ALG1,
44+ CUSPARSE. CUSPARSE_SPMM_COO_ALG2,
4645 CUSPARSE. CUSPARSE_SPMM_COO_ALG3,
47- CUSPARSE. CUSPARSE_SPMM_COO_ALG4]
48- )
49-
46+ CUSPARSE. CUSPARSE_SPMM_COO_ALG4])
5047
5148if CUSPARSE. version () >= v " 12.1.3"
5249 push! (SPMV_ALGOS[CuSparseMatrixCOO], CUSPARSE. CUSPARSE_SPMV_COO_ALG2)
@@ -57,15 +54,20 @@ if CUSPARSE.version() >= v"12.5.1"
5754 CUSPARSE. CUSPARSE_SPMM_BSR_ALG1]
5855end
5956
57+ if CUSPARSE. version () >= v " 12.6.3"
58+ SPMV_ALGOS[CuSparseMatrixBSR] = [CUSPARSE. CUSPARSE_SPMV_ALG_DEFAULT,
59+ CUSPARSE. CUSPARSE_SPMV_BSR_ALG1]
60+ end
61+
6062for SparseMatrixType in keys (SPMV_ALGOS)
6163 @testset " $SparseMatrixType -- mv! algo=$algo " for algo in SPMV_ALGOS[SparseMatrixType]
6264 @testset " mv! $T " for T in [Float32, Float64, ComplexF32, ComplexF64]
6365 @testset " transa = $transa " for (transa, opa) in [(' N' , identity), (' T' , transpose), (' C' , adjoint)]
64- SparseMatrixType == CuSparseMatrixCSC && T <: Complex && transa == ' C ' && continue
66+ ( SparseMatrixType == CuSparseMatrixBSR) && ( transa != ' N ' ) && continue
6567 A = sprand (T, 20 , 10 , 0.1 )
6668 B = transa == ' N' ? rand (T, 10 ) : rand (T, 20 )
6769 C = transa == ' N' ? rand (T, 20 ) : rand (T, 10 )
68- dA = SparseMatrixType (A)
70+ dA = SparseMatrixType == CuSparseMatrixBSR ? SparseMatrixType (A, 1 ) : SparseMatrixType (A)
6971 dB = CuArray (B)
7072 dC = CuArray (C)
7173
@@ -83,7 +85,6 @@ for SparseMatrixType in keys(SPMM_ALGOS)
8385 @testset " mm! $T " for T in [Float32, Float64, ComplexF32, ComplexF64]
8486 @testset " transa = $transa " for (transa, opa) in [(' N' , identity), (' T' , transpose), (' C' , adjoint)]
8587 @testset " transb = $transb " for (transb, opb) in [(' N' , identity), (' T' , transpose), (' C' , adjoint)]
86- CUSPARSE. version () < v " 12.0" && SparseMatrixType == CuSparseMatrixCSC && T <: Complex && transa == ' C' && continue
8788 algo == CUSPARSE. CUSPARSE_SPMM_CSR_ALG3 && (transa != ' N' || transb != ' N' ) && continue
8889 (SparseMatrixType == CuSparseMatrixBSR) && (transa != ' N' ) && continue
8990 A = sprand (T, 10 , 10 , 0.1 )
@@ -122,7 +123,6 @@ for SparseMatrixType in keys(SPMM_ALGOS)
122123 @testset " $T " for T in [Float32, Float64, ComplexF32, ComplexF64]
123124 @testset " transa = $transa " for (transa, opa) in [(' N' , identity), (' T' , transpose), (' C' , adjoint)]
124125 @testset " transb = $transb " for (transb, opb) in [(' N' , identity), (' T' , transpose), (' C' , adjoint)]
125- CUSPARSE. version () < v " 12.0" && SparseMatrixType == CuSparseMatrixCSR && T <: Complex && transb == ' C' && continue
126126 algo == CUSPARSE. CUSPARSE_SPMM_CSR_ALG3 && (transa != ' N' || transb != ' N' ) && continue
127127 A = rand (T, 10 , 10 )
128128 B = transb == ' N' ? sprand (T, 10 , 5 , 0.5 ) : sprand (T, 5 , 10 , 0.5 )
@@ -313,16 +313,14 @@ end
313313 @test Z ≈ collect (dY)
314314end
315315
316- SPGEMM_ALGOS = Dict (CuSparseMatrixCSR => [CUSPARSE. CUSPARSE_SPGEMM_DEFAULT],
317- CuSparseMatrixCSC => [CUSPARSE. CUSPARSE_SPGEMM_DEFAULT])
318- if CUSPARSE. version () >= v " 12.0"
319- append! (SPGEMM_ALGOS[CuSparseMatrixCSR], (CUSPARSE. CUSPARSE_SPGEMM_ALG1,
320- CUSPARSE. CUSPARSE_SPGEMM_ALG2,
321- CUSPARSE. CUSPARSE_SPGEMM_ALG3))
322- append! (SPGEMM_ALGOS[CuSparseMatrixCSC], (CUSPARSE. CUSPARSE_SPGEMM_ALG1,
323- CUSPARSE. CUSPARSE_SPGEMM_ALG2,
324- CUSPARSE. CUSPARSE_SPGEMM_ALG3))
325- end
316+ SPGEMM_ALGOS = Dict (CuSparseMatrixCSR => [CUSPARSE. CUSPARSE_SPGEMM_DEFAULT,
317+ CUSPARSE. CUSPARSE_SPGEMM_ALG1,
318+ CUSPARSE. CUSPARSE_SPGEMM_ALG2,
319+ CUSPARSE. CUSPARSE_SPGEMM_ALG3],
320+ CuSparseMatrixCSC => [CUSPARSE. CUSPARSE_SPGEMM_DEFAULT,
321+ CUSPARSE. CUSPARSE_SPGEMM_ALG1,
322+ CUSPARSE. CUSPARSE_SPGEMM_ALG2,
323+ CUSPARSE. CUSPARSE_SPGEMM_ALG3])
326324# Algorithms CUSPARSE.CUSPARSE_SPGEMM_CSR_ALG_DETERMINITIC and
327325# CUSPARSE.CUSPARSE_SPGEMM_CSR_ALG_NONDETERMINITIC are dedicated to the cusparseSpGEMMreuse routine.
328326
@@ -391,39 +389,40 @@ for SparseMatrixType in keys(SPGEMM_ALGOS)
391389 end
392390end
393391
394- if CUSPARSE . version () >= v " 11.4.1 "
392+ SDDMM_ALGOS = Dict (CuSparseMatrixCSR => [CUSPARSE . CUSPARSE_SDDMM_ALG_DEFAULT])
395393
396- SDDMM_ALGOS = Dict (CuSparseMatrixCSR => [CUSPARSE. CUSPARSE_SDDMM_ALG_DEFAULT])
394+ # if CUSPARSE.version() >= v"12.1.0"
395+ # SDDMM_ALGOS[CuSparseMatrixBSR] = [CUSPARSE_SDDMM_ALG_DEFAULT]
396+ # end
397397
398- for SparseMatrixType in keys (SDDMM_ALGOS)
399- @testset " $SparseMatrixType -- sddmm! algo=$algo " for algo in SDDMM_ALGOS[SparseMatrixType]
400- @testset " sddmm! $T " for T in [Float32, Float64, ComplexF32, ComplexF64]
401- @testset " transa = $transa " for (transa, opa) in [(' N' , identity), (' T' , transpose), (' C' , adjoint)]
402- @testset " transb = $transb " for (transb, opb) in [(' N' , identity), (' T' , transpose), (' C' , adjoint)]
403- T <: Complex && (transa == ' C' || transb == ' C' ) && continue
404- mA = transa == ' N' ? 25 : 10
405- nA = transa == ' N' ? 10 : 25
406- mB = transb == ' N' ? 10 : 35
407- nB = transb == ' N' ? 35 : 10
398+ for SparseMatrixType in keys (SDDMM_ALGOS)
399+ @testset " $SparseMatrixType -- sddmm! algo=$algo " for algo in SDDMM_ALGOS[SparseMatrixType]
400+ @testset " sddmm! $T " for T in [Float32, Float64, ComplexF32, ComplexF64]
401+ @testset " transa = $transa " for (transa, opa) in [(' N' , identity), (' T' , transpose), (' C' , adjoint)]
402+ @testset " transb = $transb " for (transb, opb) in [(' N' , identity), (' T' , transpose), (' C' , adjoint)]
403+ T <: Complex && (transa == ' C' || transb == ' C' ) && continue
404+ mA = transa == ' N' ? 25 : 10
405+ nA = transa == ' N' ? 10 : 25
406+ mB = transb == ' N' ? 10 : 35
407+ nB = transb == ' N' ? 35 : 10
408408
409- A = rand (T,mA,nA)
410- B = rand (T,mB,nB)
411- C = sprand (T,25 ,35 ,0.3 )
409+ A = rand (T,mA,nA)
410+ B = rand (T,mB,nB)
411+ C = sprand (T,25 ,35 ,0.3 )
412412
413- spyC = copy (C)
414- spyC. nzval .= one (T)
413+ spyC = copy (C)
414+ spyC. nzval .= one (T)
415415
416- dA = CuArray (A)
417- dB = CuArray (B)
418- dC = SparseMatrixType (C)
416+ dA = CuArray (A)
417+ dB = CuArray (B)
418+ dC = SparseMatrixType (C)
419419
420- alpha = rand (T)
421- beta = rand (T)
420+ alpha = rand (T)
421+ beta = rand (T)
422422
423- D = alpha * (opa (A) * opb (B)) .* spyC + beta * C
424- sddmm! (transa, transb, alpha, dA, dB, beta, dC, ' O' , algo)
425- @test collect (dC) ≈ D
426- end
423+ D = alpha * (opa (A) * opb (B)) .* spyC + beta * C
424+ sddmm! (transa, transb, alpha, dA, dB, beta, dC, ' O' , algo)
425+ @test collect (dC) ≈ D
427426 end
428427 end
429428 end
0 commit comments