Skip to content

Commit 46ffc87

Browse files
committed
improve #17546 for a[...] .= handling of arrays of arrays and dictionaries of arrays
1 parent 3503a59 commit 46ffc87

File tree

3 files changed

+36
-2
lines changed

3 files changed

+36
-2
lines changed

base/broadcast.jl

+25-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ module Broadcast
55
using Base.Cartesian
66
using Base: promote_op, promote_eltype, promote_eltype_op, @get!, _msk_end, unsafe_bitgetindex, linearindices, tail, OneTo, to_shape
77
import Base: .+, .-, .*, ./, .\, .//, .==, .<, .!=, .<=, , .%, .<<, .>>, .^
8-
export broadcast, broadcast!, bitbroadcast
8+
export broadcast, broadcast!, bitbroadcast, dotview
99
export broadcast_getindex, broadcast_setindex!
1010

1111
## Broadcasting utilities ##
@@ -440,4 +440,28 @@ for (sigA, sigB) in ((BitArray, BitArray),
440440
end
441441
end
442442

443+
############################################################
444+
445+
# x[...] .= f.(y...) ---> broadcast!(f, dotview(x, ...), y...).
446+
# The dotview function defaults to view, but we override it in
447+
# a few cases to get the expected in-place behavior without affecting
448+
# explicit calls to view. (All of this can go away if slices
449+
# are changed to generate views by default.)
450+
451+
dotview(args...) = view(args...)
452+
# avoid splatting penalty in common cases:
453+
for nargs = 0:5
454+
args = Symbol[Symbol("x",i) for i = 1:nargs]
455+
eval(Expr(:(=), Expr(:call, :dotview, args...), Expr(:call, :view, args...)))
456+
end
457+
458+
# for a[i...] .= ... where a is an array-of-arrays, just pass a[i...] directly
459+
# to broadcast!
460+
dotview{T<:AbstractArray,N,I<:Integer}(a::AbstractArray{T,N}, i::Vararg{I,N}) =
461+
a[i...]
462+
463+
# dict[k] .= ... should work if dict[k] is an array
464+
dotview(a::Associative, k) = a[k]
465+
dotview(a::Associative, k1, k2, ks...) = a[tuple(k1,k2,ks...)]
466+
443467
end # module

src/julia-syntax.scm

+1-1
Original file line numberDiff line numberDiff line change
@@ -1549,7 +1549,7 @@
15491549
(let* ((ex (partially-expand-ref expr))
15501550
(stmts (butlast (cdr ex)))
15511551
(refex (last (cdr ex)))
1552-
(nuref `(call (top view) ,(caddr refex) ,@(cdddr refex))))
1552+
(nuref `(call (top dotview) ,(caddr refex) ,@(cdddr refex))))
15531553
`(block ,@stmts ,nuref))
15541554
expr))
15551555

test/broadcast.jl

+10
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,16 @@ let x = [1:4;], y = x
274274
x[2:end] .= 1:3
275275
@test y === x == [0,1,2,3]
276276
end
277+
let a = [[4, 5], [6, 7]]
278+
a[1] .= 3
279+
@test a == [[3, 3], [6, 7]]
280+
end
281+
let d = Dict(:foo => [1,3,7], (3,4) => [5,9])
282+
d[:foo] .+= 2
283+
@test d[:foo] == [3,5,9]
284+
d[3,4] .-= 1
285+
@test d[3,4] == [4,8]
286+
end
277287

278288
# PR 16988
279289
@test Base.promote_op(+, Bool) === Int

0 commit comments

Comments
 (0)