@@ -11,6 +11,9 @@ export broadcast_getindex, broadcast_setindex!, dotview, @__dot__
11
11
12
12
const ScalarType = Union{Type{Any}, Type{Nullable}}
13
13
14
+ const OnePlusTuple{T} = Tuple{T,Vararg{T}}
15
+ const TwoPlusTuple{T} = Tuple{T,T,Vararg{T}}
16
+
14
17
# # Broadcasting utilities ##
15
18
# fallbacks for some special cases
16
19
@inline broadcast (f, x:: Number... ) = f (x... )
127
130
end
128
131
129
132
Base. @propagate_inbounds _broadcast_getindex (A, I) = _broadcast_getindex (containertype (A), A, I)
133
+ # `(x,)`, where `x` is a scalar, broadcasts the same way as `[x]` or `x`
134
+ Base. @propagate_inbounds _broadcast_getindex (:: Type{Tuple} , A:: Tuple{Any} , I) = A[1 ]
130
135
Base. @propagate_inbounds _broadcast_getindex (:: Type{Array} , A:: Ref , I) = A[]
131
136
Base. @propagate_inbounds _broadcast_getindex (:: ScalarType , A, I) = A
132
137
Base. @propagate_inbounds _broadcast_getindex (:: Any , A, I) = A[I]
@@ -333,13 +338,32 @@ end
333
338
end
334
339
@inline broadcast_c (f, :: Type{Any} , a... ) = f (a... )
335
340
@inline broadcast_c (f, :: Type{Tuple} , A, Bs... ) =
336
- tuplebroadcast (f, first_tuple (A, Bs... ), A, Bs... )
341
+ tuplebroadcast (f, tuplebroadcast_maxtuple (A, Bs... ), A, Bs... )
337
342
@inline tuplebroadcast (f, :: NTuple{N,Any} , As... ) where {N} =
338
343
ntuple (k -> f (tuplebroadcast_getargs (As, k)... ), Val (N))
339
344
@inline tuplebroadcast (f, :: NTuple{N,Any} , :: Type{T} , As... ) where {N,T} =
340
345
ntuple (k -> f (T, tuplebroadcast_getargs (As, k)... ), Val (N))
341
- first_tuple (A:: Tuple , Bs... ) = A
342
- @inline first_tuple (A, Bs... ) = first_tuple (Bs... )
346
+ # When the result of broadcast is a tuple it can only come from mixing n-tuples
347
+ # of the same length with scalars, empty tuples and 1-tuples. So, in order to
348
+ # have a type-stable broadcast we need to find a tuple of maximum length
349
+ # (except when there's an empty tuple in which case we return an empty tuple as
350
+ # well).
351
+ # The following methods compare broadcast arguments pairwise to determine the
352
+ # length of the final tuple.
353
+ tuplebroadcast_maxtuple (A, B) = (nothing ,)
354
+ tuplebroadcast_maxtuple (:: Tuple{} , A) = ()
355
+ tuplebroadcast_maxtuple (A, :: Tuple{} ) = ()
356
+ tuplebroadcast_maxtuple (A:: OnePlusTuple{Any} , B) = A
357
+ tuplebroadcast_maxtuple (A, B:: OnePlusTuple{Any} ) = B
358
+ tuplebroadcast_maxtuple (A:: OnePlusTuple{Any} , :: Tuple{Any} ) = A
359
+ tuplebroadcast_maxtuple (:: Tuple{Any} , A:: OnePlusTuple{Any} ) = A
360
+ tuplebroadcast_maxtuple (:: OnePlusTuple{Any} , :: Tuple{} ) = ()
361
+ tuplebroadcast_maxtuple (:: Tuple{} , :: OnePlusTuple{Any} ) = ()
362
+ tuplebroadcast_maxtuple (A:: NTuple{N,Any} , :: NTuple{N,Any} ...) where {N} = A
363
+ tuplebroadcast_maxtuple (:: TwoPlusTuple{Any} , :: TwoPlusTuple{Any} ) =
364
+ throw (DimensionMismatch (" tuples could not be broadcast to a common size" ))
365
+ @inline tuplebroadcast_maxtuple (A, Bs... ) =
366
+ tuplebroadcast_maxtuple (A, tuplebroadcast_maxtuple (Bs... ))
343
367
tuplebroadcast_getargs (:: Tuple{} , k) = ()
344
368
@inline tuplebroadcast_getargs (As, k) =
345
369
(_broadcast_getindex (first (As), k), tuplebroadcast_getargs (tail (As), k)... )
0 commit comments