@@ -376,6 +376,18 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
376376 return values
377377
378378
379+ def _update_graph_or_function_outputs (
380+ graph_or_function : _core .Graph | _core .Function ,
381+ old_values : Sequence [_core .Value ],
382+ new_values : Sequence [_core .Value ],
383+ ):
384+ """Update graph/function outputs."""
385+ replacement_mapping = dict (zip (old_values , new_values ))
386+ for idx , graph_or_function_output in enumerate (graph_or_function .outputs ):
387+ if graph_or_function_output in replacement_mapping :
388+ graph_or_function .outputs [idx ] = replacement_mapping [graph_or_function_output ]
389+
390+
379391def replace_nodes_and_values (
380392 graph_or_function : _core .Graph | _core .Function ,
381393 / ,
@@ -407,10 +419,7 @@ def replace_nodes_and_values(
407419 # Reconnect the users of the deleted values to use the new values
408420 replace_all_uses_with (old_values , new_values )
409421 # Update graph/function outputs if the node generates output
410- replacement_mapping = dict (zip (old_values , new_values ))
411- for idx , graph_or_function_output in enumerate (graph_or_function .outputs ):
412- if graph_or_function_output in replacement_mapping :
413- graph_or_function .outputs [idx ] = replacement_mapping [graph_or_function_output ]
422+ _update_graph_or_function_outputs (graph_or_function , old_values , new_values )
414423
415424 # insert new nodes after the index node
416425 graph_or_function .insert_after (insertion_point , new_nodes )
0 commit comments