Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add the ability to specify the external name of a ccallable #57763

Merged
merged 2 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions Compiler/src/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1371,8 +1371,7 @@ function typeinf_ext_toplevel(methods::Vector{Any}, worlds::Vector{UInt}, trim_m
end
# additionally enqueue the ccallable entrypoint / adapter, which implicitly
# invokes the above ci
push!(codeinfos, rt)
push!(codeinfos, sig)
push!(codeinfos, item)
end
end
while !isempty(tocompile)
Expand Down
10 changes: 5 additions & 5 deletions Compiler/src/verifytrim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ function get_verify_typeinf_trim(codeinfos::Vector{Any})
caches = IdDict{MethodInstance,CodeInstance}()
errors = ErrorList()
parents = ParentMap()
for i = 1:2:length(codeinfos)
for i = 1:length(codeinfos)
item = codeinfos[i]
if item isa CodeInstance
push!(inspected, item)
Expand All @@ -268,14 +268,14 @@ function get_verify_typeinf_trim(codeinfos::Vector{Any})
end
end
end
for i = 1:2:length(codeinfos)
for i = 1:length(codeinfos)
item = codeinfos[i]
if item isa CodeInstance
src = codeinfos[i + 1]::CodeInfo
verify_codeinstance!(item, src, inspected, caches, parents, errors)
else
rt = item::Type
sig = codeinfos[i + 1]::Type
elseif item isa SimpleVector
rt = item[1]::Type
sig = item[2]::Type
ptr = ccall(:jl_get_specialization1,
#= MethodInstance =# Ptr{Cvoid}, (Any, Csize_t, Cint),
sig, this_world, #= mt_cache =# 0)
Expand Down
4 changes: 2 additions & 2 deletions Compiler/test/verifytrim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ let infos = typeinf_ext_toplevel(Any[Core.svec(Base.SecretBuffer, Tuple{Type{Bas

$""", repr)

resize!(infos, 2)
@test infos[1] isa Type && infos[2] isa Type
resize!(infos, 1)
@test infos[1] isa Core.SimpleVector && infos[1][1] isa Type && infos[1][2] isa Type
errors, parents = get_verify_typeinf_trim(infos)
desc = only(errors)
@test !desc.first
Expand Down
22 changes: 14 additions & 8 deletions base/c.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,11 @@ function exit_on_sigint(on::Bool)
ccall(:jl_exit_on_sigint, Cvoid, (Cint,), on)
end

function _ccallable(rt::Type, sigt::Type)
ccall(:jl_extern_c, Cvoid, (Any, Any), rt, sigt)
function _ccallable(name::Union{Nothing, String}, rt::Type, sigt::Type)
ccall(:jl_extern_c, Cvoid, (Any, Any, Any), name, rt, sigt)
end

function expand_ccallable(rt, def)
function expand_ccallable(name, rt, def)
if isa(def,Expr) && (def.head === :(=) || def.head === :function)
sig = def.args[1]
if sig.head === :(::)
Expand Down Expand Up @@ -235,24 +235,30 @@ function expand_ccallable(rt, def)
end
return quote
@__doc__ $(esc(def))
_ccallable($(esc(rt)), $(Expr(:curly, :Tuple, esc(f), map(esc, at)...)))
_ccallable($name, $(esc(rt)), $(Expr(:curly, :Tuple, esc(f), map(esc, at)...)))
end
end
end
error("expected method definition in @ccallable")
end

"""
@ccallable(def)
@ccallable ["name"] function f(...)::RetType ... end

Make the annotated function be callable from C using its name. This can, for example,
be used to expose functionality as a C-API when creating a custom Julia sysimage.
be used to expose functionality as a C API when creating a custom Julia sysimage.

If the first argument is a string, it is used as the external name of the function.
"""
macro ccallable(def)
expand_ccallable(nothing, def)
expand_ccallable(nothing, nothing, def)
end
macro ccallable(rt, def)
expand_ccallable(rt, def)
if rt isa String
expand_ccallable(rt, nothing, def)
else
expand_ccallable(nothing, rt, def)
end
end

# @ccall implementation
Expand Down
9 changes: 6 additions & 3 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -784,9 +784,12 @@ void *jl_emit_native_impl(jl_array_t *codeinfos, LLVMOrcThreadSafeModuleRef llvm
compiled_functions[codeinst] = {std::move(result_m), std::move(decls)};
}
else {
jl_value_t *sig = jl_array_ptr_ref(codeinfos, ++i);
assert(jl_is_type(item) && jl_is_type(sig));
jl_generate_ccallable(clone.getModuleUnlocked(), nullptr, item, sig, params);
assert(jl_is_simplevector(item));
jl_value_t *rt = jl_svecref(item, 0);
jl_value_t *sig = jl_svecref(item, 1);
jl_value_t *nameval = jl_svec_len(item) == 2 ? jl_nothing : jl_svecref(item, 2);
assert(jl_is_type(rt) && jl_is_type(sig));
jl_generate_ccallable(clone.getModuleUnlocked(), nameval, rt, sig, params);
}
}
// finally, make sure all referenced methods get fixed up, particularly if the user declined to compile them
Expand Down
2 changes: 1 addition & 1 deletion src/codegen-stubs.c
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ JL_DLLEXPORT uint32_t jl_get_LLVM_VERSION_fallback(void)
return 0;
}

JL_DLLEXPORT int jl_compile_extern_c_fallback(LLVMOrcThreadSafeModuleRef llvmmod, void *params, void *sysimg, jl_value_t *declrt, jl_value_t *sigt)
JL_DLLEXPORT int jl_compile_extern_c_fallback(LLVMOrcThreadSafeModuleRef llvmmod, void *params, void *sysimg, jl_value_t *name, jl_value_t *declrt, jl_value_t *sigt)
{
// Assume we were able to register the ccallable with the JIT. The
// fact that we didn't is not observable since we cannot compile
Expand Down
29 changes: 8 additions & 21 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7709,14 +7709,14 @@ static jl_cgval_t emit_cfunction(jl_codectx_t &ctx, jl_value_t *output_type, con

// do codegen to create a C-callable alias/wrapper, or if sysimg_handle is set,
// restore one from a loaded system image.
const char *jl_generate_ccallable(Module *llvmmod, void *sysimg_handle, jl_value_t *declrt, jl_value_t *sigt, jl_codegen_params_t &params)
const char *jl_generate_ccallable(Module *llvmmod, jl_value_t *nameval, jl_value_t *declrt, jl_value_t *sigt, jl_codegen_params_t &params)
{
++GeneratedCCallables;
jl_datatype_t *ft = (jl_datatype_t*)jl_tparam0(sigt);
assert(jl_is_datatype(ft));
jl_value_t *ff = ft->instance;
assert(ff);
const char *name = jl_symbol_name(ft->name->mt->name);
const char *name = !jl_is_string(nameval) ? jl_symbol_name(ft->name->mt->name) : jl_string_data(nameval);
jl_value_t *crt = declrt;
if (jl_is_abstract_ref_type(declrt)) {
declrt = jl_tparam0(declrt);
Expand All @@ -7738,25 +7738,12 @@ const char *jl_generate_ccallable(Module *llvmmod, void *sysimg_handle, jl_value
function_sig_t sig("cfunction", lcrt, crt, toboxed, false,
argtypes, NULL, false, CallingConv::C, false, &params);
if (sig.err_msg.empty()) {
if (sysimg_handle) {
// restore a ccallable from the system image
void *addr;
int found = jl_dlsym(sysimg_handle, name, &addr, 0);
if (found)
add_named_global(name, addr);
else {
err = jl_get_exceptionf(jl_errorexception_type, "%s not found in sysimg", name);
jl_throw(err);
}
}
else {
//Safe b/c params holds context lock
Function *cw = gen_cfun_wrapper(llvmmod, params, sig, ff, name, declrt, sigt, NULL, NULL, NULL);
auto alias = GlobalAlias::create(cw->getValueType(), cw->getType()->getAddressSpace(),
GlobalValue::ExternalLinkage, name, cw, llvmmod);
if (params.TargetTriple.isOSBinFormatCOFF()) {
alias->setDLLStorageClass(GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
}
//Safe b/c params holds context lock
Function *cw = gen_cfun_wrapper(llvmmod, params, sig, ff, name, declrt, sigt, NULL, NULL, NULL);
auto alias = GlobalAlias::create(cw->getValueType(), cw->getType()->getAddressSpace(),
GlobalValue::ExternalLinkage, name, cw, llvmmod);
if (params.TargetTriple.isOSBinFormatCOFF()) {
alias->setDLLStorageClass(GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
}
JL_GC_POP();
return name;
Expand Down
7 changes: 5 additions & 2 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -4596,7 +4596,7 @@ JL_DLLEXPORT void jl_typeinf_timing_end(uint64_t start, int is_recompile)
}

// declare a C-callable entry point; called during code loading from the toplevel
JL_DLLEXPORT void jl_extern_c(jl_value_t *declrt, jl_tupletype_t *sigt)
JL_DLLEXPORT void jl_extern_c(jl_value_t *name, jl_value_t *declrt, jl_tupletype_t *sigt)
{
// validate arguments. try to do as many checks as possible here to avoid
// throwing errors later during codegen.
Expand Down Expand Up @@ -4627,7 +4627,10 @@ JL_DLLEXPORT void jl_extern_c(jl_value_t *declrt, jl_tupletype_t *sigt)
if (!jl_is_method(meth))
jl_error("@ccallable: could not find requested method");
JL_GC_PUSH1(&meth);
meth->ccallable = jl_svec2(declrt, (jl_value_t*)sigt);
if (name == jl_nothing)
meth->ccallable = jl_svec2(declrt, (jl_value_t*)sigt);
else
meth->ccallable = jl_svec3(declrt, (jl_value_t*)sigt, name);
jl_gc_wb(meth, meth->ccallable);
JL_GC_POP();
}
Expand Down
2 changes: 1 addition & 1 deletion src/jitlayers.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ struct jl_codegen_params_t {
~jl_codegen_params_t() JL_NOTSAFEPOINT JL_NOTSAFEPOINT_LEAVE = default;
};

const char *jl_generate_ccallable(Module *llvmmod, void *sysimg_handle, jl_value_t *declrt, jl_value_t *sigt, jl_codegen_params_t &params);
const char *jl_generate_ccallable(Module *llvmmod, jl_value_t *nameval, jl_value_t *declrt, jl_value_t *sigt, jl_codegen_params_t &params);

jl_llvm_functions_t jl_emit_code(
orc::ThreadSafeModule &M,
Expand Down
1 change: 1 addition & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -1920,6 +1920,7 @@ JL_DLLEXPORT jl_method_instance_t *jl_new_method_instance_uninit(void);
JL_DLLEXPORT jl_svec_t *jl_svec(size_t n, ...) JL_MAYBE_UNROOTED;
JL_DLLEXPORT jl_svec_t *jl_svec1(void *a);
JL_DLLEXPORT jl_svec_t *jl_svec2(void *a, void *b);
JL_DLLEXPORT jl_svec_t *jl_svec3(void *a, void *b, void *c);
JL_DLLEXPORT jl_svec_t *jl_alloc_svec(size_t n);
JL_DLLEXPORT jl_svec_t *jl_alloc_svec_uninit(size_t n);
JL_DLLEXPORT jl_svec_t *jl_svec_copy(jl_svec_t *a);
Expand Down
2 changes: 1 addition & 1 deletion src/precompile_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ static void *jl_precompile_(jl_array_t *m, int external_linkage)
}
else {
assert(jl_is_simplevector(item));
assert(jl_svec_len(item) == 2);
assert(jl_svec_len(item) == 2 || jl_svec_len(item) == 3);
jl_array_ptr_1d_push(m2, item);
}
}
Expand Down
13 changes: 13 additions & 0 deletions src/simplevector.c
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ JL_DLLEXPORT jl_svec_t *jl_svec2(void *a, void *b)
return v;
}

JL_DLLEXPORT jl_svec_t *jl_svec3(void *a, void *b, void *c)
{
jl_task_t *ct = jl_current_task;
jl_svec_t *v = (jl_svec_t*)jl_gc_alloc(ct->ptls, sizeof(void*) * 4,
jl_simplevector_type);
jl_set_typetagof(v, jl_simplevector_tag, 0);
jl_svec_set_len_unsafe(v, 3);
jl_svec_data(v)[0] = (jl_value_t*)a;
jl_svec_data(v)[1] = (jl_value_t*)b;
jl_svec_data(v)[2] = (jl_value_t*)c;
return v;
}

JL_DLLEXPORT jl_svec_t *jl_alloc_svec_uninit(size_t n)
{
jl_task_t *ct = jl_current_task;
Expand Down
3 changes: 3 additions & 0 deletions test/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ precompile_test_harness(false) do dir
# check that @ccallable works from precompiled modules
Base.@ccallable Cint f35014(x::Cint) = x+Cint(1)
Base.@ccallable "f35014_other" f35014_2(x::Cint)::Cint = x+Cint(1)
# check that Tasks work from serialized state
ch1 = Channel(x -> nothing)
Expand Down Expand Up @@ -399,6 +400,8 @@ precompile_test_harness(false) do dir
let foo_ptr = Libdl.dlopen(ocachefile::String, RTLD_NOLOAD)
f35014_ptr = Libdl.dlsym(foo_ptr, :f35014)
@test ccall(f35014_ptr, Int32, (Int32,), 3) == 4
f35014_other_ptr = Libdl.dlsym(foo_ptr, :f35014_other)
@test ccall(f35014_other_ptr, Int32, (Int32,), 3) == 4
end
else
ocachefile = nothing
Expand Down