@@ -8,8 +8,12 @@ struct StructuredMatrixStyle{T} <: Broadcast.AbstractArrayStyle{2} end
8
8
StructuredMatrixStyle {T} (:: Val{2} ) where {T} = StructuredMatrixStyle {T} ()
9
9
StructuredMatrixStyle {T} (:: Val{N} ) where {T,N} = Broadcast. DefaultArrayStyle {N} ()
10
10
11
- const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T}}
12
- for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular)
11
+ const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},
12
+ LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T},
13
+ UpperHessenberg{T}}
14
+ for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,
15
+ LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular,
16
+ UpperHessenberg)
13
17
@eval Broadcast. BroadcastStyle (:: Type{<:$ST} ) = $ (StructuredMatrixStyle {ST} ())
14
18
end
15
19
@@ -27,28 +31,46 @@ Broadcast.BroadcastStyle(::StructuredMatrixStyle{Diagonal}, ::StructuredMatrixSt
27
31
StructuredMatrixStyle {LowerTriangular} ()
28
32
Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Diagonal} , :: StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}} ) =
29
33
StructuredMatrixStyle {UpperTriangular} ()
34
+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Diagonal} , :: StructuredMatrixStyle{UpperHessenberg} ) =
35
+ StructuredMatrixStyle {UpperHessenberg} ()
30
36
31
37
Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Bidiagonal} , :: StructuredMatrixStyle{Diagonal} ) =
32
38
StructuredMatrixStyle {Bidiagonal} ()
33
39
Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Bidiagonal} , :: StructuredMatrixStyle{<:Union{Bidiagonal,SymTridiagonal,Tridiagonal}} ) =
34
40
StructuredMatrixStyle {Tridiagonal} ()
41
+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Bidiagonal} , :: StructuredMatrixStyle{UpperHessenberg} ) =
42
+ StructuredMatrixStyle {UpperHessenberg} ()
43
+
35
44
Broadcast. BroadcastStyle (:: StructuredMatrixStyle{SymTridiagonal} , :: StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}} ) =
36
45
StructuredMatrixStyle {Tridiagonal} ()
46
+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{SymTridiagonal} , :: StructuredMatrixStyle{UpperHessenberg} ) =
47
+ StructuredMatrixStyle {UpperHessenberg} ()
37
48
Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Tridiagonal} , :: StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}} ) =
38
49
StructuredMatrixStyle {Tridiagonal} ()
50
+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Tridiagonal} , :: StructuredMatrixStyle{UpperHessenberg} ) =
51
+ StructuredMatrixStyle {UpperHessenberg} ()
39
52
40
53
Broadcast. BroadcastStyle (:: StructuredMatrixStyle{LowerTriangular} , :: StructuredMatrixStyle{<:Union{Diagonal,LowerTriangular,UnitLowerTriangular}} ) =
41
54
StructuredMatrixStyle {LowerTriangular} ()
42
55
Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UpperTriangular} , :: StructuredMatrixStyle{<:Union{Diagonal,UpperTriangular,UnitUpperTriangular}} ) =
43
56
StructuredMatrixStyle {UpperTriangular} ()
57
+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UpperTriangular} , :: StructuredMatrixStyle{UpperHessenberg} ) =
58
+ StructuredMatrixStyle {UpperHessenberg} ()
44
59
Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UnitLowerTriangular} , :: StructuredMatrixStyle{<:Union{Diagonal,LowerTriangular,UnitLowerTriangular}} ) =
45
60
StructuredMatrixStyle {LowerTriangular} ()
46
61
Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UnitUpperTriangular} , :: StructuredMatrixStyle{<:Union{Diagonal,UpperTriangular,UnitUpperTriangular}} ) =
47
62
StructuredMatrixStyle {UpperTriangular} ()
63
+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UnitUpperTriangular} , :: StructuredMatrixStyle{UpperHessenberg} ) =
64
+ StructuredMatrixStyle {UpperHessenberg} ()
65
+
66
+ function Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UpperHessenberg} ,
67
+ :: StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal,UnitUpperTriangular,UpperTriangular}} )
68
+ StructuredMatrixStyle {UpperHessenberg} ()
69
+ end
48
70
49
- Broadcast. BroadcastStyle (:: StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}} , :: StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}} ) =
71
+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}} , :: StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular,UpperHessenberg }} ) =
50
72
StructuredMatrixStyle {Matrix} ()
51
- Broadcast. BroadcastStyle (:: StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}} , :: StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}} ) =
73
+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular,UpperHessenberg }} , :: StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}} ) =
52
74
StructuredMatrixStyle {Matrix} ()
53
75
54
76
# Make sure that `StructuredMatrixStyle{Matrix}` doesn't ever end up falling
@@ -97,7 +119,7 @@ function structured_broadcast_alloc(bc, ::Type{Tridiagonal},
97
119
Tridiagonal (Array {ElType} (undef, n1),Array {ElType} (undef, n),Array {ElType} (undef, n1))
98
120
end
99
121
function structured_broadcast_alloc (bc, :: Type{T} , :: Type{ElType} ,
100
- sz:: NTuple{2,Integer} ) where {ElType,T<: UpperOrLowerTriangular }
122
+ sz:: NTuple{2,Integer} ) where {ElType,T<: Union{ UpperOrLowerTriangular, UpperHessenberg} }
101
123
T (Array {ElType} (undef, sz))
102
124
end
103
125
structured_broadcast_alloc (bc, :: Type{Matrix} , :: Type{ElType} , sz:: NTuple{2,Integer} ) where {ElType} =
@@ -293,6 +315,18 @@ function copyto!(dest::UpperTriangular, bc::Broadcasted{<:StructuredMatrixStyle}
293
315
return dest
294
316
end
295
317
318
+ function copyto! (dest:: UpperHessenberg , bc:: Broadcasted{<:StructuredMatrixStyle} )
319
+ isvalidstructbc (dest, bc) || return copyto! (dest, convert (Broadcasted{Nothing}, bc))
320
+ axs = axes (dest)
321
+ axes (bc) == axs || Broadcast. throwdm (axes (bc), axs)
322
+ for j in axs[2 ]
323
+ for i in 1 : min (size (dest. data,1 ), j+ 1 )
324
+ @inbounds dest. data[i,j] = bc[CartesianIndex (i, j)]
325
+ end
326
+ end
327
+ return dest
328
+ end
329
+
296
330
# We can also implement `map` and its promotion in terms of broadcast with a stricter dimension check
297
331
function map (f, A:: StructuredMatrix , Bs:: StructuredMatrix... )
298
332
sz = size (A)
0 commit comments