Skip to content

Commit 79b911d

Browse files
gogglembauman
authored andcommitted
Introduce StructuredMatrixStyle{Matrix} (fixes #33397) (#33506)
* Introduce `StructuredMatrixStyle{Matrix}` (fixes #33397) * Remove `Matrix` from `const StructuredMatrix` * Prevent `StructuredMatrixStyle{<:Matrix}` from falling back to `DefaultArrayStyle{2}`
1 parent 643ec18 commit 79b911d

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

stdlib/LinearAlgebra/src/structuredbroadcast.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,17 @@ Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:UnitLowerTriangular}, ::Struc
4040
Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:UnitUpperTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,UpperTriangular,UnitUpperTriangular}}) =
4141
StructuredMatrixStyle{UpperTriangular}()
4242

43+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}}, ::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}}) =
44+
StructuredMatrixStyle{Matrix}()
45+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}}, ::StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}}) =
46+
StructuredMatrixStyle{Matrix}()
47+
48+
# Make sure that `StructuredMatrixStyle{<:Matrix}` doesn't ever end up falling
49+
# through and give back `DefaultArrayStyle{2}`
50+
Broadcast.BroadcastStyle(T::StructuredMatrixStyle{<:Matrix}, ::StructuredMatrixStyle) = T
51+
Broadcast.BroadcastStyle(::StructuredMatrixStyle, T::StructuredMatrixStyle{<:Matrix}) = T
52+
Broadcast.BroadcastStyle(T::StructuredMatrixStyle{<:Matrix}, ::StructuredMatrixStyle{<:Matrix}) = T
53+
4354
# All other combinations fall back to the default style
4455
Broadcast.BroadcastStyle(::StructuredMatrixStyle, ::StructuredMatrixStyle) = DefaultArrayStyle{2}()
4556

@@ -69,6 +80,8 @@ structured_broadcast_alloc(bc, ::Type{<:UnitLowerTriangular}, ::Type{ElType}, n)
6980
UnitLowerTriangular(Array{ElType}(undef, n, n))
7081
structured_broadcast_alloc(bc, ::Type{<:UnitUpperTriangular}, ::Type{ElType}, n) where {ElType} =
7182
UnitUpperTriangular(Array{ElType}(undef, n, n))
83+
structured_broadcast_alloc(bc, ::Type{<:Matrix}, ::Type{ElType}, n) where {ElType} =
84+
Matrix(Array{ElType}(undef, n, n))
7285

7386
# A _very_ limited list of structure-preserving functions known at compile-time. This list is
7487
# derived from the formerly-implemented `broadcast` methods in 0.6. Note that this must

stdlib/LinearAlgebra/test/structuredbroadcast.jl

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ using Test, LinearAlgebra
1414
T = Tridiagonal(rand(N - 1), rand(N), rand(N - 1))
1515
U = UpperTriangular(rand(N,N))
1616
L = LowerTriangular(rand(N,N))
17-
structuredarrays = (D, B, T, U, L)
17+
M = Matrix(rand(N,N))
18+
structuredarrays = (D, B, T, U, L, M)
1819
fstructuredarrays = map(Array, structuredarrays)
1920
for (X, fX) in zip(structuredarrays, fstructuredarrays)
2021
@test (Q = broadcast(sin, X); typeof(Q) == typeof(X) && Q == broadcast(sin, fX))
@@ -56,19 +57,22 @@ end
5657
T = Tridiagonal(rand(N - 1), rand(N), rand(N - 1))
5758
= LowerTriangular(rand(N,N))
5859
= UpperTriangular(rand(N,N))
60+
M = Matrix(rand(N,N))
5961

6062
@test broadcast!(sin, copy(D), D) == Diagonal(sin.(D))
6163
@test broadcast!(sin, copy(Bu), Bu) == Bidiagonal(sin.(Bu), :U)
6264
@test broadcast!(sin, copy(Bl), Bl) == Bidiagonal(sin.(Bl), :L)
6365
@test broadcast!(sin, copy(T), T) == Tridiagonal(sin.(T))
6466
@test broadcast!(sin, copy(◣), ◣) == LowerTriangular(sin.(◣))
6567
@test broadcast!(sin, copy(◥), ◥) == UpperTriangular(sin.(◥))
68+
@test broadcast!(sin, copy(M), M) == Matrix(sin.(M))
6669
@test broadcast!(*, copy(D), D, A) == Diagonal(broadcast(*, D, A))
6770
@test broadcast!(*, copy(Bu), Bu, A) == Bidiagonal(broadcast(*, Bu, A), :U)
6871
@test broadcast!(*, copy(Bl), Bl, A) == Bidiagonal(broadcast(*, Bl, A), :L)
6972
@test broadcast!(*, copy(T), T, A) == Tridiagonal(broadcast(*, T, A))
7073
@test broadcast!(*, copy(◣), ◣, A) == LowerTriangular(broadcast(*, ◣, A))
7174
@test broadcast!(*, copy(◥), ◥, A) == UpperTriangular(broadcast(*, ◥, A))
75+
@test broadcast!(*, copy(M), M, A) == Matrix(broadcast(*, M, A))
7276

7377
@test_throws ArgumentError broadcast!(cos, copy(D), D) == Diagonal(sin.(D))
7478
@test_throws ArgumentError broadcast!(cos, copy(Bu), Bu) == Bidiagonal(sin.(Bu), :U)
@@ -93,7 +97,8 @@ end
9397
T = Tridiagonal(rand(N - 1), rand(N), rand(N - 1))
9498
U = UpperTriangular(rand(N,N))
9599
L = LowerTriangular(rand(N,N))
96-
structuredarrays = (D, B, T, U, L)
100+
M = Matrix(rand(N,N))
101+
structuredarrays = (M, D, B, T, U, L)
97102
fstructuredarrays = map(Array, structuredarrays)
98103
for (X, fX) in zip(structuredarrays, fstructuredarrays)
99104
@test (Q = map(sin, X); typeof(Q) == typeof(X) && Q == map(sin, fX))
@@ -123,4 +128,22 @@ end
123128
end
124129
end
125130

131+
@testset "Issue #33397" begin
132+
N = 5
133+
U = UpperTriangular(rand(N, N))
134+
L = LowerTriangular(rand(N, N))
135+
UnitU = UnitUpperTriangular(rand(N, N))
136+
UnitL = UnitLowerTriangular(rand(N, N))
137+
D = Diagonal(rand(N))
138+
@test U .+ L .+ D == U + L + D
139+
@test L .+ U .+ D == L + U + D
140+
@test UnitU .+ UnitL .+ D == UnitU + UnitL + D
141+
@test UnitL .+ UnitU .+ D == UnitL + UnitU + D
142+
@test U .+ UnitL .+ D == U + UnitL + D
143+
@test L .+ UnitU .+ D == L + UnitU + D
144+
@test L .+ U .+ L .+ U == L + U + L + U
145+
@test U .+ L .+ U .+ L == U + L + U + L
146+
@test L .+ UnitL .+ UnitU .+ U .+ D == L + UnitL + UnitU + U + D
147+
@test L .+ U .+ D .+ D .+ D .+ D == L + U + D + D + D + D
148+
end
126149
end

0 commit comments

Comments
 (0)