-
Notifications
You must be signed in to change notification settings - Fork 55
Mock Enzyme plugin #636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: 09-26-make_gpuinterpreter_extensible
Are you sure you want to change the base?
Mock Enzyme plugin #636
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -37,7 +37,7 @@ struct NeverInlineMeta <: InlineStateMeta end | |||
import GPUCompiler: abstract_call_known, GPUInterpreter | ||||
import Core.Compiler: CallMeta, Effects, NoCallInfo, ArgInfo, | ||||
StmtInfo, AbsIntState, EFFECTS_TOTAL, | ||||
MethodResultPure | ||||
MethodResultPure, CallInfo, IRCode | ||||
|
||||
function abstract_call_known(meta::InlineStateMeta, interp::GPUInterpreter, @nospecialize(f), | ||||
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int) | ||||
|
@@ -70,5 +70,178 @@ function inlining_handler(meta::InlineStateMeta, interp::GPUInterpreter, @nospec | |||
return nothing | ||||
end | ||||
|
||||
struct MockEnzymeMeta end | ||||
|
||||
end | ||||
# Having to define this function is annoying | ||||
# introduce `abstract type InferenceMeta` | ||||
function inlining_handler(meta::MockEnzymeMeta, interp::GPUInterpreter, @nospecialize(atype), callinfo) | ||||
return nothing | ||||
end | ||||
|
||||
function autodiff end | ||||
|
||||
import GPUCompiler: DeferredCallInfo | ||||
struct AutodiffCallInfo <: CallInfo | ||||
rt | ||||
info::DeferredCallInfo | ||||
end | ||||
|
||||
function abstract_call_known(meta::Nothing, interp::GPUInterpreter, f::typeof(autodiff), | ||||
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int) | ||||
(; fargs, argtypes) = arginfo | ||||
|
||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
@assert f === autodiff | ||||
if length(argtypes) <= 1 | ||||
@static if VERSION < v"1.11.0-" | ||||
return CallMeta(Union{}, Effects(), NoCallInfo()) | ||||
else | ||||
return CallMeta(Union{}, Union{}, Effects(), NoCallInfo()) | ||||
end | ||||
end | ||||
|
||||
other_fargs = fargs === nothing ? nothing : fargs[2:end] | ||||
other_arginfo = ArgInfo(other_fargs, argtypes[2:end]) | ||||
# TODO: Ought we not change absint to use MockEnzymeMeta(), otherwise we fill the cache for nothing. | ||||
call = Core.Compiler.abstract_call(interp, other_arginfo, si, sv, max_methods) | ||||
callinfo = DeferredCallInfo(MockEnzymeMeta(), call.rt, call.info) | ||||
|
||||
# Real Enzyme must compute `rt` and `exct` according to enzyme semantics | ||||
# and likely perform a unwrapping of fargs... | ||||
rt = call.rt | ||||
|
||||
# TODO: Edges? Effects? | ||||
@static if VERSION < v"1.11.0-" | ||||
# Can't use call.effects since otherwise this call might be just replaced with rt | ||||
return CallMeta(rt, Effects(), AutodiffCallInfo(rt, callinfo)) | ||||
else | ||||
return CallMeta(rt, call.exct, Effects(), AutodiffCallInfo(rt, callinfo)) | ||||
end | ||||
end | ||||
|
||||
function abstract_call_known(meta::MockEnzymeMeta, interp::GPUInterpreter, @nospecialize(f), | ||||
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int) | ||||
return nothing | ||||
end | ||||
|
||||
import Core.Compiler: insert_node!, NewInstruction, ReturnNode, Instruction, InliningState, Signature | ||||
|
||||
# We really need a Compiler stdlib | ||||
Base.getindex(ir::IRCode, i) = Core.Compiler.getindex(ir, i) | ||||
Base.setindex!(inst::Instruction, val, i) = Core.Compiler.setindex!(inst, val, i) | ||||
|
||||
const FlagType = VERSION >= v"1.11.0-" ? UInt32 : UInt8 | ||||
function Core.Compiler.handle_call!(todo::Vector{Pair{Int,Any}}, ir::IRCode, stmt_idx::Int, | ||||
stmt::Expr, info::AutodiffCallInfo, flag::FlagType, | ||||
sig::Signature, state::InliningState) | ||||
|
||||
# Goal: | ||||
# The IR we want to inline here is: | ||||
# unpack the args .. | ||||
# ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...) | ||||
# ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...) | ||||
|
||||
# 0. Obtain primal mi from DeferredCallInfo | ||||
# TODO: remove this code duplication | ||||
deferred_info = info.info | ||||
minfo = deferred_info.info | ||||
results = minfo.results | ||||
if length(results.matches) != 1 | ||||
return nothing | ||||
end | ||||
match = only(results.matches) | ||||
|
||||
# lookup the target mi with correct edge tracking | ||||
# TODO: Effects? | ||||
case = Core.Compiler.compileable_specialization( | ||||
match, Core.Compiler.Effects(), Core.Compiler.InliningEdgeTracker(state), info) | ||||
@assert case isa Core.Compiler.InvokeCase | ||||
@assert stmt.head === :call | ||||
|
||||
# Now create the IR we want to inline | ||||
ir = Core.Compiler.IRCode() # contains a placeholder | ||||
args = [Core.Compiler.Argument(i) for i in 2:length(stmt.args)] # f, args... | ||||
idx = 0 | ||||
|
||||
# 0. Enzyme proper: Desugar args | ||||
primal_args = args | ||||
primal_argtypes = match.spec_types.parameters[2:end] | ||||
|
||||
adjoint_rt = info.rt | ||||
adjoint_args = args # TODO | ||||
adjoint_argtypes = primal_argtypes | ||||
|
||||
# 1: Since Julia's inliner goes bottom up we need to pretend that we inlined the deferred call | ||||
expr = Expr(:foreigncall, | ||||
"extern gpuc.lookup", | ||||
Ptr{Cvoid}, | ||||
Core.svec(#=meta=# Any, #=mi=# Any, #=f=# Any, primal_argtypes...), # Must use Any for MethodInstance or ftype | ||||
0, | ||||
QuoteNode(:llvmcall), | ||||
deferred_info.meta, | ||||
case.invoke, | ||||
primal_args... | ||||
) | ||||
Comment on lines
+161
to
+184
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @aviatesk does this look correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
function Core.Compiler.src_inlining_policy(interp::GPUCompiler#=or EnzymeInterpreter?=#,
@nospecialize(src), info::AutodiffCallInfo, stmt_flag::UInt32)
# Goal:
# The IR we want to return here is:
# unpack the args ..
# ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...)
# ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...)
ir = Core.Compiler.IRCode() # contains a placeholder
...
return ir
end By overloading There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Haha, then I have misunderstood the comment in: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I just realized that overloading But I believe this approach (overloading
It seems like I’ve ended up betraying my past self. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I think you once told me to extend handle_call |
||||
ptr = insert_node!(ir, (idx += 1), NewInstruction(expr, Ptr{Cvoid})) | ||||
|
||||
# 2. Call to magic `__autodiff` | ||||
expr = Expr(:foreigncall, | ||||
"extern __autodiff", | ||||
adjoint_rt, | ||||
Core.svec(Ptr{Cvoid}, Any, adjoint_argtypes...), | ||||
0, | ||||
QuoteNode(:llvmcall), | ||||
ptr, | ||||
adjoint_args... | ||||
) | ||||
ret = insert_node!(ir, idx, NewInstruction(expr, adjoint_rt)) | ||||
|
||||
# Finally replace placeholder return | ||||
ir[Core.SSAValue(1)][:inst] = Core.ReturnNode(ret) | ||||
ir[Core.SSAValue(1)][:type] = Ptr{Cvoid} | ||||
|
||||
ir = Core.Compiler.compact!(ir) | ||||
|
||||
# which mi to use here? | ||||
# push inlining todos | ||||
# TODO: Effects | ||||
# aviatesk mentioned using inlining_policy instead... | ||||
itodo = Core.Compiler.InliningTodo(case.invoke, ir, Core.Compiler.Effects()) | ||||
@assert itodo.linear_inline_eligible | ||||
push!(todo, (stmt_idx=>itodo)) | ||||
|
||||
return nothing | ||||
end | ||||
|
||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think presently Enzyme does more than that as well. To rough approximation, it does the following as its entire compilation step
Post Enzyme optimization (https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8131):
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be more useful to think not about "what does Enzyme do", but "what information does Enzyme need and when". The actual plugin PR is The issue is that Enzyme requires an orthogonal composability axes than the job level. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. E.g. what are the interfaces you need, also what his historic and what are design constraints. As an example, I would like the LLVM IR to be serializable without references to runtime pointers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think all of the before Enzyme stuff is necessary atm. The caveat being that if there was a better way to make sure functions had better names than us doing the converstion from an inttoptr to the name and restoring it, that would be great. So we definitely need a "right after IR is generated pass plugin" We presently need ideally other hooks into the optimization pipeline to add the other passes we run. For example https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler/optimize.jl#L2522 re 3) We need ways to know the full context of the julia AD request. I think this is basically just the gpucompiler config object [and the restore inttoptr name map] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So a caveat of the implementation here is that we no longer create a EnzymeConfig object when inside a CUDA compilation. We still do when we do a host compilation, but wouldn't anymore for nested compilation. We would need in the lowering from
Right, that's what I added in #633 (comment) So keep in mind that we are talking about two compilation modes, first of all the inside CUDA.jl and the second one for nested. For CPU nested you would still have a lot more control over code flow and could do all the blas things and inttoptr, but those are errors already on the GPU. |
||||
function mock_enzyme!(@nospecialize(job), intrinsic, mod::LLVM.Module) | ||||
changed = false | ||||
|
||||
for use in LLVM.uses(intrinsic) | ||||
call = LLVM.user(use) | ||||
LLVM.@dispose builder=LLVM.IRBuilder() begin | ||||
LLVM.position!(builder, call) | ||||
ops = LLVM.operands(call) | ||||
target = ops[1] | ||||
if target isa LLVM.ConstantExpr && (LLVM.opcode(target) == LLVM.API.LLVMPtrToInt || | ||||
LLVM.opcode(target) == LLVM.API.LLVMBitCast) | ||||
target = first(LLVM.operands(target)) | ||||
end | ||||
funcT = LLVM.called_type(call) | ||||
funcT = LLVM.FunctionType(LLVM.return_type(funcT), LLVM.parameters(funcT)[3:end]) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the last op is the called value |
||||
direct_call = LLVM.call!(builder, funcT, target, ops[3:end - 1]) # why is the -1 necessary | ||||
|
||||
LLVM.replace_uses!(call, direct_call) | ||||
end | ||||
if isempty(LLVM.uses(call)) | ||||
LLVM.erase!(call) | ||||
changed = true | ||||
else | ||||
# the validator will detect this | ||||
end | ||||
end | ||||
|
||||
return changed | ||||
end | ||||
|
||||
GPUCompiler.register_plugin!("__autodiff", mock_enzyme!) | ||||
|
||||
end #module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
abstract_call_known
with this signature would never be called fromCore.Compiler
, so this overload would do nothing?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See https://github.com/JuliaGPU/GPUCompiler.jl/pull/634/files#diff-1bf557ada55697d453a6ccf81e9c263e46e573e324c453c4ecffb70994f36c37R504 for context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a double extension problem. GPUCompiler provides has a GPUInterpreter.
Both GPUCompiler/CUDA and Enzyme want to modify the rules being applied.
But when we are applying Enzyme to CUDA code we must "inherit" the rules from CUDA, up to now Enzyme had a EnzymeInterpreter, but I would like to get rid of that.
But Enzyme rules shouldn't apply to CUDA code by default. However I also need to teach in an extensible matter GPUCompiler about
autodiff
such that:works.