Skip to content

Commit 332dbaa

Browse files
committed
Fix length calculation of broadcast over tuples
1 parent 674e64b commit 332dbaa

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

base/broadcast.jl

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ export broadcast_getindex, broadcast_setindex!, dotview, @__dot__
1111

1212
const ScalarType = Union{Type{Any}, Type{Nullable}}
1313

14+
const OnePlusTuple{T} = Tuple{T,Vararg{T}}
15+
const TwoPlusTuple{T} = Tuple{T,T,Vararg{T}}
16+
1417
## Broadcasting utilities ##
1518
# fallbacks for some special cases
1619
@inline broadcast(f, x::Number...) = f(x...)
@@ -127,6 +130,8 @@ end
127130
end
128131

129132
Base.@propagate_inbounds _broadcast_getindex(A, I) = _broadcast_getindex(containertype(A), A, I)
133+
# `(x,)`, where `x` is a scalar, broadcasts the same way as `[x]` or `x`
134+
Base.@propagate_inbounds _broadcast_getindex(::Type{Tuple}, A::Tuple{Any}, I) = A[1]
130135
Base.@propagate_inbounds _broadcast_getindex(::Type{Array}, A::Ref, I) = A[]
131136
Base.@propagate_inbounds _broadcast_getindex(::ScalarType, A, I) = A
132137
Base.@propagate_inbounds _broadcast_getindex(::Any, A, I) = A[I]
@@ -333,13 +338,32 @@ end
333338
end
334339
@inline broadcast_c(f, ::Type{Any}, a...) = f(a...)
335340
@inline broadcast_c(f, ::Type{Tuple}, A, Bs...) =
336-
tuplebroadcast(f, first_tuple(A, Bs...), A, Bs...)
341+
tuplebroadcast(f, tuplebroadcast_maxtuple(A, Bs...), A, Bs...)
337342
@inline tuplebroadcast(f, ::NTuple{N,Any}, As...) where {N} =
338343
ntuple(k -> f(tuplebroadcast_getargs(As, k)...), Val(N))
339344
@inline tuplebroadcast(f, ::NTuple{N,Any}, ::Type{T}, As...) where {N,T} =
340345
ntuple(k -> f(T, tuplebroadcast_getargs(As, k)...), Val(N))
341-
first_tuple(A::Tuple, Bs...) = A
342-
@inline first_tuple(A, Bs...) = first_tuple(Bs...)
346+
# When the result of broadcast is a tuple it can only come from mixing n-tuples
347+
# of the same length with scalars, empty tuples and 1-tuples. So, in order to
348+
# have a type-stable broadcast we need to find a tuple of maximum length
349+
# (except when there's an empty tuple in which case we return an empty tuple as
350+
# well).
351+
# The following methods compare broadcast arguments pairwise to determine the
352+
# length of the final tuple.
353+
tuplebroadcast_maxtuple(A, B) = (nothing,)
354+
tuplebroadcast_maxtuple(::Tuple{}, A) = ()
355+
tuplebroadcast_maxtuple(A, ::Tuple{}) = ()
356+
tuplebroadcast_maxtuple(A::OnePlusTuple{Any}, B) = A
357+
tuplebroadcast_maxtuple(A, B::OnePlusTuple{Any}) = B
358+
tuplebroadcast_maxtuple(A::OnePlusTuple{Any}, ::Tuple{Any}) = A
359+
tuplebroadcast_maxtuple(::Tuple{Any}, A::OnePlusTuple{Any}) = A
360+
tuplebroadcast_maxtuple(::OnePlusTuple{Any}, ::Tuple{}) = ()
361+
tuplebroadcast_maxtuple(::Tuple{}, ::OnePlusTuple{Any}) = ()
362+
tuplebroadcast_maxtuple(A::NTuple{N,Any}, ::NTuple{N,Any}...) where {N} = A
363+
tuplebroadcast_maxtuple(::TwoPlusTuple{Any}, ::TwoPlusTuple{Any}) =
364+
throw(DimensionMismatch("tuples could not be broadcast to a common size"))
365+
@inline tuplebroadcast_maxtuple(A, Bs...) =
366+
tuplebroadcast_maxtuple(A, tuplebroadcast_maxtuple(Bs...))
343367
tuplebroadcast_getargs(::Tuple{}, k) = ()
344368
@inline tuplebroadcast_getargs(As, k) =
345369
(_broadcast_getindex(first(As), k), tuplebroadcast_getargs(tail(As), k)...)

test/broadcast.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,16 @@ end
519519
Nullable("hello"))
520520
end
521521

522-
# Issue #21291
523-
let t = (0, 1, 2)
524-
o = 1
525-
@test @inferred(broadcast(+, t, o)) == (1, 2, 3)
522+
@testset "broadcast resulting in tuples" begin
523+
# Issue #21291
524+
let t = (0, 1, 2)
525+
o = 1
526+
@test @inferred(broadcast(+, t, o)) == (1, 2, 3)
527+
end
528+
529+
# Issue #23647
530+
@test (1, 2, 3) .+ (1,) == (1,) .+ (1, 2, 3) == (2, 3, 4)
531+
@test (1,) .+ () == () .+ (1,) == () .+ () == ()
532+
@test (1, 2) .+ (1, 2) == (2, 4)
533+
@test_throws DimensionMismatch (1, 2) .+ (1, 2, 3)
526534
end

0 commit comments

Comments
 (0)