Skip to content

Commit 86266bd

Browse files
committed
Take a graph
1 parent 224c86f commit 86266bd

2 files changed

Lines changed: 15 additions & 7 deletions

File tree

onnxscript/ir/_core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,7 +1135,7 @@ def __init__(
11351135
num_outputs: int | None = None,
11361136
outputs: Sequence[Value] | None = None,
11371137
version: int | None = None,
1138-
graph: Graph | None = None,
1138+
graph: Graph | Function | None = None,
11391139
name: str | None = None,
11401140
doc_string: str | None = None,
11411141
metadata_props: dict[str, str] | None = None,
@@ -1187,7 +1187,7 @@ def __init__(
11871187
self._version: int | None = version
11881188
self._metadata: _metadata.MetadataStore | None = None
11891189
self._metadata_props: dict[str, str] | None = metadata_props
1190-
self._graph: Graph | None = graph
1190+
self._graph: Graph | Function | None = graph
11911191
self.doc_string = doc_string
11921192

11931193
# Add the node as a use of the inputs
@@ -1432,11 +1432,11 @@ def metadata_props(self) -> dict[str, str]:
14321432
return self._metadata_props
14331433

14341434
@property
1435-
def graph(self) -> Graph | None:
1435+
def graph(self) -> Graph | Function | None:
14361436
return self._graph
14371437

14381438
@graph.setter
1439-
def graph(self, value: Graph | None) -> None:
1439+
def graph(self, value: Graph | Function | None) -> None:
14401440
self._graph = value
14411441

14421442
def op_identifier(self) -> _protocols.OperatorIdentifier:

onnxscript/ir/_tape.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,18 @@ class Tape:
4444
),
4545
ir_version=10,
4646
)
47+
48+
Attributes:
49+
graph_like: The graph to append the new nodes and initializers to. When
50+
it is None, the nodes and initializers are creating without owned by a graph.
51+
Initializers will not be added to functions because it is not supported by ONNX.
4752
"""
4853

49-
def __init__(self) -> None:
54+
def __init__(self, graph_like: ir.Graph | ir.Function | None = None) -> None:
5055
self._nodes: list[ir.Node] = []
5156
self._initializers: list[ir.Value] = []
5257
self._used_opsets: UsedOpsets = set()
58+
self.graph_like = graph_like
5359

5460
def __repr__(self) -> str:
5561
return f"Tape(nodes={self._nodes}, initializers={self._initializers})"
@@ -92,7 +98,7 @@ def op(
9298
num_outputs=1,
9399
overload=overload,
94100
version=version,
95-
graph=graph,
101+
graph=graph or self.graph_like,
96102
name=name,
97103
doc_string=doc_string,
98104
metadata_props=metadata_props,
@@ -129,7 +135,7 @@ def op_multi_output(
129135
num_outputs=num_outputs,
130136
overload=overload,
131137
version=version,
132-
graph=graph,
138+
graph=graph or self.graph_like,
133139
name=name,
134140
doc_string=doc_string,
135141
metadata_props=metadata_props,
@@ -148,6 +154,8 @@ def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.
148154
name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor
149155
)
150156
self._initializers.append(value)
157+
if isinstance(self.graph_like, ir.Graph):
158+
self.graph_like.register_initializer(value)
151159
return value
152160

153161

0 commit comments

Comments
 (0)