Skip to content

Commit 2e2badc

Browse files
authored
Merge pull request #210 from JuliaParallel/jps/ucx
Thunk cost estimation, chunk caching, benchmark updates
2 parents 7ff1175 + 940fdf8 commit 2e2badc

17 files changed

+801
-394
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1818

1919
[compat]
2020
Colors = "0.10, 0.11, 0.12"
21-
MemPool = "0.3.3"
21+
MemPool = "0.3.4"
2222
Requires = "1"
2323
StatsBase = "0.28, 0.29, 0.30, 0.31, 0.32, 0.33"
2424
julia = "1.0"

benchmarks/benchmark.jl

+57-25
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ elseif render == "offline"
4040
using FFMPEG, FileIO, ImageMagick
4141
end
4242
const RENDERS = Dict{Int,Dict}()
43+
const live_port = parse(Int, get(ENV, "BENCHMARK_LIVE_PORT", "8000"))
44+
45+
const graph = parse(Bool, get(ENV, "BENCHMARK_GRAPH", "0"))
46+
const profile = parse(Bool, get(ENV, "BENCHMARK_PROFILE", "0"))
4347

4448
_benches = get(ENV, "BENCHMARK", "cpu,cpu+dagger")
4549
const benches = []
@@ -106,7 +110,7 @@ end
106110

107111
theory_flops(nrow, ncol, nfeatures) = 11 * ncol * nrow * nfeatures + 2 * (ncol + nrow) * nfeatures
108112

109-
function nmf_suite(; dagger, accel, kwargs...)
113+
function nmf_suite(ctx; dagger, accel)
110114
suite = BenchmarkGroup()
111115

112116
#= TODO: Re-enable
@@ -179,47 +183,56 @@ function nmf_suite(; dagger, accel, kwargs...)
179183
])
180184
elseif accel == "cpu"
181185
Dagger.Sch.SchedulerOptions()
186+
else
187+
error("Unknown accelerator $accel")
182188
end
183-
ctx = Context(collect((1:nw) .+ 1); kwargs...)
184189
p = sum([length(Dagger.get_processors(OSProc(id))) for id in 2:(nw+1)])
190+
#bsz = ncol ÷ length(workers())
191+
bsz = ncol ÷ 64
185192
nsuite["Workers: $nw"] = @benchmarkable begin
186-
compute($ctx, nnmf($X[], $W[], $H[]); options=$opts)
193+
_ctx = Context($ctx, workers()[1:$nw])
194+
compute(_ctx, nnmf($X[], $W[], $H[]); options=$opts)
187195
end setup=begin
188196
_nw, _scale = $nw, $scale
189197
@info "Starting $_nw worker Dagger NNMF (scale by $_scale)"
190-
if render != ""
191-
Dagger.show_gantt($ctx; width=1800, window_length=20, delay=2, port=4040, live=live)
192-
end
193198
if $accel == "cuda"
194199
# FIXME: Allocate with CUDA.rand if possible
195-
$X[] = Dagger.mapchunks(CUDA.cu, compute(rand(Blocks($nrow, $ncol÷$p), Float32, $nrow, $ncol); options=$opts))
196-
$W[] = Dagger.mapchunks(CUDA.cu, compute(rand(Blocks($nrow, $ncol÷$p), Float32, $nrow, $nfeatures); options=$opts))
197-
$H[] = Dagger.mapchunks(CUDA.cu, compute(rand(Blocks($nrow, $ncol÷$p), Float32, $nfeatures, $ncol); options=$opts))
200+
$X[] = Dagger.mapchunks(CUDA.cu, compute(rand(Blocks($nrow, $bsz), Float32, $nrow, $ncol); options=$opts))
201+
$W[] = Dagger.mapchunks(CUDA.cu, compute(rand(Blocks($nrow, $bsz), Float32, $nrow, $nfeatures); options=$opts))
202+
$H[] = Dagger.mapchunks(CUDA.cu, compute(rand(Blocks($nrow, $bsz), Float32, $nfeatures, $ncol); options=$opts))
198203
elseif $accel == "amdgpu"
199204
$X[] = Dagger.mapchunks(ROCArray, compute(rand(Blocks($nrow, $ncol÷$p), Float32, $nrow, $ncol); options=$opts))
200205
$W[] = Dagger.mapchunks(ROCArray, compute(rand(Blocks($nrow, $ncol÷$p), Float32, $nrow, $nfeatures); options=$opts))
201206
$H[] = Dagger.mapchunks(ROCArray, compute(rand(Blocks($nrow, $ncol÷$p), Float32, $nfeatures, $ncol); options=$opts))
202207
elseif $accel == "cpu"
203-
$X[] = compute(rand(Blocks($nrow, $ncol÷$p), Float32, $nrow, $ncol); options=$opts)
204-
$W[] = compute(rand(Blocks($nrow, $ncol÷$p), Float32, $nrow, $nfeatures); options=$opts)
205-
$H[] = compute(rand(Blocks($nrow, $ncol÷$p), Float32, $nfeatures, $ncol); options=$opts)
208+
$X[] = compute(rand(Blocks($nrow, $bsz), Float32, $nrow, $ncol); options=$opts)
209+
$W[] = compute(rand(Blocks($nrow, $bsz), Float32, $nrow, $nfeatures); options=$opts)
210+
$H[] = compute(rand(Blocks($nrow, $bsz), Float32, $nfeatures, $ncol); options=$opts)
206211
end
207212
end teardown=begin
208-
if render != ""
213+
if render != "" && !live
209214
Dagger.continue_rendering[] = false
210-
video_paths = take!(Dagger.render_results)
211-
try
212-
video_data = Dict(key=>read(video_paths[key]) for key in keys(video_paths))
213-
push!(get!(()->[], RENDERS[$scale], $nw), video_data)
214-
catch
215+
for i in 1:5
216+
isready(Dagger.render_results) && break
217+
sleep(1)
218+
end
219+
if isready(Dagger.render_results)
220+
video_paths = take!(Dagger.render_results)
221+
try
222+
video_data = Dict(key=>read(video_paths[key]) for key in keys(video_paths))
223+
push!(get!(()->[], RENDERS[$scale], $nw), video_data)
224+
catch err
225+
@error "Failed to process render results" exception=(err,catch_backtrace())
226+
end
227+
else
228+
@warn "Failed to fetch render results"
215229
end
216230
end
217231
$X[] = nothing
218232
$W[] = nothing
219233
$H[] = nothing
220234
@everywhere GC.gc()
221235
end
222-
break
223236
nw ÷= 2
224237
end
225238
suite["NNMF scaled by: $scale"] = nsuite
@@ -234,28 +247,42 @@ function main()
234247
output_prefix = "result-$(np)workers-$(nt)threads-$(Dates.now())"
235248

236249
suites = Dict()
250+
graph_opts = if graph && render != ""
251+
(log_sink=Dagger.LocalEventLog(), log_file=output_prefix*".dot")
252+
elseif render != ""
253+
(log_sink=Dagger.LocalEventLog(),)
254+
else
255+
NamedTuple()
256+
end
257+
ctx = Context(collect((1:nw) .+ 1); profile=profile, graph_opts...)
237258
for bench in benches
238259
name = bench.name
239260
println("creating $name benchmarks")
240-
suites[name] = if bench.dagger
241-
nmf_suite(; dagger=true, accel=bench.accel, log_sink=Dagger.LocalEventLog(), log_file=output_prefix*".dot", profile=false)
242-
else
243-
nmf_suite(; dagger=false, accel=bench.accel)
261+
suites[name] = nmf_suite(ctx; dagger=bench.dagger, accel=bench.accel)
262+
end
263+
if render != ""
264+
Dagger.show_gantt(ctx; width=1800, window_length=5, delay=2, port=live_port, live=live)
265+
if live
266+
# Make sure server code is compiled
267+
sleep(1)
268+
run(pipeline(`curl -s localhost:$live_port/`; stdout=devnull))
269+
run(pipeline(`curl -s localhost:$live_port/profile`; stdout=devnull))
270+
@info "Rendering started on port $live_port"
244271
end
245272
end
246273
res = Dict()
247274
for bench in benches
248275
name = bench.name
249276
println("running $name benchmarks")
250277
res[name] = try
251-
run(suites[name]; samples=5, seconds=10*60, gcsample=true)
278+
run(suites[name]; samples=3, seconds=10*60, gcsample=true)
252279
catch err
253280
@error "Error running $name benchmarks" exception=(err,catch_backtrace())
254281
nothing
255282
end
256283
end
257284
for bench in benches
258-
println("benchmark results for $(bench.name): $(res[bench.name])")
285+
println("benchmark results for $(bench.name): $(minimum(res[bench.name]))")
259286
end
260287

261288
println("saving results in $output_prefix.$output_format")
@@ -267,6 +294,11 @@ function main()
267294
serialize(io, outdict)
268295
end
269296
end
297+
298+
if parse(Bool, get(ENV, "BENCHMARK_VISUALIZE", "0"))
299+
run(`$(Base.julia_cmd()) $(joinpath(pwd(), "visualize.jl")) -- $(output_prefix*"."*output_format)`)
300+
end
301+
270302
println("Done.")
271303

272304
# TODO: Compare with multiple results

benchmarks/visualize.jl

+71-33
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,59 @@
1-
using JLD
1+
using JLD, Serialization
22
using BenchmarkTools
33
using TypedTables
44

5-
res = JLD.load(ARGS[1])
5+
res = if endswith(ARGS[1], ".jld")
6+
JLD.load(ARGS[1])
7+
elseif endswith(ARGS[1], ".jls")
8+
deserialize(ARGS[1])
9+
else
10+
error("Unknown file type")
11+
end
612

7-
serial_results = res["results"]["Serial"]
8-
dagger_results = res["results"]["Dagger"]
13+
serial_results = filter(x->!occursin("dagger", x[1]), res["results"])
14+
@assert length(keys(serial_results)) > 0 "No serial results found"
15+
dagger_results = filter(x->occursin("dagger", x[1]), res["results"])
16+
@assert length(keys(dagger_results)) > 0 "No Dagger results found"
17+
18+
scale_set = sort([key=>parse(Int, lstrip(last(split(key, ':')), ' ')) for key in keys(first(serial_results)[2])]; by=x->x[2])
19+
nw_set = sort([key=>parse(Int, lstrip(last(split(key, ':')), ' ')) for key in keys(first(dagger_results)[2][first(first(scale_set))])]; by=x->x[2])
20+
raw_table = NamedTuple[]
21+
for bset_key in keys(res["results"])
22+
bset = res["results"][bset_key]
23+
if typeof(bset[first(first(scale_set))]) <: BenchmarkGroup
24+
procs = parse(Int, lstrip(last(split(first(first(bset[first(first(scale_set))])), ':')), ' '))
25+
for nw in nw_set
26+
for i in 1:length(scale_set)
27+
set_times = [minimum(bset[scale][nw[1]]).time/(10^9) for scale in first.(scale_set)]
28+
push!(raw_table, (name=bset_key, time=set_times[i], scale=last.(scale_set)[i], procs=nw[2]))
29+
end
30+
end
31+
else
32+
set_times = [minimum(bset[scale]).time/(10^9) for scale in first.(scale_set)]
33+
procs = 8 # default for OpenBLAS
34+
for i in 1:length(set_times)
35+
push!(raw_table, (name=bset_key, time=set_times[i], scale=last.(scale_set)[i], procs=procs))
36+
end
37+
end
38+
end
39+
table = Table(raw_table)
940

10-
scale_set = sort([key=>parse(Int, lstrip(last(split(key, ':')), ' ')) for key in keys(serial_results)]; by=x->x[2])
11-
serial_times = [minimum(serial_results[scale]).time/(10^9) for scale in first.(scale_set)]
12-
nw_set = sort([key=>parse(Int, lstrip(last(split(key, ':')), ' ')) for key in keys(dagger_results[first(first(scale_set))])]; by=x->x[2])
41+
btable = copy(table[map(x->!x, occursin.(Ref("dagger"), table.name))])
42+
dtable = copy(table[occursin.(Ref("dagger"), table.name)])
1343

14-
table = Table(name=[:Base for _ in 1:3], time=serial_times, scale=last.(scale_set), procs=[8 for _ in 1:3])
44+
#table = Table(name=[:Base for _ in 1:3], time=serial_times, scale=last.(scale_set), procs=[8 for _ in 1:3])
1545

16-
btable = copy(table)
46+
#btable = copy(table)
1747

48+
#=
1849
for (nw,nw_val) in nw_set
1950
dagger_times = [minimum(dagger_results[scale][nw]).time/(10^9) for scale in first.(scale_set)]
2051
t = Table(name=[:Dagger for _ in 1:3], time=dagger_times, scale=last.(scale_set), procs=[parse(Int,split(nw, ":")[2]) for _ in 1:3])
2152
append!(table, t)
2253
end
54+
=#
2355

24-
dtable = table[table.name .== :Dagger]
56+
#dtable = table[table.name .== :Dagger]
2557

2658
# Plotting
2759

@@ -45,11 +77,11 @@ legend_names = String[]
4577

4678
scales = unique(dtable.scale)
4779

48-
colors = distinguishable_colors(lenght(scales), ColorSchemes.seaborn_deep.colors)
80+
colors = distinguishable_colors(length(scales), ColorSchemes.seaborn_deep.colors)
4981

5082
for (i, scale) in enumerate(scales)
5183
stable = dtable[dtable.scale .== scale]
52-
t1 = first(stable[stable.procs .== 1].time)
84+
t1 = first(stable[stable.procs .== minimum(dtable.procs)].time)
5385
ss_efficiency = strong_scaling.(t1, stable.time, stable.procs)
5486
push!(line_plots, lines!(ssp, stable.procs, ss_efficiency, linewidth=3.0, color = colors[i]))
5587
push!(legend_names, "scale = $scale")
@@ -65,25 +97,32 @@ save("strong_scaling.png", fig)
6597
# too little data
6698

6799
fig = Figure(resolution = (1200, 800))
68-
weak_scaling(t1, tn) = t1/tn
100+
weak_scaling(t1, tn, p_prime, p) = t1/((p_prime/p)*tn)
69101

70-
dtable = table[table.name .== :Dagger]
71-
wstable = filter(row->row.scale == row.procs, dtable)
72-
wstable = sort(wstable, by=r->r.scale)
73-
t1 = first(wstable).time
102+
t1 = first(dtable[map(row->(row.scale == 10) && (row.procs == 1), dtable)]).time
74103

75104
fig = Figure(resolution = (1200, 800))
76-
perf = fig[1, 1] = Axis(fig, title = "Weak scaling")
77-
perf.xlabel = "nprocs"
78-
perf.ylabel = "Efficiency"
105+
perf = fig[1, 1] = Axis(fig, title = "Weak Scaling")
106+
perf.xlabel = "Number of processes"
107+
perf.ylabel = "Scaling efficiency"
108+
109+
line_plots = Any[]
110+
legend_names = String[]
79111

80-
lines!(perf, wstable.procs, weak_scaling.(t1, wstable.time), linewidth=3.0)
112+
wstable = similar(dtable, 0)
113+
for pair in [(10,1),(35,4),(85,8)]
114+
append!(wstable, dtable[map(row->(row.scale == pair[1]) && (row.procs == pair[2]), rows(dtable))])
115+
end
116+
push!(line_plots, lines!(perf, wstable.procs, weak_scaling.(t1, wstable.time, wstable.procs .* 10, wstable.scale), linewidth=3.0))
117+
push!(legend_names, "cpu+dagger")
118+
119+
legend = fig[1, 2] = Legend(fig, line_plots, legend_names)
81120
save("weak_scaling.png", fig)
82121

83122
# 3. Comparision against Base
84123

85124
fig = Figure(resolution = (1200, 800))
86-
perf = fig[1, 1] = Axis(fig, title = "DaggerArrays vs Base")
125+
perf = fig[1, 1] = Axis(fig, title = "Dagger vs Base")
87126
perf.xlabel = "Scaling factor"
88127
perf.ylabel = "time (s)"
89128

@@ -92,7 +131,7 @@ legend_names = String[]
92131

93132
procs = unique(dtable.procs)
94133

95-
colors = distinguishable_colors(lenght(procs) + 1, ColorSchemes.seaborn_deep.colors)
134+
colors = distinguishable_colors(length(procs) + 1, ColorSchemes.seaborn_deep.colors)
96135

97136
for (i, nproc) in enumerate(procs)
98137
stable = dtable[dtable.procs .== nproc]
@@ -109,23 +148,22 @@ save("raw_timings.png", fig)
109148

110149
# 4. Speedup
111150
fig = Figure(resolution = (1200, 800))
112-
speedup = fig[1, 1] = Axis(fig, title = "DaggerArrays vs Base (8 threads)")
113-
speedup.xlabel = "Scaling factor"
114-
speedup.ylabel = "Speedup Base/Dagger"
151+
speedup = fig[1, 1] = Axis(fig, title = "Speedup vs. 1 processor")
152+
speedup.xlabel = "Number of processors"
153+
speedup.ylabel = "Speedup"
115154

116155
line_plots = Any[]
117156
legend_names = String[]
118157

119158
colors = distinguishable_colors(length(procs), ColorSchemes.seaborn_deep.colors)
120159

121-
sort!(btable, by=r->r.scale)
160+
t1 = sort(dtable[dtable.scale .== 10]; by=r->r.procs)
122161

123-
for (i, nproc) in enumerate(unique(dtable.procs))
124-
nproc < 8 && continue
125-
stable = dtable[dtable.procs .== nproc]
126-
sort!(stable, by=r->r.scale)
127-
push!(line_plots, lines!(speedup, stable.scale, btable.time ./ stable.time, linewidth=3.0, color = colors[i]))
128-
push!(legend_names, "Dagger (nprocs = $nproc)")
162+
for (i, scale) in enumerate(unique(dtable.scale))
163+
stable = dtable[dtable.scale .== scale]
164+
sort!(stable, by=r->r.procs)
165+
push!(line_plots, lines!(speedup, stable.procs, stable.time ./ t1.time, linewidth=3.0, color = colors[i]))
166+
push!(legend_names, "Dagger (scale = $scale)")
129167
end
130168

131169
legend = fig[1, 2] = Legend(fig, line_plots, legend_names)

src/chunks.jl

+4
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,12 @@ function unrelease(c::Chunk{<:Any,DRef})
6565
end
6666
unrelease(c::Chunk) = c
6767

68+
Base.:(==)(c1::Chunk, c2::Chunk) = c1.handle == c2.handle
69+
Base.hash(c::Chunk, x::UInt64) = hash(c.handle, x)
70+
6871
collect_remote(chunk::Chunk) =
6972
move(chunk.processor, OSProc(), poolget(chunk.handle))
73+
7074
function collect(ctx::Context, chunk::Chunk; options=nothing)
7175
# delegate fetching to handle by default.
7276
if chunk.handle isa DRef && !(chunk.processor isa OSProc)

src/compute.jl

+10-8
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ end
8282
##### Dag utilities #####
8383

8484
"""
85-
dependents(node::Thunk, deps=Dict{Thunk, Set{Thunk}}()) -> Dict{Thunk, Set{Thunk}}
85+
dependents(node::Thunk) -> Dict{Union{Thunk,Chunk}, Set{Thunk}}
8686
8787
Find the set of direct dependents for each task.
8888
"""
8989
function dependents(node::Thunk)
90-
deps = Dict{Thunk, Set{Thunk}}()
90+
deps = Dict{Union{Thunk,Chunk}, Set{Thunk}}()
9191
visited = Set{Thunk}()
9292
to_visit = Set{Thunk}()
9393
push!(to_visit, node)
@@ -98,10 +98,12 @@ function dependents(node::Thunk)
9898
deps[next] = Set{Thunk}()
9999
end
100100
for inp in inputs(next)
101-
if inp isa Thunk
102-
s::Set{Thunk} = get!(()->Set{Thunk}(), deps, inp)
101+
if istask(inp) || (inp isa Chunk)
102+
s = get!(()->Set{Thunk}(), deps, inp)
103103
push!(s, next)
104-
!(inp in visited) && push!(to_visit, inp)
104+
if istask(inp) && !(inp in visited)
105+
push!(to_visit, inp)
106+
end
105107
end
106108
end
107109
push!(visited, next)
@@ -110,14 +112,14 @@ function dependents(node::Thunk)
110112
end
111113

112114
"""
113-
noffspring(dpents::Dict{Thunk, Set{Thunk}}) -> Dict{Thunk, Int}
115+
noffspring(dpents::Dict{Union{Thunk,Chunk}, Set{Thunk}}) -> Dict{Thunk, Int}
114116
115117
Recursively find the number of tasks dependent on each task in the DAG.
116118
Takes a Dict as returned by [`dependents`](@ref).
117119
"""
118-
function noffspring(dpents::Dict{Thunk, Set{Thunk}})
120+
function noffspring(dpents::Dict{Union{Thunk,Chunk}, Set{Thunk}})
119121
noff = Dict{Thunk,Int}()
120-
to_visit = collect(keys(dpents))
122+
to_visit = collect(filter(istask, keys(dpents)))
121123
while !isempty(to_visit)
122124
next = popfirst!(to_visit)
123125
haskey(noff, next) && continue

0 commit comments

Comments
 (0)