diff --git a/src/hessenberg.jl b/src/hessenberg.jl index d705c00e..a01a700b 100644 --- a/src/hessenberg.jl +++ b/src/hessenberg.jl @@ -59,6 +59,9 @@ size(H::UpperHessenberg) = size(H.data) axes(H::UpperHessenberg) = axes(H.data) parent(H::UpperHessenberg) = H.data +upperhessenbergdata(H::UpperHessenberg) = H.data +upperhessenbergdata(A) = A + # similar behaves like UpperTriangular similar(H::UpperHessenberg, ::Type{T}) where {T} = UpperHessenberg(similar(H.data, T)) similar(H::UpperHessenberg, ::Type{T}, dims::Dims{N}) where {T,N} = similar(H.data, T, dims) diff --git a/src/structuredbroadcast.jl b/src/structuredbroadcast.jl index 79a78dba..23c4d1d0 100644 --- a/src/structuredbroadcast.jl +++ b/src/structuredbroadcast.jl @@ -8,8 +8,12 @@ struct StructuredMatrixStyle{T} <: Broadcast.AbstractArrayStyle{2} end StructuredMatrixStyle{T}(::Val{2}) where {T} = StructuredMatrixStyle{T}() StructuredMatrixStyle{T}(::Val{N}) where {T,N} = Broadcast.DefaultArrayStyle{N}() -const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T}} -for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular) +const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T}, + LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T}, + UpperHessenberg{T}} +for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal, + LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular, + UpperHessenberg) @eval Broadcast.BroadcastStyle(::Type{<:$ST}) = $(StructuredMatrixStyle{ST}()) end @@ -27,28 +31,46 @@ Broadcast.BroadcastStyle(::StructuredMatrixStyle{Diagonal}, ::StructuredMatrixSt StructuredMatrixStyle{LowerTriangular}() Broadcast.BroadcastStyle(::StructuredMatrixStyle{Diagonal}, ::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}}) = StructuredMatrixStyle{UpperTriangular}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{Diagonal}, ::StructuredMatrixStyle{UpperHessenberg}) = + StructuredMatrixStyle{UpperHessenberg}() Broadcast.BroadcastStyle(::StructuredMatrixStyle{Bidiagonal}, ::StructuredMatrixStyle{Diagonal}) = StructuredMatrixStyle{Bidiagonal}() Broadcast.BroadcastStyle(::StructuredMatrixStyle{Bidiagonal}, ::StructuredMatrixStyle{<:Union{Bidiagonal,SymTridiagonal,Tridiagonal}}) = StructuredMatrixStyle{Tridiagonal}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{Bidiagonal}, ::StructuredMatrixStyle{UpperHessenberg}) = + StructuredMatrixStyle{UpperHessenberg}() + Broadcast.BroadcastStyle(::StructuredMatrixStyle{SymTridiagonal}, ::StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}}) = StructuredMatrixStyle{Tridiagonal}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{SymTridiagonal}, ::StructuredMatrixStyle{UpperHessenberg}) = + StructuredMatrixStyle{UpperHessenberg}() Broadcast.BroadcastStyle(::StructuredMatrixStyle{Tridiagonal}, ::StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}}) = StructuredMatrixStyle{Tridiagonal}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{Tridiagonal}, ::StructuredMatrixStyle{UpperHessenberg}) = + StructuredMatrixStyle{UpperHessenberg}() Broadcast.BroadcastStyle(::StructuredMatrixStyle{LowerTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,LowerTriangular,UnitLowerTriangular}}) = StructuredMatrixStyle{LowerTriangular}() Broadcast.BroadcastStyle(::StructuredMatrixStyle{UpperTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,UpperTriangular,UnitUpperTriangular}}) = StructuredMatrixStyle{UpperTriangular}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{UpperTriangular}, ::StructuredMatrixStyle{UpperHessenberg}) = + StructuredMatrixStyle{UpperHessenberg}() Broadcast.BroadcastStyle(::StructuredMatrixStyle{UnitLowerTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,LowerTriangular,UnitLowerTriangular}}) = StructuredMatrixStyle{LowerTriangular}() Broadcast.BroadcastStyle(::StructuredMatrixStyle{UnitUpperTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,UpperTriangular,UnitUpperTriangular}}) = StructuredMatrixStyle{UpperTriangular}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{UnitUpperTriangular}, ::StructuredMatrixStyle{UpperHessenberg}) = + StructuredMatrixStyle{UpperHessenberg}() + +function Broadcast.BroadcastStyle(::StructuredMatrixStyle{UpperHessenberg}, + ::StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal,UnitUpperTriangular,UpperTriangular}}) + StructuredMatrixStyle{UpperHessenberg}() +end -Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}}, ::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}}) = +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}}, ::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular,UpperHessenberg}}) = StructuredMatrixStyle{Matrix}() -Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}}, ::StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}}) = +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular,UpperHessenberg}}, ::StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}}) = StructuredMatrixStyle{Matrix}() # Make sure that `StructuredMatrixStyle{Matrix}` doesn't ever end up falling @@ -97,7 +119,7 @@ function structured_broadcast_alloc(bc, ::Type{Tridiagonal}, Tridiagonal(Array{ElType}(undef, n1),Array{ElType}(undef, n),Array{ElType}(undef, n1)) end function structured_broadcast_alloc(bc, ::Type{T}, ::Type{ElType}, - sz::NTuple{2,Integer}) where {ElType,T<:UpperOrLowerTriangular} + sz::NTuple{2,Integer}) where {ElType,T<:Union{UpperOrLowerTriangular, UpperHessenberg}} T(Array{ElType}(undef, sz)) end structured_broadcast_alloc(bc, ::Type{Matrix}, ::Type{ElType}, sz::NTuple{2,Integer}) where {ElType} = @@ -278,6 +300,7 @@ function preprocess_broadcasted(::Type{T}, bc::Broadcasted) where {T} end _preprocess_broadcasted(::Type{LowerTriangular}, A) = lowertridata(A) _preprocess_broadcasted(::Type{UpperTriangular}, A) = uppertridata(A) +_preprocess_broadcasted(::Type{UpperHessenberg}, A) = upperhessenbergdata(A) function copyto!(dest::LowerTriangular, bc::Broadcasted{<:StructuredMatrixStyle}) isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc)) @@ -305,6 +328,19 @@ function copyto!(dest::UpperTriangular, bc::Broadcasted{<:StructuredMatrixStyle} return dest end +function copyto!(dest::UpperHessenberg, bc::Broadcasted{<:StructuredMatrixStyle}) + isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc)) + axs = axes(dest) + axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) + bc_unwrapped = preprocess_broadcasted(UpperHessenberg, bc) + for j in axs[2] + for i in 1:min(size(dest.data,1), j+1) + @inbounds dest.data[i,j] = bc_unwrapped[CartesianIndex(i, j)] + end + end + return dest +end + # We can also implement `map` and its promotion in terms of broadcast with a stricter dimension check function map(f, A::StructuredMatrix, Bs::StructuredMatrix...) sz = size(A) diff --git a/test/structuredbroadcast.jl b/test/structuredbroadcast.jl index 5e6a68d7..da9df54c 100644 --- a/test/structuredbroadcast.jl +++ b/test/structuredbroadcast.jl @@ -22,8 +22,9 @@ using .Main.SizedArrays S = SymTridiagonal(rand(N), rand(max(0,N-1))) U = UpperTriangular(rand(N,N)) L = LowerTriangular(rand(N,N)) + UH = UpperHessenberg(rand(N,N)) M = Matrix(rand(N,N)) - structuredarrays = (D, B, T, U, L, M, S) + structuredarrays = (D, B, T, U, L, M, S, UH) fstructuredarrays = map(Array, structuredarrays) @testset "$(nameof(typeof(X)))" for (X, fX) in zip(structuredarrays, fstructuredarrays) @test (Q = broadcast(sin, X); typeof(Q) == typeof(X) && Q == broadcast(sin, fX)) @@ -135,6 +136,7 @@ end T = Tridiagonal(rand(max(0,N-1)), rand(N), rand(max(0,N-1))) ◣ = LowerTriangular(rand(N,N)) ◥ = UpperTriangular(rand(N,N)) + UH = UpperHessenberg(rand(N,N)) M = Matrix(rand(N,N)) @test broadcast!(sin, copy(D), D)::Diagonal == sin.(D)::Diagonal @@ -143,6 +145,7 @@ end @test broadcast!(sin, copy(T), T)::Tridiagonal == sin.(T)::Tridiagonal @test broadcast!(sin, copy(◣), ◣)::LowerTriangular == sin.(◣)::LowerTriangular @test broadcast!(sin, copy(◥), ◥)::UpperTriangular == sin.(◥)::UpperTriangular + @test broadcast!(sin, copy(UH), UH)::UpperHessenberg == sin.(UH)::UpperHessenberg @test broadcast!(sin, copy(M), M)::Matrix == sin.(M)::Matrix @test broadcast!(*, copy(D), D, A) == Diagonal(broadcast(*, D, A)) @test broadcast!(*, copy(Bu), Bu, A) == Bidiagonal(broadcast(*, Bu, A), :U) @@ -150,6 +153,7 @@ end @test broadcast!(*, copy(T), T, A) == Tridiagonal(broadcast(*, T, A)) @test broadcast!(*, copy(◣), ◣, A) == LowerTriangular(broadcast(*, ◣, A)) @test broadcast!(*, copy(◥), ◥, A) == UpperTriangular(broadcast(*, ◥, A)) + @test broadcast!(*, copy(UH), UH, A) == UpperHessenberg(broadcast(*, UH, A)) @test broadcast!(*, copy(M), M, A) == Matrix(broadcast(*, M, A)) if N > 2 @@ -181,8 +185,9 @@ end S = SymTridiagonal(rand(N), rand(N - 1)) U = UpperTriangular(rand(N,N)) L = LowerTriangular(rand(N,N)) + UH = UpperHessenberg(rand(N,N)) M = Matrix(rand(N,N)) - structuredarrays = (M, D, B, T, S, U, L) + structuredarrays = (M, D, B, T, S, U, L, UH) fstructuredarrays = map(Array, structuredarrays) for (X, fX) in zip(structuredarrays, fstructuredarrays) @test (Q = map(sin, X); typeof(Q) == typeof(X) && Q == map(sin, fX)) @@ -396,4 +401,11 @@ end end end +@testset "Rectangular UpperHessenberg" begin + UH = UpperHessenberg(ones(4,3)) + UH2 = UH .+ UH .- UH + @test UH2 == UH + @test UH2 isa UpperHessenberg +end + end