@@ -58,22 +58,105 @@ function transform_gpu!(def, constargs, force_inbounds)
58
58
end
59
59
end
60
60
pushfirst! (def[:args ], :__ctx__ )
61
- body = def[:body ]
61
+ new_stmts = Expr[]
62
+ body = MacroTools. flatten (def[:body ])
63
+ stmts = body. args
64
+ push! (new_stmts, Expr (:aliasscope ))
65
+ push! (new_stmts, :(__active_lane__ = $ __validindex (__ctx__)))
62
66
if force_inbounds
63
- body = quote
64
- @inbounds $ (body)
65
- end
67
+ push! (new_stmts, Expr (:inbounds , true ))
66
68
end
67
- body = quote
68
- if $ __validindex (__ctx__)
69
- $ (body)
70
- end
71
- return nothing
69
+ append! (new_stmts, split (emit_gpu, body. args))
70
+ if force_inbounds
71
+ push! (new_stmts, Expr (:inbounds , :pop ))
72
72
end
73
+ push! (new_stmts, Expr (:popaliasscope ))
74
+ push! (new_stmts, :(return nothing ))
73
75
def[:body ] = Expr (
74
76
:let ,
75
77
Expr (:block , let_constargs... ),
76
- body ,
78
+ Expr ( :block , new_stmts ... ) ,
77
79
)
78
80
return
79
81
end
82
+
83
+ struct WorkgroupLoop
84
+ stmts:: Vector{Any}
85
+ terminated_in_sync:: Bool
86
+ end
87
+
88
+ is_sync (expr) = @capture (expr, @synchronize () | @synchronize (a_))
89
+
90
+ function is_scope_construct (expr:: Expr )
91
+ return expr. head === :block # ||
92
+ # expr.head === :let
93
+ end
94
+
95
+ function find_sync (stmt)
96
+ result = false
97
+ postwalk (stmt) do expr
98
+ result |= is_sync (expr)
99
+ expr
100
+ end
101
+ return result
102
+ end
103
+
104
+ # TODO proper handling of LineInfo
105
+ function split (emit, stmts)
106
+ # 1. Split the code into blocks separated by `@synchronize`
107
+
108
+ current = Any[]
109
+ new_stmts = Any[]
110
+ for stmt in stmts
111
+ has_sync = find_sync (stmt)
112
+ if has_sync
113
+ loop = WorkgroupLoop (current, is_sync (stmt))
114
+ push! (new_stmts, emit (loop))
115
+ current = Any[]
116
+ is_sync (stmt) && continue
117
+
118
+ # Recurse into scope constructs
119
+ # TODO : This currently implements hard scoping
120
+ # probably need to implemet soft scoping
121
+ # by not deepcopying the environment.
122
+ recurse (x) = x
123
+ function recurse (expr:: Expr )
124
+ expr = unblock (expr)
125
+ if is_scope_construct (expr) && any (find_sync, expr. args)
126
+ new_args = unblock (split (emit, expr. args))
127
+ return Expr (expr. head, new_args... )
128
+ else
129
+ return Expr (expr. head, map (recurse, expr. args)... )
130
+ end
131
+ end
132
+ push! (new_stmts, recurse (stmt))
133
+ continue
134
+ end
135
+
136
+ push! (current, stmt)
137
+ end
138
+
139
+ # everything since the last `@synchronize`
140
+ if ! isempty (current)
141
+ loop = WorkgroupLoop (current, false )
142
+ push! (new_stmts, emit (loop))
143
+ end
144
+ return new_stmts
145
+ end
146
+
147
+ function emit_gpu (loop)
148
+ stmts = Any[]
149
+
150
+ body = Expr (:block , loop. stmts... )
151
+ loopexpr = quote
152
+ if __active_lane__
153
+ $ (unblock (body))
154
+ end
155
+ end
156
+ push! (stmts, loopexpr)
157
+ if loop. terminated_in_sync
158
+ push! (stmts, :($ __synchronize ()))
159
+ end
160
+
161
+ return unblock (Expr (:block , stmts... ))
162
+ end
0 commit comments