|
3 | 3 | using LinearAlgebra: MulAddMul, AdjOrTrans, wrap, UpperOrLowerTriangular, rmul!, lmul! |
4 | 4 | @static if VERSION ≥ v"1.12.0-rc" |
5 | 5 | # we need to use the generic wrapper to avoid dispatch to the 2x2or3x3 method |
6 | | - using LinearAlgebra: generic_matmatmul_wrapper!, BlasFlag |
7 | | - import LinearAlgebra: symm!, herm! |
| 6 | + using LinearAlgebra: generic_matmatmul_wrapper!, BlasFlag, _symm_hemm_generic! |
8 | 7 | end |
9 | 8 | # |
10 | 9 | # BLAS 1 |
|
320 | 319 | function LinearAlgebra.generic_matmatmul_wrapper!(C::StridedCuMatrix{T}, tA::AbstractChar, tB::AbstractChar, A::StridedCuVecOrMat{T}, B::StridedCuVecOrMat{T}, alpha::Number, beta::Number, val::LinearAlgebra.BlasFlag.SyrkHerkGemm) where {T<:CublasFloat} |
321 | 320 | LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta) |
322 | 321 | end |
| 322 | + function LinearAlgebra._symm_hemm_generic!(C::StridedCuMatrix{T}, tA::AbstractChar, tB::AbstractChar, A::StridedCuMatrix{T}, B::StridedCuMatrix{T}, alpha, beta, ::Val{BlasFlag.SYMM}) where {T} |
| 323 | + lrchar, ulchar = LinearAlgebra._lrchar_ulchar(tA, tB) |
| 324 | + if lrchar == 'L' |
| 325 | + symm!(lrchar, ulchar, alpha, A, B, beta, C) |
| 326 | + else |
| 327 | + symm!(lrchar, ulchar, alpha, B, A, beta, C) |
| 328 | + end |
| 329 | + end |
| 330 | + function LinearAlgebra._symm_hemm_generic!(C::StridedCuMatrix{T}, tA::AbstractChar, tB::AbstractChar, A::StridedCuMatrix{T}, B::StridedCuMatrix{T}, alpha, beta, ::Val{BlasFlag.HEMM}) where {T} |
| 331 | + lrchar, ulchar = LinearAlgebra._lrchar_ulchar(tA, tB) |
| 332 | + if lrchar == 'L' |
| 333 | + hemm!(lrchar, ulchar, alpha, A, B, beta, C) |
| 334 | + else |
| 335 | + hemm!(lrchar, ulchar, alpha, B, A, beta, C) |
| 336 | + end |
| 337 | + end |
323 | 338 | end |
324 | 339 |
|
325 | 340 | LinearAlgebra.generic_matmatmul!(C::StridedCuVecOrMat, tA, tB, A::StridedCuVecOrMat, B::StridedCuVecOrMat, _add::MulAddMul) = |
|
0 commit comments