Skip to content

Commit 41361cd

Browse files
committed
Revert "Allow for partial module loads in compiler cache."
This reverts commit 2782704.
1 parent be4ce9b commit 41361cd

9 files changed

+167
-260
lines changed

typed_python/compiler/binary_shared_object.py

+12-13
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626

2727

2828
class LoadedBinarySharedObject(LoadedModule):
29-
def __init__(self, binarySharedObject, diskPath, functionPointers, serializedGlobalVariableDefinitions):
30-
super().__init__(functionPointers, serializedGlobalVariableDefinitions)
29+
def __init__(self, binarySharedObject, diskPath, functionPointers, globalVariableDefinitions):
30+
super().__init__(functionPointers, globalVariableDefinitions)
3131

3232
self.binarySharedObject = binarySharedObject
3333
self.diskPath = diskPath
@@ -36,32 +36,30 @@ def __init__(self, binarySharedObject, diskPath, functionPointers, serializedGlo
3636
class BinarySharedObject:
3737
"""Models a shared object library (.so) loadable on linux systems."""
3838

39-
def __init__(self, binaryForm, functionTypes, serializedGlobalVariableDefinitions, globalDependencies):
39+
def __init__(self, binaryForm, functionTypes, globalVariableDefinitions):
4040
"""
4141
Args:
42-
binaryForm: a bytes object containing the actual compiled code for the module
43-
serializedGlobalVariableDefinitions: a map from name to GlobalVariableDefinition
44-
globalDependencies: a dict from function linkname to the list of global variables it depends on
42+
binaryForm - a bytes object containing the actual compiled code for the module
43+
globalVariableDefinitions - a map from name to GlobalVariableDefinition
4544
"""
4645
self.binaryForm = binaryForm
4746
self.functionTypes = functionTypes
48-
self.serializedGlobalVariableDefinitions = serializedGlobalVariableDefinitions
49-
self.globalDependencies = globalDependencies
47+
self.globalVariableDefinitions = globalVariableDefinitions
5048
self.hash = sha_hash(binaryForm)
5149

5250
@property
5351
def definedSymbols(self):
5452
return self.functionTypes.keys()
5553

5654
@staticmethod
57-
def fromDisk(path, serializedGlobalVariableDefinitions, functionNameToType, globalDependencies):
55+
def fromDisk(path, globalVariableDefinitions, functionNameToType):
5856
with open(path, "rb") as f:
5957
binaryForm = f.read()
6058

61-
return BinarySharedObject(binaryForm, functionNameToType, serializedGlobalVariableDefinitions, globalDependencies)
59+
return BinarySharedObject(binaryForm, functionNameToType, globalVariableDefinitions)
6260

6361
@staticmethod
64-
def fromModule(module, serializedGlobalVariableDefinitions, functionNameToType, globalDependencies):
62+
def fromModule(module, globalVariableDefinitions, functionNameToType):
6563
target_triple = llvm.get_process_triple()
6664
target = llvm.Target.from_triple(target_triple)
6765
target_machine_shared_object = target.create_target_machine(reloc='pic', codemodel='default')
@@ -82,7 +80,7 @@ def fromModule(module, serializedGlobalVariableDefinitions, functionNameToType,
8280
)
8381

8482
with open(os.path.join(tf, "module.so"), "rb") as so_file:
85-
return BinarySharedObject(so_file.read(), functionNameToType, serializedGlobalVariableDefinitions, globalDependencies)
83+
return BinarySharedObject(so_file.read(), functionNameToType, globalVariableDefinitions)
8684

8785
def load(self, storageDir):
8886
"""Instantiate this .so in temporary storage and return a dict from symbol -> integer function pointer"""
@@ -129,7 +127,8 @@ def loadFromPath(self, modulePath):
129127
self,
130128
modulePath,
131129
functionPointers,
132-
self.serializedGlobalVariableDefinitions
130+
self.globalVariableDefinitions
133131
)
132+
loadedModule.linkGlobalVariables()
134133

135134
return loadedModule

typed_python/compiler/compiler_cache.py

+70-106
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,9 @@
1515
import os
1616
import uuid
1717
import shutil
18+
from typed_python.compiler.loaded_module import LoadedModule
19+
from typed_python.compiler.binary_shared_object import BinarySharedObject
1820

19-
from typing import Optional, List
20-
21-
from typed_python.compiler.binary_shared_object import LoadedBinarySharedObject, BinarySharedObject
22-
from typed_python.compiler.directed_graph import DirectedGraph
23-
from typed_python.compiler.typed_call_target import TypedCallTarget
2421
from typed_python.SerializationContext import SerializationContext
2522
from typed_python import Dict, ListOf
2623

@@ -55,173 +52,146 @@ def __init__(self, cacheDir):
5552

5653
ensureDirExists(cacheDir)
5754

58-
self.loadedBinarySharedObjects = Dict(str, LoadedBinarySharedObject)()
55+
self.loadedModules = Dict(str, LoadedModule)()
5956
self.nameToModuleHash = Dict(str, str)()
6057

61-
self.moduleManifestsLoaded = set()
58+
self.modulesMarkedValid = set()
59+
self.modulesMarkedInvalid = set()
6260

6361
for moduleHash in os.listdir(self.cacheDir):
6462
if len(moduleHash) == 40:
6563
self.loadNameManifestFromStoredModuleByHash(moduleHash)
6664

67-
# the set of functions with an associated module in loadedBinarySharedObjects
68-
self.targetsLoaded: Dict[str, TypedCallTarget] = {}
69-
70-
# the set of functions with linked and validated globals (i.e. ready to be run).
71-
self.targetsValidated = set()
72-
73-
self.function_dependency_graph = DirectedGraph()
74-
# dict from function linkname to list of global names (should be llvm keys in serialisedGlobalDefinitions)
75-
self.global_dependencies = Dict(str, ListOf(str))()
65+
self.targetsLoaded = {}
7666

77-
def hasSymbol(self, linkName: str) -> bool:
78-
"""NB this will return True even if the linkName is ultimately unretrievable."""
67+
def hasSymbol(self, linkName):
7968
return linkName in self.nameToModuleHash
8069

81-
def getTarget(self, linkName: str) -> TypedCallTarget:
82-
if not self.hasSymbol(linkName):
83-
raise ValueError(f'symbol not found for linkName {linkName}')
70+
def getTarget(self, linkName):
71+
assert self.hasSymbol(linkName)
72+
8473
self.loadForSymbol(linkName)
74+
8575
return self.targetsLoaded[linkName]
8676

87-
def dependencies(self, linkName: str) -> Optional[List[str]]:
88-
"""Returns all the function names that `linkName` depends on"""
89-
return list(self.function_dependency_graph.outgoing(linkName))
77+
def markModuleHashInvalid(self, hashstr):
78+
with open(os.path.join(self.cacheDir, hashstr, "marked_invalid"), "w"):
79+
pass
9080

91-
def loadForSymbol(self, linkName: str) -> None:
92-
"""Loads the whole module, and any submodules, into LoadedBinarySharedObjects"""
81+
def loadForSymbol(self, linkName):
9382
moduleHash = self.nameToModuleHash[linkName]
9483

9584
self.loadModuleByHash(moduleHash)
9685

97-
if linkName not in self.targetsValidated:
98-
dependantFuncs = self.dependencies(linkName) + [linkName]
99-
globalsToLink = {} # dict from modulehash to list of globals.
100-
for funcName in dependantFuncs:
101-
if funcName not in self.targetsValidated:
102-
funcModuleHash = self.nameToModuleHash[funcName]
103-
# append to the list of globals to link for a given module. TODO: optimise this, don't double-link.
104-
globalsToLink[funcModuleHash] = globalsToLink.get(funcModuleHash, []) + self.global_dependencies.get(funcName, [])
105-
106-
for moduleHash, globs in globalsToLink.items(): # this works because loadModuleByHash loads submodules too.
107-
if globs:
108-
definitionsToLink = {x: self.loadedBinarySharedObjects[moduleHash].serializedGlobalVariableDefinitions[x]
109-
for x in globs
110-
}
111-
self.loadedBinarySharedObjects[moduleHash].linkGlobalVariables(definitionsToLink)
112-
if not self.loadedBinarySharedObjects[moduleHash].validateGlobalVariables(definitionsToLink):
113-
raise RuntimeError('failed to validate globals when loading:', linkName)
114-
115-
self.targetsValidated.update(dependantFuncs)
116-
117-
def loadModuleByHash(self, moduleHash: str) -> None:
86+
def loadModuleByHash(self, moduleHash):
11887
"""Load a module by name.
11988
12089
As we load, place all the newly imported typed call targets into
12190
'nameToTypedCallTarget' so that the rest of the system knows what functions
12291
have been uncovered.
12392
"""
124-
if moduleHash in self.loadedBinarySharedObjects:
125-
return
93+
if moduleHash in self.loadedModules:
94+
return True
12695

12796
targetDir = os.path.join(self.cacheDir, moduleHash)
12897

129-
# TODO (Will) - store these names as module consts, use one .dat only
130-
with open(os.path.join(targetDir, "type_manifest.dat"), "rb") as f:
131-
callTargets = SerializationContext().deserialize(f.read())
132-
133-
with open(os.path.join(targetDir, "globals_manifest.dat"), "rb") as f:
134-
serializedGlobalVarDefs = SerializationContext().deserialize(f.read())
98+
try:
99+
with open(os.path.join(targetDir, "type_manifest.dat"), "rb") as f:
100+
callTargets = SerializationContext().deserialize(f.read())
135101

136-
with open(os.path.join(targetDir, "native_type_manifest.dat"), "rb") as f:
137-
functionNameToNativeType = SerializationContext().deserialize(f.read())
102+
with open(os.path.join(targetDir, "globals_manifest.dat"), "rb") as f:
103+
globalVarDefs = SerializationContext().deserialize(f.read())
138104

139-
with open(os.path.join(targetDir, "submodules.dat"), "rb") as f:
140-
submodules = SerializationContext().deserialize(f.read(), ListOf(str))
105+
with open(os.path.join(targetDir, "native_type_manifest.dat"), "rb") as f:
106+
functionNameToNativeType = SerializationContext().deserialize(f.read())
141107

142-
with open(os.path.join(targetDir, "function_dependencies.dat"), "rb") as f:
143-
dependency_edgelist = SerializationContext().deserialize(f.read())
108+
with open(os.path.join(targetDir, "submodules.dat"), "rb") as f:
109+
submodules = SerializationContext().deserialize(f.read(), ListOf(str))
110+
except Exception:
111+
self.markModuleHashInvalid(moduleHash)
112+
return False
144113

145-
with open(os.path.join(targetDir, "global_dependencies.dat"), "rb") as f:
146-
globalDependencies = SerializationContext().deserialize(f.read())
114+
if not LoadedModule.validateGlobalVariables(globalVarDefs):
115+
self.markModuleHashInvalid(moduleHash)
116+
return False
147117

148118
# load the submodules first
149119
for submodule in submodules:
150-
self.loadModuleByHash(submodule)
120+
if not self.loadModuleByHash(submodule):
121+
return False
151122

152123
modulePath = os.path.join(targetDir, "module.so")
153124

154125
loaded = BinarySharedObject.fromDisk(
155126
modulePath,
156-
serializedGlobalVarDefs,
157-
functionNameToNativeType,
158-
globalDependencies
159-
127+
globalVarDefs,
128+
functionNameToNativeType
160129
).loadFromPath(modulePath)
161130

162-
self.loadedBinarySharedObjects[moduleHash] = loaded
131+
self.loadedModules[moduleHash] = loaded
163132

164133
self.targetsLoaded.update(callTargets)
165134

166-
assert not any(key in self.global_dependencies for key in globalDependencies) # should only happen if there's a hash collision.
167-
self.global_dependencies.update(globalDependencies)
168-
169-
# update the cache's dependency graph with our new edges.
170-
for function_name, dependant_function_name in dependency_edgelist:
171-
self.function_dependency_graph.addEdge(source=function_name, dest=dependant_function_name)
135+
return True
172136

173-
def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies, dependencyEdgelist):
137+
def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies):
174138
"""Add new code to the compiler cache.
175139
176140
Args:
177-
binarySharedObject: a BinarySharedObject containing the actual assembler
178-
we've compiled.
179-
nameToTypedCallTarget: a dict from linkname to TypedCallTarget telling us
180-
the formal python types for all the objects.
181-
linkDependencies: a set of linknames we depend on directly.
182-
dependencyEdgelist (list): a list of source, dest pairs giving the set of dependency graph for the
183-
module.
141+
binarySharedObject - a BinarySharedObject containing the actual assembler
142+
we've compiled
143+
nameToTypedCallTarget - a dict from linkname to TypedCallTarget telling us
144+
the formal python types for all the objects
145+
linkDependencies - a set of linknames we depend on directly.
184146
"""
185147
dependentHashes = set()
186148

187149
for name in linkDependencies:
188150
dependentHashes.add(self.nameToModuleHash[name])
189151

190-
path, hashToUse = self.writeModuleToDisk(binarySharedObject, nameToTypedCallTarget, dependentHashes, dependencyEdgelist)
152+
path, hashToUse = self.writeModuleToDisk(binarySharedObject, nameToTypedCallTarget, dependentHashes)
191153

192-
self.loadedBinarySharedObjects[hashToUse] = (
154+
self.loadedModules[hashToUse] = (
193155
binarySharedObject.loadFromPath(os.path.join(path, "module.so"))
194156
)
195157

196158
for n in binarySharedObject.definedSymbols:
197159
self.nameToModuleHash[n] = hashToUse
198160

199-
# link & validate all globals for the new module
200-
self.loadedBinarySharedObjects[hashToUse].linkGlobalVariables()
201-
if not self.loadedBinarySharedObjects[hashToUse].validateGlobalVariables(
202-
self.loadedBinarySharedObjects[hashToUse].serializedGlobalVariableDefinitions):
203-
raise RuntimeError('failed to validate globals in new module:', hashToUse)
204-
205-
def loadNameManifestFromStoredModuleByHash(self, moduleHash) -> None:
206-
if moduleHash in self.moduleManifestsLoaded:
207-
return
161+
def loadNameManifestFromStoredModuleByHash(self, moduleHash):
162+
if moduleHash in self.modulesMarkedValid:
163+
return True
208164

209165
targetDir = os.path.join(self.cacheDir, moduleHash)
210166

167+
# ignore 'marked invalid'
168+
if os.path.exists(os.path.join(targetDir, "marked_invalid")):
169+
# just bail - don't try to read it now
170+
171+
# for the moment, we don't try to clean up the cache, because
172+
# we can't be sure that some process is not still reading the
173+
# old files.
174+
self.modulesMarkedInvalid.add(moduleHash)
175+
return False
176+
211177
with open(os.path.join(targetDir, "submodules.dat"), "rb") as f:
212178
submodules = SerializationContext().deserialize(f.read(), ListOf(str))
213179

214180
for subHash in submodules:
215-
self.loadNameManifestFromStoredModuleByHash(subHash)
181+
if not self.loadNameManifestFromStoredModuleByHash(subHash):
182+
self.markModuleHashInvalid(subHash)
183+
return False
216184

217185
with open(os.path.join(targetDir, "name_manifest.dat"), "rb") as f:
218186
self.nameToModuleHash.update(
219187
SerializationContext().deserialize(f.read(), Dict(str, str))
220188
)
221189

222-
self.moduleManifestsLoaded.add(moduleHash)
190+
self.modulesMarkedValid.add(moduleHash)
191+
192+
return True
223193

224-
def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodules, dependencyEdgelist):
194+
def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodules):
225195
"""Write out a disk representation of this module.
226196
227197
This includes writing both the shared object, a manifest of the function names
@@ -274,17 +244,11 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
274244

275245
# write the type manifest
276246
with open(os.path.join(tempTargetDir, "globals_manifest.dat"), "wb") as f:
277-
f.write(SerializationContext().serialize(binarySharedObject.serializedGlobalVariableDefinitions))
247+
f.write(SerializationContext().serialize(binarySharedObject.globalVariableDefinitions))
278248

279249
with open(os.path.join(tempTargetDir, "submodules.dat"), "wb") as f:
280250
f.write(SerializationContext().serialize(ListOf(str)(submodules), ListOf(str)))
281251

282-
with open(os.path.join(tempTargetDir, "function_dependencies.dat"), "wb") as f:
283-
f.write(SerializationContext().serialize(dependencyEdgelist)) # might need a listof
284-
285-
with open(os.path.join(tempTargetDir, "global_dependencies.dat"), "wb") as f:
286-
f.write(SerializationContext().serialize(binarySharedObject.globalDependencies))
287-
288252
try:
289253
os.rename(tempTargetDir, targetDir)
290254
except IOError:
@@ -300,7 +264,7 @@ def function_pointer_by_name(self, linkName):
300264
if moduleHash is None:
301265
raise Exception("Can't find a module for " + linkName)
302266

303-
if moduleHash not in self.loadedBinarySharedObjects:
267+
if moduleHash not in self.loadedModules:
304268
self.loadForSymbol(linkName)
305269

306-
return self.loadedBinarySharedObjects[moduleHash].functionPointers[linkName]
270+
return self.loadedModules[moduleHash].functionPointers[linkName]

0 commit comments

Comments
 (0)