Skip to content

Commit 983f0f9

Browse files
mbaumanandreasnoack
authored andcommitted
Support sparse broadcast with transposes and adjoints (#26331)
Fixes #25331
1 parent 08445af commit 983f0f9

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

stdlib/SparseArrays/src/higherorderfns.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -973,7 +973,7 @@ broadcast(f::Tf, A::SparseMatrixCSC, ::Type{T}) where {Tf,T} = broadcast(x -> f(
973973
# and rebroadcast. otherwise, divert to generic AbstractArray broadcast code.
974974

975975
struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end
976-
StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
976+
const StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
977977
Broadcast.BroadcastStyle(::Type{<:StructuredMatrix}) = PromoteToSparse()
978978

979979
PromoteToSparse(::Val{0}) = PromoteToSparse()
@@ -988,8 +988,10 @@ Broadcast.BroadcastStyle(::PromoteToSparse, ::Broadcast.Style{Tuple}) = Broadcas
988988
# Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{0}) = PromoteToSparse()
989989
# Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse()
990990
# Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse()
991-
BroadcastStyle(::Type{<:Adjoint{T,<:Vector}}) where T = Broadcast.MatrixStyle() # Adjoint not yet defined when broadcast.jl loaded
992-
BroadcastStyle(::Type{<:Transpose{T,<:Vector}}) where T = Broadcast.MatrixStyle() # Transpose not yet defined when broadcast.jl loaded
991+
Broadcast.BroadcastStyle(::Type{<:Adjoint{T,<:Vector} where T}) = Broadcast.MatrixStyle() # Adjoint not yet defined when broadcast.jl loaded
992+
Broadcast.BroadcastStyle(::Type{<:Transpose{T,<:Vector} where T}) = Broadcast.MatrixStyle() # Transpose not yet defined when broadcast.jl loaded
993+
Broadcast.BroadcastStyle(::Type{<:Adjoint{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse()
994+
Broadcast.BroadcastStyle(::Type{<:Transpose{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse()
993995
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.VectorStyle) = PromoteToSparse()
994996
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.MatrixStyle) = PromoteToSparse()
995997
Broadcast.BroadcastStyle(::SparseVecStyle, ::Broadcast.DefaultArrayStyle{N}) where N =

stdlib/SparseArrays/test/higherorderfns.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,10 +299,12 @@ end
299299
elT = Float64
300300
s = Float32(2.0)
301301
V = sprand(elT, N, p)
302+
Vᵀ = transpose(sprand(elT, 1, N, p))
302303
A = sprand(elT, N, M, p)
303-
fV, fA = Array(V), Array(A)
304+
Aᵀ = transpose(sprand(elT, M, N, p))
305+
fV, fA, fVᵀ, fAᵀ = Array(V), Array(A), Array(Vᵀ), Array(Aᵀ)
304306
# test combinations involving one to three scalars and one to five sparse vectors/matrices
305-
spargseq, dargseq = Iterators.cycle((A, V)), Iterators.cycle((fA, fV))
307+
spargseq, dargseq = Iterators.cycle((A, V, Aᵀ, Vᵀ)), Iterators.cycle((fA, fV, fAᵀ, fVᵀ))
306308
for nargs in 1:5 # number of tensor arguments
307309
nargsl = cld(nargs, 2) # number in "left half" of tensor arguments
308310
nargsr = fld(nargs, 2) # number in "right half" of tensor arguments

0 commit comments

Comments
 (0)