Skip to content

Commit 875c18d

Browse files
committed
Refactor: update_graph_outputs in a helper (#62)
Signed-off-by: Johansmm <[email protected]>
1 parent d794fca commit 875c18d

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

src/onnx_ir/_convenience/__init__.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,18 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
376376
return values
377377

378378

379+
def _update_graph_or_function_outputs(
380+
graph_or_function: _core.Graph | _core.Function,
381+
old_values: Sequence[_core.Value],
382+
new_values: Sequence[_core.Value],
383+
):
384+
"""Update graph/function outputs."""
385+
replacement_mapping = dict(zip(old_values, new_values))
386+
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
387+
if graph_or_function_output in replacement_mapping:
388+
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]
389+
390+
379391
def replace_nodes_and_values(
380392
graph_or_function: _core.Graph | _core.Function,
381393
/,
@@ -407,10 +419,7 @@ def replace_nodes_and_values(
407419
# Reconnect the users of the deleted values to use the new values
408420
replace_all_uses_with(old_values, new_values)
409421
# Update graph/function outputs if the node generates output
410-
replacement_mapping = dict(zip(old_values, new_values))
411-
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
412-
if graph_or_function_output in replacement_mapping:
413-
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]
422+
_update_graph_or_function_outputs(graph_or_function, old_values, new_values)
414423

415424
# insert new nodes after the index node
416425
graph_or_function.insert_after(insertion_point, new_nodes)

0 commit comments

Comments
 (0)