Skip to content

Commit 499c286

Browse files
committed
Structured broadcasting for UpperHessenberg
1 parent c9b6456 commit 499c286

File tree

2 files changed

+46
-7
lines changed

2 files changed

+46
-7
lines changed

src/structuredbroadcast.jl

+39-5
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@ struct StructuredMatrixStyle{T} <: Broadcast.AbstractArrayStyle{2} end
88
StructuredMatrixStyle{T}(::Val{2}) where {T} = StructuredMatrixStyle{T}()
99
StructuredMatrixStyle{T}(::Val{N}) where {T,N} = Broadcast.DefaultArrayStyle{N}()
1010

11-
const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T}}
12-
for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular)
11+
const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},
12+
LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T},
13+
UpperHessenberg{T}}
14+
for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,
15+
LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular,
16+
UpperHessenberg)
1317
@eval Broadcast.BroadcastStyle(::Type{<:$ST}) = $(StructuredMatrixStyle{ST}())
1418
end
1519

@@ -27,28 +31,46 @@ Broadcast.BroadcastStyle(::StructuredMatrixStyle{Diagonal}, ::StructuredMatrixSt
2731
StructuredMatrixStyle{LowerTriangular}()
2832
Broadcast.BroadcastStyle(::StructuredMatrixStyle{Diagonal}, ::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}}) =
2933
StructuredMatrixStyle{UpperTriangular}()
34+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{Diagonal}, ::StructuredMatrixStyle{UpperHessenberg}) =
35+
StructuredMatrixStyle{UpperHessenberg}()
3036

3137
Broadcast.BroadcastStyle(::StructuredMatrixStyle{Bidiagonal}, ::StructuredMatrixStyle{Diagonal}) =
3238
StructuredMatrixStyle{Bidiagonal}()
3339
Broadcast.BroadcastStyle(::StructuredMatrixStyle{Bidiagonal}, ::StructuredMatrixStyle{<:Union{Bidiagonal,SymTridiagonal,Tridiagonal}}) =
3440
StructuredMatrixStyle{Tridiagonal}()
41+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{Bidiagonal}, ::StructuredMatrixStyle{UpperHessenberg}) =
42+
StructuredMatrixStyle{UpperHessenberg}()
43+
3544
Broadcast.BroadcastStyle(::StructuredMatrixStyle{SymTridiagonal}, ::StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}}) =
3645
StructuredMatrixStyle{Tridiagonal}()
46+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{SymTridiagonal}, ::StructuredMatrixStyle{UpperHessenberg}) =
47+
StructuredMatrixStyle{UpperHessenberg}()
3748
Broadcast.BroadcastStyle(::StructuredMatrixStyle{Tridiagonal}, ::StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}}) =
3849
StructuredMatrixStyle{Tridiagonal}()
50+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{Tridiagonal}, ::StructuredMatrixStyle{UpperHessenberg}) =
51+
StructuredMatrixStyle{UpperHessenberg}()
3952

4053
Broadcast.BroadcastStyle(::StructuredMatrixStyle{LowerTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,LowerTriangular,UnitLowerTriangular}}) =
4154
StructuredMatrixStyle{LowerTriangular}()
4255
Broadcast.BroadcastStyle(::StructuredMatrixStyle{UpperTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,UpperTriangular,UnitUpperTriangular}}) =
4356
StructuredMatrixStyle{UpperTriangular}()
57+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{UpperTriangular}, ::StructuredMatrixStyle{UpperHessenberg}) =
58+
StructuredMatrixStyle{UpperHessenberg}()
4459
Broadcast.BroadcastStyle(::StructuredMatrixStyle{UnitLowerTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,LowerTriangular,UnitLowerTriangular}}) =
4560
StructuredMatrixStyle{LowerTriangular}()
4661
Broadcast.BroadcastStyle(::StructuredMatrixStyle{UnitUpperTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,UpperTriangular,UnitUpperTriangular}}) =
4762
StructuredMatrixStyle{UpperTriangular}()
63+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{UnitUpperTriangular}, ::StructuredMatrixStyle{UpperHessenberg}) =
64+
StructuredMatrixStyle{UpperHessenberg}()
65+
66+
function Broadcast.BroadcastStyle(::StructuredMatrixStyle{UpperHessenberg},
67+
::StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal,UnitUpperTriangular,UpperTriangular}})
68+
StructuredMatrixStyle{UpperHessenberg}()
69+
end
4870

49-
Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}}, ::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}}) =
71+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}}, ::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular,UpperHessenberg}}) =
5072
StructuredMatrixStyle{Matrix}()
51-
Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}}, ::StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}}) =
73+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular,UpperHessenberg}}, ::StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}}) =
5274
StructuredMatrixStyle{Matrix}()
5375

5476
# Make sure that `StructuredMatrixStyle{Matrix}` doesn't ever end up falling
@@ -97,7 +119,7 @@ function structured_broadcast_alloc(bc, ::Type{Tridiagonal},
97119
Tridiagonal(Array{ElType}(undef, n1),Array{ElType}(undef, n),Array{ElType}(undef, n1))
98120
end
99121
function structured_broadcast_alloc(bc, ::Type{T}, ::Type{ElType},
100-
sz::NTuple{2,Integer}) where {ElType,T<:UpperOrLowerTriangular}
122+
sz::NTuple{2,Integer}) where {ElType,T<:Union{UpperOrLowerTriangular, UpperHessenberg}}
101123
T(Array{ElType}(undef, sz))
102124
end
103125
structured_broadcast_alloc(bc, ::Type{Matrix}, ::Type{ElType}, sz::NTuple{2,Integer}) where {ElType} =
@@ -293,6 +315,18 @@ function copyto!(dest::UpperTriangular, bc::Broadcasted{<:StructuredMatrixStyle}
293315
return dest
294316
end
295317

318+
function copyto!(dest::UpperHessenberg, bc::Broadcasted{<:StructuredMatrixStyle})
319+
isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc))
320+
axs = axes(dest)
321+
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
322+
for j in axs[2]
323+
for i in 1:min(size(dest.data,1), j+1)
324+
@inbounds dest.data[i,j] = bc[CartesianIndex(i, j)]
325+
end
326+
end
327+
return dest
328+
end
329+
296330
# We can also implement `map` and its promotion in terms of broadcast with a stricter dimension check
297331
function map(f, A::StructuredMatrix, Bs::StructuredMatrix...)
298332
sz = size(A)

test/structuredbroadcast.jl

+7-2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ using .Main.SizedArrays
2222
S = SymTridiagonal(rand(N), rand(max(0,N-1)))
2323
U = UpperTriangular(rand(N,N))
2424
L = LowerTriangular(rand(N,N))
25+
UH = UpperHessenberg(rand(N,N))
2526
M = Matrix(rand(N,N))
26-
structuredarrays = (D, B, T, U, L, M, S)
27+
structuredarrays = (D, B, T, U, L, M, S, UH)
2728
fstructuredarrays = map(Array, structuredarrays)
2829
@testset "$(nameof(typeof(X)))" for (X, fX) in zip(structuredarrays, fstructuredarrays)
2930
@test (Q = broadcast(sin, X); typeof(Q) == typeof(X) && Q == broadcast(sin, fX))
@@ -135,6 +136,7 @@ end
135136
T = Tridiagonal(rand(max(0,N-1)), rand(N), rand(max(0,N-1)))
136137
= LowerTriangular(rand(N,N))
137138
= UpperTriangular(rand(N,N))
139+
UH = UpperHessenberg(rand(N,N))
138140
M = Matrix(rand(N,N))
139141

140142
@test broadcast!(sin, copy(D), D)::Diagonal == sin.(D)::Diagonal
@@ -143,13 +145,15 @@ end
143145
@test broadcast!(sin, copy(T), T)::Tridiagonal == sin.(T)::Tridiagonal
144146
@test broadcast!(sin, copy(◣), ◣)::LowerTriangular == sin.(◣)::LowerTriangular
145147
@test broadcast!(sin, copy(◥), ◥)::UpperTriangular == sin.(◥)::UpperTriangular
148+
@test broadcast!(sin, copy(UH), UH)::UpperHessenberg == sin.(UH)::UpperHessenberg
146149
@test broadcast!(sin, copy(M), M)::Matrix == sin.(M)::Matrix
147150
@test broadcast!(*, copy(D), D, A) == Diagonal(broadcast(*, D, A))
148151
@test broadcast!(*, copy(Bu), Bu, A) == Bidiagonal(broadcast(*, Bu, A), :U)
149152
@test broadcast!(*, copy(Bl), Bl, A) == Bidiagonal(broadcast(*, Bl, A), :L)
150153
@test broadcast!(*, copy(T), T, A) == Tridiagonal(broadcast(*, T, A))
151154
@test broadcast!(*, copy(◣), ◣, A) == LowerTriangular(broadcast(*, ◣, A))
152155
@test broadcast!(*, copy(◥), ◥, A) == UpperTriangular(broadcast(*, ◥, A))
156+
@test broadcast!(*, copy(UH), UH, A) == UpperHessenberg(broadcast(*, UH, A))
153157
@test broadcast!(*, copy(M), M, A) == Matrix(broadcast(*, M, A))
154158

155159
if N > 2
@@ -181,8 +185,9 @@ end
181185
S = SymTridiagonal(rand(N), rand(N - 1))
182186
U = UpperTriangular(rand(N,N))
183187
L = LowerTriangular(rand(N,N))
188+
UH = UpperHessenberg(rand(N,N))
184189
M = Matrix(rand(N,N))
185-
structuredarrays = (M, D, B, T, S, U, L)
190+
structuredarrays = (M, D, B, T, S, U, L, UH)
186191
fstructuredarrays = map(Array, structuredarrays)
187192
for (X, fX) in zip(structuredarrays, fstructuredarrays)
188193
@test (Q = map(sin, X); typeof(Q) == typeof(X) && Q == map(sin, fX))

0 commit comments

Comments
 (0)