Skip to content

Commit 03c27ff

Browse files
committed
Various alloc reductions and optimizations
Sch: Don't return values in Tasks Sch: Switch from state.cache to thunk.cache_ref tests: Improve test_throws_unwrap error comparisons
1 parent 38230f3 commit 03c27ff

30 files changed

+2186
-1369
lines changed

src/Dagger.jl

+14-8
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ if !isdefined(Base, :ScopedValues)
2121
else
2222
import Base.ScopedValues: ScopedValue, with
2323
end
24+
import TaskLocalValues: TaskLocalValue
2425

2526
if !isdefined(Base, :get_extension)
2627
import Requires: @require
@@ -34,9 +35,13 @@ import Adapt
3435
include("lib/util.jl")
3536
include("utils/dagdebug.jl")
3637

38+
# Logging Basics
39+
include("utils/logging.jl")
40+
3741
# Distributed data
3842
include("utils/locked-object.jl")
3943
include("utils/tasks.jl")
44+
include("utils/reuse.jl")
4045

4146
import MacroTools: @capture
4247
include("options.jl")
@@ -48,6 +53,7 @@ include("task-tls.jl")
4853
include("scopes.jl")
4954
include("utils/scopes.jl")
5055
include("dtask.jl")
56+
include("argument.jl")
5157
include("queue.jl")
5258
include("thunk.jl")
5359
include("submission.jl")
@@ -64,34 +70,34 @@ include("sch/Sch.jl"); using .Sch
6470
# Data dependency task queue
6571
include("datadeps.jl")
6672

73+
# File IO
74+
include("file-io.jl")
75+
6776
# Array computations
6877
include("array/darray.jl")
6978
include("array/alloc.jl")
7079
include("array/map-reduce.jl")
7180
include("array/copy.jl")
72-
73-
# File IO
74-
include("file-io.jl")
75-
81+
include("array/random.jl")
7682
include("array/operators.jl")
7783
include("array/indexing.jl")
7884
include("array/setindex.jl")
7985
include("array/matrix.jl")
8086
include("array/sparse_partition.jl")
87+
include("array/parallel-blocks.jl")
8188
include("array/sort.jl")
8289
include("array/linalg.jl")
8390
include("array/mul.jl")
8491
include("array/cholesky.jl")
8592

93+
# Custom Logging Events
94+
include("utils/logging-events.jl")
95+
8696
# Visualization
8797
include("visualization.jl")
8898
include("ui/gantt-common.jl")
8999
include("ui/gantt-text.jl")
90100

91-
# Logging
92-
include("utils/logging-events.jl")
93-
include("utils/logging.jl")
94-
95101
# Precompilation
96102
import PrecompileTools: @compile_workload
97103
include("precompile.jl")

src/argument.jl

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
mutable struct ArgPosition
2+
positional::Bool
3+
idx::Int
4+
kw::Symbol
5+
end
6+
ArgPosition() = ArgPosition(true, 0, :NULL)
7+
ArgPosition(pos::ArgPosition) = ArgPosition(pos.positional, pos.idx, pos.kw)
8+
ispositional(pos::ArgPosition) = pos.positional
9+
iskw(pos::ArgPosition) = !pos.positional
10+
function pos_idx(pos::ArgPosition)
11+
@assert pos.positional
12+
@assert pos.idx > 0
13+
@assert pos.kw == :NULL
14+
return pos.idx
15+
end
16+
function pos_kw(pos::ArgPosition)
17+
@assert !pos.positional
18+
@assert pos.idx == 0
19+
@assert pos.kw != :NULL
20+
return pos.kw
21+
end
22+
mutable struct Argument
23+
pos::ArgPosition
24+
value
25+
end
26+
Argument(pos::Integer, value) = Argument(ArgPosition(true, pos, :NULL), value)
27+
Argument(kw::Symbol, value) = Argument(ArgPosition(false, 0, kw), value)
28+
ispositional(arg::Argument) = ispositional(arg.pos)
29+
iskw(arg::Argument) = iskw(arg.pos)
30+
pos_idx(arg::Argument) = pos_idx(arg.pos)
31+
pos_kw(arg::Argument) = pos_kw(arg.pos)
32+
value(arg::Argument) = arg.value
33+
valuetype(arg::Argument) = typeof(arg.value)
34+
Base.iterate(arg::Argument) = (arg.pos, true)
35+
function Base.iterate(arg::Argument, state::Bool)
36+
if state
37+
return (arg.value, false)
38+
else
39+
return nothing
40+
end
41+
end
42+
43+
Base.copy(arg::Argument) = Argument(ArgPosition(arg.pos), arg.value)
44+
chunktype(arg::Argument) = chunktype(value(arg))

src/array/darray.jl

+8-1
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ domainchunks(d::DArray) = d.subdomains
173173
size(x::DArray) = size(domain(x))
174174
stage(ctx, c::DArray) = c
175175

176-
function Base.collect(d::DArray; tree=false)
176+
function Base.collect(d::DArray{T,N}; tree=false, copyto=false) where {T,N}
177177
a = fetch(d)
178178
if isempty(d.chunks)
179179
return Array{eltype(d)}(undef, size(d)...)
@@ -183,6 +183,13 @@ function Base.collect(d::DArray; tree=false)
183183
return fetch(a.chunks[1])
184184
end
185185

186+
if copyto
187+
C = Array{T,N}(undef, size(a))
188+
DC = view(C, Blocks(size(a)...))
189+
copyto!(DC, a)
190+
return C
191+
end
192+
186193
dimcatfuncs = [(x...) -> d.concat(x..., dims=i) for i in 1:ndims(d)]
187194
if tree
188195
collect(fetch(treereduce_nd(map(x -> ((args...,) -> Dagger.@spawn x(args...)) , dimcatfuncs), a.chunks)))

src/array/indexing.jl

-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import TaskLocalValues: TaskLocalValue
2-
31
### getindex
42

53
struct GetIndex{T,N} <: ArrayOp{T,N}

src/array/parallel-blocks.jl

+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
export ParallelBlocks
2+
3+
using Statistics
4+
5+
struct ParallelBlocks{N} <: Dagger.AbstractSingleBlocks{N}
6+
n::Int
7+
end
8+
ParallelBlocks(n::Integer) = ParallelBlocks{0}(n)
9+
ParallelBlocks{N}(dist::ParallelBlocks) where N = ParallelBlocks{N}(dist.n)
10+
ParallelBlocks() = ParallelBlocks(Dagger.num_processors())
11+
12+
Base.convert(::Type{ParallelBlocks{N}}, dist::ParallelBlocks) where N =
13+
ParallelBlocks{N}(dist.n)
14+
15+
wrap_chunks(chunks::Vector{<:Dagger.Chunk}, N::Integer, dist::ParallelBlocks) =
16+
wrap_chunks(chunks, N, dist.n)
17+
wrap_chunks(chunks::Vector{<:Dagger.Chunk}, N::Integer, n::Integer) =
18+
convert(Array{Any}, reshape(chunks, ntuple(i->i == 1 ? n : 1, N)))
19+
20+
function _finish_allocation(f::Function, dist::ParallelBlocks, dims::NTuple{N,Int}) where N
21+
d = ArrayDomain(map(x->1:x, dims))
22+
s = reshape([d for _ in 1:dist.n],
23+
ntuple(i->i == 1 ? dist.n : 1, N))
24+
data = [f(dims) for _ in 1:dist.n]
25+
dist = ParallelBlocks{N}(dist)
26+
chunks = wrap_chunks(map(Dagger.tochunk, data), N, dist)
27+
return Dagger.DArray(eltype(first(data)), d, s, chunks, dist)
28+
end
29+
30+
for fn in [:rand, :randn, :zeros, :ones]
31+
@eval begin
32+
function Base.$fn(dist::ParallelBlocks, ::Type{ET}, dims::Dims) where {ET}
33+
f(block) = $fn(ET, block)
34+
_finish_allocation(f, dist, dims)
35+
end
36+
Base.$fn(dist::ParallelBlocks, T::Type, dims::Integer...) = $fn(dist, T, dims)
37+
Base.$fn(dist::ParallelBlocks, T::Type, dims::Tuple) = $fn(dist, T, dims)
38+
Base.$fn(dist::ParallelBlocks, dims::Integer...) = $fn(dist, Float64, dims)
39+
Base.$fn(dist::ParallelBlocks, dims::Tuple) = $fn(dist, Float64, dims)
40+
end
41+
end
42+
# FIXME: sprand
43+
44+
function Dagger.distribute(data::AbstractArray{T,N}, dist::ParallelBlocks) where {T,N}
45+
dims = size(data)
46+
d = ArrayDomain(map(x->1:x, dims))
47+
s = Dagger.DomainBlocks(ntuple(_->1, N),
48+
ntuple(i->[dims[i]], N))
49+
chunks = [Dagger.tochunk(copy(data)) for _ in 1:dist.n]
50+
new_dist = ParallelBlocks{N}(dist)
51+
return Dagger.DArray(T, d, s, wrap_chunks(chunks, N, dist), new_dist)
52+
end
53+
54+
_invalid_call_pblocks(f::Symbol) =
55+
error("`$f` is not valid for a `DArray` partitioned with `ParallelBlocks`.\nConsider `Dagger.pmap($f, x)` instead.")
56+
57+
Base.collect(::Dagger.DArray{T,N,<:ParallelBlocks} where {T,N}) =
58+
_invalid_call_pblocks(:collect)
59+
Base.getindex(::Dagger.DArray{T,N,<:ParallelBlocks} where {T,N}, x...) =
60+
_invalid_call_pblocks(:getindex)
61+
Base.setindex!(::Dagger.DArray{T,N,<:ParallelBlocks} where {T,N}, value, x...) =
62+
_invalid_call_pblocks(:setindex!)
63+
64+
function pmap(f::Function, A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N}
65+
# TODO: Chunks might not be `Array`s
66+
# FIXME
67+
#AT = Array{T,N}
68+
#ET = eltype(Base.promote_op(f, AT))
69+
ET = Any
70+
new_chunks = map(A.chunks) do chunk
71+
Dagger.@spawn f(chunk)
72+
end
73+
return DArray(ET, A.domain, A.subdomains, new_chunks, A.partitioning)
74+
end
75+
# FIXME: More useful `show` method
76+
Base.show(io::IO, ::MIME"text/plain", A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N} =
77+
print(io, typeof(A))
78+
pfetch(A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N} =
79+
map(fetch, A.chunks)
80+
pcollect(A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N} =
81+
map(collect, pfetch(A))
82+
83+
function Base.map(f::Function, A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N}
84+
ET = Base.promote_op(f, T)
85+
new_chunks = map(A.chunks) do chunk
86+
Dagger.@spawn map(f, chunk)
87+
end
88+
return DArray(ET, A.domain, A.subdomains, new_chunks, A.partitioning)
89+
end
90+
function Base.map!(f::Function,
91+
x::Dagger.DArray{T1,N1,ParallelBlocks{N1}} where {T1,N1},
92+
y::Dagger.DArray{T2,N2,ParallelBlocks{N2}} where {T2,N2})
93+
x_dist = x.partitioning
94+
y_dist = y.partitioning
95+
if x_dist.n != y_dist.n
96+
throw(ArgumentError("Can't `map!` over non-matching `ParallelBlocks` distributions: $(x_dist.n) != $(y_dist.n)"))
97+
end
98+
@sync for i in 1:x_dist.n
99+
Dagger.@spawn map!(f, x.chunks[i], y.chunks[i])
100+
end
101+
end
102+
103+
#=
104+
function Base.reduce(f::Function, x::Dagger.DArray{T,N,ParallelBlocks{N}};
105+
dims=:) where {T,N}
106+
error("Out-of-place Reduce")
107+
if dims == Base.:(:)
108+
localpart = fetch(Dagger.reduce_async(f, x))
109+
return MPI.Allreduce(localpart, f, comm)
110+
elseif dims === nothing
111+
localpart = fetch(x.chunks[1])
112+
return MPI.Allreduce(localpart, f, comm)
113+
else
114+
error("Not yet implemented")
115+
end
116+
end
117+
=#
118+
function allreduce!(op::Function, x::Dagger.DArray{T,N,ParallelBlocks{N}}; nchunks::Integer=0) where {T,N}
119+
if nchunks == 0
120+
nchunks = x.partitioning.n
121+
end
122+
@assert nchunks == x.partitioning.n "Number of chunks must match the number of partitions"
123+
124+
# Split each chunk along the last dimension
125+
chunk_size = cld(size(x, ndims(x)), nchunks)
126+
chunk_dist = Blocks(ntuple(i->i == N ? chunk_size : size(x, i), N))
127+
chunk_ds = partition(chunk_dist, x.subdomains[1])
128+
num_par_chunks = length(x.chunks)
129+
130+
# Allocate temporary buffer
131+
y = copy(x)
132+
133+
# Ring-reduce into temporary buffer
134+
Dagger.spawn_datadeps() do
135+
for j in 1:length(chunk_ds)
136+
for i in 1:num_par_chunks
137+
for step in 1:(num_par_chunks-1)
138+
from_idx = i
139+
to_idx = mod1(i+step, num_par_chunks)
140+
from_chunk = x.chunks[from_idx]
141+
to_chunk = y.chunks[to_idx]
142+
sd = chunk_ds[mod1(j+i-1, length(chunk_ds))].indexes
143+
# FIXME: Specify aliasing based on `sd`
144+
Dagger.@spawn _reduce_view!(op,
145+
InOut(to_chunk), sd,
146+
In(from_chunk), sd)
147+
end
148+
end
149+
end
150+
151+
# Copy from temporary buffer back to origin
152+
for i in 1:num_par_chunks
153+
Dagger.@spawn copyto!(Out(x.chunks[i]), In(y.chunks[i]))
154+
end
155+
end
156+
157+
return x
158+
end
159+
function _reduce_view!(op, to, to_view, from, from_view)
160+
to_viewed = view(to, to_view...)
161+
from_viewed = view(from, from_view...)
162+
reduce!(op, to_viewed, from_viewed)
163+
return
164+
end
165+
function reduce!(op, to, from)
166+
to .= op.(to, from)
167+
end
168+
169+
function Statistics.mean!(A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N}
170+
allreduce!(+, A)
171+
len = length(A.chunks)
172+
map!(x->x ./ len, A, A)
173+
return A
174+
end

src/array/random.jl

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using Random
2+
3+
function Random.randn!(A::DArray{T}) where T
4+
Ac = A.chunks
5+
6+
Dagger.spawn_datadeps() do
7+
for chunk in Ac
8+
Dagger.@spawn randn!(InOut(chunk))
9+
end
10+
end
11+
12+
return A
13+
end

src/chunks.jl

+18-3
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ Base.length(s::Shard) = length(s.chunks)
250250
### Core Stuff
251251

252252
"""
253-
tochunk(x, proc::Processor, scope::AbstractScope; device=nothing, kwargs...) -> Chunk
253+
tochunk(x, proc::Processor, scope::AbstractScope; device=nothing, rewrap=false, kwargs...) -> Chunk
254254
255255
Create a chunk from data `x` which resides on `proc` and which has scope
256256
`scope`.
@@ -262,9 +262,12 @@ will be inspected to determine if it's safe to serialize; if so, the default
262262
MemPool storage device will be used; if not, then a `MemPool.CPURAMDevice` will
263263
be used.
264264
265+
If `rewrap==true` and `x isa Chunk`, then the `Chunk` will be rewrapped in a
266+
new `Chunk`.
267+
265268
All other kwargs are passed directly to `MemPool.poolset`.
266269
"""
267-
function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cache=false, device=nothing, kwargs...) where {X,P,S}
270+
function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cache=false, device=nothing, rewrap=false, kwargs...) where {X,P,S}
268271
if device === nothing
269272
device = if Sch.walk_storage_safe(x)
270273
MemPool.GLOBAL_DEVICE[]
@@ -275,7 +278,15 @@ function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cac
275278
ref = poolset(x; device, kwargs...)
276279
Chunk{X,typeof(ref),P,S}(X, domain(x), ref, proc, scope, persist)
277280
end
278-
tochunk(x::Union{Chunk, Thunk}, proc=nothing, scope=nothing; kwargs...) = x
281+
function tochunk(x::Union{Chunk, Thunk}, proc=nothing, scope=nothing; rewrap=false, kwargs...)
282+
if rewrap
283+
return remotecall_fetch(x.handle.owner) do
284+
tochunk(MemPool.poolget(x.handle), proc, scope; kwargs...)
285+
end
286+
else
287+
return x
288+
end
289+
end
279290

280291
function savechunk(data, dir, f)
281292
sz = open(joinpath(dir, f), "w") do io
@@ -302,9 +313,13 @@ function unwrap_weak_checked(c::WeakChunk)
302313
@assert cw !== nothing "WeakChunk expired: ($(c.wid), $(c.id))"
303314
return cw
304315
end
316+
wrap_weak(c::Chunk) = WeakChunk(c)
317+
isweak(c::WeakChunk) = true
318+
isweak(c::Chunk) = false
305319
is_task_or_chunk(c::WeakChunk) = true
306320
Serialization.serialize(io::AbstractSerializer, wc::WeakChunk) =
307321
error("Cannot serialize a WeakChunk")
322+
chunktype(c::WeakChunk) = chunktype(unwrap_weak_checked(c))
308323

309324
Base.@deprecate_binding AbstractPart Union{Chunk, Thunk}
310325
Base.@deprecate_binding Part Chunk

0 commit comments

Comments
 (0)