Skip to content

Commit 4f14a5e

Browse files
committed
WIP:Try out gpuc.deferred.with
1 parent 60fa61f commit 4f14a5e

File tree

2 files changed

+51
-26
lines changed

2 files changed

+51
-26
lines changed

examples/jit.jl

+39-26
Original file line numberDiff line numberDiff line change
@@ -116,31 +116,39 @@ function get_trampoline(job)
116116
return addr
117117
end
118118

119-
# import GPUCompiler: deferred_codegen_jobs
120-
# @generated function deferred_codegen(f::F, ::Val{tt}, ::Val{world}) where {F,tt,world}
121-
# # manual version of native_job because we have a function type
122-
# source = methodinstance(F, Base.to_tuple_type(tt), world)
123-
# target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true)
124-
# # XXX: do we actually require the Julia runtime?
125-
# # with jlruntime=false, we reach an unreachable.
126-
# params = TestCompilerParams()
127-
# config = CompilerConfig(target, params; kernel=false)
128-
# job = CompilerJob(source, config, world)
129-
# # XXX: invoking GPUCompiler from a generated function is not allowed!
130-
# # for things to work, we need to forward the correct world, at least.
131-
132-
# addr = get_trampoline(job)
133-
# trampoline = pointer(addr)
134-
# id = Base.reinterpret(Int, trampoline)
135-
136-
# deferred_codegen_jobs[id] = job
137-
138-
# quote
139-
# ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $trampoline)
140-
# assume(ptr != C_NULL)
141-
# return ptr
142-
# end
143-
# end
119+
const runtime_cache = Dict{Any, Ptr{Cvoid}}()
120+
121+
function compiler(job)
122+
JuliaContext() do _
123+
ir, meta = GPUCompiler.compile(:llvm, job; validate=false)
124+
# So 1. serialize the module
125+
buf = convert(MemoryBuffer, ir)
126+
buf, LLVM.name(meta.entry)
127+
end
128+
end
129+
130+
function linker(_, (buf, entry_fn))
131+
compiler = jit[]
132+
lljit = compiler.jit
133+
jd = JITDylib(lljit)
134+
135+
# 2. deserialize and wrap by a ThreadSafeModule
136+
ThreadSafeContext() do ts_ctx
137+
tsm = context!(context(ts_ctx)) do
138+
mod = parse(LLVM.Module, buf)
139+
ThreadSafeModule(mod)
140+
end
141+
142+
LLVM.add!(lljit, jd, tsm)
143+
end
144+
addr = LLVM.lookup(lljit, entry_fn)
145+
pointer(addr)
146+
end
147+
148+
function GPUCompiler.var"gpuc.deferred.with"(config::GPUCompiler.CompilerConfig{<:NativeCompilerTarget}, f::F, args...) where F
149+
source = methodinstance(F, Base.to_tuple_type(typeof(args)))
150+
GPUCompiler.cached_compilation(runtime_cache, source, config, compiler, linker)::Ptr{Cvoid}
151+
end
144152

145153
@generated function abi_call(f::Ptr{Cvoid}, rt::Type{RT}, tt::Type{T}, func::F, args::Vararg{Any, N}) where {T, RT, F, N}
146154
argtt = tt.parameters[1]
@@ -226,7 +234,12 @@ end
226234
rt = Core.Compiler.return_type(f, tt)
227235
# FIXME: Horrible idea, have `var"gpuc.deferred"` actually do the work
228236
# But that will only be needed here, and in Enzyme...
229-
ptr = GPUCompiler.var"gpuc.deferred"(f, args...)
237+
target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true)
238+
# XXX: do we actually require the Julia runtime?
239+
# with jlruntime=false, we reach an unreachable.
240+
params = TestCompilerParams()
241+
config = CompilerConfig(target, params; kernel=false)
242+
ptr = GPUCompiler.var"gpuc.deferred.with"(config, f, args...)
230243
abi_call(ptr, rt, tt, f, args...)
231244
end
232245

src/driver.jl

+12
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,20 @@ end
4141

4242
## deferred compilation
4343

44+
"""
45+
var"gpuc.deferred"(f, args...)::Ptr{Cvoid}
46+
47+
As if we were to call `f(args...)` but instead we are
48+
putting down a marker and return a function pointer to later
49+
call.
50+
"""
4451
function var"gpuc.deferred" end
4552

53+
"""
54+
var"gpuc.deferred,with"(config::CompilerConfig, f, args...)::Ptr{Cvoid}
55+
"""
56+
function var"gpuc.deferred.with" end
57+
4658
## compiler entrypoint
4759

4860
export compile

0 commit comments

Comments
 (0)