Skip to content

Commit 706f894

Browse files
braxtonmckeeWilliam Grant
authored andcommitted
Prune the graph of inflight functions to not include the ones we don't need.
When we first call a python function 'f' with a specific set of arguments, we may not know its return type the first time we try to convert it. To ensure we have a stable typing graph, we repeatedly update the active functions in our graph until the type graph is stable. This can lead to many copies of the same function, or even multiple signatures of the same function, only one of which we'll use. This change prunes those away before we submit them to the LLVM layer.
1 parent 22f01c5 commit 706f894

File tree

2 files changed

+42
-4
lines changed

2 files changed

+42
-4
lines changed

typed_python/compiler/directed_graph.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ def hasEdge(self, source, dest):
4444
return False
4545
return dest in self.sourceToDest[source]
4646

47+
def clearOutgoing(self, node):
48+
for child in list(self.outgoing(node)):
49+
self.dropEdge(node, child)
50+
4751
def outgoing(self, node):
4852
return self.sourceToDest.get(node, set())
4953

typed_python/compiler/python_to_native_converter.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,26 @@ def __init__(self):
5858
# (priority, node) pairs that need to recompute
5959
self._dirty_inflight_functions_with_order = SortedSet(key=lambda pair: pair[0])
6060

61+
def reachableInSet(self, rootIdentity, activeSet):
62+
"""Produce the subset of 'activeSet' that are reachable from 'rootIdentity'"""
63+
reachable = set()
64+
65+
def walk(node):
66+
if node in reachable or node not in activeSet:
67+
return
68+
69+
reachable.add(node)
70+
71+
for child in self._dependencies.outgoing(node):
72+
walk(child)
73+
74+
walk(rootIdentity)
75+
76+
return reachable
77+
78+
def clearOutgoingEdgesFor(self, identity):
79+
self._dependencies.clearOutgoing(identity)
80+
6181
def dropNode(self, node):
6282
self._dependencies.dropNode(node, False)
6383
if node in self._identity_levels:
@@ -341,8 +361,7 @@ def defineNonPythonFunction(self, name, identityTuple, context):
341361
if self._currentlyConverting is None:
342362
# force the function to resolve immediately
343363
self._resolveAllInflightFunctions()
344-
self._installInflightFunctions()
345-
self._inflight_function_conversions.clear()
364+
self._installInflightFunctions(identity)
346365

347366
return self.getTarget(linkName)
348367

@@ -540,6 +559,8 @@ def _resolveAllInflightFunctions(self):
540559

541560
# this calls back into convert with dependencies
542561
# they get registered as dirty
562+
self._dependencies.clearOutgoingEdgesFor(identity)
563+
543564
nativeFunction, actual_output_type = functionConverter.convertToNativeFunction()
544565

545566
if nativeFunction is not None:
@@ -900,7 +921,7 @@ def convert(
900921
if isRoot:
901922
try:
902923
self._resolveAllInflightFunctions()
903-
self._installInflightFunctions()
924+
self._installInflightFunctions(identity)
904925
return self.getTarget(name)
905926
finally:
906927
self._inflight_function_conversions.clear()
@@ -915,7 +936,7 @@ def convert(
915936
raise RuntimeError(f"Unexpected conversion error for {name}")
916937
return None
917938

918-
def _installInflightFunctions(self):
939+
def _installInflightFunctions(self, rootIdentity):
919940
"""Add all function definitions corresponding to keys in inflight_function_conversions to the relevant dictionaries."""
920941
if VALIDATE_FUNCTION_DEFINITIONS_STABLE:
921942
# this should always be true, but its expensive so we have it off by default
@@ -929,7 +950,17 @@ def _installInflightFunctions(self):
929950
finally:
930951
self._currentlyConverting = None
931952

953+
# restrict to the set of inflight functions that are reachable from rootName
954+
# we produce copies of functions that we don't actually need to compile during
955+
# early phases of type inference
956+
reachable = self._dependencies.reachableInSet(
957+
rootIdentity,
958+
set(self._inflight_function_conversions)
959+
)
960+
932961
for identifier, functionConverter in self._inflight_function_conversions.items():
962+
if identifier not in reachable:
963+
continue
933964
outboundTargets = []
934965
for outboundFuncId in self._dependencies.getNamesDependedOn(identifier):
935966
name = self._link_name_for_identity[outboundFuncId]
@@ -987,3 +1018,6 @@ def _installInflightFunctions(self):
9871018

9881019
self._definitions[name] = nativeFunction
9891020
self._new_native_functions.add(name)
1021+
1022+
self._inflight_definitions.clear()
1023+
self._inflight_function_conversions.clear()

0 commit comments

Comments
 (0)