@@ -973,35 +973,39 @@ broadcast(f::Tf, A::SparseMatrixCSC, ::Type{T}) where {Tf,T} = broadcast(x -> f(
973
973
# and rebroadcast. otherwise, divert to generic AbstractArray broadcast code.
974
974
975
975
struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end
976
- const StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
977
- Broadcast. BroadcastStyle (:: Type{<:StructuredMatrix} ) = PromoteToSparse ()
978
-
979
976
PromoteToSparse (:: Val{0} ) = PromoteToSparse ()
980
977
PromoteToSparse (:: Val{1} ) = PromoteToSparse ()
981
978
PromoteToSparse (:: Val{2} ) = PromoteToSparse ()
982
979
PromoteToSparse (:: Val{N} ) where N = Broadcast. DefaultArrayStyle {N} ()
983
980
984
- Broadcast. BroadcastStyle (:: PromoteToSparse , :: SPVM ) = PromoteToSparse ()
985
- Broadcast. BroadcastStyle (:: PromoteToSparse , :: Broadcast.Style{Tuple} ) = Broadcast. DefaultArrayStyle {2} ()
986
-
987
- # FIXME : switch to DefaultArrayStyle once we can delete VectorStyle and MatrixStyle
988
- # Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{0}) = PromoteToSparse()
989
- # Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse()
990
- # Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse()
991
- Broadcast. BroadcastStyle (:: Type{<:Adjoint{T,<:Vector} where T} ) = Broadcast. MatrixStyle () # Adjoint not yet defined when broadcast.jl loaded
992
- Broadcast. BroadcastStyle (:: Type{<:Transpose{T,<:Vector} where T} ) = Broadcast. MatrixStyle () # Transpose not yet defined when broadcast.jl loaded
981
+ const StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
982
+ Broadcast. BroadcastStyle (:: Type{<:StructuredMatrix} ) = PromoteToSparse ()
993
983
Broadcast. BroadcastStyle (:: Type{<:Adjoint{T,<:Union{SparseVector,SparseMatrixCSC}} where T} ) = PromoteToSparse ()
994
984
Broadcast. BroadcastStyle (:: Type{<:Transpose{T,<:Union{SparseVector,SparseMatrixCSC}} where T} ) = PromoteToSparse ()
995
- Broadcast. BroadcastStyle (:: SPVM , :: Broadcast.VectorStyle ) = PromoteToSparse ()
996
- Broadcast. BroadcastStyle (:: SPVM , :: Broadcast.MatrixStyle ) = PromoteToSparse ()
997
- Broadcast. BroadcastStyle (:: SparseVecStyle , :: Broadcast.DefaultArrayStyle{N} ) where N =
998
- Broadcast. DefaultArrayStyle (Broadcast. _max (Val (N), Val (1 )))
999
- Broadcast. BroadcastStyle (:: SparseMatStyle , :: Broadcast.DefaultArrayStyle{N} ) where N =
1000
- Broadcast. DefaultArrayStyle (Broadcast. _max (Val (N), Val (2 )))
1001
- # end FIXME
1002
985
1003
- broadcast (f, :: PromoteToSparse , :: Nothing , :: Nothing , As:: Vararg{Any,N} ) where {N} =
1004
- broadcast (f, map (_sparsifystructured, As)... )
986
+ Broadcast. BroadcastStyle (:: SPVM , :: Broadcast.DefaultArrayStyle{0} ) = PromoteToSparse ()
987
+ Broadcast. BroadcastStyle (:: SPVM , :: Broadcast.DefaultArrayStyle{1} ) = PromoteToSparse ()
988
+ Broadcast. BroadcastStyle (:: SPVM , :: Broadcast.DefaultArrayStyle{2} ) = PromoteToSparse ()
989
+ Broadcast. BroadcastStyle (:: PromoteToSparse , :: SPVM ) = PromoteToSparse ()
990
+ Broadcast. BroadcastStyle (:: PromoteToSparse , :: Broadcast.Style{Tuple} ) = Broadcast. DefaultArrayStyle {2} ()
991
+
992
+ # FIXME : currently sparse broadcasts are only well-tested on known array types, while any AbstractArray
993
+ # could report itself as a DefaultArrayStyle().
994
+ # See https://github.com/JuliaLang/julia/pull/23939#pullrequestreview-72075382 for more details
995
+ is_supported_sparse_broadcast () = true
996
+ is_supported_sparse_broadcast (:: AbstractArray , rest... ) = false
997
+ is_supported_sparse_broadcast (:: AbstractSparseArray , rest... ) = is_supported_sparse_broadcast (rest... )
998
+ is_supported_sparse_broadcast (:: StructuredMatrix , rest... ) = is_supported_sparse_broadcast (rest... )
999
+ is_supported_sparse_broadcast (:: Array , rest... ) = is_supported_sparse_broadcast (rest... )
1000
+ is_supported_sparse_broadcast (t:: Union{Transpose, Adjoint} , rest... ) = is_supported_sparse_broadcast (t. parent, rest... )
1001
+ is_supported_sparse_broadcast (x, rest... ) = BroadcastStyle (typeof (x)) === Broadcast. Scalar () && is_supported_sparse_broadcast (rest... )
1002
+ function broadcast (f, s:: PromoteToSparse , :: Nothing , :: Nothing , As:: Vararg{Any,N} ) where {N}
1003
+ if is_supported_sparse_broadcast (As... )
1004
+ return broadcast (f, map (_sparsifystructured, As)... )
1005
+ else
1006
+ return broadcast (f, Broadcast. ArrayConflict (), nothing , nothing , As... )
1007
+ end
1008
+ end
1005
1009
1006
1010
# For broadcast! with ::Any inputs, we need a layer of indirection to determine whether
1007
1011
# the inputs can be promoted to SparseVecOrMat. If it's just SparseVecOrMat and scalars,
0 commit comments