Skip to content

Commit 5e1d5ee

Browse files
Merge pull request #330 from AayushSabharwal/as/setindex
fix: add setindex! for higher dimensional VoA, fix checkbounds allocations
2 parents bc59b23 + 477ba0a commit 5e1d5ee

File tree

2 files changed

+50
-9
lines changed

2 files changed

+50
-9
lines changed

src/vector_of_array.jl

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,18 @@ Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N}
422422
return VA.u[i][jj] = x
423423
end
424424

425+
Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N}, x, idxs::Union{Int,Colon,CartesianIndex,AbstractArray{Int},AbstractArray{Bool}}...) where {T, N}
426+
v = view(VA, idxs...)
427+
# error message copied from Base by running `ones(3, 3, 3)[:, 2, :] = 2`
428+
if length(v) != length(x)
429+
throw(ArgumentError("indexed assignment with a single value to possibly many locations is not supported; perhaps use broadcasting `.=` instead?"))
430+
end
431+
for (i, j) in zip(eachindex(v), eachindex(x))
432+
v[i] = x[j]
433+
end
434+
return x
435+
end
436+
425437
# Interface for the two-dimensional indexing, a more standard AbstractArray interface
426438
@inline Base.size(VA::AbstractVectorOfArray) = (size(VA.u[1])..., length(VA.u))
427439
@inline Base.size(VA::AbstractVectorOfArray, i) = size(VA)[i]
@@ -534,21 +546,24 @@ function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray{T, N, <:Abstra
534546
return checkbounds(Bool, VA.u, idxs...)
535547
end
536548
function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray, idx...)
537-
if checkbounds(Bool, VA.u, last(idx))
538-
if last(idx) isa Integer
539-
return all(checkbounds.(Bool, (VA.u[last(idx)],), Base.front(idx)...))
540-
else
541-
return all(checkbounds.(Bool, VA.u[last(idx)], tuple.(Base.front(idx))...))
542-
end
549+
checkbounds(Bool, VA.u, last(idx)) || return false
550+
for i in last(idx)
551+
checkbounds(Bool, VA.u[i], Base.front(idx)...) || return false
543552
end
544-
return false
553+
return true
545554
end
546555
function Base.checkbounds(VA::AbstractVectorOfArray, idx...)
547556
checkbounds(Bool, VA, idx...) || throw(BoundsError(VA, idx))
548557
end
549558
function Base.copyto!(dest::AbstractVectorOfArray{T,N}, src::AbstractVectorOfArray{T,N}) where {T,N}
550559
copyto!.(dest.u, src.u)
551560
end
561+
# Required for broadcasted setindex! when slicing across subarrays
562+
# E.g. if `va = VectorOfArray([rand(3, 3) for i in 1:5])`
563+
# Need this method for `va[2, :, :] .= 3.0`
564+
Base.@propagate_inbounds function Base.maybeview(A::AbstractVectorOfArray, I...)
565+
return view(A, I...)
566+
end
552567

553568
# Operations
554569
function Base.isapprox(A::AbstractVectorOfArray,
@@ -619,7 +634,7 @@ function Base.fill!(VA::AbstractVectorOfArray, x)
619634
return VA
620635
end
621636

622-
Base.reshape(A::VectorOfArray, dims...) = Base.reshape(Array(A), dims...)
637+
Base.reshape(A::AbstractVectorOfArray, dims...) = Base.reshape(Array(A), dims...)
623638

624639
# Need this for ODE_DEFAULT_UNSTABLE_CHECK from DiffEqBase to work properly
625640
@inline Base.any(f, VA::AbstractVectorOfArray) = any(any(f, u) for u in VA.u)
@@ -633,7 +648,7 @@ function Base.convert(::Type{Array}, VA::AbstractVectorOfArray)
633648
if !allequal(size.(VA.u))
634649
error("Can only convert non-ragged VectorOfArray to Array")
635650
end
636-
return stack(VA)
651+
return Array(VA)
637652
end
638653

639654
# statistics

test/basic_indexing.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,32 @@ w = v .+ 1
191191
@test_broken w isa DiffEqArray # FIXME
192192
@test w.u == map(x -> x .+ 1, v.u)
193193

194+
# setindex!
195+
testva = VectorOfArray([i * ones(3, 3) for i in 1:5])
196+
testva[:, 2] = 7ones(3, 3)
197+
@test testva[:, 2] == 7ones(3, 3)
198+
testva[:, :] = [2i * ones(3, 3) for i in 1:5]
199+
for i in 1:5
200+
@test testva[:, i] == 2i * ones(3, 3)
201+
end
202+
testva[:, 1:2:5] = [5i * ones(3, 3) for i in 1:2:5]
203+
for i in 1:2:5
204+
@test testva[:, i] == 5i * ones(3, 3)
205+
end
206+
testva[CartesianIndex(3, 3, 5)] = 64.0
207+
@test testva[:, 5][3, 3] == 64.0
208+
@test_throws ArgumentError testva[2, 1:2, :] = 108.0
209+
testva[2, 1:2, :] .= 108.0
210+
for i in 1:5
211+
@test all(testva[:, i][2, 1:2] .== 108.0)
212+
end
213+
testva[:, 3, :] = [3i / 7j for i in 1:3, j in 1:5]
214+
for j in 1:5
215+
for i in 1:3
216+
@test testva[i, 3, j] == 3i / 7j
217+
end
218+
end
219+
194220
# edges cases
195221
x = [1, 2, 3, 4, 5, 6, 7, 8, 9]
196222
testva = DiffEqArray(x, x)

0 commit comments

Comments
 (0)