@@ -238,7 +238,7 @@ class VariableBuffer():
238238 allocTemplate : NodeTemplate #: NodeTemplate: Holds the buffer's allocation code
239239 deallocTemplate : NodeTemplate #: NodeTemplate: Holds the buffer's deallocation code
240240
241- def __init__ (self , name : str = '' , shape = [1 ], alias_of : Optional [List [str ]] = [] ):
241+ def __init__ (self , name : str = '' , shape = [1 ], aliases : Optional [List [str ]] = None ):
242242 self .name : str = name #: str: Canonical name that this buffer is registered as in the NetworkContext
243243 self .shape : Sequence [
244244 int ] = shape #: Sequence[int]: Represents the dimensions of the underlying tensor as a sequence of dimension sizes
@@ -257,7 +257,7 @@ def __init__(self, name: str = '', shape = [1], alias_of: Optional[List[str]] =
257257 self .is_input : bool = False
258258 self .is_output : bool = False
259259
260- self .alias_of : List [str ] = alias_of if alias_of is not None else []
260+ self .aliases : Set [str ] = set ( aliases ) if aliases is not None else set ()
261261
262262 def _bufferRepresentation (self ) -> Dict :
263263 return {"type" : self ._instance , "name" : self .name , "size" : int (np .prod (self .shape ))}
@@ -324,42 +324,7 @@ def __getstate__(self):
324324 def fromNode (cls , node : gs .Node ):
325325 return (cls (name = node .name , shape = node .shape if not isinstance (node , gs .Constant ) else node .values .shape ))
326326
327- def add_aliases (self , aliases_to_add : List [str ]):
328- """
329- Adds list of aliases to the alias_of attribute.
330- Parameters
331- ----------
332- alias_to_add : List[str]
333- List of names of aliases to add to the alias_of attribute.
334- Returns
335- -------
336- None
337- """
338-
339- if not hasattr (self , "alias_of" ):
340- return None
341-
342- for alias in aliases_to_add :
343- if alias not in self .alias_of :
344- self .alias_of .append (alias )
345-
346- return None
347-
348- def get_aliases_of (self ):
349- """
350- Getter function for the alias_of attribute.
351- Returns
352- -------
353- List[str]
354- List of names o all aliases of this VariableBuffer.
355- """
356-
357- if hasattr (self , "alias_of" ):
358- return self .alias_of
359- else :
360- return list ()
361-
362- def has_live_ancestors (self , ctxt : NetworkContext ) -> bool :
327+ def has_live_aliases (self , ctxt : NetworkContext ) -> bool :
363328 """Checks whether this VariableBuffer has any live ancestors, i.e. buffers that are still live and are aliased by this buffer.
364329 Parameters
365330 ----------
@@ -370,14 +335,18 @@ def has_live_ancestors(self, ctxt: NetworkContext) -> bool:
370335 bool
371336 True if this VariableBuffer has any live ancestors, False otherwise
372337 """
373- if not hasattr (self , "alias_of" ):
374- return False
375-
376- for alias in self .alias_of :
377- if ctxt .lookup (alias )._live :
378- return True
379-
380- return False
338+ # Do a breadth-first search across the aliasing double-linked list
339+ live = self ._live
340+ queue = set (self .aliases )
341+ visited = set (self .name )
342+ while len (queue ) > 0 :
343+ next = queue .pop ()
344+ buffNext = ctxt .lookup (next )
345+ assert isinstance (buffNext , VariableBuffer )
346+ live |= buffNext ._live
347+ visited .add (next )
348+ queue |= buffNext .aliases - visited
349+ return live
381350
382351 def sizeInBytes (self ) -> int :
383352 """Returns the size of this VariableBuffer in bytes
@@ -398,28 +367,13 @@ class TransientBuffer(VariableBuffer):
398367 """
399368
400369 def __init__ (self , name : str = '' , size = 0 ):
401- self .name = name
402- self .size = size #: int: Total BYTE size of this TransientBuffer
403-
404- # Do not override - Should be written in the parsing passes
405- self ._users = []
370+ super ().__init__ (name , shape = (size ,))
406371
407372 # Do not override - Should be written in the parsing passes
408373 self ._type : Type [Pointer ] = PointerClass (VoidType )
409-
410- # Do not override - Should be written in the deployment passes
411- self ._live = False
412-
413- # Do not override - Set in Templates depending on platform
414- self ._deploy = True
415-
416- self .is_input : bool = False
417- self .is_output : bool = False
418-
419- self .alias_of : List [str ] = []
374+ self .size = size
420375
421376 def __eq__ (self , other ):
422-
423377 ret = all ([self .name == other .name , self .size == other .size ])
424378 return ret
425379
@@ -432,10 +386,6 @@ def __str__(self) -> str:
432386 def __repr__ (self ) -> str :
433387 return f'TransientBuffer: name: { self .name } , size: { self .size } '
434388
435- @classmethod
436- def fromVariableBuffer (cls , buffer : VariableBuffer ):
437- ret = cls (name = buffer .name , size = np .prod (buffer .shape ) * buffer ._type .typeWidth // 8 )
438-
439389 def sizeInBytes (self ) -> int :
440390 return int (self .size )
441391
@@ -479,12 +429,6 @@ def __repr__(self) -> str:
479429 def _bufferRepresentation (self ) -> Dict :
480430 return {"type" : self ._type , "name" : self .name , "size" : int (np .prod (self .shape )), "values" : self ._valueString ()}
481431
482- @classmethod
483- def fromVariableBuffer (cls , buffer : VariableBuffer , values ):
484- ret = cls (name = buffer .name , shape = buffer .shape , values = values )
485-
486- return ret
487-
488432
489433class StructBuffer (VariableBuffer ):
490434 """Class to represent Struct object needed by the generated C Code
@@ -999,12 +943,15 @@ def hoistReference(self,
999943 ref ._instance = ref ._type (name , ctxt = self )
1000944 return ref
1001945
1002- def hoistConstant (self , node : gs .Node , name : str = '' , _type : Optional [Type [Pointer ]] = None ) -> str :
1003- """Register a ConstantBuffer extracted directly from a graphsurgeon Node
946+ def hoistConstant (self ,
947+ constant : gs .Constant ,
948+ name : Optional [str ] = None ,
949+ _type : Optional [Type [Pointer ]] = None ) -> str :
950+ """Register a ConstantBuffer extracted directly from a graphsurgeon Constant
1004951
1005952 Parameters
1006953 ----------
1007- node : gs.Node
954+ constant : gs.Constant
1008955 graphsurgeon.Node containing a single constant output
1009956 name : str
1010957 Name of the ConstantBuffer to be registered
@@ -1017,21 +964,18 @@ def hoistConstant(self, node: gs.Node, name: str = '', _type: Optional[Type[Poin
1017964 Returns the name of the newly registed ConstantBuffer
1018965
1019966 """
967+ assert len (constant .outputs ) <= 1 , f"Constant { constant .name } has more than one output"
1020968
1021- assert len ( node . outputs ) <= 1 , f"Constant { node . name } has more than one output"
969+ name = name if name is not None else constant . name
1022970
1023- if name == "" :
1024- name = node .name
971+ # LMACAN: The shape needs to be copied into a tuple for pickling to work. Don't ask me why..
972+ buffer = self .ConstantBuffer (name , tuple (constant .shape ), constant .values )
973+ self .add (buffer , 'global' )
1025974
1026- # SCHEREMO: This is currently heuristic, but should be annotated in ONNX
1027- localBuffer = self .VariableBuffer .fromNode (node = node )
1028- globalBuffer = self .ConstantBuffer .fromVariableBuffer (localBuffer , values = node .values )
1029- globalBuffer .name = name
1030- globalBuffer ._type = _type
975+ if _type is not None :
976+ self .annotateType (name , _type )
1031977
1032- self .add (globalBuffer , 'global' )
1033-
1034- return globalBuffer .name
978+ return name
1035979
1036980 def addUser (self , name : str , node : gs .Node ):
1037981 """Adds an operator's name to the _user list of a VariableBuffer in the context
0 commit comments