@@ -39,6 +39,17 @@ function JuliaContext(f; kwargs...)
39
39
end
40
40
41
41
42
+ # # deferred compilation
43
+
44
+ """
45
+ var"gpuc.deferred"(f, args...)::Ptr{Cvoid}
46
+
47
+ As if we were to call `f(args...)` but instead we are
48
+ putting down a marker and return a function pointer to later
49
+ call.
50
+ """
51
+ function var"gpuc.deferred" end
52
+
42
53
# # compiler entrypoint
43
54
44
55
export compile
@@ -127,33 +138,6 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool
127
138
error (" Unknown compilation output $output " )
128
139
end
129
140
130
- # primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
131
- # this could both be generalized (e.g. supporting actual function calls, instead of
132
- # returning a function pointer), and be integrated with the nonrecursive codegen.
133
- const deferred_codegen_jobs = Dict {Int, Any} ()
134
-
135
- # We make this function explicitly callable so that we can drive OrcJIT's
136
- # lazy compilation from, while also enabling recursive compilation.
137
- Base. @ccallable Ptr{Cvoid} function deferred_codegen (ptr:: Ptr{Cvoid} )
138
- ptr
139
- end
140
-
141
- @generated function deferred_codegen (:: Val{ft} , :: Val{tt} ) where {ft,tt}
142
- id = length (deferred_codegen_jobs) + 1
143
- deferred_codegen_jobs[id] = (; ft, tt)
144
- # don't bother looking up the method instance, as we'll do so again during codegen
145
- # using the world age of the parent.
146
- #
147
- # this also works around an issue on <1.10, where we don't know the world age of
148
- # generated functions so use the current world counter, which may be too new
149
- # for the world we're compiling for.
150
-
151
- quote
152
- # TODO : add an edge to this method instance to support method redefinitions
153
- ccall (" extern deferred_codegen" , llvmcall, Ptr{Cvoid}, (Int,), $ id)
154
- end
155
- end
156
-
157
141
const __llvm_initialized = Ref (false )
158
142
159
143
@locked function emit_llvm (@nospecialize (job:: CompilerJob ); toplevel:: Bool ,
@@ -183,79 +167,76 @@ const __llvm_initialized = Ref(false)
183
167
entry = finish_module! (job, ir, entry)
184
168
185
169
# deferred code generation
186
- has_deferred_jobs = toplevel && ! only_entry && haskey ( functions (ir), " deferred_codegen " )
187
- jobs = Dict {CompilerJob, String} (job => entry_fn )
188
- if has_deferred_jobs
189
- dyn_marker = functions (ir)[" deferred_codegen " ]
190
-
191
- # iterative compilation (non-recursive)
192
- changed = true
193
- while changed
194
- changed = false
195
-
196
- # find deferred compiler
197
- # TODO : recover this information earlier, from the Julia IR
198
- worklist = Dict {CompilerJob, Vector{LLVM.CallInst}} ( )
199
- for use in uses (dyn_marker)
200
- # decode the call
201
- call = user (use) :: LLVM.CallInst
202
- id = convert (Int, first (operands (call) ))
203
-
204
- global deferred_codegen_jobs
205
- dyn_val = deferred_codegen_jobs[id]
206
-
207
- # get a job in the appopriate world
208
- dyn_job = if dyn_val isa CompilerJob
209
- # trust that the user knows what they're doing
210
- dyn_val
170
+ run_optimization_for_deferred = false
171
+ if haskey ( functions (ir), " gpuc.lookup " )
172
+ run_optimization_for_deferred = true
173
+ dyn_marker = functions (ir)[" gpuc.lookup " ]
174
+
175
+ # gpuc.deferred is lowered to a gpuc.lookup foreigncall, so we need to extract the
176
+ # target method instance from the LLVM IR
177
+ function find_base_object (val)
178
+ while true
179
+ if val isa ConstantExpr && ( opcode (val) == LLVM . API . LLVMIntToPtr ||
180
+ opcode (val) == LLVM . API . LLVMBitCast ||
181
+ opcode (val) == LLVM . API . LLVMAddrSpaceCast)
182
+ val = first ( operands (val) )
183
+ elseif val isa LLVM . IntToPtrInst ||
184
+ val isa LLVM . BitCastInst ||
185
+ val isa LLVM. AddrSpaceCastInst
186
+ val = first (operands (val ))
187
+ elseif val isa LLVM . LoadInst
188
+ # In 1.11+ we no longer embed integer constants directly.
189
+ gv = first ( operands (val))
190
+ if gv isa LLVM . GlobalValue
191
+ val = LLVM . initializer (gv)
192
+ continue
193
+ end
194
+ break
211
195
else
212
- ft, tt = dyn_val
213
- dyn_src = methodinstance (ft, tt, tls_world_age ())
214
- CompilerJob (dyn_src, job. config)
196
+ break
215
197
end
216
-
217
- push! (get! (worklist, dyn_job, LLVM. CallInst[]), call)
218
198
end
199
+ return val
200
+ end
219
201
220
- # compile and link
221
- for dyn_job in keys (worklist)
222
- # cached compilation
223
- dyn_entry_fn = get! (jobs, dyn_job) do
224
- dyn_ir, dyn_meta = codegen (:llvm , dyn_job; toplevel= false ,
225
- parent_job= job)
226
- dyn_entry_fn = LLVM. name (dyn_meta. entry)
227
- merge! (compiled, dyn_meta. compiled)
228
- @assert context (dyn_ir) == context (ir)
229
- link! (ir, dyn_ir)
230
- changed = true
231
- dyn_entry_fn
232
- end
233
- dyn_entry = functions (ir)[dyn_entry_fn]
234
-
235
- # insert a pointer to the function everywhere the entry is used
236
- T_ptr = convert (LLVMType, Ptr{Cvoid})
237
- for call in worklist[dyn_job]
238
- @dispose builder= IRBuilder () begin
239
- position! (builder, call)
240
- fptr = if LLVM. version () >= v " 17"
241
- T_ptr = LLVM. PointerType ()
242
- bitcast! (builder, dyn_entry, T_ptr)
243
- elseif VERSION >= v " 1.12.0-DEV.225"
244
- T_ptr = LLVM. PointerType (LLVM. Int8Type ())
245
- bitcast! (builder, dyn_entry, T_ptr)
246
- else
247
- ptrtoint! (builder, dyn_entry, T_ptr)
248
- end
249
- replace_uses! (call, fptr)
202
+ worklist = Dict {Any, Vector{LLVM.CallInst}} ()
203
+ for use in uses (dyn_marker)
204
+ # decode the call
205
+ call = user (use):: LLVM.CallInst
206
+ dyn_mi_inst = find_base_object (operands (call)[1 ])
207
+ @compiler_assert isa (dyn_mi_inst, LLVM. ConstantInt) job
208
+ dyn_mi = Base. unsafe_pointer_to_objref (
209
+ convert (Ptr{Cvoid}, convert (Int, dyn_mi_inst)))
210
+ push! (get! (worklist, dyn_mi, LLVM. CallInst[]), call)
211
+ end
212
+
213
+ for dyn_mi in keys (worklist)
214
+ dyn_fn_name = compiled[dyn_mi]. specfunc
215
+ dyn_fn = functions (ir)[dyn_fn_name]
216
+
217
+ # insert a pointer to the function everywhere the entry is used
218
+ T_ptr = convert (LLVMType, Ptr{Cvoid})
219
+ for call in worklist[dyn_mi]
220
+ @dispose builder= IRBuilder () begin
221
+ position! (builder, call)
222
+ fptr = if LLVM. version () >= v " 17"
223
+ T_ptr = LLVM. PointerType ()
224
+ bitcast! (builder, dyn_fn, T_ptr)
225
+ elseif VERSION >= v " 1.12.0-DEV.225"
226
+ T_ptr = LLVM. PointerType (LLVM. Int8Type ())
227
+ bitcast! (builder, dyn_fn, T_ptr)
228
+ else
229
+ ptrtoint! (builder, dyn_fn, T_ptr)
250
230
end
251
- erase ! (call)
231
+ replace_uses ! (call, fptr )
252
232
end
233
+ unsafe_delete! (LLVM. parent (call), call)
253
234
end
254
235
end
255
236
256
237
# all deferred compilations should have been resolved
257
238
@compiler_assert isempty (uses (dyn_marker)) job
258
- erase! ( dyn_marker)
239
+ unsafe_delete! (ir, dyn_marker)
259
240
end
260
241
261
242
if libraries
@@ -285,7 +266,7 @@ const __llvm_initialized = Ref(false)
285
266
# global variables. this makes sure that the optimizer can, e.g.,
286
267
# rewrite function signatures.
287
268
if toplevel
288
- preserved_gvs = collect ( values (jobs))
269
+ preserved_gvs = [entry_fn]
289
270
for gvar in globals (ir)
290
271
if linkage (gvar) == LLVM. API. LLVMExternalLinkage
291
272
push! (preserved_gvs, LLVM. name (gvar))
@@ -317,7 +298,7 @@ const __llvm_initialized = Ref(false)
317
298
# deferred codegen has some special optimization requirements,
318
299
# which also need to happen _after_ regular optimization.
319
300
# XXX : make these part of the optimizer pipeline?
320
- if has_deferred_jobs
301
+ if run_optimization_for_deferred
321
302
@dispose pb= NewPMPassBuilder () begin
322
303
add! (pb, NewPMFunctionPassManager ()) do fpm
323
304
add! (fpm, InstCombinePass ())
@@ -353,15 +334,15 @@ const __llvm_initialized = Ref(false)
353
334
# finish the module
354
335
#
355
336
# we want to finish the module after optimization, so we cannot do so
356
- # during deferred code generation. instead , process the deferred jobs
357
- # here.
337
+ # during deferred code generation. Instead , process the merged module
338
+ # from all the jobs here.
358
339
if toplevel
359
340
entry = finish_ir! (job, ir, entry)
360
341
361
- for (job′, fn′) in jobs
362
- job′ == job && continue
363
- finish_ir! (job′, ir, functions (ir)[fn′])
364
- end
342
+ # for (job′, fn′) in jobs
343
+ # job′ == job && continue
344
+ # finish_ir!(job′, ir, functions(ir)[fn′])
345
+ # end
365
346
end
366
347
367
348
# replace non-entry function definitions with a declaration
0 commit comments