Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DaggerMPI subpackage for MPI integrations #356

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions lib/DaggerMPI/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name = "DaggerMPI"
uuid = "37bfb287-2338-4693-8557-581796463535"
authors = ["Julian P Samaroo <[email protected]>"]
version = "0.1.0"

[deps]
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
171 changes: 171 additions & 0 deletions lib/DaggerMPI/src/DaggerMPI.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
module DaggerMPI

using Dagger
using MPI

struct MPIProcessor{P,C} <: Dagger.Processor
proc::P
comm::MPI.Comm
color_algo::C
end

struct SimpleColoring end
function (sc::SimpleColoring)(comm, key)
return UInt64(rem(key, MPI.Comm_size(comm)))
end

const MPI_PROCESSORS = Ref{Int}(-1)

const PREVIOUS_PROCESSORS = Set()


function initialize(comm::MPI.Comm=MPI.COMM_WORLD; color_algo=SimpleColoring())
@assert MPI_PROCESSORS[] == -1 "DaggerMPI already initialized"

# Force eager_thunk to run
fetch(Dagger.@spawn 1+1)

MPI.Init(; finalize_atexit=false)
procs = Dagger.get_processors(OSProc())
i = 0
empty!(Dagger.PROCESSOR_CALLBACKS)
empty!(Dagger.OSPROC_PROCESSOR_CACHE)
for proc in procs
Dagger.add_processor_callback!("mpiprocessor_$i") do
return MPIProcessor(proc, comm, color_algo)
end
i += 1
end
MPI_PROCESSORS[] = i

# FIXME: Hack to populate new processors
Dagger.get_processors(OSProc())

return nothing
end

function finalize()
@assert MPI_PROCESSORS[] > -1 "DaggerMPI not yet initialized"
for i in 1:MPI_PROCESSORS[]
Dagger.delete_processor_callback!("mpiprocessor_$i")
end
empty!(Dagger.PROCESSOR_CALLBACKS)
empty!(Dagger.OSPROC_PROCESSOR_CACHE)
i = 1
for proc in PREVIOUS_PROCESSORS
Dagger.add_processor_callback!("old_processor_$i") do
return proc
end
i += 1
end
empty!(PREVIOUS_PROCESSORS)
MPI.Finalize()
MPI_PROCESSORS[] = -1
end

"References a value stored on some MPI rank."
struct MPIColoredValue{T}
color::UInt64
value::T
comm::MPI.Comm
end

Dagger.get_parent(proc::MPIProcessor) = Dagger.OSProc()
Dagger.default_enabled(proc::MPIProcessor) = true


"Busy-loop Irecv that yields to other tasks."
function recv_yield(src, tag, comm)
while true
(got, msg, stat) = MPI.Improbe(src, tag, comm, MPI.Status)
if got
count = MPI.Get_count(stat, UInt8)
buf = Array{UInt8}(undef, count)
req = MPI.Imrecv!(MPI.Buffer(buf), msg)
while true
finish = MPI.Test(req)
if finish
value = MPI.deserialize(buf)
return value
end
yield()
end
end
# TODO: Sigmoidal backoff
yield()
end
end

function Dagger.execute!(proc::MPIProcessor, f, args...)
rank = MPI.Comm_rank(proc.comm)
tag = abs(Base.unsafe_trunc(Int32, Dagger.get_task_hash() >> 32))
color = proc.color_algo(proc.comm, tag)
if rank == color
@debug "[$rank] Executing $f on $tag"
return MPIColoredValue(color, Dagger.execute!(proc.proc, f, args...), proc.comm)
end
# Return nothing, we won't use this value anyway
@debug "[$rank] Skipped $f on $tag"
return MPIColoredValue(color, nothing, proc.comm)
end

function Dagger.move(from_proc::MPIProcessor, to_proc::MPIProcessor, x::Dagger.Chunk)
@assert from_proc.comm == to_proc.comm "Mixing different MPI communicators is not supported"
@assert Dagger.chunktype(x) <: MPIColoredValue
x_value = fetch(x)
rank = MPI.Comm_rank(from_proc.comm)
tag = abs(Base.unsafe_trunc(Int32, Dagger.get_task_hash(:input) >> 32))
other_tag = abs(Base.unsafe_trunc(Int32, Dagger.get_task_hash(:self) >> 32))
other = from_proc.color_algo(from_proc.comm, other_tag)
if x_value.color == rank == other
# We generated and will use this input
return Dagger.move(from_proc.proc, to_proc.proc, x_value.value)
elseif x_value.color == rank
# We generated this input
@debug "[$rank] Starting P2P send to [$other] from $tag to $other_tag"
MPI.isend(x_value.value, other, tag, from_proc.comm)
@debug "[$rank] Finished P2P send to [$other] from $tag to $other_tag"
return Dagger.move(from_proc.proc, to_proc.proc, x_value.value)
elseif other == rank
# We will use this input
@debug "[$rank] Starting P2P recv from $tag to $other_tag"
value = recv_yield(x_value.color, tag, from_proc.comm)
@debug "[$rank] Finished P2P recv from $tag to $other_tag"
return Dagger.move(from_proc.proc, to_proc.proc, value)
else
# We didn't generate and will not use this input
return nothing
end
end

function Dagger.move(from_proc::MPIProcessor, to_proc::Dagger.Processor, x::Dagger.Chunk)
@assert Dagger.chunktype(x) <: MPIColoredValue
x_value = fetch(x)
rank = MPI.Comm_rank(from_proc.comm)
tag = abs(Base.unsafe_trunc(Int32, Dagger.get_task_hash(:input) >> 32))
if rank == x_value.color
# FIXME: Broadcast send
@sync for other in 0:(MPI.Comm_size(from_proc.comm)-1)
other == rank && continue
@async begin
@debug "[$rank] Starting bcast send to [$other] on $tag"
MPI.isend(x_value.value, other, tag, from_proc.comm)
@debug "[$rank] Finished bcast send to [$other] on $tag"
end
end
return Dagger.move(from_proc.proc, to_proc, x_value.value)
else
@debug "[$rank] Starting bcast recv on $tag"
value = recv_yield(x_value.color, tag, from_proc.comm)
@debug "[$rank] Finished bcast recv on $tag"
return Dagger.move(from_proc.proc, to_proc, value)
end
end

function Dagger.move(from_proc::Dagger.Processor, to_proc::MPIProcessor, x::Dagger.Chunk)
@assert !(Dagger.chunktype(x) <: MPIColoredValue)
rank = MPI.Comm_rank(to_proc.comm)
return MPIColoredValue(rank, Dagger.move(from_proc, to_proc.proc, x), from_proc.comm)
end

end # module
1 change: 1 addition & 0 deletions src/Dagger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ include("chunks.jl")
include("compute.jl")
include("utils/clock.jl")
include("utils/system_uuid.jl")
include("utils/uhash.jl")
include("sch/Sch.jl"); using .Sch

# Array computations
Expand Down
7 changes: 4 additions & 3 deletions src/chunks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ mutable struct Chunk{T, H, P<:Processor, S<:AbstractScope}
processor::P
scope::S
persist::Bool
hash::UInt
end

domain(c::Chunk) = c.domain
Expand Down Expand Up @@ -242,16 +243,16 @@ be used.

All other kwargs are passed directly to `MemPool.poolset`.
"""
function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cache=false, device=nothing, kwargs...) where {X,P,S}
function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cache=false, device=nothing, hash=UInt(0), kwargs...) where {X,P,S}
if device === nothing
device = if Sch.walk_storage_safe(x)
MemPool.GLOBAL_DEVICE[]
else
MemPool.CPURAMDevice()
end
end
ref = poolset(x; device, kwargs...)
Chunk{X,typeof(ref),P,S}(X, domain(x), ref, proc, scope, persist)
ref = poolset(move(OSProc(), proc, x); device, kwargs...)
Chunk{X,typeof(ref),P,S}(X, domain(x), ref, proc, scope, persist, hash)
end
tochunk(x::Union{Chunk, Thunk}, proc=nothing, scope=nothing; kwargs...) = x

Expand Down
41 changes: 34 additions & 7 deletions src/processor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,19 +288,40 @@ end
# In-Thunk Helpers

"""
thunk_processor()
thunk_processor() -> Dagger.Processor

Get the current processor executing the current thunk.
"""
thunk_processor() = task_local_storage(:_dagger_processor)::Processor

"""
in_thunk()
in_thunk() -> Bool

Returns `true` if currently in a [`Thunk`](@ref) process, else `false`.
"""
in_thunk() = haskey(task_local_storage(), :_dagger_sch_uid)

"""
get_task_hash(kind::Symbol=:self) -> UInt

Returns the unified hash of the current task or of an input to the current
task. If `kind == :self`, then the hash is for the current task; if `kind ==
:input`, then the hash is for the current input to the task that is being
processed. The `:self` hash is available during `Dagger.execute!` and
`Dagger.move`, whereas the `:input` hash is only available during
`Dagger.move`. This hash is consistent across Julia processes (if all
processes are running the same Julia version on the same architecture).
"""
function get_task_hash(kind::Symbol=:self)::UInt
if kind == :self
return task_local_storage(:_dagger_task_hash)::UInt
elseif kind == :input
return task_local_storage(:_dagger_input_hash)::UInt
else
throw(ArgumentError("Invalid task hash kind: $kind"))
end
end

"""
get_tls()

Expand All @@ -309,6 +330,8 @@ Gets all Dagger TLS variable as a `NamedTuple`.
get_tls() = (
sch_uid=task_local_storage(:_dagger_sch_uid),
sch_handle=task_local_storage(:_dagger_sch_handle),
task_hash=task_local_storage(:_dagger_task_hash),
input_hash=get(task_local_storage(), :_dagger_input_hash, nothing),
processor=thunk_processor(),
time_utilization=task_local_storage(:_dagger_time_utilization),
alloc_utilization=task_local_storage(:_dagger_alloc_utilization),
Expand All @@ -320,9 +343,13 @@ get_tls() = (
Sets all Dagger TLS variables from the `NamedTuple` `tls`.
"""
function set_tls!(tls)
task_local_storage(:_dagger_sch_uid, tls.sch_uid)
task_local_storage(:_dagger_sch_handle, tls.sch_handle)
task_local_storage(:_dagger_processor, tls.processor)
task_local_storage(:_dagger_time_utilization, tls.time_utilization)
task_local_storage(:_dagger_alloc_utilization, tls.alloc_utilization)
task_local_storage(:_dagger_sch_uid, get(tls, :sch_uid, nothing))
task_local_storage(:_dagger_sch_handle, get(tls, :sch_handle, nothing))
task_local_storage(:_dagger_task_hash, get(tls, :task_hash, nothing))
if haskey(tls, :input_hash) && tls.input_hash !== nothing
task_local_storage(:_dagger_input_hash, tls.input_hash)
end
task_local_storage(:_dagger_processor, get(tls, :processor, nothing))
task_local_storage(:_dagger_time_utilization, get(tls, :time_utilization, nothing))
task_local_storage(:_dagger_alloc_utilization, get(tls, :alloc_utilization, nothing))
end
Loading