Skip to content

Commit fa3fe32

Browse files
committed
Reimplement and generalize all-scalar optimization. Add documentation. Explicitly return dest in various broadcast!-related methods. This is to make things easier on inference. Found by @timholy. Collapse spbroadcast_args! into broadcast! as suggested by @Sacha0.
1 parent 6c94bd8 commit fa3fe32

File tree

3 files changed

+91
-36
lines changed

3 files changed

+91
-36
lines changed

base/broadcast.jl

+35-12
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,6 @@ broadcast_indices
239239
# special cases defined for performance
240240
broadcast(f, x::Number...) = f(x...)
241241
@inline broadcast(f, t::NTuple{N,Any}, ts::Vararg{NTuple{N,Any}}) where {N} = map(f, t, ts...)
242-
@inline broadcast!(::typeof(identity), x::AbstractArray{T,N}, y::AbstractArray{S,N}) where {T,S,N} =
243-
Base.axes(x) == Base.axes(y) ? copyto!(x, y) : _broadcast!(identity, x, y)
244-
245-
# special cases for "X .= ..." (broadcast!) assignments
246-
broadcast!(::typeof(identity), X::AbstractArray, x::Number) = fill!(X, x)
247-
broadcast!(f, X::AbstractArray, x::Number...) = (@inbounds for I in eachindex(X); X[I] = f(x...); end; X)
248242

249243
## logic for deciding the BroadcastStyle
250244
# Dimensionality: computing max(M,N) in the type domain so we preserve inferrability
@@ -261,7 +255,7 @@ longest(::Tuple{}, ::Tuple{}) = ()
261255
# combine_styles operates on values (arbitrarily many)
262256
combine_styles(c) = result_style(BroadcastStyle(typeof(c)))
263257
combine_styles(c1, c2) = result_style(combine_styles(c1), combine_styles(c2))
264-
combine_styles(c1, c2, cs...) = result_style(combine_styles(c1), combine_styles(c2, cs...))
258+
@inline combine_styles(c1, c2, cs...) = result_style(combine_styles(c1), combine_styles(c2, cs...))
265259

266260
# result_style works on types (singletons and pairs), and leverages `BroadcastStyle`
267261
result_style(s::BroadcastStyle) = s
@@ -397,6 +391,7 @@ Base.@propagate_inbounds _broadcast_getindex(::Style{Tuple}, A::Tuple{Any}, I) =
397391
result = @ncall $nargs f val
398392
@inbounds B[I] = result
399393
end
394+
return B
400395
end
401396
end
402397

@@ -433,6 +428,7 @@ end
433428
@inbounds C[ind:bitcache_size] = false
434429
dumpbitcache(Bc, cind, C)
435430
end
431+
return B
436432
end
437433
end
438434

@@ -445,11 +441,38 @@ Note that `dest` is only used to store the result, and does not supply
445441
arguments to `f` unless it is also listed in the `As`,
446442
as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`.
447443
"""
448-
@inline broadcast!(f, C::AbstractArray, A, Bs::Vararg{Any,N}) where {N} =
449-
_broadcast!(f, C, A, Bs...)
444+
@inline broadcast!(f::Tf, dest, As::Vararg{Any,N}) where {Tf,N} = broadcast!(f, dest, combine_styles(As...), As...)
445+
@inline broadcast!(f::Tf, dest, ::BroadcastStyle, As::Vararg{Any,N}) where {Tf,N} = broadcast!(f, dest, nothing, As...)
446+
447+
# Default behavior (separated out so that it can be called by users who want to extend broadcast!).
448+
@inline function broadcast!(f, dest, ::Nothing, As::Vararg{Any, N}) where N
449+
if f isa typeof(identity) && N == 1
450+
A = As[1]
451+
if A isa AbstractArray && Base.axes(dest) == Base.axes(A)
452+
return copyto!(dest, A)
453+
end
454+
end
455+
_broadcast!(f, dest, As...)
456+
return dest
457+
end
458+
459+
# Optimization for the all-Scalar case.
460+
@inline function broadcast!(f, dest, ::Scalar, As::Vararg{Any, N}) where N
461+
if dest isa AbstractArray
462+
if f isa typeof(identity) && N == 1
463+
return fill!(dest, As[1])
464+
else
465+
@inbounds for I in eachindex(dest)
466+
dest[I] = f(As...)
467+
end
468+
return dest
469+
end
470+
end
471+
_broadcast!(f, dest, As...)
472+
return dest
473+
end
450474

451-
# This indirection allows size-dependent implementations (e.g., see the copying `identity`
452-
# specialization above)
475+
# This indirection allows size-dependent implementations.
453476
@inline function _broadcast!(f, C, A, Bs::Vararg{Any,N}) where N
454477
shape = broadcast_indices(C)
455478
@boundscheck check_broadcast_indices(shape, A, Bs...)
@@ -630,7 +653,7 @@ function broadcast_nonleaf(f, s::NonleafHandlingTypes, ::Type{ElType}, shape::In
630653
dest = Base.similar(Array{typeof(val)}, shape)
631654
end
632655
dest[I] = val
633-
return _broadcast!(f, dest, keeps, Idefaults, As, Val(nargs), iter, st, 1)
656+
_broadcast!(f, dest, keeps, Idefaults, As, Val(nargs), iter, st, 1)
634657
end
635658

636659
broadcast(f, ::Union{Scalar,Unknown}, ::Nothing, ::Nothing, a...) = f(a...)

base/sparse/higherorderfns.jl

+20-24
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ end
9393
# (3) broadcast[!] entry points
9494
broadcast(f::Tf, A::SparseVector) where {Tf} = _noshapecheck_map(f, A)
9595
broadcast(f::Tf, A::SparseMatrixCSC) where {Tf} = _noshapecheck_map(f, A)
96-
function broadcast!(f::Tf, C::SparseVecOrMat) where Tf
96+
97+
@inline function broadcast!(f::Tf, C::SparseVecOrMat, ::Nothing) where Tf
9798
isempty(C) && return _finishempty!(C)
9899
fofnoargs = f()
99100
if _iszero(fofnoargs) # f() is zero, so empty C
@@ -106,14 +107,7 @@ function broadcast!(f::Tf, C::SparseVecOrMat) where Tf
106107
end
107108
return C
108109
end
109-
function broadcast!(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N}
110-
_aresameshape(C, A, Bs...) && return _noshapecheck_map!(f, C, A, Bs...)
111-
Base.Broadcast.check_broadcast_indices(axes(C), A, Bs...)
112-
fofzeros = f(_zeros_eltypes(A, Bs...)...)
113-
fpreszeros = _iszero(fofzeros)
114-
return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) :
115-
_broadcast_notzeropres!(f, fofzeros, C, A, Bs...)
116-
end
110+
117111
# the following three similar defs are necessary for type stability in the mixed vector/matrix case
118112
broadcast(f::Tf, A::SparseVector, Bs::Vararg{SparseVector,N}) where {Tf,N} =
119113
_aresameshape(A, Bs...) ? _noshapecheck_map(f, A, Bs...) : _diffshape_broadcast(f, A, Bs...)
@@ -1006,28 +1000,30 @@ Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.DefaultArrayStyle{N}) whe
10061000
broadcast(f, ::PromoteToSparse, ::Nothing, ::Nothing, As::Vararg{Any,N}) where {N} =
10071001
broadcast(f, map(_sparsifystructured, As)...)
10081002

1009-
# ambiguity resolution
1010-
broadcast!(::typeof(identity), dest::SparseVecOrMat, x::Number) =
1011-
fill!(dest, x)
1012-
broadcast!(f, dest::SparseVecOrMat, x::Number...) =
1013-
spbroadcast_args!(f, dest, SPVM, x...)
1014-
10151003
# For broadcast! with ::Any inputs, we need a layer of indirection to determine whether
10161004
# the inputs can be promoted to SparseVecOrMat. If it's just SparseVecOrMat and scalars,
10171005
# we can handle it here, otherwise see below for the promotion machinery.
1018-
broadcast!(f, dest::SparseVecOrMat, mixedsrcargs::Vararg{Any,N}) where N =
1019-
spbroadcast_args!(f, dest, Broadcast.combine_styles(mixedsrcargs...), mixedsrcargs...)
1020-
function spbroadcast_args!(f, dest, ::Type{SPVM}, mixedsrcargs::Vararg{Any,N}) where N
1006+
function broadcast!(f::Tf, dest::SparseVecOrMat, ::SPVM, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N}
1007+
if f isa typeof(identity) && N == 0 && Base.axes(dest) == Base.axes(A)
1008+
return copyto!(dest, A)
1009+
end
1010+
_aresameshape(dest, A, Bs...) && return _noshapecheck_map!(f, dest, A, Bs...)
1011+
Base.Broadcast.check_broadcast_indices(axes(dest), A, Bs...)
1012+
fofzeros = f(_zeros_eltypes(A, Bs...)...)
1013+
fpreszeros = _iszero(fofzeros)
1014+
fpreszeros ? _broadcast_zeropres!(f, dest, A, Bs...) :
1015+
_broadcast_notzeropres!(f, fofzeros, dest, A, Bs...)
1016+
return dest
1017+
end
1018+
function broadcast!(f::Tf, dest::SparseVecOrMat, ::SPVM, mixedsrcargs::Vararg{Any,N}) where {Tf,N}
10211019
# mixedsrcargs contains nothing but SparseVecOrMat and scalars
10221020
parevalf, passedsrcargstup = capturescalars(f, mixedsrcargs)
1023-
return broadcast!(parevalf, dest, passedsrcargstup...)
1021+
broadcast!(parevalf, dest, passedsrcargstup...)
1022+
return dest
10241023
end
1025-
function spbroadcast_args!(f, dest, ::PromoteToSparse, mixedsrcargs::Vararg{Any,N}) where N
1024+
function broadcast!(f::Tf, dest::SparseVecOrMat, ::PromoteToSparse, mixedsrcargs::Vararg{Any,N}) where {Tf,N}
10261025
broadcast!(f, dest, map(_sparsifystructured, mixedsrcargs)...)
1027-
end
1028-
function spbroadcast_args!(f, dest, ::Any, mixedsrcargs::Vararg{Any,N}) where N
1029-
# Fallback. From a performance perspective would it be best to densify?
1030-
Broadcast._broadcast!(f, dest, mixedsrcargs...)
1026+
return dest
10311027
end
10321028

10331029
_sparsifystructured(M::AbstractMatrix) = SparseMatrixCSC(M)

doc/src/manual/interfaces.md

+36
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,8 @@ perhaps range-types `Ind` of your own design. For more information, see [Arrays
404404
| `broadcast(f, As...)` | Complete bypass of broadcasting machinery |
405405
| `broadcast(f, ::DestStyle, ::Nothing, ::Nothing, As...)` | Bypass after container type is computed |
406406
| `broadcast(f, ::DestStyle, ::Type{ElType}, inds::Tuple, As...)` | Bypass after container type, eltype, and indices are computed |
407+
| `broadcast!(f, dest::DestType, ::Nothing, As...)` | Bypass in-place broadcast, specialization on destination type |
408+
| `broadcast!(f, dest, ::BroadcastStyle, As...)` | Bypass in-place broadcast, specialization on `BroadcastStyle` |
407409

408410
[Broadcasting](@ref) is triggered by an explicit call to `broadcast` or `broadcast!`, or implicitly by
409411
"dot" operations like `A .+ b`. Any `AbstractArray` type supports broadcasting,
@@ -591,3 +593,37 @@ yields another `SparseVecStyle`, that its combination with a 2-dimensional array
591593
yields a `SparseMatStyle`, and anything of higher dimensionality falls back to the dense arbitrary-dimensional framework.
592594
These rules allow broadcasting to keep the sparse representation for operations that result
593595
in one or two dimensional outputs, but produce an `Array` for any other dimensionality.
596+
597+
### [Extending `broadcast!`](@id extending-in-place-broadcast)
598+
599+
Extending `broadcast!` (in-place broadcast) should be done with care, as it is easy to introduce
600+
ambiguities between packages. To avoid these ambiguities, we adhere to the following conventions.
601+
602+
First, if you want to specialize on the destination type, say `DestType`, then you should
603+
define a method with the following signature:
604+
605+
```julia
606+
broadcast!(f, dest::DestType, ::Nothing, As...)
607+
```
608+
609+
Note that no bounds should be placed on the types of `f` and `As...`.
610+
611+
Second, if specialized `broadcast!` behavior is desired depending on the input types,
612+
you should write [binary broadcasting rules](@ref writing-binary-broadcasting-rules) to
613+
determine a custom `BroadcastStyle` given the input types, say `MyBroadcastStyle`, and you should define a method with the following
614+
signature:
615+
616+
```julia
617+
broadcast!(f, dest, ::MyBroadcastStyle, As...)
618+
```
619+
620+
Note the lack of bounds on `f`, `dest`, and `As...`.
621+
622+
Third, simultaneously specializing on both the type of `dest` and the `BroadcastStyle` is fine. In this case,
623+
it is also allowed to specialize on the types of the source arguments (`As...`). For example, these method signatures are OK:
624+
625+
```julia
626+
broadcast!(f, dest::DestType, ::MyBroadcastStyle, As...)
627+
broadcast!(f, dest::DestType, ::MyBroadcastStyle, As::AbstractArray...)
628+
broadcast!(f, dest::DestType, ::Broadcast.Scalar, As::Number...)
629+
```

0 commit comments

Comments
 (0)