Skip to content

Commit ee7b906

Browse files
committed
Parameterize thunk cost on signature
1 parent 539d02e commit ee7b906

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

src/sch/Sch.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ Fields:
5959
- worker_loadavg::Dict{Int,NTuple{3,Float64}} - Worker load average
6060
- worker_chans::Dict{Int, Tuple{RemoteChannel,RemoteChannel}} - Communication channels between the scheduler and each worker
6161
- procs_cache_list::Base.RefValue{Union{ProcessorCacheEntry,Nothing}} - Cached linked list of processors ready to be used
62-
- function_cost_cache::Dict{Function,Float64} - Cache of estimated CPU time required to compute the function
62+
- function_cost_cache::Dict{Type{<:Tuple},Float64} - Cache of estimated CPU time required to compute the given signature
6363
- halt::Base.RefValue{Bool} - Flag indicating, when set, that the scheduler should halt immediately
6464
- lock::ReentrantLock() - Lock around operations which modify the state
6565
- futures::Dict{Thunk, Vector{ThunkFuture}} - Futures registered for waiting on the result of a thunk.
@@ -82,7 +82,7 @@ struct ComputeState
8282
worker_loadavg::Dict{Int,NTuple{3,Float64}}
8383
worker_chans::Dict{Int, Tuple{RemoteChannel,RemoteChannel}}
8484
procs_cache_list::Base.RefValue{Union{ProcessorCacheEntry,Nothing}}
85-
function_cost_cache::Dict{Function,Float64}
85+
function_cost_cache::Dict{Type{<:Tuple},Float64}
8686
halt::Base.RefValue{Bool}
8787
lock::ReentrantLock
8888
futures::Dict{Thunk, Vector{ThunkFuture}}
@@ -106,7 +106,7 @@ function start_state(deps::Dict, node_order, chan)
106106
Dict{Int,NTuple{3,Float64}}(),
107107
Dict{Int, Tuple{RemoteChannel,RemoteChannel}}(),
108108
Ref{Union{ProcessorCacheEntry,Nothing}}(nothing),
109-
Dict{Function,Float64}(),
109+
Dict{Type{<:Tuple},Float64}(),
110110
Ref{Bool}(false),
111111
ReentrantLock(),
112112
Dict{Thunk, Vector{ThunkFuture}}(),
@@ -386,7 +386,8 @@ function compute_dag(ctx, d::Thunk; options=SchedulerOptions())
386386
state.worker_pressure[pid][typeof(proc)] = metadata.pressure
387387
state.worker_loadavg[pid] = metadata.loadavg
388388
node = state.thunk_dict[thunk_id]
389-
state.function_cost_cache[node.f] = (metadata.threadtime + get(state.function_cost_cache, node.f, 0.0)) / 2
389+
sig = signature(node, state)
390+
state.function_cost_cache[sig] = (metadata.threadtime + get(state.function_cost_cache, sig, 0.0)) / 2
390391
state.cache[node] = res
391392
if node.options !== nothing && node.options.checkpoint !== nothing
392393
try
@@ -489,14 +490,14 @@ function schedule!(ctx, state, procs=procs_to_use(ctx))
489490

490491
return true
491492
end
492-
function has_capacity(p, gp, procutil, f)
493+
function has_capacity(p, gp, procutil, sig)
493494
T = typeof(p)
494495
extra_util = get(procutil, T, 1.0)
495496
real_util = state.worker_pressure[gp][T]
496-
if (T === Dagger.ThreadProc) && haskey(state.function_cost_cache, f)
497+
if (T === Dagger.ThreadProc) && haskey(state.function_cost_cache, sig)
497498
# Assume that the extra pressure is between estimated and measured
498499
# TODO: Generalize this to arbitrary processor types
499-
extra_util = min(extra_util, state.function_cost_cache[f])
500+
extra_util = min(extra_util, state.function_cost_cache[sig])
500501
end
501502
# TODO: update real_util based on loadavg
502503
cap = state.worker_capacity[gp][T]
@@ -514,15 +515,16 @@ function schedule!(ctx, state, procs=procs_to_use(ctx))
514515
# Select a new task and get its options
515516
task = pop!(state.ready)
516517
opts = merge(ctx.options, task.options)
518+
sig = signature(task, state)
517519

518520
# Try to select a processor
519521
selected_entry = nothing
520522
entry = state.procs_cache_list[]
521523
cap, extra_util = nothing, nothing
522524
procs_found = false
523525
# N.B. if we only have one processor, we need to select it now
524-
if can_use_proc(task, entry.proc, opts)
525-
has_cap, cap, extra_util = has_capacity(entry.proc, entry.gproc.pid, opts.procutil, task.f)
526+
if can_use_proc(task, entry.gproc, entry.proc, opts)
527+
has_cap, cap, extra_util = has_capacity(entry.proc, entry.gproc.pid, opts.procutil, sig)
526528
if has_cap
527529
selected_entry = entry
528530
else
@@ -542,8 +544,8 @@ function schedule!(ctx, state, procs=procs_to_use(ctx))
542544
end
543545
end
544546

545-
if can_use_proc(task, entry.proc, opts)
546-
has_cap, cap, extra_util = has_capacity(entry.proc, entry.gproc.pid, opts.procutil, task.f)
547+
if can_use_proc(task, entry.gproc, entry.proc, opts)
548+
has_cap, cap, extra_util = has_capacity(entry.proc, entry.gproc.pid, opts.procutil, sig)
547549
if has_cap
548550
# Select this processor
549551
selected_entry = entry

src/sch/util.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,8 @@ function fetch_report(task)
121121
end
122122
end
123123
end
124+
125+
function signature(task::Thunk, state)
126+
inputs = map(x->istask(x) ? state.cache[x] : x, task.inputs)
127+
Tuple{typeof(task.f), map(x->x isa Chunk ? x.chunktype : typeof(x), inputs)...}
128+
end

0 commit comments

Comments
 (0)