Skip to content

De-duplicate edges in typeinfer instead of gf.c #58117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 55 additions & 27 deletions Compiler/src/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -578,12 +578,17 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter, cycleid::
nothing
end

# record the backedges
function store_backedges(caller::CodeInstance, edges::SimpleVector)
isa(caller.def.def, Method) || return # don't add backedges to toplevel method instance
i = 1
while true
i > length(edges) && return nothing
# Iterate a series of back-edges that need registering, based on the provided forward edge list.
# Back-edges are returned as (invokesig, item), where the item is a Binding, MethodInstance, or
# MethodTable.
struct ForwardToBackedgeIterator
forward_edges::SimpleVector
end

function Base.iterate(it::ForwardToBackedgeIterator, i::Int = 1)
edges = it.forward_edges
i > length(edges) && return nothing
while i ≤ length(edges)
item = edges[i]
if item isa Int
i += 2
Expand All @@ -593,32 +598,55 @@ function store_backedges(caller::CodeInstance, edges::SimpleVector)
i += 1
continue
elseif isa(item, Core.Binding)
i += 1
maybe_add_binding_backedge!(item, caller)
continue
return ((nothing, item), i + 1)
end
if isa(item, CodeInstance)
item = item.def
end
if isa(item, MethodInstance) # regular dispatch
ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), item, nothing, caller)
i += 1
item = get_ci_mi(item)
return ((nothing, item), i + 1)
elseif isa(item, MethodInstance) # regular dispatch
return ((nothing, item), i + 1)
else
invokesig = item
callee = edges[i+1]
if isa(callee, MethodTable) # abstract dispatch (legacy style edges)
ccall(:jl_method_table_add_backedge, Cvoid, (Any, Any, Any), callee, item, caller)
i += 2
continue
elseif isa(callee, Method)
# ignore `Method`-edges (from e.g. failed `abstract_call_method`)
i += 2
continue
# `invoke` edge
elseif isa(callee, CodeInstance)
callee = get_ci_mi(callee)
isa(callee, Method) && (i += 2; continue) # ignore `Method`-edges (from e.g. failed `abstract_call_method`)
if isa(callee, MethodTable)
# abstract dispatch (legacy style edges)
return ((invokesig, callee), i + 2)
else
# `invoke` edge
callee = isa(callee, CodeInstance) ? get_ci_mi(callee) : callee::MethodInstance
return ((invokesig, callee), i + 2)
end
end
end
return nothing
end

# record the backedges
function store_backedges(caller::CodeInstance, edges::SimpleVector)
isa(caller.def.def, Method) || return # don't add backedges to toplevel method instance

backedges = ForwardToBackedgeIterator(edges)
for (i, (invokesig, item)) in enumerate(backedges)
# check for any duplicate edges we've already registered
duplicate_found = false
for (i′, (invokesig′, item′)) in enumerate(backedges)
i == i′ && break
if item′ === item && invokesig′ == invokesig
duplicate_found = true
break
end
end

if !duplicate_found
if item isa Core.Binding
maybe_add_binding_backedge!(item, caller)
elseif item isa MethodTable
ccall(:jl_method_table_add_backedge, Cvoid, (Any, Any, Any), item, invokesig, caller)
else
item::MethodInstance
ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), item, invokesig, caller)
end
ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), callee, item, caller)
i += 2
end
end
nothing
Expand Down
15 changes: 13 additions & 2 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -1997,6 +1997,9 @@ JL_DLLEXPORT void jl_method_instance_add_backedge(jl_method_instance_t *callee,
jl_gc_wb(callee, backedges);
}
else {
#ifndef JL_NDEBUG
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This introduces a race condition into the code (because of gc, threads, etc). Probably useful to know this scan is costly on some benchmarks though

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify where the race condition is?

Do you mean that inference may simultaneously try to store_backedges for the same caller CI from two different threads?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see, we might be able to now assume that each CI is unique, while in older versions the MI were not expected to be unique

// It is the caller's (inference's) responsibility to de-duplicate edges. Here we are only
// checking its work.
size_t i = 0, l = jl_array_nrows(backedges);
for (i = 0; i < l; i++) {
// optimized version of while (i < l) i = get_next_edge(callee->backedges, i, &invokeTypes, &mi);
Expand All @@ -2012,6 +2015,8 @@ JL_DLLEXPORT void jl_method_instance_add_backedge(jl_method_instance_t *callee,
break;
}
}
assert(!found && "duplicate back-edge registered");
#endif
}
if (!found)
push_edge(backedges, invokesig, caller);
Expand All @@ -2037,14 +2042,20 @@ JL_DLLEXPORT void jl_method_table_add_backedge(jl_methtable_t *mt, jl_value_t *t
else {
// check if the edge is already present and avoid adding a duplicate
size_t i, l = jl_array_nrows(mt->backedges);
#ifndef JL_NDEBUG
// It is the caller's (inference's) responsibility to de-duplicate edges. Here we are only
// checking its work.
int found = 0;
for (i = 1; i < l; i += 2) {
if (jl_array_ptr_ref(mt->backedges, i) == (jl_value_t*)caller) {
if (jl_types_equal(jl_array_ptr_ref(mt->backedges, i - 1), typ)) {
JL_UNLOCK(&mt->writelock);
return;
found = 1;
break;
}
}
}
assert(!found && "duplicate back-edge registered");
#endif
// reuse an already cached instance of this type, if possible
// TODO: use jl_cache_type_(tt) like cache_method does, instead of this linear scan?
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's worth seeing if we can bypass this remaining linear scan in the MethodTable edge insertion.

If I disable it manually, the timing drops to 440 milliseconds

for (i = 1; i < l; i += 2) {
Expand Down
10 changes: 9 additions & 1 deletion src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,12 @@ static void jl_insert_into_serialization_queue(jl_serializer_state *s, jl_value_
}
goto done_fields; // for now
}
if (s->incremental && jl_is_mtable(v)) {
jl_methtable_t *mt = (jl_methtable_t *)v;
// Any back-edges will be re-validated and added by staticdata.jl, so
// drop them from the image here
record_field_change((jl_value_t**)&mt->backedges, NULL);
}
if (s->incremental && jl_is_method_instance(v)) {
jl_method_instance_t *mi = (jl_method_instance_t*)v;
jl_value_t *def = mi->def.value;
Expand All @@ -877,12 +883,14 @@ static void jl_insert_into_serialization_queue(jl_serializer_state *s, jl_value_
// we only need 3 specific fields of this (the rest are restored afterward, if valid)
// in particular, cache is repopulated by jl_mi_cache_insert for all foreign function,
// so must not be present here
record_field_change((jl_value_t**)&mi->backedges, NULL);
record_field_change((jl_value_t**)&mi->cache, NULL);
}
else {
assert(!needs_recaching(v, s->query_cache));
}
// Any back-edges will be re-validated and added by staticdata.jl, so
// drop them from the image here
record_field_change((jl_value_t**)&mi->backedges, NULL);
// n.b. opaque closures cannot be inspected and relied upon like a
// normal method since they can get improperly introduced by generated
// functions, so if they appeared at all, we will probably serialize
Expand Down