Skip to content

Commit fc9c280

Browse files
authored
Fix performance bug for * with AbstractQ (#44615)
1 parent b91dd02 commit fc9c280

File tree

5 files changed

+96
-36
lines changed

5 files changed

+96
-36
lines changed

stdlib/LinearAlgebra/src/bidiag.jl

+4-6
Original file line numberDiff line numberDiff line change
@@ -168,15 +168,13 @@ end
168168
function Matrix{T}(A::Bidiagonal) where T
169169
n = size(A, 1)
170170
B = zeros(T, n, n)
171-
if n == 0
172-
return B
173-
end
174-
for i = 1:n - 1
171+
n == 0 && return B
172+
@inbounds for i = 1:n - 1
175173
B[i,i] = A.dv[i]
176174
if A.uplo == 'U'
177-
B[i, i + 1] = A.ev[i]
175+
B[i,i+1] = A.ev[i]
178176
else
179-
B[i + 1, i] = A.ev[i]
177+
B[i+1,i] = A.ev[i]
180178
end
181179
end
182180
B[n,n] = A.dv[n]

stdlib/LinearAlgebra/src/diagonal.jl

+10-2
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,16 @@ Diagonal{T}(D::Diagonal{T}) where {T} = D
7777
Diagonal{T}(D::Diagonal) where {T} = Diagonal{T}(D.diag)
7878

7979
AbstractMatrix{T}(D::Diagonal) where {T} = Diagonal{T}(D)
80-
Matrix(D::Diagonal) = diagm(0 => D.diag)
81-
Array(D::Diagonal) = Matrix(D)
80+
Matrix(D::Diagonal{T}) where {T} = Matrix{T}(D)
81+
Array(D::Diagonal{T}) where {T} = Matrix{T}(D)
82+
function Matrix{T}(D::Diagonal) where {T}
83+
n = size(D, 1)
84+
B = zeros(T, n, n)
85+
@inbounds for i in 1:n
86+
B[i,i] = D.diag[i]
87+
end
88+
return B
89+
end
8290

8391
"""
8492
Diagonal{T}(undef, n)

stdlib/LinearAlgebra/src/special.jl

+59-10
Original file line numberDiff line numberDiff line change
@@ -292,16 +292,65 @@ function (-)(A::UniformScaling, B::Diagonal{<:Number})
292292
Diagonal(A.λ .- B.diag)
293293
end
294294

295-
rmul!(A::AbstractTriangular, adjB::Adjoint{<:Any,<:Union{QRCompactWYQ,QRPackedQ}}) =
296-
rmul!(full!(A), adjB)
297-
*(A::AbstractTriangular, adjB::Adjoint{<:Any,<:Union{QRCompactWYQ,QRPackedQ}}) =
298-
*(copyto!(similar(parent(A)), A), adjB)
299-
*(A::BiTriSym, adjB::Adjoint{<:Any,<:Union{QRCompactWYQ, QRPackedQ}}) =
300-
rmul!(copyto!(Array{promote_type(eltype(A), eltype(adjB))}(undef, size(A)...), A), adjB)
301-
*(adjA::Adjoint{<:Any,<:Union{QRCompactWYQ, QRPackedQ}}, B::Diagonal) =
302-
lmul!(adjA, copyto!(Array{promote_type(eltype(adjA), eltype(B))}(undef, size(B)...), B))
303-
*(adjA::Adjoint{<:Any,<:Union{QRCompactWYQ, QRPackedQ}}, B::BiTriSym) =
304-
lmul!(adjA, copyto!(Array{promote_type(eltype(adjA), eltype(B))}(undef, size(B)...), B))
295+
lmul!(Q::AbstractQ, B::AbstractTriangular) = lmul!(Q, full!(B))
296+
lmul!(Q::QRPackedQ, B::AbstractTriangular) = lmul!(Q, full!(B)) # disambiguation
297+
lmul!(Q::Adjoint{<:Any,<:AbstractQ}, B::AbstractTriangular) = lmul!(Q, full!(B))
298+
lmul!(Q::Adjoint{<:Any,<:QRPackedQ}, B::AbstractTriangular) = lmul!(Q, full!(B)) # disambiguation
299+
300+
function _qlmul(Q::AbstractQ, B)
301+
TQB = promote_type(eltype(Q), eltype(B))
302+
if size(Q.factors, 1) == size(B, 1)
303+
Bnew = Matrix{TQB}(B)
304+
elseif size(Q.factors, 2) == size(B, 1)
305+
Bnew = [Matrix{TQB}(B); zeros(TQB, size(Q.factors, 1) - size(B,1), size(B, 2))]
306+
else
307+
throw(DimensionMismatch("first dimension of matrix must have size either $(size(Q.factors, 1)) or $(size(Q.factors, 2))"))
308+
end
309+
lmul!(convert(AbstractMatrix{TQB}, Q), Bnew)
310+
end
311+
function _qlmul(adjQ::Adjoint{<:Any,<:AbstractQ}, B)
312+
TQB = promote_type(eltype(adjQ), eltype(B))
313+
lmul!(adjoint(convert(AbstractMatrix{TQB}, parent(adjQ))), Matrix{TQB}(B))
314+
end
315+
316+
*(Q::AbstractQ, B::AbstractTriangular) = _qlmul(Q, B)
317+
*(Q::Adjoint{<:Any,<:AbstractQ}, B::AbstractTriangular) = _qlmul(Q, B)
318+
*(Q::AbstractQ, B::BiTriSym) = _qlmul(Q, B)
319+
*(Q::Adjoint{<:Any,<:AbstractQ}, B::BiTriSym) = _qlmul(Q, B)
320+
*(Q::AbstractQ, B::Diagonal) = _qlmul(Q, B)
321+
*(Q::Adjoint{<:Any,<:AbstractQ}, B::Diagonal) = _qlmul(Q, B)
322+
323+
rmul!(A::AbstractTriangular, Q::AbstractQ) = rmul!(full!(A), Q)
324+
rmul!(A::AbstractTriangular, Q::Adjoint{<:Any,<:AbstractQ}) = rmul!(full!(A), Q)
325+
326+
function _qrmul(A, Q::AbstractQ)
327+
TAQ = promote_type(eltype(A), eltype(Q))
328+
return rmul!(Matrix{TAQ}(A), convert(AbstractMatrix{TAQ}, Q))
329+
end
330+
function _qrmul(A, adjQ::Adjoint{<:Any,<:AbstractQ})
331+
Q = adjQ.parent
332+
TAQ = promote_type(eltype(A), eltype(Q))
333+
if size(A,2) == size(Q.factors, 1)
334+
Anew = Matrix{TAQ}(A)
335+
elseif size(A,2) == size(Q.factors,2)
336+
Anew = [Matrix{TAQ}(A) zeros(TAQ, size(A, 1), size(Q.factors, 1) - size(Q.factors, 2))]
337+
else
338+
throw(DimensionMismatch("matrix A has dimensions $(size(A)) but matrix B has dimensions $(size(Q))"))
339+
end
340+
return rmul!(Anew, adjoint(convert(AbstractMatrix{TAQ}, Q)))
341+
end
342+
343+
*(A::AbstractTriangular, Q::AbstractQ) = _qrmul(A, Q)
344+
*(A::AbstractTriangular, Q::Adjoint{<:Any,<:AbstractQ}) = _qrmul(A, Q)
345+
*(A::BiTriSym, Q::AbstractQ) = _qrmul(A, Q)
346+
*(A::BiTriSym, Q::Adjoint{<:Any,<:AbstractQ}) = _qrmul(A, Q)
347+
*(A::Diagonal, Q::AbstractQ) = _qrmul(A, Q)
348+
*(A::Diagonal, Q::Adjoint{<:Any,<:AbstractQ}) = _qrmul(A, Q)
349+
350+
*(Q::AbstractQ, B::AbstractQ) = _qlmul(Q, B)
351+
*(Q::Adjoint{<:Any,<:AbstractQ}, B::AbstractQ) = _qrmul(Q, B)
352+
*(Q::AbstractQ, B::Adjoint{<:Any,<:AbstractQ}) = _qlmul(Q, B)
353+
*(Q::Adjoint{<:Any,<:AbstractQ}, B::Adjoint{<:Any,<:AbstractQ}) = _qrmul(Q, B)
305354

306355
# fill[stored]! methods
307356
fillstored!(A::Diagonal, x) = (fill!(A.diag, x); A)

stdlib/LinearAlgebra/src/tridiag.jl

+5-4
Original file line numberDiff line numberDiff line change
@@ -571,15 +571,16 @@ function size(M::Tridiagonal, d::Integer)
571571
end
572572
end
573573

574-
function Matrix{T}(M::Tridiagonal{T}) where T
574+
function Matrix{T}(M::Tridiagonal) where {T}
575575
A = zeros(T, size(M))
576-
for i = 1:length(M.d)
576+
n = length(M.d)
577+
n == 0 && return A
578+
for i in 1:n-1
577579
A[i,i] = M.d[i]
578-
end
579-
for i = 1:length(M.d)-1
580580
A[i+1,i] = M.dl[i]
581581
A[i,i+1] = M.du[i]
582582
end
583+
A[n,n] = M.d[n]
583584
A
584585
end
585586
Matrix(M::Tridiagonal{T}) where {T} = Matrix{T}(M)

stdlib/LinearAlgebra/test/special.jl

+18-14
Original file line numberDiff line numberDiff line change
@@ -188,16 +188,21 @@ end
188188

189189

190190
@testset "Triangular Types and QR" begin
191-
for typ in [UpperTriangular,LowerTriangular,LinearAlgebra.UnitUpperTriangular,LinearAlgebra.UnitLowerTriangular]
191+
for typ in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular)
192192
a = rand(n,n)
193193
atri = typ(a)
194+
matri = Matrix(atri)
194195
b = rand(n,n)
195196
qrb = qr(b, ColumnNorm())
196-
@test *(atri, adjoint(qrb.Q)) Matrix(atri) * qrb.Q'
197-
@test rmul!(copy(atri), adjoint(qrb.Q)) Matrix(atri) * qrb.Q'
197+
@test atri * qrb.Q matri * qrb.Q rmul!(copy(atri), qrb.Q)
198+
@test atri * qrb.Q' matri * qrb.Q' rmul!(copy(atri), qrb.Q')
199+
@test qrb.Q * atri qrb.Q * matri lmul!(qrb.Q, copy(atri))
200+
@test qrb.Q' * atri qrb.Q' * matri lmul!(qrb.Q', copy(atri))
198201
qrb = qr(b, NoPivot())
199-
@test *(atri, adjoint(qrb.Q)) Matrix(atri) * qrb.Q'
200-
@test rmul!(copy(atri), adjoint(qrb.Q)) Matrix(atri) * qrb.Q'
202+
@test atri * qrb.Q matri * qrb.Q rmul!(copy(atri), qrb.Q)
203+
@test atri * qrb.Q' matri * qrb.Q' rmul!(copy(atri), qrb.Q')
204+
@test qrb.Q * atri qrb.Q * matri lmul!(qrb.Q, copy(atri))
205+
@test qrb.Q' * atri qrb.Q' * matri lmul!(qrb.Q', copy(atri))
201206
end
202207
end
203208

@@ -421,19 +426,18 @@ end
421426
end
422427

423428
@testset "BiTriSym*Q' and Q'*BiTriSym" begin
424-
dl = [1, 1, 1];
425-
d = [1, 1, 1, 1];
426-
Tri = Tridiagonal(dl, d, dl)
429+
dl = [1, 1, 1]
430+
d = [1, 1, 1, 1]
431+
D = Diagonal(d)
427432
Bi = Bidiagonal(d, dl, :L)
433+
Tri = Tridiagonal(dl, d, dl)
428434
Sym = SymTridiagonal(d, dl)
429435
F = qr(ones(4, 1))
430436
A = F.Q'
431-
@test Tri*A Matrix(Tri)*A
432-
@test A*Tri A*Matrix(Tri)
433-
@test Bi*A Matrix(Bi)*A
434-
@test A*Bi A*Matrix(Bi)
435-
@test Sym*A Matrix(Sym)*A
436-
@test A*Sym A*Matrix(Sym)
437+
for A in (F.Q, F.Q'), B in (D, Bi, Tri, Sym)
438+
@test B*A Matrix(B)*A
439+
@test A*B A*Matrix(B)
440+
end
437441
end
438442

439443
@testset "Ops on SymTridiagonal ev has the same length as dv" begin

0 commit comments

Comments
 (0)