From 2fa787115d3a98084cce0e37e15f497c1ce1ec2a Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 13 May 2024 10:43:25 -0400 Subject: [PATCH 1/7] Add new deferred compilation mechanism. --- src/driver.jl | 142 ++++++++++++++++++++++++++++-------- src/irgen.jl | 11 +++ src/jlgen.jl | 167 +++++++++++++++++++++++++++++++++++++++++-- test/native_tests.jl | 14 ++++ 4 files changed, 298 insertions(+), 36 deletions(-) diff --git a/src/driver.jl b/src/driver.jl index 54a6c053..905b04e8 100644 --- a/src/driver.jl +++ b/src/driver.jl @@ -39,6 +39,41 @@ function JuliaContext(f; kwargs...) end +## deferred compilation + +function var"gpuc.deferred" end + +# old, deprecated mechanism slated for removal once Enzyme is updated to the new intrinsic +begin + # primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism. + # this could both be generalized (e.g. supporting actual function calls, instead of + # returning a function pointer), and be integrated with the nonrecursive codegen. + const deferred_codegen_jobs = Dict{Int, Any}() + + # We make this function explicitly callable so that we can drive OrcJIT's + # lazy compilation from, while also enabling recursive compilation. + Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid}) + ptr + end + + @generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft,tt} + id = length(deferred_codegen_jobs) + 1 + deferred_codegen_jobs[id] = (; ft, tt) + # don't bother looking up the method instance, as we'll do so again during codegen + # using the world age of the parent. + # + # this also works around an issue on <1.10, where we don't know the world age of + # generated functions so use the current world counter, which may be too new + # for the world we're compiling for. + + quote + # TODO: add an edge to this method instance to support method redefinitions + ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), $id) + end + end +end + + ## compiler entrypoint export compile @@ -127,33 +162,6 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool error("Unknown compilation output $output") end -# primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism. -# this could both be generalized (e.g. supporting actual function calls, instead of -# returning a function pointer), and be integrated with the nonrecursive codegen. -const deferred_codegen_jobs = Dict{Int, Any}() - -# We make this function explicitly callable so that we can drive OrcJIT's -# lazy compilation from, while also enabling recursive compilation. -Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid}) - ptr -end - -@generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft,tt} - id = length(deferred_codegen_jobs) + 1 - deferred_codegen_jobs[id] = (; ft, tt) - # don't bother looking up the method instance, as we'll do so again during codegen - # using the world age of the parent. - # - # this also works around an issue on <1.10, where we don't know the world age of - # generated functions so use the current world counter, which may be too new - # for the world we're compiling for. - - quote - # TODO: add an edge to this method instance to support method redefinitions - ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), $id) - end -end - const __llvm_initialized = Ref(false) @locked function emit_llvm(@nospecialize(job::CompilerJob); toplevel::Bool, @@ -183,9 +191,82 @@ const __llvm_initialized = Ref(false) entry = finish_module!(job, ir, entry) # deferred code generation - has_deferred_jobs = toplevel && !only_entry && haskey(functions(ir), "deferred_codegen") + run_optimization_for_deferred = false + if haskey(functions(ir), "gpuc.lookup") + run_optimization_for_deferred = true + dyn_marker = functions(ir)["gpuc.lookup"] + + # gpuc.deferred is lowered to a gpuc.lookup foreigncall, so we need to extract the + # target method instance from the LLVM IR + # TODO: drive deferred compilation from the Julia IR instead + function find_base_object(val) + while true + if val isa ConstantExpr && (opcode(val) == LLVM.API.LLVMIntToPtr || + opcode(val) == LLVM.API.LLVMBitCast || + opcode(val) == LLVM.API.LLVMAddrSpaceCast) + val = first(operands(val)) + elseif val isa LLVM.IntToPtrInst || + val isa LLVM.BitCastInst || + val isa LLVM.AddrSpaceCastInst + val = first(operands(val)) + elseif val isa LLVM.LoadInst + # In 1.11+ we no longer embed integer constants directly. + gv = first(operands(val)) + if gv isa LLVM.GlobalValue + val = LLVM.initializer(gv) + continue + end + break + else + break + end + end + return val + end + + worklist = Dict{Any, Vector{LLVM.CallInst}}() + for use in uses(dyn_marker) + # decode the call + call = user(use)::LLVM.CallInst + dyn_mi_inst = find_base_object(operands(call)[1]) + @compiler_assert isa(dyn_mi_inst, LLVM.ConstantInt) job + dyn_mi = Base.unsafe_pointer_to_objref( + convert(Ptr{Cvoid}, convert(Int, dyn_mi_inst))) + push!(get!(worklist, dyn_mi, LLVM.CallInst[]), call) + end + + for dyn_mi in keys(worklist) + dyn_fn_name = compiled[dyn_mi].specfunc + dyn_fn = functions(ir)[dyn_fn_name] + + # insert a pointer to the function everywhere the entry is used + T_ptr = convert(LLVMType, Ptr{Cvoid}) + for call in worklist[dyn_mi] + @dispose builder=IRBuilder() begin + position!(builder, call) + fptr = if LLVM.version() >= v"17" + T_ptr = LLVM.PointerType() + bitcast!(builder, dyn_fn, T_ptr) + elseif VERSION >= v"1.12.0-DEV.225" + T_ptr = LLVM.PointerType(LLVM.Int8Type()) + bitcast!(builder, dyn_fn, T_ptr) + else + ptrtoint!(builder, dyn_fn, T_ptr) + end + replace_uses!(call, fptr) + end + unsafe_delete!(LLVM.parent(call), call) + end + end + + # all deferred compilations should have been resolved + @compiler_assert isempty(uses(dyn_marker)) job + unsafe_delete!(ir, dyn_marker) + end + ## old, deprecated implementation jobs = Dict{CompilerJob, String}(job => entry_fn) - if has_deferred_jobs + if toplevel && !only_entry && haskey(functions(ir), "deferred_codegen") + run_optimization_for_deferred = true dyn_marker = functions(ir)["deferred_codegen"] # iterative compilation (non-recursive) @@ -194,7 +275,6 @@ const __llvm_initialized = Ref(false) changed = false # find deferred compiler - # TODO: recover this information earlier, from the Julia IR worklist = Dict{CompilerJob, Vector{LLVM.CallInst}}() for use in uses(dyn_marker) # decode the call @@ -317,7 +397,7 @@ const __llvm_initialized = Ref(false) # deferred codegen has some special optimization requirements, # which also need to happen _after_ regular optimization. # XXX: make these part of the optimizer pipeline? - if has_deferred_jobs + if run_optimization_for_deferred @dispose pb=NewPMPassBuilder() begin add!(pb, NewPMFunctionPassManager()) do fpm add!(fpm, InstCombinePass()) diff --git a/src/irgen.jl b/src/irgen.jl index 0545a89f..bfbf1463 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -80,6 +80,17 @@ function irgen(@nospecialize(job::CompilerJob)) compiled[job.source] = (; compiled[job.source].ci, func, specfunc) + # Earlier we sanitize global names, this invalidates the + # func, specfunc names safed in compiled. Update the names now, + # such that when when use the compiled mappings to lookup the + # llvm function for a methodinstance (deferred codegen) we have + # valid targets. + for mi in keys(compiled) + mi == job.source && continue + ci, func, specfunc = compiled[mi] + compiled[mi] = (; ci, func=safe_name(func), specfunc=safe_name(specfunc)) + end + # minimal required optimization @timeit_debug to "rewrite" begin if job.config.kernel && needs_byval(job) diff --git a/src/jlgen.jl b/src/jlgen.jl index a34bd42e..cc6d5b3f 100644 --- a/src/jlgen.jl +++ b/src/jlgen.jl @@ -1,6 +1,5 @@ # Julia compiler integration - ## world age lookups # `tls_world_age` should be used to look up the current world age. in most cases, this is @@ -12,6 +11,7 @@ else tls_world_age() = ccall(:jl_get_tls_world_age, UInt, ()) end + ## looking up method instances export methodinstance, generic_methodinstance @@ -159,6 +159,7 @@ end ## code instance cache + const HAS_INTEGRATED_CACHE = VERSION >= v"1.11.0-DEV.1552" if !HAS_INTEGRATED_CACHE @@ -318,7 +319,8 @@ else get_method_table_view(world::UInt, mt::MTType) = OverlayMethodTable(world, mt) end -struct GPUInterpreter <: CC.AbstractInterpreter +abstract type AbstractGPUInterpreter <: CC.AbstractInterpreter end +struct GPUInterpreter <: AbstractGPUInterpreter world::UInt method_table::GPUMethodTableView @@ -436,6 +438,112 @@ function CC.concrete_eval_eligible(interp::GPUInterpreter, end +## deferred compilation + +struct DeferredCallInfo <: CC.CallInfo + rt::DataType + info::CC.CallInfo +end + +# recognize calls to gpuc.deferred and save DeferredCallInfo metadata +function CC.abstract_call_known(interp::AbstractGPUInterpreter, @nospecialize(f), + arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState, + max_methods::Int = CC.get_max_methods(interp, f, sv)) + (; fargs, argtypes) = arginfo + if f === var"gpuc.deferred" + argvec = argtypes[2:end] + call = CC.abstract_call(interp, CC.ArgInfo(nothing, argvec), si, sv, max_methods) + callinfo = DeferredCallInfo(call.rt, call.info) + @static if VERSION < v"1.11.0-" + return CC.CallMeta(Ptr{Cvoid}, CC.Effects(), callinfo) + else + return CC.CallMeta(Ptr{Cvoid}, Union{}, CC.Effects(), callinfo) + end + end + return @invoke CC.abstract_call_known(interp::CC.AbstractInterpreter, f, + arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState, + max_methods::Int) +end + +# during inlining, refine deferred calls to gpuc.lookup foreigncalls +const FlagType = VERSION >= v"1.11.0-" ? UInt32 : UInt8 +function CC.handle_call!(todo::Vector{Pair{Int,Any}}, ir::CC.IRCode, idx::CC.Int, + stmt::Expr, info::DeferredCallInfo, flag::FlagType, + sig::CC.Signature, state::CC.InliningState) + minfo = info.info + results = minfo.results + if length(results.matches) != 1 + return nothing + end + match = only(results.matches) + + # lookup the target mi with correct edge tracking + case = CC.compileable_specialization(match, CC.Effects(), CC.InliningEdgeTracker(state), + info) + @assert case isa CC.InvokeCase + @assert stmt.head === :call + + args = Any[ + "extern gpuc.lookup", + Ptr{Cvoid}, + Core.svec(Any, Any, match.spec_types.parameters[2:end]...), # Must use Any for MethodInstance or ftype + 0, + QuoteNode(:llvmcall), + case.invoke, + stmt.args[2:end]... + ] + stmt.head = :foreigncall + stmt.args = args + return nothing +end + +struct DeferredEdges + edges::Vector{MethodInstance} +end + +function find_deferred_edges(ir::CC.IRCode) + edges = MethodInstance[] + # XXX: can we add this instead in handle_call? + for stmt in ir.stmts + inst = stmt[:inst] + inst isa Expr || continue + expr = inst::Expr + if expr.head === :foreigncall && + expr.args[1] == "extern gpuc.lookup" + deferred_mi = expr.args[6] + push!(edges, deferred_mi) + end + end + unique!(edges) + return edges +end + +if VERSION >= v"1.11.0-" +function CC.ipo_dataflow_analysis!(interp::AbstractGPUInterpreter, ir::CC.IRCode, + caller::CC.InferenceResult) + edges = find_deferred_edges(ir) + if !isempty(edges) + CC.stack_analysis_result!(caller, DeferredEdges(edges)) + end + @invoke CC.ipo_dataflow_analysis!(interp::CC.AbstractInterpreter, ir::CC.IRCode, + caller::CC.InferenceResult) +end +else # v1.10 +# 1.10 doesn't have stack_analysis_result or ipo_dataflow_analysis +function CC.finish(interp::AbstractGPUInterpreter, opt::CC.OptimizationState, ir::CC.IRCode, + caller::CC.InferenceResult) + edges = find_deferred_edges(ir) + if !isempty(edges) + # HACK: we store the deferred edges in the argescapes field, which is invalid, + # but nobody should be running EA on our results. + caller.argescapes = DeferredEdges(edges) + end + @invoke CC.finish(interp::CC.AbstractInterpreter, opt::CC.OptimizationState, + ir::CC.IRCode, caller::CC.InferenceResult) +end +end + + ## world view of the cache using Core.Compiler: WorldView @@ -584,6 +692,30 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) error("Cannot compile $(job.source) for world $(job.world); method is only valid in worlds $(job.source.def.primary_world) to $(job.source.def.deleted_world)") end + # A poor man's worklist implementation. + # `compiled` contains a mapping from `mi->ci, func, specfunc` + # FIXME: Since we are disabling Julia internal caching we might + # generate for the same mi multiple LLVM functions. + # `outstanding` are the missing edges that were not compiled by `compile_method_instance` + # Currently these edges are generated through deferred codegen. + compiled = IdDict() + llvm_mod, outstanding = compile_method_instance(job, compiled) + worklist = outstanding + while !isempty(worklist) + source = pop!(worklist) + haskey(compiled, source) && continue # We have fulfilled the request already + # Create a new compiler job for this edge, reusing the config settings from the inital one + job2 = CompilerJob(source, job.config) + llvm_mod2, outstanding = compile_method_instance(job2, compiled) + append!(worklist, outstanding) # merge worklist with new outstanding edges + @assert context(llvm_mod) == context(llvm_mod2) + link!(llvm_mod, llvm_mod2) + end + + return llvm_mod, compiled +end + +function compile_method_instance(@nospecialize(job::CompilerJob), compiled::IdDict{Any, Any}) # populate the cache interp = get_interpreter(job) cache = CC.code_cache(interp) @@ -594,7 +726,7 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) # create a callback to look-up function in our cache, # and keep track of the method instances we needed. - method_instances = [] + method_instances = Any[] if Sys.ARCH == :x86 || Sys.ARCH == :x86_64 function lookup_fun(mi, min_world, max_world) push!(method_instances, mi) @@ -659,7 +791,6 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) end # process all compiled method instances - compiled = Dict() for mi in method_instances ci = ci_cache_lookup(cache, mi, job.world, job.world) ci === nothing && continue @@ -693,13 +824,39 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) # NOTE: it's not safe to store raw LLVM functions here, since those may get # removed or renamed during optimization, so we store their name instead. + # FIXME: Enable this assert when we have a fully featured worklist + # @assert !haskey(compiled, mi) compiled[mi] = (; ci, func=llvm_func, specfunc=llvm_specfunc) end + # Collect the deferred edges + outstanding = Any[] + for mi in method_instances + !haskey(compiled, mi) && continue # Equivalent to ci_cache_lookup == nothing + ci = compiled[mi].ci + @static if VERSION >= v"1.11.0-" + edges = CC.traverse_analysis_results(ci) do @nospecialize result + return result isa DeferredEdges ? result : return + end + else + edges = ci.argescapes + if !(edges isa Union{Nothing, DeferredEdges}) + edges = nothing + end + end + if edges !== nothing + for deferred_mi in (edges::DeferredEdges).edges + if !haskey(compiled, deferred_mi) + push!(outstanding, deferred_mi) + end + end + end + end + # ensure that the requested method instance was compiled @assert haskey(compiled, job.source) - return llvm_mod, compiled + return llvm_mod, outstanding end # partially revert JuliaLangjulia#49391 diff --git a/test/native_tests.jl b/test/native_tests.jl index 298c1010..9ea6ebd6 100644 --- a/test/native_tests.jl +++ b/test/native_tests.jl @@ -162,6 +162,20 @@ end ir = fetch(t) @test contains(ir, r"add i64 %\d+, 3") end + + @testset "deferred" begin + @gensym child kernel unrelated + @eval @noinline $child(i) = i + @eval $kernel(i) = GPUCompiler.var"gpuc.deferred"($child, i) + + # smoke test + job, _ = Native.create_job(eval(kernel), (Int64,)) + + ci, rt = only(GPUCompiler.code_typed(job)) + @test rt === Ptr{Cvoid} + + ir = sprint(io->GPUCompiler.code_llvm(io, job)) + end end ############################################################################################ From b54b5e4b5ed1234f454cb5f17e62515c0e968067 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 9 Aug 2024 15:59:30 +0200 Subject: [PATCH 2/7] remove old deferred implementation --- examples/jit.jl | 55 ++++++++++++------------ src/driver.jl | 110 +----------------------------------------------- 2 files changed, 30 insertions(+), 135 deletions(-) diff --git a/examples/jit.jl b/examples/jit.jl index 8a70a543..6f0a259c 100644 --- a/examples/jit.jl +++ b/examples/jit.jl @@ -116,31 +116,31 @@ function get_trampoline(job) return addr end -import GPUCompiler: deferred_codegen_jobs -@generated function deferred_codegen(f::F, ::Val{tt}, ::Val{world}) where {F,tt,world} - # manual version of native_job because we have a function type - source = methodinstance(F, Base.to_tuple_type(tt), world) - target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true) - # XXX: do we actually require the Julia runtime? - # with jlruntime=false, we reach an unreachable. - params = TestCompilerParams() - config = CompilerConfig(target, params; kernel=false) - job = CompilerJob(source, config, world) - # XXX: invoking GPUCompiler from a generated function is not allowed! - # for things to work, we need to forward the correct world, at least. - - addr = get_trampoline(job) - trampoline = pointer(addr) - id = Base.reinterpret(Int, trampoline) - - deferred_codegen_jobs[id] = job - - quote - ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $trampoline) - assume(ptr != C_NULL) - return ptr - end -end +# import GPUCompiler: deferred_codegen_jobs +# @generated function deferred_codegen(f::F, ::Val{tt}, ::Val{world}) where {F,tt,world} +# # manual version of native_job because we have a function type +# source = methodinstance(F, Base.to_tuple_type(tt), world) +# target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true) +# # XXX: do we actually require the Julia runtime? +# # with jlruntime=false, we reach an unreachable. +# params = TestCompilerParams() +# config = CompilerConfig(target, params; kernel=false) +# job = CompilerJob(source, config, world) +# # XXX: invoking GPUCompiler from a generated function is not allowed! +# # for things to work, we need to forward the correct world, at least. + +# addr = get_trampoline(job) +# trampoline = pointer(addr) +# id = Base.reinterpret(Int, trampoline) + +# deferred_codegen_jobs[id] = job + +# quote +# ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $trampoline) +# assume(ptr != C_NULL) +# return ptr +# end +# end @generated function abi_call(f::Ptr{Cvoid}, rt::Type{RT}, tt::Type{T}, func::F, args::Vararg{Any, N}) where {T, RT, F, N} argtt = tt.parameters[1] @@ -224,8 +224,9 @@ end @inline function call_delayed(f::F, args...) where F tt = Tuple{map(Core.Typeof, args)...} rt = Core.Compiler.return_type(f, tt) - world = GPUCompiler.tls_world_age() - ptr = deferred_codegen(f, Val(tt), Val(world)) + # FIXME: Horrible idea, have `var"gpuc.deferred"` actually do the work + # But that will only be needed here, and in Enzyme... + ptr = GPUCompiler.var"gpuc.deferred"(f, args...) abi_call(ptr, rt, tt, f, args...) end diff --git a/src/driver.jl b/src/driver.jl index 905b04e8..a60d0393 100644 --- a/src/driver.jl +++ b/src/driver.jl @@ -43,37 +43,6 @@ end function var"gpuc.deferred" end -# old, deprecated mechanism slated for removal once Enzyme is updated to the new intrinsic -begin - # primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism. - # this could both be generalized (e.g. supporting actual function calls, instead of - # returning a function pointer), and be integrated with the nonrecursive codegen. - const deferred_codegen_jobs = Dict{Int, Any}() - - # We make this function explicitly callable so that we can drive OrcJIT's - # lazy compilation from, while also enabling recursive compilation. - Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid}) - ptr - end - - @generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft,tt} - id = length(deferred_codegen_jobs) + 1 - deferred_codegen_jobs[id] = (; ft, tt) - # don't bother looking up the method instance, as we'll do so again during codegen - # using the world age of the parent. - # - # this also works around an issue on <1.10, where we don't know the world age of - # generated functions so use the current world counter, which may be too new - # for the world we're compiling for. - - quote - # TODO: add an edge to this method instance to support method redefinitions - ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), $id) - end - end -end - - ## compiler entrypoint export compile @@ -198,7 +167,6 @@ const __llvm_initialized = Ref(false) # gpuc.deferred is lowered to a gpuc.lookup foreigncall, so we need to extract the # target method instance from the LLVM IR - # TODO: drive deferred compilation from the Julia IR instead function find_base_object(val) while true if val isa ConstantExpr && (opcode(val) == LLVM.API.LLVMIntToPtr || @@ -263,80 +231,6 @@ const __llvm_initialized = Ref(false) @compiler_assert isempty(uses(dyn_marker)) job unsafe_delete!(ir, dyn_marker) end - ## old, deprecated implementation - jobs = Dict{CompilerJob, String}(job => entry_fn) - if toplevel && !only_entry && haskey(functions(ir), "deferred_codegen") - run_optimization_for_deferred = true - dyn_marker = functions(ir)["deferred_codegen"] - - # iterative compilation (non-recursive) - changed = true - while changed - changed = false - - # find deferred compiler - worklist = Dict{CompilerJob, Vector{LLVM.CallInst}}() - for use in uses(dyn_marker) - # decode the call - call = user(use)::LLVM.CallInst - id = convert(Int, first(operands(call))) - - global deferred_codegen_jobs - dyn_val = deferred_codegen_jobs[id] - - # get a job in the appopriate world - dyn_job = if dyn_val isa CompilerJob - # trust that the user knows what they're doing - dyn_val - else - ft, tt = dyn_val - dyn_src = methodinstance(ft, tt, tls_world_age()) - CompilerJob(dyn_src, job.config) - end - - push!(get!(worklist, dyn_job, LLVM.CallInst[]), call) - end - - # compile and link - for dyn_job in keys(worklist) - # cached compilation - dyn_entry_fn = get!(jobs, dyn_job) do - dyn_ir, dyn_meta = codegen(:llvm, dyn_job; toplevel=false, - parent_job=job) - dyn_entry_fn = LLVM.name(dyn_meta.entry) - merge!(compiled, dyn_meta.compiled) - @assert context(dyn_ir) == context(ir) - link!(ir, dyn_ir) - changed = true - dyn_entry_fn - end - dyn_entry = functions(ir)[dyn_entry_fn] - - # insert a pointer to the function everywhere the entry is used - T_ptr = convert(LLVMType, Ptr{Cvoid}) - for call in worklist[dyn_job] - @dispose builder=IRBuilder() begin - position!(builder, call) - fptr = if LLVM.version() >= v"17" - T_ptr = LLVM.PointerType() - bitcast!(builder, dyn_entry, T_ptr) - elseif VERSION >= v"1.12.0-DEV.225" - T_ptr = LLVM.PointerType(LLVM.Int8Type()) - bitcast!(builder, dyn_entry, T_ptr) - else - ptrtoint!(builder, dyn_entry, T_ptr) - end - replace_uses!(call, fptr) - end - unsafe_delete!(LLVM.parent(call), call) - end - end - end - - # all deferred compilations should have been resolved - @compiler_assert isempty(uses(dyn_marker)) job - unsafe_delete!(ir, dyn_marker) - end if libraries # load the runtime outside of a timing block (because it recurses into the compiler) @@ -433,8 +327,8 @@ const __llvm_initialized = Ref(false) # finish the module # # we want to finish the module after optimization, so we cannot do so - # during deferred code generation. instead, process the deferred jobs - # here. + # during deferred code generation. Instead, process the merged module + # from all the jobs here. if toplevel entry = finish_ir!(job, ir, entry) From 550805b7770cb7149737f20bab7caf51fda18c5a Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 9 Aug 2024 16:43:13 +0200 Subject: [PATCH 3/7] WIP:Try out gpuc.deferred.with --- examples/jit.jl | 65 +++++++++++++++++++++++++++++-------------------- src/driver.jl | 12 +++++++++ 2 files changed, 51 insertions(+), 26 deletions(-) diff --git a/examples/jit.jl b/examples/jit.jl index 6f0a259c..106506be 100644 --- a/examples/jit.jl +++ b/examples/jit.jl @@ -116,31 +116,39 @@ function get_trampoline(job) return addr end -# import GPUCompiler: deferred_codegen_jobs -# @generated function deferred_codegen(f::F, ::Val{tt}, ::Val{world}) where {F,tt,world} -# # manual version of native_job because we have a function type -# source = methodinstance(F, Base.to_tuple_type(tt), world) -# target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true) -# # XXX: do we actually require the Julia runtime? -# # with jlruntime=false, we reach an unreachable. -# params = TestCompilerParams() -# config = CompilerConfig(target, params; kernel=false) -# job = CompilerJob(source, config, world) -# # XXX: invoking GPUCompiler from a generated function is not allowed! -# # for things to work, we need to forward the correct world, at least. - -# addr = get_trampoline(job) -# trampoline = pointer(addr) -# id = Base.reinterpret(Int, trampoline) - -# deferred_codegen_jobs[id] = job - -# quote -# ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $trampoline) -# assume(ptr != C_NULL) -# return ptr -# end -# end +const runtime_cache = Dict{Any, Ptr{Cvoid}}() + +function compiler(job) + JuliaContext() do _ + ir, meta = GPUCompiler.compile(:llvm, job; validate=false) + # So 1. serialize the module + buf = convert(MemoryBuffer, ir) + buf, meta + end +end + +function linker(_, (buf, meta)) + compiler = jit[] + lljit = compiler.jit + jd = JITDylib(lljit) + + # 2. deserialize and wrap by a ThreadSafeModule + ThreadSafeContext() do ts_ctx + tsm = context!(context(ts_ctx)) do + mod = parse(LLVM.Module, buf) + ThreadSafeModule(mod) + end + + LLVM.add!(lljit, jd, tsm) + end + addr = LLVM.lookup(lljit, meta.entry) + pointer(addr) +end + +function GPUCompiler.var"gpuc.deferred.with"(config::GPUCompiler.CompilerConfig{<:NativeCompilerTarget}, f::F, args...) where F + source = methodinstance(F, Base.to_tuple_type(typeof(args))) + GPUCompiler.cached_compilation(runtime_cache, source, config, compiler, linker)::Ptr{Cvoid} +end @generated function abi_call(f::Ptr{Cvoid}, rt::Type{RT}, tt::Type{T}, func::F, args::Vararg{Any, N}) where {T, RT, F, N} argtt = tt.parameters[1] @@ -226,7 +234,12 @@ end rt = Core.Compiler.return_type(f, tt) # FIXME: Horrible idea, have `var"gpuc.deferred"` actually do the work # But that will only be needed here, and in Enzyme... - ptr = GPUCompiler.var"gpuc.deferred"(f, args...) + target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true) + # XXX: do we actually require the Julia runtime? + # with jlruntime=false, we reach an unreachable. + params = TestCompilerParams() + config = CompilerConfig(target, params; kernel=false) + ptr = GPUCompiler.var"gpuc.deferred.with"(config, f, args...) abi_call(ptr, rt, tt, f, args...) end diff --git a/src/driver.jl b/src/driver.jl index a60d0393..8167ac06 100644 --- a/src/driver.jl +++ b/src/driver.jl @@ -41,8 +41,20 @@ end ## deferred compilation +""" + var"gpuc.deferred"(f, args...)::Ptr{Cvoid} + +As if we were to call `f(args...)` but instead we are +putting down a marker and return a function pointer to later +call. +""" function var"gpuc.deferred" end +""" + var"gpuc.deferred,with"(config::CompilerConfig, f, args...)::Ptr{Cvoid} +""" +function var"gpuc.deferred.with" end + ## compiler entrypoint export compile From 2dada613e6ea694756d91e59bcba576920dddf5d Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 9 Aug 2024 16:56:17 +0200 Subject: [PATCH 4/7] fixup! remove old deferred implementation --- src/driver.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/driver.jl b/src/driver.jl index 8167ac06..6087210a 100644 --- a/src/driver.jl +++ b/src/driver.jl @@ -271,7 +271,7 @@ const __llvm_initialized = Ref(false) # global variables. this makes sure that the optimizer can, e.g., # rewrite function signatures. if toplevel - preserved_gvs = collect(values(jobs)) + preserved_gvs = [entry_fn] for gvar in globals(ir) if linkage(gvar) == LLVM.API.LLVMExternalLinkage push!(preserved_gvs, LLVM.name(gvar)) From 4ac6aac3a7e4f91d1ca27b5d003d4edb8544cf92 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 9 Aug 2024 16:57:25 +0200 Subject: [PATCH 5/7] fixup! fixup! remove old deferred implementation --- src/driver.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/driver.jl b/src/driver.jl index 6087210a..de6a0810 100644 --- a/src/driver.jl +++ b/src/driver.jl @@ -344,10 +344,10 @@ const __llvm_initialized = Ref(false) if toplevel entry = finish_ir!(job, ir, entry) - for (job′, fn′) in jobs - job′ == job && continue - finish_ir!(job′, ir, functions(ir)[fn′]) - end + # for (job′, fn′) in jobs + # job′ == job && continue + # finish_ir!(job′, ir, functions(ir)[fn′]) + # end end # replace non-entry function definitions with a declaration From f54410e873da4db11e6b8c547f3976f5fec101c6 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 9 Aug 2024 17:03:32 +0200 Subject: [PATCH 6/7] fixup! WIP:Try out gpuc.deferred.with --- examples/jit.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/jit.jl b/examples/jit.jl index 106506be..86fd467c 100644 --- a/examples/jit.jl +++ b/examples/jit.jl @@ -123,11 +123,11 @@ function compiler(job) ir, meta = GPUCompiler.compile(:llvm, job; validate=false) # So 1. serialize the module buf = convert(MemoryBuffer, ir) - buf, meta + buf, LLVM.name(meta.entry) end end -function linker(_, (buf, meta)) +function linker(_, (buf, entry_fn)) compiler = jit[] lljit = compiler.jit jd = JITDylib(lljit) @@ -141,7 +141,7 @@ function linker(_, (buf, meta)) LLVM.add!(lljit, jd, tsm) end - addr = LLVM.lookup(lljit, meta.entry) + addr = LLVM.lookup(lljit, entry_fn) pointer(addr) end From 12f9e51e57e7c039f57bd5f18dce84e1255d4eb1 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 26 Aug 2024 15:30:50 +0200 Subject: [PATCH 7/7] Add optimization callbacks that fire on a marker function --- src/optim.jl | 21 ++++++++++++++++++++- test/ptx_tests.jl | 9 +++++++++ test/ptx_testsetup.jl | 22 ++++++++++++++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/src/optim.jl b/src/optim.jl index 80433d3d..1b1cdee3 100644 --- a/src/optim.jl +++ b/src/optim.jl @@ -3,7 +3,7 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=1) tm = llvm_machine(job.config.target) - global current_job + global current_job # ScopedValue? current_job = job @dispose pb=NewPMPassBuilder() begin @@ -24,6 +24,8 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level= return end +const PIPELINE_CALLBACKS = Dict{String, Any}() + function buildNewPMPipeline!(mpm, @nospecialize(job::CompilerJob), opt_level) buildEarlySimplificationPipeline(mpm, job, opt_level) add!(mpm, AlwaysInlinerPass()) @@ -41,6 +43,9 @@ function buildNewPMPipeline!(mpm, @nospecialize(job::CompilerJob), opt_level) add!(fpm, WarnMissedTransformationsPass()) end end + for (name, callback) in PIPELINE_CALLBACKS + add!(mpm, CallbackPass(name, callback)) + end buildIntrinsicLoweringPipeline(mpm, job, opt_level) buildCleanupPipeline(mpm, job, opt_level) end @@ -423,3 +428,17 @@ function lower_ptls!(mod::LLVM.Module) return changed end LowerPTLSPass() = NewPMModulePass("GPULowerPTLS", lower_ptls!) + + +function callback_pass!(name, callback::F, mod::LLVM.Module) where F + job = current_job::CompilerJob + changed = false + + if haskey(functions(mod), name) + marker = functions(mod)[name] + changed = callback(job, marker, mod) + end + return changed +end + +CallbackPass(name, callback) = NewPMModulePass("CallbackPass<$name>", (mod)->callback_pass!(name, callback, mod)) diff --git a/test/ptx_tests.jl b/test/ptx_tests.jl index c059ba60..79c26e29 100644 --- a/test/ptx_tests.jl +++ b/test/ptx_tests.jl @@ -250,6 +250,15 @@ end @test "We did not crash!" != "" end +@testset "Pipeline callbacks" begin + function kernel(x) + PTX.mark(x) + return + end + ir = sprint(io->PTX.code_llvm(io, kernel, Tuple{Int})) + @test !occursin("gpucompuler.mark", ir) +end + @testset "exception arguments" begin function kernel(a) unsafe_store!(a, trunc(Int, unsafe_load(a))) diff --git a/test/ptx_testsetup.jl b/test/ptx_testsetup.jl index ed5026f1..42aff2a3 100644 --- a/test/ptx_testsetup.jl +++ b/test/ptx_testsetup.jl @@ -16,6 +16,28 @@ end GPUCompiler.kernel_state_type(@nospecialize(job::PTXCompilerJob)) = PTXKernelState @inline @generated kernel_state() = GPUCompiler.kernel_state_value(PTXKernelState) +function mark(x) + ccall("gpucompiler.mark", llvcmall, Nothing, (Int,), x) +end + +function remove_mark!(@nospecialize(job), intrinsic, mod::LLVM.Module) + changed = false + + for use in uses(intrinsic) + val = user(use) + if isempty(uses(val)) + unsafe_delete!(LLVM.parent(val), val) + changed = true + else + # the validator will detect this + end + end + + return changed +end + +GPUCompiler.PIPELINE_CALLBACKS["gpucompiler.mark"] = remove_mark! + # a version of the test runtime that has some side effects, loading the kernel state # (so that we can test if kernel state arguments are appropriately optimized away) module PTXTestRuntime