Skip to content

Commit e532812

Browse files
vchuravymaleadt
authored andcommitted
Add new deferred compilation mechanism.
1 parent d68a7fc commit e532812

File tree

4 files changed

+280
-33
lines changed

4 files changed

+280
-33
lines changed

src/driver.jl

+106-28
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,41 @@ function JuliaContext(f; kwargs...)
3939
end
4040

4141

42+
## deferred compilation
43+
44+
function var"gpuc.deferred" end
45+
46+
# old, deprecated mechanism slated for removal once Enzyme is updated to the new intrinsic
47+
begin
48+
# primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
49+
# this could both be generalized (e.g. supporting actual function calls, instead of
50+
# returning a function pointer), and be integrated with the nonrecursive codegen.
51+
const deferred_codegen_jobs = Dict{Int, Any}()
52+
53+
# We make this function explicitly callable so that we can drive OrcJIT's
54+
# lazy compilation from, while also enabling recursive compilation.
55+
Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid})
56+
ptr
57+
end
58+
59+
@generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft,tt}
60+
id = length(deferred_codegen_jobs) + 1
61+
deferred_codegen_jobs[id] = (; ft, tt)
62+
# don't bother looking up the method instance, as we'll do so again during codegen
63+
# using the world age of the parent.
64+
#
65+
# this also works around an issue on <1.10, where we don't know the world age of
66+
# generated functions so use the current world counter, which may be too new
67+
# for the world we're compiling for.
68+
69+
quote
70+
# TODO: add an edge to this method instance to support method redefinitions
71+
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), $id)
72+
end
73+
end
74+
end
75+
76+
4277
## compiler entrypoint
4378

4479
export compile
@@ -127,33 +162,6 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool
127162
error("Unknown compilation output $output")
128163
end
129164

130-
# primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
131-
# this could both be generalized (e.g. supporting actual function calls, instead of
132-
# returning a function pointer), and be integrated with the nonrecursive codegen.
133-
const deferred_codegen_jobs = Dict{Int, Any}()
134-
135-
# We make this function explicitly callable so that we can drive OrcJIT's
136-
# lazy compilation from, while also enabling recursive compilation.
137-
Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid})
138-
ptr
139-
end
140-
141-
@generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft,tt}
142-
id = length(deferred_codegen_jobs) + 1
143-
deferred_codegen_jobs[id] = (; ft, tt)
144-
# don't bother looking up the method instance, as we'll do so again during codegen
145-
# using the world age of the parent.
146-
#
147-
# this also works around an issue on <1.10, where we don't know the world age of
148-
# generated functions so use the current world counter, which may be too new
149-
# for the world we're compiling for.
150-
151-
quote
152-
# TODO: add an edge to this method instance to support method redefinitions
153-
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), $id)
154-
end
155-
end
156-
157165
const __llvm_initialized = Ref(false)
158166

159167
@locked function emit_llvm(@nospecialize(job::CompilerJob); toplevel::Bool,
@@ -183,6 +191,77 @@ const __llvm_initialized = Ref(false)
183191
entry = finish_module!(job, ir, entry)
184192

185193
# deferred code generation
194+
if haskey(functions(ir), "gpuc.lookup")
195+
dyn_marker = functions(ir)["gpuc.lookup"]
196+
197+
# gpuc.deferred is lowered to a gpuc.lookup foreigncall, so we need to extract the
198+
# target method instance from the LLVM IR
199+
# TODO: drive deferred compilation from the Julia IR instead
200+
function find_base_object(val)
201+
while true
202+
if val isa ConstantExpr && (opcode(val) == LLVM.API.LLVMIntToPtr ||
203+
opcode(val) == LLVM.API.LLVMBitCast ||
204+
opcode(val) == LLVM.API.LLVMAddrSpaceCast)
205+
val = first(operands(val))
206+
elseif val isa LLVM.IntToPtrInst ||
207+
val isa LLVM.BitCastInst ||
208+
val isa LLVM.AddrSpaceCastInst
209+
val = first(operands(val))
210+
elseif val isa LLVM.LoadInst
211+
# In 1.11+ we no longer embed integer constants directly.
212+
gv = first(operands(val))
213+
if gv isa LLVM.GlobalValue
214+
val = LLVM.initializer(gv)
215+
continue
216+
end
217+
break
218+
else
219+
break
220+
end
221+
end
222+
return val
223+
end
224+
225+
worklist = Dict{Any, Vector{LLVM.CallInst}}()
226+
for use in uses(dyn_marker)
227+
# decode the call
228+
call = user(use)::LLVM.CallInst
229+
dyn_mi_inst = find_base_object(operands(call)[1])
230+
@compiler_assert isa(dyn_mi_inst, LLVM.ConstantInt) job
231+
dyn_mi = Base.unsafe_pointer_to_objref(
232+
convert(Ptr{Cvoid}, convert(Int, dyn_mi_inst)))
233+
push!(get!(worklist, dyn_mi, LLVM.CallInst[]), call)
234+
end
235+
236+
for dyn_mi in keys(worklist)
237+
dyn_fn_name = compiled[dyn_mi].specfunc
238+
dyn_fn = functions(ir)[dyn_fn_name]
239+
240+
# insert a pointer to the function everywhere the entry is used
241+
T_ptr = convert(LLVMType, Ptr{Cvoid})
242+
for call in worklist[dyn_mi]
243+
@dispose builder=IRBuilder() begin
244+
position!(builder, call)
245+
fptr = if LLVM.version() >= v"17"
246+
T_ptr = LLVM.PointerType()
247+
bitcast!(builder, dyn_fn, T_ptr)
248+
elseif VERSION >= v"1.12.0-DEV.225"
249+
T_ptr = LLVM.PointerType(LLVM.Int8Type())
250+
bitcast!(builder, dyn_fn, T_ptr)
251+
else
252+
ptrtoint!(builder, dyn_fn, T_ptr)
253+
end
254+
replace_uses!(call, fptr)
255+
end
256+
unsafe_delete!(LLVM.parent(call), call)
257+
end
258+
end
259+
260+
# all deferred compilations should have been resolved
261+
@compiler_assert isempty(uses(dyn_marker)) job
262+
unsafe_delete!(ir, dyn_marker)
263+
end
264+
## old, deprecated implementation
186265
has_deferred_jobs = toplevel && !only_entry && haskey(functions(ir), "deferred_codegen")
187266
jobs = Dict{CompilerJob, String}(job => entry_fn)
188267
if has_deferred_jobs
@@ -194,7 +273,6 @@ const __llvm_initialized = Ref(false)
194273
changed = false
195274

196275
# find deferred compiler
197-
# TODO: recover this information earlier, from the Julia IR
198276
worklist = Dict{CompilerJob, Vector{LLVM.CallInst}}()
199277
for use in uses(dyn_marker)
200278
# decode the call

src/irgen.jl

+6
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ function irgen(@nospecialize(job::CompilerJob))
8080
compiled[job.source] =
8181
(; compiled[job.source].ci, func, specfunc)
8282

83+
for mi in keys(compiled)
84+
mi == job.source && continue
85+
ci, func, specfunc = compiled[mi]
86+
compiled[mi] = (; ci, func=safe_name(func), specfunc=safe_name(specfunc))
87+
end
88+
8389
# minimal required optimization
8490
@timeit_debug to "rewrite" begin
8591
if job.config.kernel && needs_byval(job)

src/jlgen.jl

+154-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Julia compiler integration
22

3-
43
## world age lookups
54

65
# `tls_world_age` should be used to look up the current world age. in most cases, this is
@@ -12,6 +11,7 @@ else
1211
tls_world_age() = ccall(:jl_get_tls_world_age, UInt, ())
1312
end
1413

14+
1515
## looking up method instances
1616

1717
export methodinstance, generic_methodinstance
@@ -159,6 +159,7 @@ end
159159

160160

161161
## code instance cache
162+
162163
const HAS_INTEGRATED_CACHE = VERSION >= v"1.11.0-DEV.1552"
163164

164165
if !HAS_INTEGRATED_CACHE
@@ -318,7 +319,8 @@ else
318319
get_method_table_view(world::UInt, mt::MTType) = OverlayMethodTable(world, mt)
319320
end
320321

321-
struct GPUInterpreter <: CC.AbstractInterpreter
322+
abstract type AbstractGPUInterpreter <: CC.AbstractInterpreter end
323+
struct GPUInterpreter <: AbstractGPUInterpreter
322324
world::UInt
323325
method_table::GPUMethodTableView
324326

@@ -436,6 +438,112 @@ function CC.concrete_eval_eligible(interp::GPUInterpreter,
436438
end
437439

438440

441+
## deferred compilation
442+
443+
struct DeferredCallInfo <: CC.CallInfo
444+
rt::DataType
445+
info::CC.CallInfo
446+
end
447+
448+
# recognize calls to gpuc.deferred and save DeferredCallInfo metadata
449+
function CC.abstract_call_known(interp::AbstractGPUInterpreter, @nospecialize(f),
450+
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
451+
max_methods::Int = CC.get_max_methods(interp, f, sv))
452+
(; fargs, argtypes) = arginfo
453+
if f === var"gpuc.deferred"
454+
argvec = argtypes[2:end]
455+
call = CC.abstract_call(interp, CC.ArgInfo(nothing, argvec), si, sv, max_methods)
456+
callinfo = DeferredCallInfo(call.rt, call.info)
457+
@static if VERSION < v"1.11.0-"
458+
return CC.CallMeta(Ptr{Cvoid}, CC.Effects(), callinfo)
459+
else
460+
return CC.CallMeta(Ptr{Cvoid}, Union{}, CC.Effects(), callinfo)
461+
end
462+
end
463+
return @invoke CC.abstract_call_known(interp::CC.AbstractInterpreter, f,
464+
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
465+
max_methods::Int)
466+
end
467+
468+
# during inlining, refine deferred calls to gpuc.lookup foreigncalls
469+
const FlagType = VERSION >= v"1.11.0-" ? UInt32 : UInt8
470+
function CC.handle_call!(todo::Vector{Pair{Int,Any}}, ir::CC.IRCode, idx::CC.Int,
471+
stmt::Expr, info::DeferredCallInfo, flag::FlagType,
472+
sig::CC.Signature, state::CC.InliningState)
473+
minfo = info.info
474+
results = minfo.results
475+
if length(results.matches) != 1
476+
return nothing
477+
end
478+
match = only(results.matches)
479+
480+
# lookup the target mi with correct edge tracking
481+
case = CC.compileable_specialization(match, CC.Effects(), CC.InliningEdgeTracker(state),
482+
info)
483+
@assert case isa CC.InvokeCase
484+
@assert stmt.head === :call
485+
486+
args = Any[
487+
"extern gpuc.lookup",
488+
Ptr{Cvoid},
489+
Core.svec(Any, Any, match.spec_types.parameters[2:end]...), # Must use Any for MethodInstance or ftype
490+
0,
491+
QuoteNode(:llvmcall),
492+
case.invoke,
493+
stmt.args[2:end]...
494+
]
495+
stmt.head = :foreigncall
496+
stmt.args = args
497+
return nothing
498+
end
499+
500+
struct DeferredEdges
501+
edges::Vector{MethodInstance}
502+
end
503+
504+
function find_deferred_edges(ir::CC.IRCode)
505+
edges = MethodInstance[]
506+
# XXX: can we add this instead in handle_call?
507+
for stmt in ir.stmts
508+
inst = stmt[:inst]
509+
inst isa Expr || continue
510+
expr = inst::Expr
511+
if expr.head === :foreigncall &&
512+
expr.args[1] == "extern gpuc.lookup"
513+
deferred_mi = expr.args[6]
514+
push!(edges, deferred_mi)
515+
end
516+
end
517+
unique!(edges)
518+
return edges
519+
end
520+
521+
if VERSION >= v"1.11.0-"
522+
function CC.ipo_dataflow_analysis!(interp::AbstractGPUInterpreter, ir::CC.IRCode,
523+
caller::CC.InferenceResult)
524+
edges = find_deferred_edges(ir)
525+
if !isempty(edges)
526+
CC.stack_analysis_result!(caller, DeferredEdges(edges))
527+
end
528+
@invoke CC.ipo_dataflow_analysis!(interp::CC.AbstractInterpreter, ir::CC.IRCode,
529+
caller::CC.InferenceResult)
530+
end
531+
else # v1.10
532+
# 1.10 doesn't have stack_analysis_result or ipo_dataflow_analysis
533+
function CC.finish(interp::AbstractGPUInterpreter, opt::CC.OptimizationState, ir::CC.IRCode,
534+
caller::CC.InferenceResult)
535+
edges = find_deferred_edges(ir)
536+
if !isempty(edges)
537+
# HACK: we store the deferred edges in the argescapes field, which is invalid,
538+
# but nobody should be running EA on our results.
539+
caller.argescapes = DeferredEdges(edges)
540+
end
541+
@invoke CC.finish(interp::CC.AbstractInterpreter, opt::CC.OptimizationState,
542+
ir::CC.IRCode, caller::CC.InferenceResult)
543+
end
544+
end
545+
546+
439547
## world view of the cache
440548
using Core.Compiler: WorldView
441549

@@ -584,6 +692,24 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
584692
error("Cannot compile $(job.source) for world $(job.world); method is only valid in worlds $(job.source.def.primary_world) to $(job.source.def.deleted_world)")
585693
end
586694

695+
compiled = IdDict()
696+
llvm_mod, outstanding = compile_method_instance(job, compiled)
697+
worklist = outstanding
698+
while !isempty(worklist)
699+
source = pop!(worklist)
700+
haskey(compiled, source) && continue
701+
job2 = CompilerJob(source, job.config)
702+
@debug "Processing..." job2
703+
llvm_mod2, outstanding = compile_method_instance(job2, compiled)
704+
append!(worklist, outstanding)
705+
@assert context(llvm_mod) == context(llvm_mod2)
706+
link!(llvm_mod, llvm_mod2)
707+
end
708+
709+
return llvm_mod, compiled
710+
end
711+
712+
function compile_method_instance(@nospecialize(job::CompilerJob), compiled::IdDict{Any, Any})
587713
# populate the cache
588714
interp = get_interpreter(job)
589715
cache = CC.code_cache(interp)
@@ -594,7 +720,7 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
594720

595721
# create a callback to look-up function in our cache,
596722
# and keep track of the method instances we needed.
597-
method_instances = []
723+
method_instances = Any[]
598724
if Sys.ARCH == :x86 || Sys.ARCH == :x86_64
599725
function lookup_fun(mi, min_world, max_world)
600726
push!(method_instances, mi)
@@ -659,7 +785,6 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
659785
end
660786

661787
# process all compiled method instances
662-
compiled = Dict()
663788
for mi in method_instances
664789
ci = ci_cache_lookup(cache, mi, job.world, job.world)
665790
ci === nothing && continue
@@ -696,10 +821,34 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
696821
compiled[mi] = (; ci, func=llvm_func, specfunc=llvm_specfunc)
697822
end
698823

824+
# Collect the deferred edges
825+
outstanding = Any[]
826+
for mi in method_instances
827+
!haskey(compiled, mi) && continue # Equivalent to ci_cache_lookup == nothing
828+
ci = compiled[mi].ci
829+
@static if VERSION >= v"1.11.0-"
830+
edges = CC.traverse_analysis_results(ci) do @nospecialize result
831+
return result isa DeferredEdges ? result : return
832+
end
833+
else
834+
edges = ci.argescapes
835+
if !(edges isa Union{Nothing, DeferredEdges})
836+
edges = nothing
837+
end
838+
end
839+
if edges !== nothing
840+
for deferred_mi in (edges::DeferredEdges).edges
841+
if !haskey(compiled, deferred_mi)
842+
push!(outstanding, deferred_mi)
843+
end
844+
end
845+
end
846+
end
847+
699848
# ensure that the requested method instance was compiled
700849
@assert haskey(compiled, job.source)
701850

702-
return llvm_mod, compiled
851+
return llvm_mod, outstanding
703852
end
704853

705854
# partially revert JuliaLangjulia#49391

0 commit comments

Comments
 (0)