Skip to content

Commit 7ebadea

Browse files
committed
Crime doesn't pay. (Stop stealing.)
1 parent a7036b9 commit 7ebadea

File tree

3 files changed

+35
-55
lines changed

3 files changed

+35
-55
lines changed

src/ThreadingUtilities.jl

+3-4
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@ using VectorizationBase:
44
pause, StaticInt, StridedPointer, stridedpointer, offsets, cache_linesize, align, __vload, __vstore!, num_threads, assume, False, register_size, NativeTypes
55

66
@enum ThreadState::UInt32 begin
7-
TASK = 0 # 3: task available
8-
EXEC = 1 # 2: task executed
9-
WAIT = 2 # 1: waiting
10-
SPIN = 3 # 0: spinning
7+
TASK = 0 # 0: task available
8+
WAIT = 1 # 1: waiting
9+
SPIN = 2 # 2: spinning
1110
end
1211
const TASKS = Task[]
1312
const THREADBUFFERSIZE = 512

src/threadtasks.jl

+30-49
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,51 @@
11
struct ThreadTask
2-
p::Ptr{UInt}
2+
p::Ptr{UInt}
33
end
44
Base.pointer(tt::ThreadTask) = tt.p
55

66
@inline taskpointer(tid::T) where {T} = THREADPOOLPTR[] + tid*(THREADBUFFERSIZE%T)
77

88
@inline function _call(p::Ptr{UInt})
9-
fptr = load(p + sizeof(UInt), Ptr{Cvoid})
10-
assume(fptr C_NULL)
11-
ccall(fptr, Cvoid, (Ptr{UInt},), p)
9+
fptr = load(p + sizeof(UInt), Ptr{Cvoid})
10+
assume(fptr C_NULL)
11+
ccall(fptr, Cvoid, (Ptr{UInt},), p)
1212
end
1313
@inline function launch(f::F, tid::Integer, args::Vararg{Any,K}) where {F,K}
14-
p = taskpointer(tid)
15-
f(p, args...)
16-
state = _atomic_xchg!(p, TASK)
17-
state == WAIT && wake_thread!(tid)
18-
return nothing
14+
p = taskpointer(tid)
15+
f(p, args...)
16+
state = _atomic_xchg!(p, TASK) # exchange must happen atomically, to prevent it from switching to `WAIT` after reading
17+
state == WAIT && wake_thread!(tid)
18+
return nothing
1919
end
2020

2121
function (tt::ThreadTask)()
22-
p = pointer(tt)
23-
max_wait = one(UInt32) << 20
24-
wait_counter = max_wait
25-
GC.@preserve THREADPOOL begin
26-
while true
27-
# if _atomic_state(p) == TASK
28-
if _atomic_cas_cmp!(p, TASK, EXEC)
29-
_call(p)
30-
# store!(p, SPIN)
31-
_atomic_store!(p, SPIN)
32-
wait_counter = zero(UInt32)
33-
continue
34-
end
35-
pause()
36-
if (wait_counter += one(UInt32)) > max_wait
37-
wait_counter = zero(UInt32)
38-
_atomic_cas_cmp!(p, SPIN, WAIT) && Base.wait()
39-
end
40-
end
22+
p = pointer(tt)
23+
max_wait = one(UInt32) << 20
24+
wait_counter = max_wait
25+
GC.@preserve THREADPOOL begin
26+
while true
27+
if _atomic_state(p) == TASK
28+
# if _atomic_cas_cmp!(p, TASK, EXEC)
29+
_call(p)
30+
# store!(p, SPIN)
31+
_atomic_store!(p, SPIN)
32+
wait_counter = zero(UInt32)
33+
continue
34+
end
35+
pause()
36+
if (wait_counter += one(UInt32)) > max_wait
37+
wait_counter = zero(UInt32)
38+
_atomic_cas_cmp!(p, SPIN, WAIT) && Base.wait()
39+
end
4140
end
41+
end
4242
end
4343

4444
# 1-based tid, pushes into task 2-nthreads()
4545
# function wake_thread!(tid::T) where {T <: Unsigned}
4646
@noinline function wake_thread!(_tid::T) where {T <: Integer}
4747
tid = _tid % Int
48-
store!(taskpointer(_tid), TASK)
48+
# store!(taskpointer(_tid), TASK)
4949
tidp1 = tid + one(tid)
5050
assume(unsigned(length(Base.Workqueues)) > unsigned(tid))
5151
assume(unsigned(length(TASKS)) > unsigned(tidp1))
@@ -56,29 +56,10 @@ end
5656
# 1-based tid
5757
@inline wait(tid::Integer) = wait(taskpointer(tid))
5858
@inline function wait(p::Ptr{UInt})
59-
# TASK = 0
60-
# EXEC = 1
61-
# WAIT = 2
62-
# SPIN = 3
63-
s = _atomic_umax!(p, EXEC) # s = old state, state gets set to EXEC if s == TASK or s == EXEC
64-
if s == TASK # thread hasn't begun yet for some reason, so we steal the work.
65-
_call(p)
66-
store!(p, SPIN)
67-
return
68-
end
6959
counter = 0x00000000
70-
while _atomic_state(p) == EXEC
60+
while _atomic_state(p) == TASK
7161
pause()
7262
((counter += 0x00000001) > 0x00010000) && yield()
7363
end
7464
end
7565

76-
77-
# function launch_thread(f::F, tid) where {F}
78-
# cfunc = @cfunction($mapper, Cvoid, ());
79-
80-
# fptr = Base.unsafe_convert(Ptr{Cvoid}, cfunc)
81-
82-
# ccall(fptr, Cvoid, ())
83-
84-
# end

test/threadingutilities.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ end
5353

5454
@testset "ThreadingUtilities.jl" begin
5555
for tid eachindex(ThreadingUtilities.TASKS)
56-
@test unsafe_load(Ptr{UInt32}(ThreadingUtilities.taskpointer(tid))) == 0x00000002
56+
@test unsafe_load(Ptr{UInt32}(ThreadingUtilities.taskpointer(tid))) == 0x00000001
5757
end
5858
@test all(eachindex(ThreadingUtilities.TASKS)) do tid
5959
ThreadingUtilities.load(ThreadingUtilities.taskpointer(tid), ThreadingUtilities.ThreadState) === ThreadingUtilities.WAIT
@@ -67,8 +67,8 @@ end
6767
GC.@preserve x begin
6868
ThreadingUtilities._atomic_store!(pointer(x), zero(UInt))
6969
@test ThreadingUtilities._atomic_xchg!(pointer(x), ThreadingUtilities.WAIT) == ThreadingUtilities.TASK
70+
@test ThreadingUtilities._atomic_umax!(pointer(x), ThreadingUtilities.TASK) == ThreadingUtilities.WAIT
7071
@test ThreadingUtilities._atomic_umax!(pointer(x), ThreadingUtilities.SPIN) == ThreadingUtilities.WAIT
71-
@test ThreadingUtilities._atomic_umax!(pointer(x), ThreadingUtilities.EXEC) == ThreadingUtilities.SPIN
7272
@test ThreadingUtilities.load(pointer(x), ThreadingUtilities.ThreadState) == ThreadingUtilities.SPIN
7373
end
7474
for tid eachindex(ThreadingUtilities.TASKS)

0 commit comments

Comments
 (0)