Skip to content

Commit 49fa788

Browse files
authored
Actually fix the 1.12 symm/herm issue (#2932)
1 parent bb0a27f commit 49fa788

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

lib/cublas/linalg.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
using LinearAlgebra: MulAddMul, AdjOrTrans, wrap, UpperOrLowerTriangular, rmul!, lmul!
44
@static if VERSION v"1.12.0-rc"
55
# we need to use the generic wrapper to avoid dispatch to the 2x2or3x3 method
6-
using LinearAlgebra: generic_matmatmul_wrapper!, BlasFlag
7-
import LinearAlgebra: symm!, herm!
6+
using LinearAlgebra: generic_matmatmul_wrapper!, BlasFlag, _symm_hemm_generic!
87
end
98
#
109
# BLAS 1
@@ -320,6 +319,22 @@ end
320319
function LinearAlgebra.generic_matmatmul_wrapper!(C::StridedCuMatrix{T}, tA::AbstractChar, tB::AbstractChar, A::StridedCuVecOrMat{T}, B::StridedCuVecOrMat{T}, alpha::Number, beta::Number, val::LinearAlgebra.BlasFlag.SyrkHerkGemm) where {T<:CublasFloat}
321320
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta)
322321
end
322+
function LinearAlgebra._symm_hemm_generic!(C::StridedCuMatrix{T}, tA::AbstractChar, tB::AbstractChar, A::StridedCuMatrix{T}, B::StridedCuMatrix{T}, alpha, beta, ::Val{BlasFlag.SYMM}) where {T}
323+
lrchar, ulchar = LinearAlgebra._lrchar_ulchar(tA, tB)
324+
if lrchar == 'L'
325+
symm!(lrchar, ulchar, alpha, A, B, beta, C)
326+
else
327+
symm!(lrchar, ulchar, alpha, B, A, beta, C)
328+
end
329+
end
330+
function LinearAlgebra._symm_hemm_generic!(C::StridedCuMatrix{T}, tA::AbstractChar, tB::AbstractChar, A::StridedCuMatrix{T}, B::StridedCuMatrix{T}, alpha, beta, ::Val{BlasFlag.HEMM}) where {T}
331+
lrchar, ulchar = LinearAlgebra._lrchar_ulchar(tA, tB)
332+
if lrchar == 'L'
333+
hemm!(lrchar, ulchar, alpha, A, B, beta, C)
334+
else
335+
hemm!(lrchar, ulchar, alpha, B, A, beta, C)
336+
end
337+
end
323338
end
324339

325340
LinearAlgebra.generic_matmatmul!(C::StridedCuVecOrMat, tA, tB, A::StridedCuVecOrMat, B::StridedCuVecOrMat, _add::MulAddMul) =

0 commit comments

Comments
 (0)