Skip to content

Commit 6366f40

Browse files
authored
port partr multiq to julia (#44653)
Direct translation, not necessarily fully idiomatic. In preparation for future improvements.
1 parent 62e0729 commit 6366f40

13 files changed

+229
-233
lines changed

base/Base.jl

+1
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ include("condition.jl")
277277
include("threads.jl")
278278
include("lock.jl")
279279
include("channels.jl")
280+
include("partr.jl")
280281
include("task.jl")
281282
include("threads_overloads.jl")
282283
include("weakkeydict.jl")

base/boot.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ primitive type Char <: AbstractChar 32 end
224224
primitive type Int8 <: Signed 8 end
225225
#primitive type UInt8 <: Unsigned 8 end
226226
primitive type Int16 <: Signed 16 end
227-
primitive type UInt16 <: Unsigned 16 end
227+
#primitive type UInt16 <: Unsigned 16 end
228228
#primitive type Int32 <: Signed 32 end
229229
#primitive type UInt32 <: Unsigned 32 end
230230
#primitive type Int64 <: Signed 64 end

base/partr.jl

+166
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
module Partr
4+
5+
using ..Threads: SpinLock
6+
7+
# a task heap
8+
mutable struct taskheap
9+
const lock::SpinLock
10+
const tasks::Vector{Task}
11+
@atomic ntasks::Int32
12+
@atomic priority::UInt16
13+
taskheap() = new(SpinLock(), Vector{Task}(undef, tasks_per_heap), zero(Int32), typemax(UInt16))
14+
end
15+
16+
# multiqueue parameters
17+
const heap_d = UInt32(8)
18+
const heap_c = UInt32(2)
19+
20+
# size of each heap
21+
const tasks_per_heap = Int32(65536) # TODO: this should be smaller by default, but growable!
22+
23+
# the multiqueue's heaps
24+
global heaps::Vector{taskheap}
25+
global heap_p::UInt32 = 0
26+
27+
# unbias state for the RNG
28+
global cong_unbias::UInt32 = 0
29+
30+
31+
cong(max::UInt32, unbias::UInt32) = ccall(:jl_rand_ptls, UInt32, (UInt32, UInt32), max, unbias) + UInt32(1)
32+
33+
function unbias_cong(max::UInt32)
34+
return typemax(UInt32) - ((typemax(UInt32) % max) + UInt32(1))
35+
end
36+
37+
38+
function multiq_init(nthreads)
39+
global heap_p = heap_c * nthreads
40+
global cong_unbias = unbias_cong(UInt32(heap_p))
41+
global heaps = Vector{taskheap}(undef, heap_p)
42+
for i = UInt32(1):heap_p
43+
heaps[i] = taskheap()
44+
end
45+
nothing
46+
end
47+
48+
49+
function sift_up(heap::taskheap, idx::Int32)
50+
while idx > Int32(1)
51+
parent = (idx - Int32(2)) ÷ heap_d + Int32(1)
52+
if heap.tasks[idx].priority < heap.tasks[parent].priority
53+
t = heap.tasks[parent]
54+
heap.tasks[parent] = heap.tasks[idx]
55+
heap.tasks[idx] = t
56+
idx = parent
57+
else
58+
break
59+
end
60+
end
61+
end
62+
63+
64+
function sift_down(heap::taskheap, idx::Int32)
65+
if idx <= heap.ntasks
66+
for child = (heap_d * idx - heap_d + Int32(2)):(heap_d * idx + Int32(1))
67+
child > tasks_per_heap && break
68+
if isassigned(heap.tasks, child) &&
69+
heap.tasks[child].priority < heap.tasks[idx].priority
70+
t = heap.tasks[idx]
71+
heap.tasks[idx] = heap.tasks[child]
72+
heap.tasks[child] = t
73+
sift_down(heap, child)
74+
end
75+
end
76+
end
77+
end
78+
79+
80+
function multiq_insert(task::Task, priority::UInt16)
81+
task.priority = priority
82+
83+
rn = cong(heap_p, cong_unbias)
84+
while !trylock(heaps[rn].lock)
85+
rn = cong(heap_p, cong_unbias)
86+
end
87+
88+
if heaps[rn].ntasks >= tasks_per_heap
89+
unlock(heaps[rn].lock)
90+
# multiq insertion failed, increase #tasks per heap
91+
return false
92+
end
93+
94+
ntasks = heaps[rn].ntasks + Int32(1)
95+
@atomic :monotonic heaps[rn].ntasks = ntasks
96+
heaps[rn].tasks[ntasks] = task
97+
sift_up(heaps[rn], ntasks)
98+
priority = heaps[rn].priority
99+
if task.priority < priority
100+
@atomic :monotonic heaps[rn].priority = task.priority
101+
end
102+
unlock(heaps[rn].lock)
103+
return true
104+
end
105+
106+
107+
function multiq_deletemin()
108+
local rn1, rn2
109+
local prio1, prio2
110+
111+
@label retry
112+
GC.safepoint()
113+
for i = UInt32(1):heap_p
114+
if i == heap_p
115+
return nothing
116+
end
117+
rn1 = cong(heap_p, cong_unbias)
118+
rn2 = cong(heap_p, cong_unbias)
119+
prio1 = heaps[rn1].priority
120+
prio2 = heaps[rn2].priority
121+
if prio1 > prio2
122+
prio1 = prio2
123+
rn1 = rn2
124+
elseif prio1 == prio2 && prio1 == typemax(UInt16)
125+
continue
126+
end
127+
if trylock(heaps[rn1].lock)
128+
if prio1 == heaps[rn1].priority
129+
break
130+
end
131+
unlock(heaps[rn1].lock)
132+
end
133+
end
134+
135+
task = heaps[rn1].tasks[1]
136+
tid = Threads.threadid()
137+
if ccall(:jl_set_task_tid, Cint, (Any, Cint), task, tid-1) == 0
138+
unlock(heaps[rn1].lock)
139+
@goto retry
140+
end
141+
ntasks = heaps[rn1].ntasks
142+
@atomic :monotonic heaps[rn1].ntasks = ntasks - Int32(1)
143+
heaps[rn1].tasks[1] = heaps[rn1].tasks[ntasks]
144+
Base._unsetindex!(heaps[rn1].tasks, Int(ntasks))
145+
prio1 = typemax(UInt16)
146+
if ntasks > 1
147+
sift_down(heaps[rn1], Int32(1))
148+
prio1 = heaps[rn1].tasks[1].priority
149+
end
150+
@atomic :monotonic heaps[rn1].priority = prio1
151+
unlock(heaps[rn1].lock)
152+
153+
return task
154+
end
155+
156+
157+
function multiq_check_empty()
158+
for i = UInt32(1):heap_p
159+
if heaps[i].ntasks != 0
160+
return false
161+
end
162+
end
163+
return true
164+
end
165+
166+
end

base/task.jl

+24-16
Original file line numberDiff line numberDiff line change
@@ -711,12 +711,14 @@ const StickyWorkqueue = InvasiveLinkedListSynchronized{Task}
711711
global const Workqueues = [StickyWorkqueue()]
712712
global const Workqueue = Workqueues[1] # default work queue is thread 1
713713
function __preinit_threads__()
714-
if length(Workqueues) < Threads.nthreads()
715-
resize!(Workqueues, Threads.nthreads())
716-
for i = 2:length(Workqueues)
714+
nt = Threads.nthreads()
715+
if length(Workqueues) < nt
716+
resize!(Workqueues, nt)
717+
for i = 2:nt
717718
Workqueues[i] = StickyWorkqueue()
718719
end
719720
end
721+
Partr.multiq_init(nt)
720722
nothing
721723
end
722724

@@ -741,7 +743,7 @@ function enq_work(t::Task)
741743
end
742744
push!(Workqueues[tid], t)
743745
else
744-
if ccall(:jl_enqueue_task, Cint, (Any,), t) != 0
746+
if !Partr.multiq_insert(t, t.priority)
745747
# if multiq is full, give to a random thread (TODO fix)
746748
if tid == 0
747749
tid = mod(time_ns() % Int, Threads.nthreads()) + 1
@@ -907,24 +909,30 @@ function ensure_rescheduled(othertask::Task)
907909
end
908910

909911
function trypoptask(W::StickyWorkqueue)
910-
isempty(W) && return
911-
t = popfirst!(W)
912-
if t._state !== task_state_runnable
913-
# assume this somehow got queued twice,
914-
# probably broken now, but try discarding this switch and keep going
915-
# can't throw here, because it's probably not the fault of the caller to wait
916-
# and don't want to use print() here, because that may try to incur a task switch
917-
ccall(:jl_safe_printf, Cvoid, (Ptr{UInt8}, Int32...),
918-
"\nWARNING: Workqueue inconsistency detected: popfirst!(Workqueue).state != :runnable\n")
919-
return
912+
while !isempty(W)
913+
t = popfirst!(W)
914+
if t._state !== task_state_runnable
915+
# assume this somehow got queued twice,
916+
# probably broken now, but try discarding this switch and keep going
917+
# can't throw here, because it's probably not the fault of the caller to wait
918+
# and don't want to use print() here, because that may try to incur a task switch
919+
ccall(:jl_safe_printf, Cvoid, (Ptr{UInt8}, Int32...),
920+
"\nWARNING: Workqueue inconsistency detected: popfirst!(Workqueue).state != :runnable\n")
921+
continue
922+
end
923+
return t
920924
end
921-
return t
925+
return Partr.multiq_deletemin()
926+
end
927+
928+
function checktaskempty()
929+
return Partr.multiq_check_empty()
922930
end
923931

924932
@noinline function poptask(W::StickyWorkqueue)
925933
task = trypoptask(W)
926934
if !(task isa Task)
927-
task = ccall(:jl_task_get_next, Ref{Task}, (Any, Any), trypoptask, W)
935+
task = ccall(:jl_task_get_next, Ref{Task}, (Any, Any, Any), trypoptask, W, checktaskempty)
928936
end
929937
set_next_task(task)
930938
nothing

src/builtins.c

+3-2
Original file line numberDiff line numberDiff line change
@@ -2021,10 +2021,11 @@ void jl_init_primitives(void) JL_GC_DISABLED
20212021

20222022
add_builtin("Bool", (jl_value_t*)jl_bool_type);
20232023
add_builtin("UInt8", (jl_value_t*)jl_uint8_type);
2024-
add_builtin("Int32", (jl_value_t*)jl_int32_type);
2025-
add_builtin("Int64", (jl_value_t*)jl_int64_type);
2024+
add_builtin("UInt16", (jl_value_t*)jl_uint16_type);
20262025
add_builtin("UInt32", (jl_value_t*)jl_uint32_type);
20272026
add_builtin("UInt64", (jl_value_t*)jl_uint64_type);
2027+
add_builtin("Int32", (jl_value_t*)jl_int32_type);
2028+
add_builtin("Int64", (jl_value_t*)jl_int64_type);
20282029
#ifdef _P64
20292030
add_builtin("Int", (jl_value_t*)jl_int64_type);
20302031
#else

src/gc.c

-4
Original file line numberDiff line numberDiff line change
@@ -2824,7 +2824,6 @@ static void jl_gc_queue_thread_local(jl_gc_mark_cache_t *gc_cache, jl_gc_mark_sp
28242824
gc_mark_queue_obj(gc_cache, sp, ptls2->previous_exception);
28252825
}
28262826

2827-
void jl_gc_mark_enqueued_tasks(jl_gc_mark_cache_t *gc_cache, jl_gc_mark_sp_t *sp);
28282827
extern jl_value_t *cmpswap_names JL_GLOBALLY_ROOTED;
28292828

28302829
// mark the initial root set
@@ -2833,9 +2832,6 @@ static void mark_roots(jl_gc_mark_cache_t *gc_cache, jl_gc_mark_sp_t *sp)
28332832
// modules
28342833
gc_mark_queue_obj(gc_cache, sp, jl_main_module);
28352834

2836-
// tasks
2837-
jl_gc_mark_enqueued_tasks(gc_cache, sp);
2838-
28392835
// invisible builtin values
28402836
if (jl_an_empty_vec_any != NULL)
28412837
gc_mark_queue_obj(gc_cache, sp, jl_an_empty_vec_any);

src/init.c

+3-3
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,6 @@ static void post_boot_hooks(void)
780780
jl_char_type = (jl_datatype_t*)core("Char");
781781
jl_int8_type = (jl_datatype_t*)core("Int8");
782782
jl_int16_type = (jl_datatype_t*)core("Int16");
783-
jl_uint16_type = (jl_datatype_t*)core("UInt16");
784783
jl_float16_type = (jl_datatype_t*)core("Float16");
785784
jl_float32_type = (jl_datatype_t*)core("Float32");
786785
jl_float64_type = (jl_datatype_t*)core("Float64");
@@ -792,10 +791,11 @@ static void post_boot_hooks(void)
792791

793792
jl_bool_type->super = jl_integer_type;
794793
jl_uint8_type->super = jl_unsigned_type;
795-
jl_int32_type->super = jl_signed_type;
796-
jl_int64_type->super = jl_signed_type;
794+
jl_uint16_type->super = jl_unsigned_type;
797795
jl_uint32_type->super = jl_unsigned_type;
798796
jl_uint64_type->super = jl_unsigned_type;
797+
jl_int32_type->super = jl_signed_type;
798+
jl_int64_type->super = jl_signed_type;
799799

800800
jl_errorexception_type = (jl_datatype_t*)core("ErrorException");
801801
jl_stackovf_exception = jl_new_struct_uninit((jl_datatype_t*)core("StackOverflowError"));

src/jl_exported_funcs.inc

-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@
119119
XX(jl_egal__bits) \
120120
XX(jl_egal__special) \
121121
XX(jl_eh_restore_state) \
122-
XX(jl_enqueue_task) \
123122
XX(jl_enter_handler) \
124123
XX(jl_enter_threaded_region) \
125124
XX(jl_environ) \

src/jltypes.c

+9-5
Original file line numberDiff line numberDiff line change
@@ -2152,6 +2152,8 @@ void jl_init_types(void) JL_GC_DISABLED
21522152
jl_any_type, jl_emptysvec, 64);
21532153
jl_uint8_type = jl_new_primitivetype((jl_value_t*)jl_symbol("UInt8"), core,
21542154
jl_any_type, jl_emptysvec, 8);
2155+
jl_uint16_type = jl_new_primitivetype((jl_value_t*)jl_symbol("UInt16"), core,
2156+
jl_any_type, jl_emptysvec, 16);
21552157

21562158
jl_ssavalue_type = jl_new_datatype(jl_symbol("SSAValue"), core, jl_any_type, jl_emptysvec,
21572159
jl_perm_symsvec(1, "id"),
@@ -2516,7 +2518,7 @@ void jl_init_types(void) JL_GC_DISABLED
25162518
"inferred",
25172519
//"edges",
25182520
//"absolute_max",
2519-
"ipo_purity_bits", "purity_bits",
2521+
"ipo_purity_bits", "purity_bits",
25202522
"argescapes",
25212523
"isspecsig", "precompile", "invoke", "specptr", // function object decls
25222524
"relocatability"),
@@ -2610,7 +2612,7 @@ void jl_init_types(void) JL_GC_DISABLED
26102612
NULL,
26112613
jl_any_type,
26122614
jl_emptysvec,
2613-
jl_perm_symsvec(14,
2615+
jl_perm_symsvec(15,
26142616
"next",
26152617
"queue",
26162618
"storage",
@@ -2624,8 +2626,9 @@ void jl_init_types(void) JL_GC_DISABLED
26242626
"rngState3",
26252627
"_state",
26262628
"sticky",
2627-
"_isexception"),
2628-
jl_svec(14,
2629+
"_isexception",
2630+
"priority"),
2631+
jl_svec(15,
26292632
jl_any_type,
26302633
jl_any_type,
26312634
jl_any_type,
@@ -2639,7 +2642,8 @@ void jl_init_types(void) JL_GC_DISABLED
26392642
jl_uint64_type,
26402643
jl_uint8_type,
26412644
jl_bool_type,
2642-
jl_bool_type),
2645+
jl_bool_type,
2646+
jl_uint16_type),
26432647
jl_emptysvec,
26442648
0, 1, 6);
26452649
jl_value_t *listt = jl_new_struct(jl_uniontype_type, jl_task_type, jl_nothing_type);

src/julia.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -1881,12 +1881,12 @@ typedef struct _jl_task_t {
18811881
_Atomic(uint8_t) _state;
18821882
uint8_t sticky; // record whether this Task can be migrated to a new thread
18831883
_Atomic(uint8_t) _isexception; // set if `result` is an exception to throw or that we exited with
1884+
// multiqueue priority
1885+
uint16_t priority;
18841886

18851887
// hidden state:
18861888
// id of owning thread - does not need to be defined until the task runs
18871889
_Atomic(int16_t) tid;
1888-
// multiqueue priority
1889-
int16_t prio;
18901890
// saved gc stack top for context switches
18911891
jl_gcframe_t *gcstack;
18921892
size_t world_age;

0 commit comments

Comments
 (0)