Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add entry point to construct an OpaqueClosure from pre-optimized IRCode #44197

Merged
merged 7 commits into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion base/errorshow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,11 @@ function show_method_candidates(io::IO, ex::MethodError, @nospecialize kwargs=()
buf = IOBuffer()
iob0 = iob = IOContext(buf, io)
tv = Any[]
sig0 = method.sig
if func isa Core.OpaqueClosure
sig0 = signature_type(func, typeof(func).parameters[1])
else
sig0 = method.sig
end
while isa(sig0, UnionAll)
push!(tv, sig0.var)
iob = IOContext(iob, :unionall_env => sig0.var)
Expand Down
3 changes: 3 additions & 0 deletions base/methodshow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ end

# NOTE: second argument is deprecated and is no longer used
function kwarg_decl(m::Method, kwtype = nothing)
if m.sig === Tuple # OpaqueClosure
return Symbol[]
end
mt = get_methodtable(m)
if isdefined(mt, :kwsorter)
kwtype = typeof(mt.kwsorter)
Expand Down
21 changes: 12 additions & 9 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ extern jl_value_t *jl_builtin_getfield;
extern jl_value_t *jl_builtin_tuple;

jl_method_t *jl_make_opaque_closure_method(jl_module_t *module, jl_value_t *name,
jl_value_t *nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva);
int nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva);

static void check_c_types(const char *where, jl_value_t *rt, jl_value_t *at)
{
Expand Down Expand Up @@ -51,11 +51,14 @@ static jl_value_t *resolve_globals(jl_value_t *expr, jl_module_t *module, jl_sve
return jl_module_globalref(module, (jl_sym_t*)expr);
}
else if (jl_is_returnnode(expr)) {
jl_value_t *val = resolve_globals(jl_returnnode_value(expr), module, sparam_vals, binding_effects, eager_resolve);
if (val != jl_returnnode_value(expr)) {
JL_GC_PUSH1(&val);
expr = jl_new_struct(jl_returnnode_type, val);
JL_GC_POP();
jl_value_t *retval = jl_returnnode_value(expr);
if (retval) {
jl_value_t *val = resolve_globals(retval, module, sparam_vals, binding_effects, eager_resolve);
if (val != retval) {
JL_GC_PUSH1(&val);
expr = jl_new_struct(jl_returnnode_type, val);
JL_GC_POP();
}
}
return expr;
}
Expand Down Expand Up @@ -102,7 +105,7 @@ static jl_value_t *resolve_globals(jl_value_t *expr, jl_module_t *module, jl_sve
if (!jl_is_code_info(ci)) {
jl_error("opaque_closure_method: lambda should be a CodeInfo");
}
jl_method_t *m = jl_make_opaque_closure_method(module, name, nargs, functionloc, (jl_code_info_t*)ci, isva);
jl_method_t *m = jl_make_opaque_closure_method(module, name, jl_unbox_long(nargs), functionloc, (jl_code_info_t*)ci, isva);
return (jl_value_t*)m;
}
if (e->head == jl_cfunction_sym) {
Expand Down Expand Up @@ -782,7 +785,7 @@ JL_DLLEXPORT jl_method_t *jl_new_method_uninit(jl_module_t *module)
// method definition ----------------------------------------------------------

jl_method_t *jl_make_opaque_closure_method(jl_module_t *module, jl_value_t *name,
jl_value_t *nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva)
int nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva)
{
jl_method_t *m = jl_new_method_uninit(module);
JL_GC_PUSH1(&m);
Expand All @@ -796,7 +799,7 @@ jl_method_t *jl_make_opaque_closure_method(jl_module_t *module, jl_value_t *name
assert(jl_is_symbol(name));
m->name = (jl_sym_t*)name;
}
m->nargs = jl_unbox_long(nargs) + 1;
m->nargs = nargs + 1;
assert(jl_is_linenode(functionloc));
jl_value_t *file = jl_linenode_file(functionloc);
m->file = jl_is_symbol(file) ? (jl_sym_t*)file : jl_empty_sym;
Expand Down
82 changes: 68 additions & 14 deletions src/opaque_closure.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,23 @@ JL_DLLEXPORT int jl_is_valid_oc_argtype(jl_tupletype_t *argt, jl_method_t *sourc
return 1;
}

jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub,
jl_value_t *source_, jl_value_t **env, size_t nenv)
static jl_value_t *prepend_type(jl_value_t *t0, jl_tupletype_t *t)
{
jl_svec_t *sig_args = NULL;
JL_GC_PUSH1(&sig_args);
size_t nsig = 1 + jl_svec_len(t->parameters);
sig_args = jl_alloc_svec_uninit(nsig);
jl_svecset(sig_args, 0, t0);
for (size_t i = 0; i < nsig-1; ++i) {
jl_svecset(sig_args, 1+i, jl_tparam(t, i));
}
jl_value_t *sigtype = (jl_value_t*)jl_apply_tuple_type_v(jl_svec_data(sig_args), nsig);
JL_GC_POP();
return sigtype;
}

static jl_opaque_closure_t *new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub,
jl_value_t *source_, jl_value_t *captures)
{
if (!jl_is_tuple_type((jl_value_t*)argt)) {
jl_error("OpaqueClosure argument tuple must be a tuple type");
Expand All @@ -40,26 +55,19 @@ jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_
}
if (jl_nparams(argt) + 1 - jl_is_va_tuple(argt) < source->nargs - source->isva)
jl_error("Argument type tuple has too few required arguments for method");
jl_task_t *ct = jl_current_task;
jl_value_t *sigtype = NULL;
JL_GC_PUSH1(&sigtype);
sigtype = prepend_type(jl_typeof(captures), argt);

jl_value_t *oc_type JL_ALWAYS_LEAFTYPE;
oc_type = jl_apply_type2((jl_value_t*)jl_opaque_closure_type, (jl_value_t*)argt, rt_ub);
JL_GC_PROMISE_ROOTED(oc_type);
jl_value_t *captures = NULL, *sigtype = NULL;
jl_svec_t *sig_args = NULL;
JL_GC_PUSH3(&captures, &sigtype, &sig_args);
captures = jl_f_tuple(NULL, env, nenv);

size_t nsig = 1 + jl_svec_len(argt->parameters);
sig_args = jl_alloc_svec_uninit(nsig);
jl_svecset(sig_args, 0, jl_typeof(captures));
for (size_t i = 0; i < nsig-1; ++i) {
jl_svecset(sig_args, 1+i, jl_tparam(argt, i));
}
sigtype = (jl_value_t*)jl_apply_tuple_type_v(jl_svec_data(sig_args), nsig);
jl_method_instance_t *mi = jl_specializations_get_linfo(source, sigtype, jl_emptysvec);
size_t world = jl_atomic_load_acquire(&jl_world_counter);
jl_code_instance_t *ci = jl_compile_method_internal(mi, world);

jl_task_t *ct = jl_current_task;
jl_opaque_closure_t *oc = (jl_opaque_closure_t*)jl_gc_alloc(ct->ptls, sizeof(jl_opaque_closure_t), oc_type);
JL_GC_POP();
oc->source = source;
Expand All @@ -82,6 +90,52 @@ jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_
return oc;
}

jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub,
jl_value_t *source_, jl_value_t **env, size_t nenv)
{
jl_value_t *captures = jl_f_tuple(NULL, env, nenv);
JL_GC_PUSH1(&captures);
jl_opaque_closure_t *oc = new_opaque_closure(argt, rt_lb, rt_ub, source_, captures);
JL_GC_POP();
return oc;
}

jl_method_t *jl_make_opaque_closure_method(jl_module_t *module, jl_value_t *name,
int nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva);

JL_DLLEXPORT jl_code_instance_t* jl_new_codeinst(
jl_method_instance_t *mi, jl_value_t *rettype,
jl_value_t *inferred_const, jl_value_t *inferred,
int32_t const_flags, size_t min_world, size_t max_world,
uint32_t ipo_effects, uint32_t effects, jl_value_t *argescapes,
uint8_t relocatability);

JL_DLLEXPORT void jl_mi_cache_insert(jl_method_instance_t *mi JL_ROOTING_ARGUMENT,
jl_code_instance_t *ci JL_ROOTED_ARGUMENT JL_MAYBE_UNROOTED);

JL_DLLEXPORT jl_opaque_closure_t *jl_new_opaque_closure_from_code_info(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub,
jl_module_t *mod, jl_code_info_t *ci, int lineno, jl_value_t *file, int nargs, int isva, jl_value_t *env)
{
if (!ci->inferred)
jl_error("CodeInfo must already be inferred");
jl_value_t *root = NULL, *sigtype = NULL;
jl_code_instance_t *inst = NULL;
JL_GC_PUSH3(&root, &sigtype, &inst);
root = jl_box_long(lineno);
root = jl_new_struct(jl_linenumbernode_type, root, file);
root = (jl_value_t*)jl_make_opaque_closure_method(mod, jl_nothing, nargs, root, ci, isva);

sigtype = prepend_type(jl_typeof(env), argt);
jl_method_instance_t *mi = jl_specializations_get_linfo((jl_method_t*)root, sigtype, jl_emptysvec);
inst = jl_new_codeinst(mi, rt_ub, NULL, (jl_value_t*)ci,
0, ((jl_method_t*)root)->primary_world, -1, 0, 0, jl_nothing, 0);
jl_mi_cache_insert(mi, inst);

jl_opaque_closure_t *oc = new_opaque_closure(argt, rt_lb, rt_ub, root, env);
JL_GC_POP();
return oc;
}

JL_CALLABLE(jl_new_opaque_closure_jlcall)
{
if (nargs < 4)
Expand Down
46 changes: 46 additions & 0 deletions test/opaque_closure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,49 @@ end
let oc = @opaque a->sin(a)
@test length(code_typed(oc, (Int,))) == 1
end

# constructing an opaque closure from IRCode
using Core.Compiler: IRCode
using Core: CodeInfo

function OC(ir::IRCode, nargs::Int, isva::Bool, env...)
if (isva && nargs > length(ir.argtypes)) || (!isva && nargs != length(ir.argtypes)-1)
throw(ArgumentError("invalid argument count"))
end
src = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
src.slotflags = UInt8[]
src.slotnames = fill(:none, nargs+1)
Core.Compiler.replace_code_newstyle!(src, ir, nargs+1)
Core.Compiler.widen_all_consts!(src)
src.inferred = true
# NOTE: we need ir.argtypes[1] == typeof(env)

ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
Tuple{ir.argtypes[2:end]...}, Union{}, Any, @__MODULE__, src, 0, nothing, nargs, isva, env)
end

function OC(src::CodeInfo, env...)
M = src.parent.def
sig = Base.tuple_type_tail(src.parent.specTypes)

ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
sig, Union{}, Any, @__MODULE__, src, 0, nothing, M.nargs - 1, M.isva, env)
end

let ci = code_typed(+, (Int, Int))[1][1]
ir = Core.Compiler.inflate_ir(ci)
@test OC(ir, 2, false)(40, 2) == 42
@test OC(ci)(40, 2) == 42
end

let ci = code_typed((x, y...)->(x, y), (Int, Int))[1][1]
ir = Core.Compiler.inflate_ir(ci)
@test OC(ir, 2, true)(40, 2) === (40, (2,))
@test OC(ci)(40, 2) === (40, (2,))
end

let ci = code_typed((x, y...)->(x, y), (Int, Int))[1][1]
ir = Core.Compiler.inflate_ir(ci)
@test_throws MethodError OC(ir, 2, true)(1, 2, 3)
@test_throws MethodError OC(ci)(1, 2, 3)
end