Skip to content

Commit c47e08b

Browse files
Use ismutable from ArrayInterface and clean up copyat via recursivecopy
1 parent edcf0bd commit c47e08b

File tree

5 files changed

+10
-25
lines changed

5 files changed

+10
-25
lines changed

REQUIRE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ julia 1.0
22
Requires
33
RecipesBase 0.1.0
44
StaticArrays
5+
ArrayInterface 0.1.0

src/RecursiveArrayTools.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ __precompile__()
22

33
module RecursiveArrayTools
44

5-
using Requires, RecipesBase, StaticArrays, Statistics
5+
using Requires, RecipesBase, StaticArrays, Statistics,
6+
ArrayInterface
67

78
abstract type AbstractVectorOfArray{T, N} <: AbstractArray{T, N} end
89
abstract type AbstractDiffEqArray{T, N} <: AbstractVectorOfArray{T, N} end

src/array_partition.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ function recursivecopy!(A::ArrayPartition, B::ArrayPartition)
186186
recursivecopy!(a, b)
187187
end
188188
end
189+
recursivecopy(A::ArrayPartition) = ArrayPartition(copy.(A.x))
189190

190191
recursive_mean(A::ArrayPartition) = mean((recursive_mean(x) for x in A.x))
191192

@@ -227,7 +228,7 @@ combine_styles(args::Tuple{Any}) = Broadcast.result_style(Broadcast.Broadca
227228
combine_styles(args::Tuple{Any, Any}) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]), Broadcast.BroadcastStyle(args[2]))
228229
@inline combine_styles(args::Tuple) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]), combine_styles(Base.tail(args)))
229230

230-
function Broadcast.BroadcastStyle(::Type{ArrayPartition{T,S}}) where {T, S}
231+
function Broadcast.BroadcastStyle(::Type{ArrayPartition{T,S}}) where {T, S}
231232
Style = combine_styles((S.parameters...,))
232233
ArrayPartitionStyle(Style)
233234
end
@@ -270,7 +271,7 @@ _npartitions(args::Tuple{}) = 0
270271
@inline unpack(bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args))
271272
unpack(x,::Any) = x
272273
unpack(x::ArrayPartition, i) = x.x[i]
273-
274+
274275
@inline unpack_args(i, args::Tuple) = (unpack(args[1], i), unpack_args(i, Base.tail(args))...)
275276
unpack_args(i, args::Tuple{Any}) = (unpack(args[1], i),)
276277
unpack_args(::Any, args::Tuple{}) = ()

src/utils.jl

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,7 @@
1-
"""
2-
is_mutable_type(x::DataType)
3-
4-
Query whether a type is mutable or not, see
5-
https://github.com/JuliaDiffEq/RecursiveArrayTools.jl/issues/19.
6-
"""
7-
Base.@pure is_mutable_type(x::DataType) = x.mutable
8-
91
function recursivecopy(a)
102
deepcopy(a)
113
end
12-
13-
recursivecopy(a::Number) = copy(a)
14-
4+
recursivecopy(a::Union{SVector,SMatrix,SArray,Number}) = copy(a)
155
function recursivecopy(a::AbstractArray{T,N}) where {T<:Number,N}
166
copy(a)
177
end
@@ -66,7 +56,7 @@ end
6656

6757
function copyat_or_push!(a::AbstractVector{T},i::Int,x,nc::Type{Val{perform_copy}}=Val{true}) where {T,perform_copy}
6858
@inbounds if length(a) >= i
69-
if T <: Number || T <: SArray || (T <: FieldVector && !is_mutable_type(T)) || !perform_copy
59+
if !ismutable(T) || !perform_copy
7060
# TODO: Check for `setindex!`` if T <: StaticArray and use `copy!(b[i],a[i])`
7161
# or `b[i] = a[i]`, see https://github.com/JuliaDiffEq/RecursiveArrayTools.jl/issues/19
7262
a[i] = x
@@ -75,16 +65,7 @@ function copyat_or_push!(a::AbstractVector{T},i::Int,x,nc::Type{Val{perform_copy
7565
end
7666
else
7767
if perform_copy
78-
if typeof(x) <: Array && !(eltype(x) <: Number)
79-
push!(a,recursivecopy(x))
80-
elseif typeof(x) <: Array || typeof(x) <: ArrayPartition ||
81-
typeof(x) <: AbstractVectorOfArray
82-
push!(a,copy(x))
83-
elseif typeof(x) <: Union{SVector,SMatrix,SArray,Number} # Only immutable
84-
push!(a,x)
85-
else
86-
push!(a,deepcopy(x))
87-
end
68+
push!(a,recursivecopy(x))
8869
else
8970
push!(a,x)
9071
end

src/vector_of_array.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ end
6666

6767
# Tools for creating similar objects
6868
@inline Base.similar(VA::VectorOfArray, ::Type{T} = eltype(VA)) where {T} = VectorOfArray([similar(VA[i], T) for i in eachindex(VA)])
69+
recursivecopy(VA::VectorOfArray) = VectorOfArray(copy.(VA.u))
6970

7071
# fill!
7172
# For DiffEqArray it ignores ts and fills only u

0 commit comments

Comments
 (0)