@@ -164,6 +164,8 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob);
164
164
end
165
165
166
166
# GPUCompiler intrinsic that marks deferred compilation
167
+ # In contrast to `deferred_codegen` this doesn't support arbitrary
168
+ # jobs as call targets.
167
169
function var"gpuc.deferred" end
168
170
169
171
# primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
@@ -221,12 +223,28 @@ const __llvm_initialized = Ref(false)
221
223
# since those modules have been finalized themselves, and we don't want to re-finalize.
222
224
entry = finish_module! (job, ir, entry)
223
225
226
+ function unwrap_constant (val)
227
+ while val isa ConstantExpr
228
+ if opcode (val) == LLVM. API. LLVMIntToPtr ||
229
+ opcode (val) == LLVM. API. LLVMBitCast ||
230
+ opcode (val) == LLVM. API. LLVMAddrSpaceCast
231
+ val = first (operands (val))
232
+ else
233
+ break
234
+ end
235
+ end
236
+ return val
237
+ end
238
+
224
239
# deferred code generation
225
240
has_deferred_jobs = ! only_entry && toplevel &&
226
- haskey (functions (ir), " deferred_codegen" )
241
+ (haskey (functions (ir), " deferred_codegen" ) ||
242
+ haskey (functions (ir), " gpuc.lookup" ))
243
+
227
244
jobs = Dict {CompilerJob, String} (job => entry_fn)
228
245
if has_deferred_jobs
229
- dyn_marker = functions (ir)[" deferred_codegen" ]
246
+ dyn_marker = haskey (functions (ir), " deferred_codegen" ) ? functions (ir)[" deferred_codegen" ] : nothing
247
+ dyn_marker_v2 = haskey (functions (ir), " gpuc.lookup" ) ? functions (ir)[" gpuc.lookup" ] : nothing
230
248
231
249
# iterative compilation (non-recursive)
232
250
changed = true
@@ -235,26 +253,40 @@ const __llvm_initialized = Ref(false)
235
253
236
254
# find deferred compiler
237
255
# TODO : recover this information earlier, from the Julia IR
256
+ # We can do this now with gpuc.lookup
238
257
worklist = Dict {CompilerJob, Vector{LLVM.CallInst}} ()
239
- for use in uses (dyn_marker)
240
- # decode the call
241
- call = user (use):: LLVM.CallInst
242
- id = convert (Int, first (operands (call)))
243
-
244
- global deferred_codegen_jobs
245
- dyn_val = deferred_codegen_jobs[id]
246
-
247
- # get a job in the appopriate world
248
- dyn_job = if dyn_val isa CompilerJob
249
- # trust that the user knows what they're doing
250
- dyn_val
251
- else
252
- ft, tt = dyn_val
253
- dyn_src = methodinstance (ft, tt, tls_world_age ())
254
- CompilerJob (dyn_src, job. config)
258
+ if dyn_marker != = nothing
259
+ for use in uses (dyn_marker)
260
+ # decode the call
261
+ call = user (use):: LLVM.CallInst
262
+ id = convert (Int, first (operands (call)))
263
+
264
+ global deferred_codegen_jobs
265
+ dyn_val = deferred_codegen_jobs[id]
266
+
267
+ # get a job in the appopriate world
268
+ dyn_job = if dyn_val isa CompilerJob
269
+ # trust that the user knows what they're doing
270
+ dyn_val
271
+ else
272
+ ft, tt = dyn_val
273
+ dyn_src = methodinstance (ft, tt, tls_world_age ())
274
+ CompilerJob (dyn_src, job. config)
275
+ end
276
+
277
+ push! (get! (worklist, dyn_job, LLVM. CallInst[]), call)
255
278
end
279
+ end
256
280
257
- push! (get! (worklist, dyn_job, LLVM. CallInst[]), call)
281
+ if dyn_marker_v2 != = nothing
282
+ for use in uses (dyn_marker_v2)
283
+ # decode the call
284
+ call = user (use):: LLVM.CallInst
285
+ dyn_mi = Base. unsafe_pointer_to_objref (
286
+ convert (Ptr{Cvoid}, convert (Int, unwrap_constant (operands (call)[1 ]))))
287
+ dyn_job = CompilerJob (dyn_mi, job. config)
288
+ push! (get! (worklist, dyn_job, LLVM. CallInst[]), call)
289
+ end
258
290
end
259
291
260
292
# compile and link
@@ -296,8 +328,15 @@ const __llvm_initialized = Ref(false)
296
328
end
297
329
298
330
# all deferred compilations should have been resolved
299
- @compiler_assert isempty (uses (dyn_marker)) job
300
- unsafe_delete! (ir, dyn_marker)
331
+ if dyn_marker != = nothing
332
+ @compiler_assert isempty (uses (dyn_marker)) job
333
+ unsafe_delete! (ir, dyn_marker)
334
+ end
335
+
336
+ if dyn_marker_v2 != = nothing
337
+ @compiler_assert isempty (uses (dyn_marker_v2)) job
338
+ unsafe_delete! (ir, dyn_marker_v2)
339
+ end
301
340
end
302
341
303
342
if toplevel
0 commit comments