Skip to content

Commit 43adce8

Browse files
committed
add compiler support for gpuc.lookup
1 parent fe5df66 commit 43adce8

File tree

2 files changed

+63
-23
lines changed

2 files changed

+63
-23
lines changed

src/driver.jl

+60-21
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob);
164164
end
165165

166166
# GPUCompiler intrinsic that marks deferred compilation
167+
# In contrast to `deferred_codegen` this doesn't support arbitrary
168+
# jobs as call targets.
167169
function var"gpuc.deferred" end
168170

169171
# primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
@@ -221,12 +223,28 @@ const __llvm_initialized = Ref(false)
221223
# since those modules have been finalized themselves, and we don't want to re-finalize.
222224
entry = finish_module!(job, ir, entry)
223225

226+
function unwrap_constant(val)
227+
while val isa ConstantExpr
228+
if opcode(val) == LLVM.API.LLVMIntToPtr ||
229+
opcode(val) == LLVM.API.LLVMBitCast ||
230+
opcode(val) == LLVM.API.LLVMAddrSpaceCast
231+
val = first(operands(val))
232+
else
233+
break
234+
end
235+
end
236+
return val
237+
end
238+
224239
# deferred code generation
225240
has_deferred_jobs = !only_entry && toplevel &&
226-
haskey(functions(ir), "deferred_codegen")
241+
(haskey(functions(ir), "deferred_codegen") ||
242+
haskey(functions(ir), "gpuc.lookup"))
243+
227244
jobs = Dict{CompilerJob, String}(job => entry_fn)
228245
if has_deferred_jobs
229-
dyn_marker = functions(ir)["deferred_codegen"]
246+
dyn_marker = haskey(functions(ir), "deferred_codegen") ? functions(ir)["deferred_codegen"] : nothing
247+
dyn_marker_v2 = haskey(functions(ir), "gpuc.lookup") ? functions(ir)["gpuc.lookup"] : nothing
230248

231249
# iterative compilation (non-recursive)
232250
changed = true
@@ -235,26 +253,40 @@ const __llvm_initialized = Ref(false)
235253

236254
# find deferred compiler
237255
# TODO: recover this information earlier, from the Julia IR
256+
# We can do this now with gpuc.lookup
238257
worklist = Dict{CompilerJob, Vector{LLVM.CallInst}}()
239-
for use in uses(dyn_marker)
240-
# decode the call
241-
call = user(use)::LLVM.CallInst
242-
id = convert(Int, first(operands(call)))
243-
244-
global deferred_codegen_jobs
245-
dyn_val = deferred_codegen_jobs[id]
246-
247-
# get a job in the appopriate world
248-
dyn_job = if dyn_val isa CompilerJob
249-
# trust that the user knows what they're doing
250-
dyn_val
251-
else
252-
ft, tt = dyn_val
253-
dyn_src = methodinstance(ft, tt, tls_world_age())
254-
CompilerJob(dyn_src, job.config)
258+
if dyn_marker !== nothing
259+
for use in uses(dyn_marker)
260+
# decode the call
261+
call = user(use)::LLVM.CallInst
262+
id = convert(Int, first(operands(call)))
263+
264+
global deferred_codegen_jobs
265+
dyn_val = deferred_codegen_jobs[id]
266+
267+
# get a job in the appopriate world
268+
dyn_job = if dyn_val isa CompilerJob
269+
# trust that the user knows what they're doing
270+
dyn_val
271+
else
272+
ft, tt = dyn_val
273+
dyn_src = methodinstance(ft, tt, tls_world_age())
274+
CompilerJob(dyn_src, job.config)
275+
end
276+
277+
push!(get!(worklist, dyn_job, LLVM.CallInst[]), call)
255278
end
279+
end
256280

257-
push!(get!(worklist, dyn_job, LLVM.CallInst[]), call)
281+
if dyn_marker_v2 !== nothing
282+
for use in uses(dyn_marker_v2)
283+
# decode the call
284+
call = user(use)::LLVM.CallInst
285+
dyn_mi = Base.unsafe_pointer_to_objref(
286+
convert(Ptr{Cvoid}, convert(Int, unwrap_constant(operands(call)[1]))))
287+
dyn_job = CompilerJob(dyn_mi, job.config)
288+
push!(get!(worklist, dyn_job, LLVM.CallInst[]), call)
289+
end
258290
end
259291

260292
# compile and link
@@ -296,8 +328,15 @@ const __llvm_initialized = Ref(false)
296328
end
297329

298330
# all deferred compilations should have been resolved
299-
@compiler_assert isempty(uses(dyn_marker)) job
300-
unsafe_delete!(ir, dyn_marker)
331+
if dyn_marker !== nothing
332+
@compiler_assert isempty(uses(dyn_marker)) job
333+
unsafe_delete!(ir, dyn_marker)
334+
end
335+
336+
if dyn_marker_v2 !== nothing
337+
@compiler_assert isempty(uses(dyn_marker_v2)) job
338+
unsafe_delete!(ir, dyn_marker_v2)
339+
end
301340
end
302341

303342
if toplevel

src/jlgen.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,8 @@ else
329329
get_method_table_view(world::UInt, mt::MTType) = OverlayMethodTable(world, mt)
330330
end
331331

332-
struct GPUInterpreter <: CC.AbstractInterpreter
332+
abstract type AbstractGPUInterpreter <: CC.AbstractInterpreter end
333+
struct GPUInterpreter <: AbstractGPUInterpreter
333334
world::UInt
334335
method_table::GPUMethodTableView
335336

@@ -465,7 +466,7 @@ struct DeferredCallInfo <: CC.CallInfo
465466
info::CC.CallInfo
466467
end
467468

468-
function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
469+
function CC.abstract_call_known(interp::AbstractGPUInterpreter, @nospecialize(f),
469470
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
470471
max_methods::Int = CC.get_max_methods(interp, f, sv))
471472
(; fargs, argtypes) = arginfo

0 commit comments

Comments
 (0)