Skip to content

Commit 6cda11c

Browse files
jishnubN5N3
authored andcommitted
Fix (l/r)mul! with Diagonal/Bidiagonal (#55052)
1 parent ca40417 commit 6cda11c

File tree

5 files changed

+99
-2
lines changed

5 files changed

+99
-2
lines changed

stdlib/LinearAlgebra/src/bidiag.jl

+26
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,32 @@ end
411411
/(A::Bidiagonal, B::Number) = Bidiagonal(A.dv/B, A.ev/B, A.uplo)
412412
\(B::Number, A::Bidiagonal) = Bidiagonal(B\A.dv, B\A.ev, A.uplo)
413413

414+
# B .= D * B
415+
function lmul!(D::Diagonal, B::Bidiagonal)
416+
_muldiag_size_check(D, B)
417+
(; dv, ev) = B
418+
isL = B.uplo == 'L'
419+
dv[1] = D.diag[1] * dv[1]
420+
for i in axes(ev,1)
421+
ev[i] = D.diag[i + isL] * ev[i]
422+
dv[i+1] = D.diag[i+1] * dv[i+1]
423+
end
424+
return B
425+
end
426+
427+
# B .= B * D
428+
function rmul!(B::Bidiagonal, D::Diagonal)
429+
_muldiag_size_check(B, D)
430+
(; dv, ev) = B
431+
isU = B.uplo == 'U'
432+
dv[1] *= D.diag[1]
433+
for i in axes(ev,1)
434+
ev[i] *= D.diag[i + isU]
435+
dv[i+1] *= D.diag[i+1]
436+
end
437+
return B
438+
end
439+
414440
function ==(A::Bidiagonal, B::Bidiagonal)
415441
if A.uplo == B.uplo
416442
return A.dv == B.dv && A.ev == B.ev

stdlib/LinearAlgebra/src/diagonal.jl

+43-2
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,49 @@ end
294294
(*)(D::Diagonal, A::HermOrSym) =
295295
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), D, A)
296296

297-
rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D)
298-
lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B)
297+
function rmul!(A::AbstractMatrix, D::Diagonal)
298+
_muldiag_size_check(A, D)
299+
for I in CartesianIndices(A)
300+
row, col = Tuple(I)
301+
@inbounds A[row, col] *= D.diag[col]
302+
end
303+
return A
304+
end
305+
# T .= T * D
306+
function rmul!(T::Tridiagonal, D::Diagonal)
307+
_muldiag_size_check(T, D)
308+
(; dl, d, du) = T
309+
d[1] *= D.diag[1]
310+
for i in axes(dl,1)
311+
dl[i] *= D.diag[i]
312+
du[i] *= D.diag[i+1]
313+
d[i+1] *= D.diag[i+1]
314+
end
315+
return T
316+
end
317+
318+
function lmul!(D::Diagonal, B::AbstractVecOrMat)
319+
_muldiag_size_check(D, B)
320+
for I in CartesianIndices(B)
321+
row = I[1]
322+
@inbounds B[I] = D.diag[row] * B[I]
323+
end
324+
return B
325+
end
326+
327+
# in-place multiplication with a diagonal
328+
# T .= D * T
329+
function lmul!(D::Diagonal, T::Tridiagonal)
330+
_muldiag_size_check(D, T)
331+
(; dl, d, du) = T
332+
d[1] = D.diag[1] * d[1]
333+
for i in axes(dl,1)
334+
dl[i] = D.diag[i+1] * dl[i]
335+
du[i] = D.diag[i] * du[i]
336+
d[i+1] = D.diag[i+1] * d[i+1]
337+
end
338+
return T
339+
end
299340

300341
function (*)(A::AdjOrTransAbsMat, D::Diagonal)
301342
Ac = copy_similar(A, promote_op(*, eltype(A), eltype(D.diag)))

stdlib/LinearAlgebra/test/bidiag.jl

+9
Original file line numberDiff line numberDiff line change
@@ -827,4 +827,13 @@ end
827827
@test_throws "cannot set entry" B[1,2] = 4
828828
end
829829

830+
@testset "rmul!/lmul! with banded matrices" begin
831+
dv, ev = rand(4), rand(3)
832+
for A in (Bidiagonal(dv, ev, :U), Bidiagonal(dv, ev, :L))
833+
D = Diagonal(dv)
834+
@test rmul!(copy(A), D) A * D
835+
@test lmul!(D, copy(A)) D * A
836+
end
837+
end
838+
830839
end # module TestBidiagonal

stdlib/LinearAlgebra/test/diagonal.jl

+13
Original file line numberDiff line numberDiff line change
@@ -1180,4 +1180,17 @@ end
11801180
@test *(Diagonal(ones(n)), Diagonal(1:n), Diagonal(ones(n)), Diagonal(1:n)) isa Diagonal
11811181
end
11821182

1183+
@testset "rmul!/lmul! with banded matrices" begin
1184+
@testset "$(nameof(typeof(B)))" for B in (
1185+
Bidiagonal(rand(4), rand(3), :L),
1186+
Tridiagonal(rand(3), rand(4), rand(3))
1187+
)
1188+
BA = Array(B)
1189+
D = Diagonal(rand(size(B,1)))
1190+
DA = Array(D)
1191+
@test rmul!(copy(B), D) B * D BA * DA
1192+
@test lmul!(D, copy(B)) D * B DA * BA
1193+
end
1194+
end
1195+
11831196
end # module TestDiagonal

stdlib/LinearAlgebra/test/tridiag.jl

+8
Original file line numberDiff line numberDiff line change
@@ -802,4 +802,12 @@ end
802802
end
803803
end
804804

805+
@testset "rmul!/lmul! with banded matrices" begin
806+
dl, d, du = rand(3), rand(4), rand(3)
807+
A = Tridiagonal(dl, d, du)
808+
D = Diagonal(d)
809+
@test rmul!(copy(A), D) A * D
810+
@test lmul!(D, copy(A)) D * A
811+
end
812+
805813
end # module TestTridiagonal

0 commit comments

Comments
 (0)