Skip to content

Mock Enzyme plugin #636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: 09-26-make_gpuinterpreter_extensible
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions test/native_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,6 @@ end
# smoke test
job, _ = Native.create_job(eval(kernel), (Int64,))

# TODO: Add a `kernel=true` test

ci, rt = only(GPUCompiler.code_typed(job))
@test rt === Ptr{Cvoid}

Expand Down
177 changes: 175 additions & 2 deletions test/plugin_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct NeverInlineMeta <: InlineStateMeta end
import GPUCompiler: abstract_call_known, GPUInterpreter
import Core.Compiler: CallMeta, Effects, NoCallInfo, ArgInfo,
StmtInfo, AbsIntState, EFFECTS_TOTAL,
MethodResultPure
MethodResultPure, CallInfo, IRCode

function abstract_call_known(meta::InlineStateMeta, interp::GPUInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int)
Expand Down Expand Up @@ -70,5 +70,178 @@ function inlining_handler(meta::InlineStateMeta, interp::GPUInterpreter, @nospec
return nothing
end

struct MockEnzymeMeta end

end
# Having to define this function is annoying
# introduce `abstract type InferenceMeta`
function inlining_handler(meta::MockEnzymeMeta, interp::GPUInterpreter, @nospecialize(atype), callinfo)
return nothing
end

function autodiff end

import GPUCompiler: DeferredCallInfo
struct AutodiffCallInfo <: CallInfo
rt
info::DeferredCallInfo
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

abstract_call_known with this signature would never be called from Core.Compiler, so this overload would do nothing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a double extension problem. GPUCompiler provides has a GPUInterpreter.
Both GPUCompiler/CUDA and Enzyme want to modify the rules being applied.

But when we are applying Enzyme to CUDA code we must "inherit" the rules from CUDA, up to now Enzyme had a EnzymeInterpreter, but I would like to get rid of that.

But Enzyme rules shouldn't apply to CUDA code by default. However I also need to teach in an extensible matter GPUCompiler about autodiff such that:

function kernel(args...)
    autodiff(f, ....)
end
@cuda kernel(args...)

works.

function abstract_call_known(meta::Nothing, interp::GPUInterpreter, f::typeof(autodiff),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int)
(; fargs, argtypes) = arginfo

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@assert f === autodiff

@assert f === autodiff
if length(argtypes) <= 1
@static if VERSION < v"1.11.0-"
return CallMeta(Union{}, Effects(), NoCallInfo())
else
return CallMeta(Union{}, Union{}, Effects(), NoCallInfo())
end
end

other_fargs = fargs === nothing ? nothing : fargs[2:end]
other_arginfo = ArgInfo(other_fargs, argtypes[2:end])
# TODO: Ought we not change absint to use MockEnzymeMeta(), otherwise we fill the cache for nothing.
call = Core.Compiler.abstract_call(interp, other_arginfo, si, sv, max_methods)
callinfo = DeferredCallInfo(MockEnzymeMeta(), call.rt, call.info)

# Real Enzyme must compute `rt` and `exct` according to enzyme semantics
# and likely perform a unwrapping of fargs...
rt = call.rt

# TODO: Edges? Effects?
@static if VERSION < v"1.11.0-"
# Can't use call.effects since otherwise this call might be just replaced with rt
return CallMeta(rt, Effects(), AutodiffCallInfo(rt, callinfo))
else
return CallMeta(rt, call.exct, Effects(), AutodiffCallInfo(rt, callinfo))
end
end

function abstract_call_known(meta::MockEnzymeMeta, interp::GPUInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int)
return nothing
end

import Core.Compiler: insert_node!, NewInstruction, ReturnNode, Instruction, InliningState, Signature

# We really need a Compiler stdlib
Base.getindex(ir::IRCode, i) = Core.Compiler.getindex(ir, i)
Base.setindex!(inst::Instruction, val, i) = Core.Compiler.setindex!(inst, val, i)

const FlagType = VERSION >= v"1.11.0-" ? UInt32 : UInt8
function Core.Compiler.handle_call!(todo::Vector{Pair{Int,Any}}, ir::IRCode, stmt_idx::Int,
stmt::Expr, info::AutodiffCallInfo, flag::FlagType,
sig::Signature, state::InliningState)

# Goal:
# The IR we want to inline here is:
# unpack the args ..
# ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...)
# ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...)

# 0. Obtain primal mi from DeferredCallInfo
# TODO: remove this code duplication
deferred_info = info.info
minfo = deferred_info.info
results = minfo.results
if length(results.matches) != 1
return nothing
end
match = only(results.matches)

# lookup the target mi with correct edge tracking
# TODO: Effects?
case = Core.Compiler.compileable_specialization(
match, Core.Compiler.Effects(), Core.Compiler.InliningEdgeTracker(state), info)
@assert case isa Core.Compiler.InvokeCase
@assert stmt.head === :call

# Now create the IR we want to inline
ir = Core.Compiler.IRCode() # contains a placeholder
args = [Core.Compiler.Argument(i) for i in 2:length(stmt.args)] # f, args...
idx = 0

# 0. Enzyme proper: Desugar args
primal_args = args
primal_argtypes = match.spec_types.parameters[2:end]

adjoint_rt = info.rt
adjoint_args = args # TODO
adjoint_argtypes = primal_argtypes

# 1: Since Julia's inliner goes bottom up we need to pretend that we inlined the deferred call
expr = Expr(:foreigncall,
"extern gpuc.lookup",
Ptr{Cvoid},
Core.svec(#=meta=# Any, #=mi=# Any, #=f=# Any, primal_argtypes...), # Must use Any for MethodInstance or ftype
0,
QuoteNode(:llvmcall),
deferred_info.meta,
case.invoke,
primal_args...
)
Comment on lines +161 to +184
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aviatesk does this look correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handle_call! isn’t meant to be overloaded, so I think this approach is preferred:

function Core.Compiler.src_inlining_policy(interp::GPUCompiler#=or EnzymeInterpreter?=#,
    @nospecialize(src), info::AutodiffCallInfo, stmt_flag::UInt32)
    # Goal:
    # The IR we want to return here is:
    # unpack the args ..
    # ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...) 
    # ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...)
    ir = Core.Compiler.IRCode() # contains a placeholder
    ...
    return ir
end

By overloading src_inlining_policy (or inlining_policy in older versions), we can apply this custom inlining to const-propped call sites and semi-concrete interpreted call sites as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I just realized that overloading src_inlining_policy wouldn't be enough.. We need to overload retrieve_ir_for_inlining too, but it doesn't take info::CallInfo, so maybe we need to tweak the interface..

But I believe this approach (overloading inlining_policy) works at least for pre-1.11.

Haha, then I have misunderstood the comment in:

https://github.com/JuliaLang/julia/blob/b9d9b69165493f6fc03870d975be05c67f14a30b/base/compiler/ssair/inlining.jl#L1668-L1669

It seems like I’ve ended up betraying my past self.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think you once told me to extend handle_call

ptr = insert_node!(ir, (idx += 1), NewInstruction(expr, Ptr{Cvoid}))

# 2. Call to magic `__autodiff`
expr = Expr(:foreigncall,
"extern __autodiff",
adjoint_rt,
Core.svec(Ptr{Cvoid}, Any, adjoint_argtypes...),
0,
QuoteNode(:llvmcall),
ptr,
adjoint_args...
)
ret = insert_node!(ir, idx, NewInstruction(expr, adjoint_rt))

# Finally replace placeholder return
ir[Core.SSAValue(1)][:inst] = Core.ReturnNode(ret)
ir[Core.SSAValue(1)][:type] = Ptr{Cvoid}

ir = Core.Compiler.compact!(ir)

# which mi to use here?
# push inlining todos
# TODO: Effects
# aviatesk mentioned using inlining_policy instead...
itodo = Core.Compiler.InliningTodo(case.invoke, ir, Core.Compiler.Effects())
@assert itodo.linear_inline_eligible
push!(todo, (stmt_idx=>itodo))

return nothing
end

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feedback from @wsmoses

This pass would need access to the compiled dictionary so that Enzyme can do a lookup from "emitted function" to Julia function. This would speak for a different phase order in #633

Or we add two callbacks "one early" and one "as part of the optimization pipeline"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think presently Enzyme does more than that as well. To rough approximation, it does the following as its entire compilation step

  1. "Before anything else happens"
    Set each llvmf to know about its worldage, methodinstance, and return type: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6703
    Make things inline ready [e.g. remove some tbaa which is broken] https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6705
    Rewrite some nvvm and related intrinsics https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6709
    mark various type unstable calls as inactive and change inttoptr'd ccalls into calls by name [storing the actual int value to later restore]: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6714
    Replace unhandled blas calls with fallback: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6755
    Annotate types and activities: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6894
    Mark custom rules and related as noinline: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7035
    Lower calling convention of functino being differentiated: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7465

  2. Optimization pipeline: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7514
    Currently we use a modified optimization pipeline that also adds new passes which we found to be critical for performance (namely the new jl_inst_simpliy pass among others for interprocedural dead arg elim)

  3. AD
    First we run a julia analysis pass if the fn differentiated was a closure and requested we error if it is written https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7659
    Upgrading some memcpy's to load/store: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7799
    Actually generating the derivatives: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7801
    Inverse of the preserve nvvm pass above: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7897
    Restoring the actual inttoptr => function name from above https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8074
    Other immediate post Enzyme passes: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8085 and https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8087 and https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8105 and https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8129

Post Enzyme optimization (https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8131):

  • This includes running new passes to fix garbage collection/etc but presumably can just be scheduled

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be more useful to think not about "what does Enzyme do", but "what information does Enzyme need and when".

The actual plugin PR is
#633 we can add more callbacks, as long as the callbacks only trigger if we detect that a marker function is present in the module.

The issue is that Enzyme requires an orthogonal composability axes than the job level.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g. what are the interfaces you need, also what his historic and what are design constraints.

As an example, I would like the LLVM IR to be serializable without references to runtime pointers.
Otherwise caching will be impossible and CUDA.jl is already caching it's LLVM IR (post optimizations currently so we could just throw away the Enzyme metadata)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all of the before Enzyme stuff is necessary atm. The caveat being that if there was a better way to make sure functions had better names than us doing the converstion from an inttoptr to the name and restoring it, that would be great. So we definitely need a "right after IR is generated pass plugin"

We presently need ideally other hooks into the optimization pipeline to add the other passes we run. For example https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler/optimize.jl#L2522

re 3) We need ways to know the full context of the julia AD request. I think this is basically just the gpucompiler config object [and the restore inttoptr name map]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So a caveat of the implementation here is that we no longer create a EnzymeConfig object when inside a CUDA compilation. We still do when we do a host compilation, but wouldn't anymore for nested compilation.

We would need in the lowering from autododiff to llvmcall of autodiff capture the necessary information and encode that on the calling side (e.g. similar to the Clang plugin for Enzyme).

"right after IR is generated pass plugin"

Right, that's what I added in #633 (comment)

So keep in mind that we are talking about two compilation modes, first of all the inside CUDA.jl and the second one for nested.

For CPU nested you would still have a lot more control over code flow and could do all the blas things and inttoptr, but those are errors already on the GPU.

function mock_enzyme!(@nospecialize(job), intrinsic, mod::LLVM.Module)
changed = false

for use in LLVM.uses(intrinsic)
call = LLVM.user(use)
LLVM.@dispose builder=LLVM.IRBuilder() begin
LLVM.position!(builder, call)
ops = LLVM.operands(call)
target = ops[1]
if target isa LLVM.ConstantExpr && (LLVM.opcode(target) == LLVM.API.LLVMPtrToInt ||
LLVM.opcode(target) == LLVM.API.LLVMBitCast)
target = first(LLVM.operands(target))
end
funcT = LLVM.called_type(call)
funcT = LLVM.FunctionType(LLVM.return_type(funcT), LLVM.parameters(funcT)[3:end])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the last op is the called value

direct_call = LLVM.call!(builder, funcT, target, ops[3:end - 1]) # why is the -1 necessary

LLVM.replace_uses!(call, direct_call)
end
if isempty(LLVM.uses(call))
LLVM.erase!(call)
changed = true
else
# the validator will detect this
end
end

return changed
end

GPUCompiler.register_plugin!("__autodiff", mock_enzyme!)

end #module
24 changes: 24 additions & 0 deletions test/ptx_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -504,4 +504,28 @@ end
ir = sprint(io->PTX.code_llvm(io, kernel_inline, Tuple{Ptr{Int64}, Int64}, meta=Plugin.NeverInlineMeta()))
@test occursin("call fastcc i64 @julia_inline", ir)
end

@testset "Mock Enzyme" begin
function f(x)
x^2
end

function kernel(a, x)
y = Plugin.autodiff(f, x)
unsafe_store!(a, y)
nothing
end

# This tests deferred_codegen with kernel=true
@show PTX.code_typed(kernel, Tuple{Ptr{Float64}, Float64})

ir = sprint(io->PTX.code_llvm(io, kernel, Tuple{Ptr{Float64}, Float64}, optimize=false))
@test occursin("call double @__autodiff", ir)
@test !occursin("call fastcc double @julia_f", ir)

ir = sprint(io->PTX.code_llvm(io, kernel, Tuple{Ptr{Float64}, Float64}, optimize=true))
@test !occursin("call double @__autodiff", ir)
@test occursin("call fastcc double @julia_f", ir)
end

end #testitem
Loading