diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index b408898f7..876e09033 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1145,22 +1145,24 @@ def __init__( Args: domain: The domain of the operator. For onnx operators, this is an empty string. op_type: The name of the operator. - inputs: The input values. When an input is None, it is an empty input. + inputs: The input values. When an input is ``None``, it is an empty input. attributes: The attributes. RefAttr can be used only when the node is defined in a Function. overload: The overload name when the node is invoking a function. num_outputs: The number of outputs of the node. If not specified, the number is 1. - outputs: The output values. If None, the outputs are created during initialization. - version: The version of the operator. If None, the version is unspecified and will follow that of the graph. - graph: The graph that the node belongs to. If None, the node is not added to any graph. - A `Node` must belong to zero or one graph. - name: The name of the node. If None, the node is anonymous. + outputs: The output values. If ``None``, the outputs are created during initialization. + version: The version of the operator. If ``None``, the version is unspecified and will follow that of the graph. + graph: The graph that the node belongs to. If ``None``, the node is not added to any graph. + A `Node` must belong to zero or one graph. If a :class:`Function`, the underlying graph + of the function is assigned to the node. + name: The name of the node. If ``None``, the node is anonymous. The name may be + set by a :class:`Graph` if ``graph`` is specified. doc_string: The documentation string. metadata_props: The metadata properties. Raises: - TypeError: If the attributes are not Attr or RefAttr. - ValueError: If `num_outputs`, when not None, is not the same as the length of the outputs. - ValueError: If an output value is None, when outputs is specified. + TypeError: If the attributes are not :class:`Attr` or :class:`RefAttr`. + ValueError: If ``num_outputs``, when not ``None``, is not the same as the length of the outputs. + ValueError: If an output value is ``None``, when outputs is specified. ValueError: If an output value has a producer set already, when outputs is specified. """ self._name = name @@ -1187,7 +1189,11 @@ def __init__( self._version: int | None = version self._metadata: _metadata.MetadataStore | None = None self._metadata_props: dict[str, str] | None = metadata_props - self._graph: Graph | Function | None = graph + # _graph is set by graph.append + self._graph: Graph | None = None + # Add the node to the graph if graph is specified + if graph is not None: + graph.append(self) self.doc_string = doc_string # Add the node as a use of the inputs @@ -1195,10 +1201,6 @@ def __init__( if input_value is not None: input_value._add_usage(self, i) # pylint: disable=protected-access - # Add the node to the graph if graph is specified - if self._graph is not None: - self._graph.append(self) - def _create_outputs( self, num_outputs: int | None, outputs: Sequence[Value] | None ) -> tuple[Value, ...]: @@ -1432,11 +1434,11 @@ def metadata_props(self) -> dict[str, str]: return self._metadata_props @property - def graph(self) -> Graph | Function | None: + def graph(self) -> Graph | None: return self._graph @graph.setter - def graph(self, value: Graph | Function | None) -> None: + def graph(self, value: Graph | None) -> None: self._graph = value def op_identifier(self) -> _protocols.OperatorIdentifier: @@ -2178,7 +2180,7 @@ def sort(self) -> None: # Obtain all nodes from the graph and its subgraphs for sorting nodes = list(onnxscript.ir.traversal.RecursiveGraphIterator(self)) # Store the sorted nodes of each subgraph - sorted_nodes_by_graph: dict[Graph | Function, list[Node]] = { + sorted_nodes_by_graph: dict[Graph, list[Node]] = { graph: [] for graph in {node.graph for node in nodes if node.graph is not None} } # TODO: Explain why we need to store direct predecessors and children and why diff --git a/onnxscript/ir/passes/common/constant_manipulation.py b/onnxscript/ir/passes/common/constant_manipulation.py index 226bdfafc..888053a8f 100644 --- a/onnxscript/ir/passes/common/constant_manipulation.py +++ b/onnxscript/ir/passes/common/constant_manipulation.py @@ -67,9 +67,9 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: type=ir.TensorType(tensor.dtype), const_value=tensor, ) - assert isinstance(node.graph, ir.Graph) + assert node.graph is not None node.graph.register_initializer(initializer) - # Replace the constant node with the initilizer + # Replace the constant node with the initializer ir.convenience.replace_all_uses_with(node.outputs[0], initializer) node.graph.remove(node, safe=True) count += 1