15
15
import os
16
16
import uuid
17
17
import shutil
18
+ from typed_python .compiler .loaded_module import LoadedModule
19
+ from typed_python .compiler .binary_shared_object import BinarySharedObject
18
20
19
- from typing import Optional , List
20
-
21
- from typed_python .compiler .binary_shared_object import LoadedBinarySharedObject , BinarySharedObject
22
- from typed_python .compiler .directed_graph import DirectedGraph
23
- from typed_python .compiler .typed_call_target import TypedCallTarget
24
21
from typed_python .SerializationContext import SerializationContext
25
22
from typed_python import Dict , ListOf
26
23
@@ -55,173 +52,146 @@ def __init__(self, cacheDir):
55
52
56
53
ensureDirExists (cacheDir )
57
54
58
- self .loadedBinarySharedObjects = Dict (str , LoadedBinarySharedObject )()
55
+ self .loadedModules = Dict (str , LoadedModule )()
59
56
self .nameToModuleHash = Dict (str , str )()
60
57
61
- self .moduleManifestsLoaded = set ()
58
+ self .modulesMarkedValid = set ()
59
+ self .modulesMarkedInvalid = set ()
62
60
63
61
for moduleHash in os .listdir (self .cacheDir ):
64
62
if len (moduleHash ) == 40 :
65
63
self .loadNameManifestFromStoredModuleByHash (moduleHash )
66
64
67
- # the set of functions with an associated module in loadedBinarySharedObjects
68
- self .targetsLoaded : Dict [str , TypedCallTarget ] = {}
69
-
70
- # the set of functions with linked and validated globals (i.e. ready to be run).
71
- self .targetsValidated = set ()
72
-
73
- self .function_dependency_graph = DirectedGraph ()
74
- # dict from function linkname to list of global names (should be llvm keys in serialisedGlobalDefinitions)
75
- self .global_dependencies = Dict (str , ListOf (str ))()
65
+ self .targetsLoaded = {}
76
66
77
- def hasSymbol (self , linkName : str ) -> bool :
78
- """NB this will return True even if the linkName is ultimately unretrievable."""
67
+ def hasSymbol (self , linkName ):
79
68
return linkName in self .nameToModuleHash
80
69
81
- def getTarget (self , linkName : str ) -> TypedCallTarget :
82
- if not self .hasSymbol (linkName ):
83
- raise ValueError ( f'symbol not found for linkName { linkName } ' )
70
+ def getTarget (self , linkName ) :
71
+ assert self .hasSymbol (linkName )
72
+
84
73
self .loadForSymbol (linkName )
74
+
85
75
return self .targetsLoaded [linkName ]
86
76
87
- def dependencies (self , linkName : str ) -> Optional [ List [ str ]] :
88
- """Returns all the function names that `linkName` depends on"""
89
- return list ( self . function_dependency_graph . outgoing ( linkName ))
77
+ def markModuleHashInvalid (self , hashstr ) :
78
+ with open ( os . path . join ( self . cacheDir , hashstr , "marked_invalid" ), "w" ):
79
+ pass
90
80
91
- def loadForSymbol (self , linkName : str ) -> None :
92
- """Loads the whole module, and any submodules, into LoadedBinarySharedObjects"""
81
+ def loadForSymbol (self , linkName ):
93
82
moduleHash = self .nameToModuleHash [linkName ]
94
83
95
84
self .loadModuleByHash (moduleHash )
96
85
97
- if linkName not in self .targetsValidated :
98
- dependantFuncs = self .dependencies (linkName ) + [linkName ]
99
- globalsToLink = {} # dict from modulehash to list of globals.
100
- for funcName in dependantFuncs :
101
- if funcName not in self .targetsValidated :
102
- funcModuleHash = self .nameToModuleHash [funcName ]
103
- # append to the list of globals to link for a given module. TODO: optimise this, don't double-link.
104
- globalsToLink [funcModuleHash ] = globalsToLink .get (funcModuleHash , []) + self .global_dependencies .get (funcName , [])
105
-
106
- for moduleHash , globs in globalsToLink .items (): # this works because loadModuleByHash loads submodules too.
107
- if globs :
108
- definitionsToLink = {x : self .loadedBinarySharedObjects [moduleHash ].serializedGlobalVariableDefinitions [x ]
109
- for x in globs
110
- }
111
- self .loadedBinarySharedObjects [moduleHash ].linkGlobalVariables (definitionsToLink )
112
- if not self .loadedBinarySharedObjects [moduleHash ].validateGlobalVariables (definitionsToLink ):
113
- raise RuntimeError ('failed to validate globals when loading:' , linkName )
114
-
115
- self .targetsValidated .update (dependantFuncs )
116
-
117
- def loadModuleByHash (self , moduleHash : str ) -> None :
86
+ def loadModuleByHash (self , moduleHash ):
118
87
"""Load a module by name.
119
88
120
89
As we load, place all the newly imported typed call targets into
121
90
'nameToTypedCallTarget' so that the rest of the system knows what functions
122
91
have been uncovered.
123
92
"""
124
- if moduleHash in self .loadedBinarySharedObjects :
125
- return
93
+ if moduleHash in self .loadedModules :
94
+ return True
126
95
127
96
targetDir = os .path .join (self .cacheDir , moduleHash )
128
97
129
- # TODO (Will) - store these names as module consts, use one .dat only
130
- with open (os .path .join (targetDir , "type_manifest.dat" ), "rb" ) as f :
131
- callTargets = SerializationContext ().deserialize (f .read ())
132
-
133
- with open (os .path .join (targetDir , "globals_manifest.dat" ), "rb" ) as f :
134
- serializedGlobalVarDefs = SerializationContext ().deserialize (f .read ())
98
+ try :
99
+ with open (os .path .join (targetDir , "type_manifest.dat" ), "rb" ) as f :
100
+ callTargets = SerializationContext ().deserialize (f .read ())
135
101
136
- with open (os .path .join (targetDir , "native_type_manifest .dat" ), "rb" ) as f :
137
- functionNameToNativeType = SerializationContext ().deserialize (f .read ())
102
+ with open (os .path .join (targetDir , "globals_manifest .dat" ), "rb" ) as f :
103
+ globalVarDefs = SerializationContext ().deserialize (f .read ())
138
104
139
- with open (os .path .join (targetDir , "submodules .dat" ), "rb" ) as f :
140
- submodules = SerializationContext ().deserialize (f .read (), ListOf ( str ))
105
+ with open (os .path .join (targetDir , "native_type_manifest .dat" ), "rb" ) as f :
106
+ functionNameToNativeType = SerializationContext ().deserialize (f .read ())
141
107
142
- with open (os .path .join (targetDir , "function_dependencies.dat" ), "rb" ) as f :
143
- dependency_edgelist = SerializationContext ().deserialize (f .read ())
108
+ with open (os .path .join (targetDir , "submodules.dat" ), "rb" ) as f :
109
+ submodules = SerializationContext ().deserialize (f .read (), ListOf (str ))
110
+ except Exception :
111
+ self .markModuleHashInvalid (moduleHash )
112
+ return False
144
113
145
- with open (os .path .join (targetDir , "global_dependencies.dat" ), "rb" ) as f :
146
- globalDependencies = SerializationContext ().deserialize (f .read ())
114
+ if not LoadedModule .validateGlobalVariables (globalVarDefs ):
115
+ self .markModuleHashInvalid (moduleHash )
116
+ return False
147
117
148
118
# load the submodules first
149
119
for submodule in submodules :
150
- self .loadModuleByHash (submodule )
120
+ if not self .loadModuleByHash (submodule ):
121
+ return False
151
122
152
123
modulePath = os .path .join (targetDir , "module.so" )
153
124
154
125
loaded = BinarySharedObject .fromDisk (
155
126
modulePath ,
156
- serializedGlobalVarDefs ,
157
- functionNameToNativeType ,
158
- globalDependencies
159
-
127
+ globalVarDefs ,
128
+ functionNameToNativeType
160
129
).loadFromPath (modulePath )
161
130
162
- self .loadedBinarySharedObjects [moduleHash ] = loaded
131
+ self .loadedModules [moduleHash ] = loaded
163
132
164
133
self .targetsLoaded .update (callTargets )
165
134
166
- assert not any (key in self .global_dependencies for key in globalDependencies ) # should only happen if there's a hash collision.
167
- self .global_dependencies .update (globalDependencies )
168
-
169
- # update the cache's dependency graph with our new edges.
170
- for function_name , dependant_function_name in dependency_edgelist :
171
- self .function_dependency_graph .addEdge (source = function_name , dest = dependant_function_name )
135
+ return True
172
136
173
- def addModule (self , binarySharedObject , nameToTypedCallTarget , linkDependencies , dependencyEdgelist ):
137
+ def addModule (self , binarySharedObject , nameToTypedCallTarget , linkDependencies ):
174
138
"""Add new code to the compiler cache.
175
139
176
140
Args:
177
- binarySharedObject: a BinarySharedObject containing the actual assembler
178
- we've compiled.
179
- nameToTypedCallTarget: a dict from linkname to TypedCallTarget telling us
180
- the formal python types for all the objects.
181
- linkDependencies: a set of linknames we depend on directly.
182
- dependencyEdgelist (list): a list of source, dest pairs giving the set of dependency graph for the
183
- module.
141
+ binarySharedObject - a BinarySharedObject containing the actual assembler
142
+ we've compiled
143
+ nameToTypedCallTarget - a dict from linkname to TypedCallTarget telling us
144
+ the formal python types for all the objects
145
+ linkDependencies - a set of linknames we depend on directly.
184
146
"""
185
147
dependentHashes = set ()
186
148
187
149
for name in linkDependencies :
188
150
dependentHashes .add (self .nameToModuleHash [name ])
189
151
190
- path , hashToUse = self .writeModuleToDisk (binarySharedObject , nameToTypedCallTarget , dependentHashes , dependencyEdgelist )
152
+ path , hashToUse = self .writeModuleToDisk (binarySharedObject , nameToTypedCallTarget , dependentHashes )
191
153
192
- self .loadedBinarySharedObjects [hashToUse ] = (
154
+ self .loadedModules [hashToUse ] = (
193
155
binarySharedObject .loadFromPath (os .path .join (path , "module.so" ))
194
156
)
195
157
196
158
for n in binarySharedObject .definedSymbols :
197
159
self .nameToModuleHash [n ] = hashToUse
198
160
199
- # link & validate all globals for the new module
200
- self .loadedBinarySharedObjects [hashToUse ].linkGlobalVariables ()
201
- if not self .loadedBinarySharedObjects [hashToUse ].validateGlobalVariables (
202
- self .loadedBinarySharedObjects [hashToUse ].serializedGlobalVariableDefinitions ):
203
- raise RuntimeError ('failed to validate globals in new module:' , hashToUse )
204
-
205
- def loadNameManifestFromStoredModuleByHash (self , moduleHash ) -> None :
206
- if moduleHash in self .moduleManifestsLoaded :
207
- return
161
+ def loadNameManifestFromStoredModuleByHash (self , moduleHash ):
162
+ if moduleHash in self .modulesMarkedValid :
163
+ return True
208
164
209
165
targetDir = os .path .join (self .cacheDir , moduleHash )
210
166
167
+ # ignore 'marked invalid'
168
+ if os .path .exists (os .path .join (targetDir , "marked_invalid" )):
169
+ # just bail - don't try to read it now
170
+
171
+ # for the moment, we don't try to clean up the cache, because
172
+ # we can't be sure that some process is not still reading the
173
+ # old files.
174
+ self .modulesMarkedInvalid .add (moduleHash )
175
+ return False
176
+
211
177
with open (os .path .join (targetDir , "submodules.dat" ), "rb" ) as f :
212
178
submodules = SerializationContext ().deserialize (f .read (), ListOf (str ))
213
179
214
180
for subHash in submodules :
215
- self .loadNameManifestFromStoredModuleByHash (subHash )
181
+ if not self .loadNameManifestFromStoredModuleByHash (subHash ):
182
+ self .markModuleHashInvalid (subHash )
183
+ return False
216
184
217
185
with open (os .path .join (targetDir , "name_manifest.dat" ), "rb" ) as f :
218
186
self .nameToModuleHash .update (
219
187
SerializationContext ().deserialize (f .read (), Dict (str , str ))
220
188
)
221
189
222
- self .moduleManifestsLoaded .add (moduleHash )
190
+ self .modulesMarkedValid .add (moduleHash )
191
+
192
+ return True
223
193
224
- def writeModuleToDisk (self , binarySharedObject , nameToTypedCallTarget , submodules , dependencyEdgelist ):
194
+ def writeModuleToDisk (self , binarySharedObject , nameToTypedCallTarget , submodules ):
225
195
"""Write out a disk representation of this module.
226
196
227
197
This includes writing both the shared object, a manifest of the function names
@@ -274,17 +244,11 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
274
244
275
245
# write the type manifest
276
246
with open (os .path .join (tempTargetDir , "globals_manifest.dat" ), "wb" ) as f :
277
- f .write (SerializationContext ().serialize (binarySharedObject .serializedGlobalVariableDefinitions ))
247
+ f .write (SerializationContext ().serialize (binarySharedObject .globalVariableDefinitions ))
278
248
279
249
with open (os .path .join (tempTargetDir , "submodules.dat" ), "wb" ) as f :
280
250
f .write (SerializationContext ().serialize (ListOf (str )(submodules ), ListOf (str )))
281
251
282
- with open (os .path .join (tempTargetDir , "function_dependencies.dat" ), "wb" ) as f :
283
- f .write (SerializationContext ().serialize (dependencyEdgelist )) # might need a listof
284
-
285
- with open (os .path .join (tempTargetDir , "global_dependencies.dat" ), "wb" ) as f :
286
- f .write (SerializationContext ().serialize (binarySharedObject .globalDependencies ))
287
-
288
252
try :
289
253
os .rename (tempTargetDir , targetDir )
290
254
except IOError :
@@ -300,7 +264,7 @@ def function_pointer_by_name(self, linkName):
300
264
if moduleHash is None :
301
265
raise Exception ("Can't find a module for " + linkName )
302
266
303
- if moduleHash not in self .loadedBinarySharedObjects :
267
+ if moduleHash not in self .loadedModules :
304
268
self .loadForSymbol (linkName )
305
269
306
- return self .loadedBinarySharedObjects [moduleHash ].functionPointers [linkName ]
270
+ return self .loadedModules [moduleHash ].functionPointers [linkName ]
0 commit comments