Skip to content

Commit 210658c

Browse files
committed
Forbid divergent execution of work-group barriers
1 parent f88ee87 commit 210658c

File tree

1 file changed

+6
-24
lines changed

1 file changed

+6
-24
lines changed

src/macros.jl

+6-24
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,8 @@ function transform_gpu!(def, constargs, force_inbounds)
8181
end
8282

8383
struct WorkgroupLoop
84-
indicies::Vector{Any}
8584
stmts::Vector{Any}
8685
allocations::Vector{Any}
87-
private_allocations::Vector{Any}
88-
private::Set{Symbol}
8986
terminated_in_sync::Bool
9087
end
9188

@@ -106,26 +103,18 @@ function find_sync(stmt)
106103
end
107104

108105
# TODO proper handling of LineInfo
109-
function split(
110-
stmts,
111-
indicies = Any[], private = Set{Symbol}(),
112-
)
106+
function split(stmts)
113107
# 1. Split the code into blocks separated by `@synchronize`
114-
# 2. Aggregate `@index` expressions
115-
# 3. Hoist allocations
116-
# 4. Hoist uniforms
117108

118109
current = Any[]
119110
allocations = Any[]
120-
private_allocations = Any[]
121111
new_stmts = Any[]
122112
for stmt in stmts
123113
has_sync = find_sync(stmt)
124114
if has_sync
125-
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, private_allocations, deepcopy(private), is_sync(stmt))
115+
loop = WorkgroupLoop(current, allocations, is_sync(stmt))
126116
push!(new_stmts, emit(loop))
127117
allocations = Any[]
128-
private_allocations = Any[]
129118
current = Any[]
130119
is_sync(stmt) && continue
131120

@@ -137,7 +126,7 @@ function split(
137126
function recurse(expr::Expr)
138127
expr = unblock(expr)
139128
if is_scope_construct(expr) && any(find_sync, expr.args)
140-
new_args = unblock(split(expr.args, deepcopy(indicies), deepcopy(private)))
129+
new_args = unblock(split(expr.args))
141130
return Expr(expr.head, new_args...)
142131
else
143132
return Expr(expr.head, map(recurse, expr.args)...)
@@ -151,14 +140,10 @@ function split(
151140
push!(allocations, stmt)
152141
continue
153142
elseif @capture(stmt, @private lhs_ = rhs_)
154-
push!(private, lhs)
155-
push!(private_allocations, :($lhs = $rhs))
143+
push!(allocations, :($lhs = $rhs))
156144
continue
157145
elseif @capture(stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
158-
if @capture(rhs, @index(args__))
159-
push!(indicies, stmt)
160-
continue
161-
elseif @capture(rhs, @localmem(args__) | @uniform(args__))
146+
if @capture(rhs, @localmem(args__) | @uniform(args__))
162147
push!(allocations, stmt)
163148
continue
164149
elseif @capture(rhs, @private(T_, dims_))
@@ -170,7 +155,6 @@ function split(
170155
end
171156
alloc = :($Scratchpad(__ctx__, $T, Val($dims)))
172157
push!(allocations, :($lhs = $alloc))
173-
push!(private, lhs)
174158
continue
175159
end
176160
end
@@ -180,7 +164,7 @@ function split(
180164

181165
# everything since the last `@synchronize`
182166
if !isempty(current)
183-
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, private_allocations, deepcopy(private), false)
167+
loop = WorkgroupLoop(current, allocations, false)
184168
push!(new_stmts, emit(loop))
185169
end
186170
return new_stmts
@@ -192,9 +176,7 @@ function emit(loop)
192176
body = Expr(:block, loop.stmts...)
193177
loopexpr = quote
194178
$(loop.allocations...)
195-
$(loop.private_allocations...)
196179
if __active_lane__
197-
$(loop.indicies...)
198180
$(unblock(body))
199181
end
200182
end

0 commit comments

Comments
 (0)