Skip to content

Commit 82cd0c9

Browse files
pabloferzararslan
authored andcommitted
Fix length calculation of broadcast over tuples (#23887)
(cherry picked from commit 698ef27)
1 parent 0c06745 commit 82cd0c9

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

base/broadcast.jl

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ end
128128
end
129129

130130
Base.@propagate_inbounds _broadcast_getindex(A, I) = _broadcast_getindex(containertype(A), A, I)
131+
# `(x,)`, where `x` is a scalar, broadcasts the same way as `[x]` or `x`
132+
Base.@propagate_inbounds _broadcast_getindex(::Type{Tuple}, A::Tuple{Any}, I) = A[1]
131133
Base.@propagate_inbounds _broadcast_getindex(::Type{Array}, A::Ref, I) = A[]
132134
Base.@propagate_inbounds _broadcast_getindex(::ScalarType, A, I) = A
133135
Base.@propagate_inbounds _broadcast_getindex(::Any, A, I) = A[I]
@@ -334,13 +336,32 @@ end
334336
end
335337
@inline broadcast_c(f, ::Type{Any}, a...) = f(a...)
336338
@inline broadcast_c(f, ::Type{Tuple}, A, Bs...) =
337-
tuplebroadcast(f, first_tuple(A, Bs...), A, Bs...)
339+
tuplebroadcast(f, tuplebroadcast_maxtuple(A, Bs...), A, Bs...)
338340
@inline tuplebroadcast(f, ::NTuple{N,Any}, As...) where {N} =
339341
ntuple(k -> f(tuplebroadcast_getargs(As, k)...), Val{N})
340342
@inline tuplebroadcast(f, ::NTuple{N,Any}, ::Type{T}, As...) where {N,T} =
341343
ntuple(k -> f(T, tuplebroadcast_getargs(As, k)...), Val{N})
342-
first_tuple(A::Tuple, Bs...) = A
343-
@inline first_tuple(A, Bs...) = first_tuple(Bs...)
344+
# When the result of broadcast is a tuple it can only come from mixing n-tuples
345+
# of the same length with scalars and 1-tuples. So, in order to have a
346+
# type-stable broadcast, we need to find a tuple of maximum length (except when
347+
# there are only scalars, empty tuples and 1-tuples, in which case the
348+
# returned value will be an empty tuple).
349+
# The following methods compare broadcast arguments pairwise to determine the
350+
# length of the final tuple.
351+
tuplebroadcast_maxtuple(A, B) =
352+
_tuplebroadcast_maxtuple(containertype(A), containertype(B), A, B)
353+
@inline tuplebroadcast_maxtuple(A, Bs...) =
354+
tuplebroadcast_maxtuple(A, tuplebroadcast_maxtuple(Bs...))
355+
tuplebroadcast_maxtuple(A::NTuple{N,Any}, ::NTuple{N,Any}...) where {N} = A
356+
# Here we use the containertype trait to easier disambiguate between methods
357+
_tuplebroadcast_maxtuple(::Type{Any}, ::Type{Any}, A, B) = (nothing,)
358+
_tuplebroadcast_maxtuple(::Type{Tuple}, ::Type{Any}, A, B) = A
359+
_tuplebroadcast_maxtuple(::Type{Any}, ::Type{Tuple}, A, B) = B
360+
_tuplebroadcast_maxtuple(::Type{Tuple}, ::Type{Tuple}, A, B::Tuple{Any}) = A
361+
_tuplebroadcast_maxtuple(::Type{Tuple}, ::Type{Tuple}, A::Tuple{Any}, B) = B
362+
_tuplebroadcast_maxtuple(::Type{Tuple}, ::Type{Tuple}, A::Tuple{Any}, ::Tuple{Any}) = A
363+
_tuplebroadcast_maxtuple(::Type{Tuple}, ::Type{Tuple}, A, B) =
364+
throw(DimensionMismatch("tuples could not be broadcast to a common size"))
344365
tuplebroadcast_getargs(::Tuple{}, k) = ()
345366
@inline tuplebroadcast_getargs(As, k) =
346367
(_broadcast_getindex(first(As), k), tuplebroadcast_getargs(tail(As), k)...)

test/broadcast.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -515,8 +515,16 @@ end
515515
Nullable("hello"))
516516
end
517517

518-
# Issue #21291
519-
let t = (0, 1, 2)
520-
o = 1
521-
@test @inferred(broadcast(+, t, o)) == (1, 2, 3)
518+
@testset "broadcast resulting in tuples" begin
519+
# Issue #21291
520+
let t = (0, 1, 2)
521+
o = 1
522+
@test @inferred(broadcast(+, t, o)) == (1, 2, 3)
523+
end
524+
525+
# Issue #23647
526+
@test (1, 2, 3) .+ (1,) == (1,) .+ (1, 2, 3) == (2, 3, 4)
527+
@test (1,) .+ () == () .+ (1,) == () .+ () == ()
528+
@test (1, 2) .+ (1, 2) == (2, 4)
529+
@test_throws DimensionMismatch (1, 2) .+ (1, 2, 3)
522530
end

0 commit comments

Comments
 (0)