Skip to content

Commit 35638c7

Browse files
committed
[ci skip]
Address comments. Reimplement and generalize all-scalar optimization. Fix allocation tests for sparse broadcast!.
1 parent 6ebf703 commit 35638c7

File tree

2 files changed

+36
-16
lines changed

2 files changed

+36
-16
lines changed

base/broadcast.jl

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -442,21 +442,36 @@ Note that `dest` is only used to store the result, and does not supply
442442
arguments to `f` unless it is also listed in the `As`,
443443
as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`.
444444
"""
445-
broadcast!(f, dest, As...) = broadcast!(f, dest, combine_styles(As...), As...)
446-
broadcast!(f, dest, ::BroadcastStyle, As...) = broadcast!(f, dest, nothing, As...)
447-
@inline function broadcast!(f, C, ::Void, A, Bs::Vararg{Any,N}) where N
448-
if isa(f, typeof(identity)) && N == 0
449-
if isa(A, Number)
450-
return fill!(C, A)
451-
elseif isa(C, AbstractArray) && isa(A, AbstractArray) && Base.axes(C) == Base.axes(A)
452-
return copy!(C, A)
445+
@inline broadcast!(f, dest, As...) = broadcast!(f, dest, combine_styles(As...), As...)
446+
@inline broadcast!(f, dest, ::BroadcastStyle, As...) = broadcast!(f, dest, nothing, As...)
447+
448+
# Default behavior (separated out so that it can be called by users who want to extend broadcast!).
449+
@inline function broadcast!(f, dest, ::Void, As::Vararg{Any, N}) where N
450+
if f isa typeof(identity) && N == 1
451+
A = As[1]
452+
if A isa AbstractArray && Base.axes(dest) == Base.axes(A)
453+
return copy!(dest, A)
453454
end
454455
end
455-
return _broadcast!(f, C, A, Bs...)
456+
return _broadcast!(f, dest, As...)
456457
end
457458

458-
# This indirection allows size-dependent implementations (e.g., see the copying `identity`
459-
# specialization above)
459+
# Optimization for the all-Scalar case.
460+
@inline function broadcast!(f, dest, ::Scalar, As::Vararg{Any, N}) where N
461+
if dest isa AbstractArray
462+
if f isa typeof(identity) && N == 1
463+
return fill!(dest, As[1])
464+
else
465+
@inbounds for I in eachindex(dest)
466+
dest[I] = f(As...)
467+
end
468+
return dest
469+
end
470+
end
471+
return _broadcast!(f, dest, As...)
472+
end
473+
474+
# This indirection allows size-dependent implementations.
460475
@inline function _broadcast!(f, C, A, Bs::Vararg{Any,N}) where N
461476
shape = broadcast_indices(C)
462477
@boundscheck check_broadcast_indices(shape, A, Bs...)

base/sparse/higherorderfns.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ end
9494
broadcast(f::Tf, A::SparseVector) where {Tf} = _noshapecheck_map(f, A)
9595
broadcast(f::Tf, A::SparseMatrixCSC) where {Tf} = _noshapecheck_map(f, A)
9696

97-
function broadcast!(f::Tf, C::SparseVecOrMat, ::Void) where Tf
97+
@inline function broadcast!(f::Tf, C::SparseVecOrMat, ::BroadcastStyle) where Tf
9898
isempty(C) && return _finishempty!(C)
9999
fofnoargs = f()
100100
if _iszero(fofnoargs) # f() is zero, so empty C
@@ -107,11 +107,16 @@ function broadcast!(f::Tf, C::SparseVecOrMat, ::Void) where Tf
107107
end
108108
return C
109109
end
110-
function broadcast!(f, dest::SparseVecOrMat, ::Void, A, Bs::Vararg{Any,N}) where N
111-
if isa(f, typeof(identity)) && N == 0 && isa(A, Number)
112-
return fill!(dest, A)
110+
@inline function broadcast!(f::Tf, dest::SparseVecOrMat, style::BroadcastStyle, As::Vararg{Any,N}) where {Tf,N}
111+
if f isa typeof(identity) && N == 1
112+
A = As[1]
113+
if A isa Number
114+
return fill!(dest, A)
115+
elseif A isa AbstractArray && Base.axes(dest) == Base.axes(A)
116+
return copy!(dest, A)
117+
end
113118
end
114-
return spbroadcast_args!(f, dest, Broadcast.combine_styles(A, Bs...), A, Bs...)
119+
return spbroadcast_args!(f, dest, style, As...)
115120
end
116121

117122
# the following three similar defs are necessary for type stability in the mixed vector/matrix case

0 commit comments

Comments
 (0)