Skip to content

Commit 26769f7

Browse files
committed
[ci skip]
Fix JuliaLang#24914 (WIP). Address comments. Reimplement and generalize all-scalar optimization. Fix allocation tests for sparse broadcast!. Revert back to calling Broadcast.combine_styles twice. Fix more allocation tests in higherorderfns.jl (by @timholy).
1 parent a5c9e88 commit 26769f7

File tree

2 files changed

+54
-30
lines changed

2 files changed

+54
-30
lines changed

base/broadcast.jl

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,6 @@ broadcast_indices
239239
# special cases defined for performance
240240
broadcast(f, x::Number...) = f(x...)
241241
@inline broadcast(f, t::NTuple{N,Any}, ts::Vararg{NTuple{N,Any}}) where {N} = map(f, t, ts...)
242-
@inline broadcast!(::typeof(identity), x::AbstractArray{T,N}, y::AbstractArray{S,N}) where {T,S,N} =
243-
Base.axes(x) == Base.axes(y) ? copyto!(x, y) : _broadcast!(identity, x, y)
244-
245-
# special cases for "X .= ..." (broadcast!) assignments
246-
broadcast!(::typeof(identity), X::AbstractArray, x::Number) = fill!(X, x)
247-
broadcast!(f, X::AbstractArray, x::Number...) = (@inbounds for I in eachindex(X); X[I] = f(x...); end; X)
248242

249243
## logic for deciding the BroadcastStyle
250244
# Dimensionality: computing max(M,N) in the type domain so we preserve inferrability
@@ -261,7 +255,7 @@ longest(::Tuple{}, ::Tuple{}) = ()
261255
# combine_styles operates on values (arbitrarily many)
262256
combine_styles(c) = result_style(BroadcastStyle(typeof(c)))
263257
combine_styles(c1, c2) = result_style(combine_styles(c1), combine_styles(c2))
264-
combine_styles(c1, c2, cs...) = result_style(combine_styles(c1), combine_styles(c2, cs...))
258+
@inline combine_styles(c1, c2, cs...) = result_style(combine_styles(c1), combine_styles(c2, cs...))
265259

266260
# result_style works on types (singletons and pairs), and leverages `BroadcastStyle`
267261
result_style(s::BroadcastStyle) = s
@@ -445,11 +439,36 @@ Note that `dest` is only used to store the result, and does not supply
445439
arguments to `f` unless it is also listed in the `As`,
446440
as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`.
447441
"""
448-
@inline broadcast!(f, C::AbstractArray, A, Bs::Vararg{Any,N}) where {N} =
449-
_broadcast!(f, C, A, Bs...)
442+
@inline broadcast!(f::Tf, dest, As::Vararg{Any,N}) where {Tf,N} = broadcast!(f, dest, combine_styles(As...), As...)
443+
@inline broadcast!(f::Tf, dest, ::BroadcastStyle, As::Vararg{Any,N}) where {Tf,N} = broadcast!(f, dest, nothing, As...)
444+
445+
# Default behavior (separated out so that it can be called by users who want to extend broadcast!).
446+
@inline function broadcast!(f, dest, ::Void, As::Vararg{Any, N}) where N
447+
if f isa typeof(identity) && N == 1
448+
A = As[1]
449+
if A isa AbstractArray && Base.axes(dest) == Base.axes(A)
450+
return copy!(dest, A)
451+
end
452+
end
453+
return _broadcast!(f, dest, As...)
454+
end
455+
456+
# Optimization for the all-Scalar case.
457+
@inline function broadcast!(f, dest, ::Scalar, As::Vararg{Any, N}) where N
458+
if dest isa AbstractArray
459+
if f isa typeof(identity) && N == 1
460+
return fill!(dest, As[1])
461+
else
462+
@inbounds for I in eachindex(dest)
463+
dest[I] = f(As...)
464+
end
465+
return dest
466+
end
467+
end
468+
return _broadcast!(f, dest, As...)
469+
end
450470

451-
# This indirection allows size-dependent implementations (e.g., see the copying `identity`
452-
# specialization above)
471+
# This indirection allows size-dependent implementations.
453472
@inline function _broadcast!(f, C, A, Bs::Vararg{Any,N}) where N
454473
shape = broadcast_indices(C)
455474
@boundscheck check_broadcast_indices(shape, A, Bs...)

base/sparse/higherorderfns.jl

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ end
9393
# (3) broadcast[!] entry points
9494
broadcast(f::Tf, A::SparseVector) where {Tf} = _noshapecheck_map(f, A)
9595
broadcast(f::Tf, A::SparseMatrixCSC) where {Tf} = _noshapecheck_map(f, A)
96-
function broadcast!(f::Tf, C::SparseVecOrMat) where Tf
96+
97+
@inline function broadcast!(f::Tf, C::SparseVecOrMat, ::Void) where Tf
9798
isempty(C) && return _finishempty!(C)
9899
fofnoargs = f()
99100
if _iszero(fofnoargs) # f() is zero, so empty C
@@ -106,14 +107,18 @@ function broadcast!(f::Tf, C::SparseVecOrMat) where Tf
106107
end
107108
return C
108109
end
109-
function broadcast!(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N}
110-
_aresameshape(C, A, Bs...) && return _noshapecheck_map!(f, C, A, Bs...)
111-
Base.Broadcast.check_broadcast_indices(axes(C), A, Bs...)
112-
fofzeros = f(_zeros_eltypes(A, Bs...)...)
113-
fpreszeros = _iszero(fofzeros)
114-
return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) :
115-
_broadcast_notzeropres!(f, fofzeros, C, A, Bs...)
110+
@inline function broadcast!(f::Tf, dest::SparseVecOrMat, ::Void, As::Vararg{Any,N}) where {Tf,N}
111+
if f isa typeof(identity) && N == 1
112+
A = As[1]
113+
if A isa Number
114+
return fill!(dest, A)
115+
elseif A isa AbstractArray && Base.axes(dest) == Base.axes(A)
116+
return copy!(dest, A)
117+
end
118+
end
119+
return spbroadcast_args!(f, dest, Broadcast.combine_styles(As...), As...)
116120
end
121+
117122
# the following three similar defs are necessary for type stability in the mixed vector/matrix case
118123
broadcast(f::Tf, A::SparseVector, Bs::Vararg{SparseVector,N}) where {Tf,N} =
119124
_aresameshape(A, Bs...) ? _noshapecheck_map(f, A, Bs...) : _diffshape_broadcast(f, A, Bs...)
@@ -1006,26 +1011,26 @@ Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.DefaultArrayStyle{N}) whe
10061011
broadcast(f, ::PromoteToSparse, ::Void, ::Void, As::Vararg{Any,N}) where {N} =
10071012
broadcast(f, map(_sparsifystructured, As)...)
10081013

1009-
# ambiguity resolution
1010-
broadcast!(::typeof(identity), dest::SparseVecOrMat, x::Number) =
1011-
fill!(dest, x)
1012-
broadcast!(f, dest::SparseVecOrMat, x::Number...) =
1013-
spbroadcast_args!(f, dest, SPVM, x...)
1014-
10151014
# For broadcast! with ::Any inputs, we need a layer of indirection to determine whether
10161015
# the inputs can be promoted to SparseVecOrMat. If it's just SparseVecOrMat and scalars,
10171016
# we can handle it here, otherwise see below for the promotion machinery.
1018-
broadcast!(f, dest::SparseVecOrMat, mixedsrcargs::Vararg{Any,N}) where N =
1019-
spbroadcast_args!(f, dest, Broadcast.combine_styles(mixedsrcargs...), mixedsrcargs...)
1020-
function spbroadcast_args!(f, dest, ::Type{SPVM}, mixedsrcargs::Vararg{Any,N}) where N
1017+
function spbroadcast_args!(f::Tf, C, ::SPVM, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N}
1018+
_aresameshape(C, A, Bs...) && return _noshapecheck_map!(f, C, A, Bs...)
1019+
Base.Broadcast.check_broadcast_indices(axes(C), A, Bs...)
1020+
fofzeros = f(_zeros_eltypes(A, Bs...)...)
1021+
fpreszeros = _iszero(fofzeros)
1022+
return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) :
1023+
_broadcast_notzeropres!(f, fofzeros, C, A, Bs...)
1024+
end
1025+
function spbroadcast_args!(f::Tf, dest, ::SPVM, mixedsrcargs::Vararg{Any,N}) where {Tf,N}
10211026
# mixedsrcargs contains nothing but SparseVecOrMat and scalars
10221027
parevalf, passedsrcargstup = capturescalars(f, mixedsrcargs)
10231028
return broadcast!(parevalf, dest, passedsrcargstup...)
10241029
end
1025-
function spbroadcast_args!(f, dest, ::PromoteToSparse, mixedsrcargs::Vararg{Any,N}) where N
1030+
function spbroadcast_args!(f::Tf, dest, ::PromoteToSparse, mixedsrcargs::Vararg{Any,N}) where {Tf,N}
10261031
broadcast!(f, dest, map(_sparsifystructured, mixedsrcargs)...)
10271032
end
1028-
function spbroadcast_args!(f, dest, ::Any, mixedsrcargs::Vararg{Any,N}) where N
1033+
function spbroadcast_args!(f::Tf, dest, ::Any, mixedsrcargs::Vararg{Any,N}) where {Tf,N}
10291034
# Fallback. From a performance perspective would it be best to densify?
10301035
Broadcast._broadcast!(f, dest, mixedsrcargs...)
10311036
end

0 commit comments

Comments
 (0)