Skip to content

Commit 8072f4c

Browse files
committed
Forbid divergent execution of work-group barriers
1 parent 90a10d7 commit 8072f4c

File tree

2 files changed

+102
-11
lines changed

2 files changed

+102
-11
lines changed

src/KernelAbstractions.jl

+9-1
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,9 @@ end
297297
After a `@synchronize` statement all read and writes to global and local memory
298298
from each thread in the workgroup are visible in from all other threads in the
299299
workgroup.
300+
301+
!!! note
302+
`@synchronize()` must be encountered by all workitems of a work-group executing the kernel or by none at all.
300303
"""
301304
macro synchronize()
302305
return quote
@@ -314,10 +317,15 @@ workgroup. `cond` is not allowed to have any visible sideffects.
314317
# Platform differences
315318
- `GPU`: This synchronization will only occur if the `cond` evaluates.
316319
- `CPU`: This synchronization will always occur.
320+
321+
!!! warn
322+
This variant of the `@synchronize` macro violates the requirement that `@synchronize` must be encountered
323+
by all workitems of a work-group executing the kernel or by none at all.
324+
Since v`0.9.34` this version of the macro is deprecated and lowers to `@synchronize()`
317325
"""
318326
macro synchronize(cond)
319327
return quote
320-
$(esc(cond)) && $__synchronize()
328+
$__synchronize()
321329
end
322330
end
323331

src/macros.jl

+93-10
Original file line numberDiff line numberDiff line change
@@ -58,22 +58,105 @@ function transform_gpu!(def, constargs, force_inbounds)
5858
end
5959
end
6060
pushfirst!(def[:args], :__ctx__)
61-
body = def[:body]
61+
new_stmts = Expr[]
62+
body = MacroTools.flatten(def[:body])
63+
stmts = body.args
64+
push!(new_stmts, Expr(:aliasscope))
65+
push!(new_stmts, :(__active_lane__ = $__validindex(__ctx__)))
6266
if force_inbounds
63-
body = quote
64-
@inbounds $(body)
65-
end
67+
push!(new_stmts, Expr(:inbounds, true))
6668
end
67-
body = quote
68-
if $__validindex(__ctx__)
69-
$(body)
70-
end
71-
return nothing
69+
append!(new_stmts, split(emit_gpu, body.args))
70+
if force_inbounds
71+
push!(new_stmts, Expr(:inbounds, :pop))
7272
end
73+
push!(new_stmts, Expr(:popaliasscope))
74+
push!(new_stmts, :(return nothing))
7375
def[:body] = Expr(
7476
:let,
7577
Expr(:block, let_constargs...),
76-
body,
78+
Expr(:block, new_stmts...),
7779
)
7880
return
7981
end
82+
83+
struct WorkgroupLoop
84+
stmts::Vector{Any}
85+
terminated_in_sync::Bool
86+
end
87+
88+
is_sync(expr) = @capture(expr, @synchronize() | @synchronize(a_))
89+
90+
function is_scope_construct(expr::Expr)
91+
return expr.head === :block # ||
92+
# expr.head === :let
93+
end
94+
95+
function find_sync(stmt)
96+
result = false
97+
postwalk(stmt) do expr
98+
result |= is_sync(expr)
99+
expr
100+
end
101+
return result
102+
end
103+
104+
# TODO proper handling of LineInfo
105+
function split(emit, stmts)
106+
# 1. Split the code into blocks separated by `@synchronize`
107+
108+
current = Any[]
109+
new_stmts = Any[]
110+
for stmt in stmts
111+
has_sync = find_sync(stmt)
112+
if has_sync
113+
loop = WorkgroupLoop(current, is_sync(stmt))
114+
push!(new_stmts, emit(loop))
115+
current = Any[]
116+
is_sync(stmt) && continue
117+
118+
# Recurse into scope constructs
119+
# TODO: This currently implements hard scoping
120+
# probably need to implemet soft scoping
121+
# by not deepcopying the environment.
122+
recurse(x) = x
123+
function recurse(expr::Expr)
124+
expr = unblock(expr)
125+
if is_scope_construct(expr) && any(find_sync, expr.args)
126+
new_args = unblock(split(emit, expr.args))
127+
return Expr(expr.head, new_args...)
128+
else
129+
return Expr(expr.head, map(recurse, expr.args)...)
130+
end
131+
end
132+
push!(new_stmts, recurse(stmt))
133+
continue
134+
end
135+
136+
push!(current, stmt)
137+
end
138+
139+
# everything since the last `@synchronize`
140+
if !isempty(current)
141+
loop = WorkgroupLoop(current, false)
142+
push!(new_stmts, emit(loop))
143+
end
144+
return new_stmts
145+
end
146+
147+
function emit_gpu(loop)
148+
stmts = Any[]
149+
150+
body = Expr(:block, loop.stmts...)
151+
loopexpr = quote
152+
if __active_lane__
153+
$(unblock(body))
154+
end
155+
end
156+
push!(stmts, loopexpr)
157+
if loop.terminated_in_sync
158+
push!(stmts, :($__synchronize()))
159+
end
160+
161+
return unblock(Expr(:block, stmts...))
162+
end

0 commit comments

Comments
 (0)