Skip to content

Commit a018c17

Browse files
committed
Rework deferred compilation mechanism.
1 parent 316668b commit a018c17

File tree

5 files changed

+293
-129
lines changed

5 files changed

+293
-129
lines changed

examples/jit.jl

+28-27
Original file line numberDiff line numberDiff line change
@@ -116,31 +116,31 @@ function get_trampoline(job)
116116
return addr
117117
end
118118

119-
import GPUCompiler: deferred_codegen_jobs
120-
@generated function deferred_codegen(f::F, ::Val{tt}, ::Val{world}) where {F,tt,world}
121-
# manual version of native_job because we have a function type
122-
source = methodinstance(F, Base.to_tuple_type(tt), world)
123-
target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true)
124-
# XXX: do we actually require the Julia runtime?
125-
# with jlruntime=false, we reach an unreachable.
126-
params = TestCompilerParams()
127-
config = CompilerConfig(target, params; kernel=false)
128-
job = CompilerJob(source, config, world)
129-
# XXX: invoking GPUCompiler from a generated function is not allowed!
130-
# for things to work, we need to forward the correct world, at least.
131-
132-
addr = get_trampoline(job)
133-
trampoline = pointer(addr)
134-
id = Base.reinterpret(Int, trampoline)
135-
136-
deferred_codegen_jobs[id] = job
137-
138-
quote
139-
ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $trampoline)
140-
assume(ptr != C_NULL)
141-
return ptr
142-
end
143-
end
119+
# import GPUCompiler: deferred_codegen_jobs
120+
# @generated function deferred_codegen(f::F, ::Val{tt}, ::Val{world}) where {F,tt,world}
121+
# # manual version of native_job because we have a function type
122+
# source = methodinstance(F, Base.to_tuple_type(tt), world)
123+
# target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true)
124+
# # XXX: do we actually require the Julia runtime?
125+
# # with jlruntime=false, we reach an unreachable.
126+
# params = TestCompilerParams()
127+
# config = CompilerConfig(target, params; kernel=false)
128+
# job = CompilerJob(source, config, world)
129+
# # XXX: invoking GPUCompiler from a generated function is not allowed!
130+
# # for things to work, we need to forward the correct world, at least.
131+
132+
# addr = get_trampoline(job)
133+
# trampoline = pointer(addr)
134+
# id = Base.reinterpret(Int, trampoline)
135+
136+
# deferred_codegen_jobs[id] = job
137+
138+
# quote
139+
# ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $trampoline)
140+
# assume(ptr != C_NULL)
141+
# return ptr
142+
# end
143+
# end
144144

145145
@generated function abi_call(f::Ptr{Cvoid}, rt::Type{RT}, tt::Type{T}, func::F, args::Vararg{Any, N}) where {T, RT, F, N}
146146
argtt = tt.parameters[1]
@@ -224,8 +224,9 @@ end
224224
@inline function call_delayed(f::F, args...) where F
225225
tt = Tuple{map(Core.Typeof, args)...}
226226
rt = Core.Compiler.return_type(f, tt)
227-
world = GPUCompiler.tls_world_age()
228-
ptr = deferred_codegen(f, Val(tt), Val(world))
227+
# FIXME: Horrible idea, have `var"gpuc.deferred"` actually do the work
228+
# But that will only be needed here, and in Enzyme...
229+
ptr = GPUCompiler.var"gpuc.deferred"(f, args...)
229230
abi_call(ptr, rt, tt, f, args...)
230231
end
231232

src/driver.jl

+78-97
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@ function JuliaContext(f; kwargs...)
3939
end
4040

4141

42+
## deferred compilation
43+
44+
"""
45+
var"gpuc.deferred"(f, args...)::Ptr{Cvoid}
46+
47+
As if we were to call `f(args...)` but instead we are
48+
putting down a marker and return a function pointer to later
49+
call.
50+
"""
51+
function var"gpuc.deferred" end
52+
4253
## compiler entrypoint
4354

4455
export compile
@@ -127,33 +138,6 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool
127138
error("Unknown compilation output $output")
128139
end
129140

130-
# primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
131-
# this could both be generalized (e.g. supporting actual function calls, instead of
132-
# returning a function pointer), and be integrated with the nonrecursive codegen.
133-
const deferred_codegen_jobs = Dict{Int, Any}()
134-
135-
# We make this function explicitly callable so that we can drive OrcJIT's
136-
# lazy compilation from, while also enabling recursive compilation.
137-
Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid})
138-
ptr
139-
end
140-
141-
@generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft,tt}
142-
id = length(deferred_codegen_jobs) + 1
143-
deferred_codegen_jobs[id] = (; ft, tt)
144-
# don't bother looking up the method instance, as we'll do so again during codegen
145-
# using the world age of the parent.
146-
#
147-
# this also works around an issue on <1.10, where we don't know the world age of
148-
# generated functions so use the current world counter, which may be too new
149-
# for the world we're compiling for.
150-
151-
quote
152-
# TODO: add an edge to this method instance to support method redefinitions
153-
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), $id)
154-
end
155-
end
156-
157141
const __llvm_initialized = Ref(false)
158142

159143
@locked function emit_llvm(@nospecialize(job::CompilerJob); toplevel::Bool,
@@ -183,79 +167,76 @@ const __llvm_initialized = Ref(false)
183167
entry = finish_module!(job, ir, entry)
184168

185169
# deferred code generation
186-
has_deferred_jobs = toplevel && !only_entry && haskey(functions(ir), "deferred_codegen")
187-
jobs = Dict{CompilerJob, String}(job => entry_fn)
188-
if has_deferred_jobs
189-
dyn_marker = functions(ir)["deferred_codegen"]
190-
191-
# iterative compilation (non-recursive)
192-
changed = true
193-
while changed
194-
changed = false
195-
196-
# find deferred compiler
197-
# TODO: recover this information earlier, from the Julia IR
198-
worklist = Dict{CompilerJob, Vector{LLVM.CallInst}}()
199-
for use in uses(dyn_marker)
200-
# decode the call
201-
call = user(use)::LLVM.CallInst
202-
id = convert(Int, first(operands(call)))
203-
204-
global deferred_codegen_jobs
205-
dyn_val = deferred_codegen_jobs[id]
206-
207-
# get a job in the appopriate world
208-
dyn_job = if dyn_val isa CompilerJob
209-
# trust that the user knows what they're doing
210-
dyn_val
170+
run_optimization_for_deferred = false
171+
if haskey(functions(ir), "gpuc.lookup")
172+
run_optimization_for_deferred = true
173+
dyn_marker = functions(ir)["gpuc.lookup"]
174+
175+
# gpuc.deferred is lowered to a gpuc.lookup foreigncall, so we need to extract the
176+
# target method instance from the LLVM IR
177+
function find_base_object(val)
178+
while true
179+
if val isa ConstantExpr && (opcode(val) == LLVM.API.LLVMIntToPtr ||
180+
opcode(val) == LLVM.API.LLVMBitCast ||
181+
opcode(val) == LLVM.API.LLVMAddrSpaceCast)
182+
val = first(operands(val))
183+
elseif val isa LLVM.IntToPtrInst ||
184+
val isa LLVM.BitCastInst ||
185+
val isa LLVM.AddrSpaceCastInst
186+
val = first(operands(val))
187+
elseif val isa LLVM.LoadInst
188+
# In 1.11+ we no longer embed integer constants directly.
189+
gv = first(operands(val))
190+
if gv isa LLVM.GlobalValue
191+
val = LLVM.initializer(gv)
192+
continue
193+
end
194+
break
211195
else
212-
ft, tt = dyn_val
213-
dyn_src = methodinstance(ft, tt, tls_world_age())
214-
CompilerJob(dyn_src, job.config)
196+
break
215197
end
216-
217-
push!(get!(worklist, dyn_job, LLVM.CallInst[]), call)
218198
end
199+
return val
200+
end
219201

220-
# compile and link
221-
for dyn_job in keys(worklist)
222-
# cached compilation
223-
dyn_entry_fn = get!(jobs, dyn_job) do
224-
dyn_ir, dyn_meta = codegen(:llvm, dyn_job; toplevel=false,
225-
parent_job=job)
226-
dyn_entry_fn = LLVM.name(dyn_meta.entry)
227-
merge!(compiled, dyn_meta.compiled)
228-
@assert context(dyn_ir) == context(ir)
229-
link!(ir, dyn_ir)
230-
changed = true
231-
dyn_entry_fn
232-
end
233-
dyn_entry = functions(ir)[dyn_entry_fn]
234-
235-
# insert a pointer to the function everywhere the entry is used
236-
T_ptr = convert(LLVMType, Ptr{Cvoid})
237-
for call in worklist[dyn_job]
238-
@dispose builder=IRBuilder() begin
239-
position!(builder, call)
240-
fptr = if LLVM.version() >= v"17"
241-
T_ptr = LLVM.PointerType()
242-
bitcast!(builder, dyn_entry, T_ptr)
243-
elseif VERSION >= v"1.12.0-DEV.225"
244-
T_ptr = LLVM.PointerType(LLVM.Int8Type())
245-
bitcast!(builder, dyn_entry, T_ptr)
246-
else
247-
ptrtoint!(builder, dyn_entry, T_ptr)
248-
end
249-
replace_uses!(call, fptr)
202+
worklist = Dict{Any, Vector{LLVM.CallInst}}()
203+
for use in uses(dyn_marker)
204+
# decode the call
205+
call = user(use)::LLVM.CallInst
206+
dyn_mi_inst = find_base_object(operands(call)[1])
207+
@compiler_assert isa(dyn_mi_inst, LLVM.ConstantInt) job
208+
dyn_mi = Base.unsafe_pointer_to_objref(
209+
convert(Ptr{Cvoid}, convert(Int, dyn_mi_inst)))
210+
push!(get!(worklist, dyn_mi, LLVM.CallInst[]), call)
211+
end
212+
213+
for dyn_mi in keys(worklist)
214+
dyn_fn_name = compiled[dyn_mi].specfunc
215+
dyn_fn = functions(ir)[dyn_fn_name]
216+
217+
# insert a pointer to the function everywhere the entry is used
218+
T_ptr = convert(LLVMType, Ptr{Cvoid})
219+
for call in worklist[dyn_mi]
220+
@dispose builder=IRBuilder() begin
221+
position!(builder, call)
222+
fptr = if LLVM.version() >= v"17"
223+
T_ptr = LLVM.PointerType()
224+
bitcast!(builder, dyn_fn, T_ptr)
225+
elseif VERSION >= v"1.12.0-DEV.225"
226+
T_ptr = LLVM.PointerType(LLVM.Int8Type())
227+
bitcast!(builder, dyn_fn, T_ptr)
228+
else
229+
ptrtoint!(builder, dyn_fn, T_ptr)
250230
end
251-
erase!(call)
231+
replace_uses!(call, fptr)
252232
end
233+
unsafe_delete!(LLVM.parent(call), call)
253234
end
254235
end
255236

256237
# all deferred compilations should have been resolved
257238
@compiler_assert isempty(uses(dyn_marker)) job
258-
erase!(dyn_marker)
239+
unsafe_delete!(ir, dyn_marker)
259240
end
260241

261242
if libraries
@@ -285,7 +266,7 @@ const __llvm_initialized = Ref(false)
285266
# global variables. this makes sure that the optimizer can, e.g.,
286267
# rewrite function signatures.
287268
if toplevel
288-
preserved_gvs = collect(values(jobs))
269+
preserved_gvs = [entry_fn]
289270
for gvar in globals(ir)
290271
if linkage(gvar) == LLVM.API.LLVMExternalLinkage
291272
push!(preserved_gvs, LLVM.name(gvar))
@@ -317,7 +298,7 @@ const __llvm_initialized = Ref(false)
317298
# deferred codegen has some special optimization requirements,
318299
# which also need to happen _after_ regular optimization.
319300
# XXX: make these part of the optimizer pipeline?
320-
if has_deferred_jobs
301+
if run_optimization_for_deferred
321302
@dispose pb=NewPMPassBuilder() begin
322303
add!(pb, NewPMFunctionPassManager()) do fpm
323304
add!(fpm, InstCombinePass())
@@ -353,15 +334,15 @@ const __llvm_initialized = Ref(false)
353334
# finish the module
354335
#
355336
# we want to finish the module after optimization, so we cannot do so
356-
# during deferred code generation. instead, process the deferred jobs
357-
# here.
337+
# during deferred code generation. Instead, process the merged module
338+
# from all the jobs here.
358339
if toplevel
359340
entry = finish_ir!(job, ir, entry)
360341

361-
for (job′, fn′) in jobs
362-
job′ == job && continue
363-
finish_ir!(job′, ir, functions(ir)[fn′])
364-
end
342+
# for (job′, fn′) in jobs
343+
# job′ == job && continue
344+
# finish_ir!(job′, ir, functions(ir)[fn′])
345+
# end
365346
end
366347

367348
# replace non-entry function definitions with a declaration

src/irgen.jl

+11
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,17 @@ function irgen(@nospecialize(job::CompilerJob))
8080
compiled[job.source] =
8181
(; compiled[job.source].ci, func, specfunc)
8282

83+
# Earlier we sanitize global names, this invalidates the
84+
# func, specfunc names safed in compiled. Update the names now,
85+
# such that when when use the compiled mappings to lookup the
86+
# llvm function for a methodinstance (deferred codegen) we have
87+
# valid targets.
88+
for mi in keys(compiled)
89+
mi == job.source && continue
90+
ci, func, specfunc = compiled[mi]
91+
compiled[mi] = (; ci, func=safe_name(func), specfunc=safe_name(specfunc))
92+
end
93+
8394
# minimal required optimization
8495
@timeit_debug to "rewrite" begin
8596
if job.config.kernel && needs_byval(job)

0 commit comments

Comments
 (0)