@@ -81,11 +81,8 @@ function transform_gpu!(def, constargs, force_inbounds)
81
81
end
82
82
83
83
struct WorkgroupLoop
84
- indicies:: Vector{Any}
85
84
stmts:: Vector{Any}
86
85
allocations:: Vector{Any}
87
- private_allocations:: Vector{Any}
88
- private:: Set{Symbol}
89
86
terminated_in_sync:: Bool
90
87
end
91
88
@@ -106,26 +103,18 @@ function find_sync(stmt)
106
103
end
107
104
108
105
# TODO proper handling of LineInfo
109
- function split (
110
- stmts,
111
- indicies = Any[], private = Set {Symbol} (),
112
- )
106
+ function split (stmts)
113
107
# 1. Split the code into blocks separated by `@synchronize`
114
- # 2. Aggregate `@index` expressions
115
- # 3. Hoist allocations
116
- # 4. Hoist uniforms
117
108
118
109
current = Any[]
119
110
allocations = Any[]
120
- private_allocations = Any[]
121
111
new_stmts = Any[]
122
112
for stmt in stmts
123
113
has_sync = find_sync (stmt)
124
114
if has_sync
125
- loop = WorkgroupLoop (deepcopy (indicies), current, allocations, private_allocations, deepcopy (private) , is_sync (stmt))
115
+ loop = WorkgroupLoop (current, allocations, is_sync (stmt))
126
116
push! (new_stmts, emit (loop))
127
117
allocations = Any[]
128
- private_allocations = Any[]
129
118
current = Any[]
130
119
is_sync (stmt) && continue
131
120
@@ -137,7 +126,7 @@ function split(
137
126
function recurse (expr:: Expr )
138
127
expr = unblock (expr)
139
128
if is_scope_construct (expr) && any (find_sync, expr. args)
140
- new_args = unblock (split (expr. args, deepcopy (indicies), deepcopy (private) ))
129
+ new_args = unblock (split (expr. args))
141
130
return Expr (expr. head, new_args... )
142
131
else
143
132
return Expr (expr. head, map (recurse, expr. args)... )
@@ -151,14 +140,10 @@ function split(
151
140
push! (allocations, stmt)
152
141
continue
153
142
elseif @capture (stmt, @private lhs_ = rhs_)
154
- push! (private, lhs)
155
- push! (private_allocations, :($ lhs = $ rhs))
143
+ push! (allocations, :($ lhs = $ rhs))
156
144
continue
157
145
elseif @capture (stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
158
- if @capture (rhs, @index (args__))
159
- push! (indicies, stmt)
160
- continue
161
- elseif @capture (rhs, @localmem (args__) | @uniform (args__))
146
+ if @capture (rhs, @localmem (args__) | @uniform (args__))
162
147
push! (allocations, stmt)
163
148
continue
164
149
elseif @capture (rhs, @private (T_, dims_))
@@ -170,7 +155,6 @@ function split(
170
155
end
171
156
alloc = :($ Scratchpad (__ctx__, $ T, Val ($ dims)))
172
157
push! (allocations, :($ lhs = $ alloc))
173
- push! (private, lhs)
174
158
continue
175
159
end
176
160
end
@@ -180,7 +164,7 @@ function split(
180
164
181
165
# everything since the last `@synchronize`
182
166
if ! isempty (current)
183
- loop = WorkgroupLoop (deepcopy (indicies), current, allocations, private_allocations, deepcopy (private) , false )
167
+ loop = WorkgroupLoop (current, allocations, false )
184
168
push! (new_stmts, emit (loop))
185
169
end
186
170
return new_stmts
@@ -192,9 +176,7 @@ function emit(loop)
192
176
body = Expr (:block , loop. stmts... )
193
177
loopexpr = quote
194
178
$ (loop. allocations... )
195
- $ (loop. private_allocations... )
196
179
if __active_lane__
197
- $ (loop. indicies... )
198
180
$ (unblock (body))
199
181
end
200
182
end
0 commit comments