@@ -238,7 +238,6 @@ def generate_debug_handle(ep: ExportedProgram) -> int:
238
238
call_delegate_node .meta ["val" ] = submodule_output_node .meta ["val" ]
239
239
call_submodule_node .replace_all_uses_with (call_delegate_node )
240
240
owning_graph_module .graph .erase_node (call_submodule_node )
241
-
242
241
if is_submodule :
243
242
assert len (toplevel_input_specs_to_delete ) == 0
244
243
assert len (toplevel_output_specs_to_delete ) == 0
@@ -574,26 +573,29 @@ def lower_all_submodules_to_backend(
574
573
# The created exported program for the submodules are in the call_module node's meta data
575
574
# We just map the method_to_submodule_nodes directly to the method_to_partitioned_exported_programs
576
575
method_to_partitioned_program = {
577
- method_name : [node .meta ["submodule_program" ] for node in call_submodule_nodes ]
576
+ method_name : [
577
+ copy .deepcopy (node .meta ["submodule_program" ])
578
+ for node in call_submodule_nodes
579
+ ]
578
580
for method_name , call_submodule_nodes in method_to_submodules_nodes .items ()
579
581
}
580
582
method_to_compile_specs = {
581
583
method_name : [node .meta ["compile_spec" ] for node in call_submodule_nodes ]
582
584
for method_name , call_submodule_nodes in method_to_submodules_nodes .items ()
583
585
}
584
- backend_found = False
585
- for cls in BackendDetails .__subclasses__ ():
586
- if backend_id == cls .__name__ :
587
- method_to_preprocess_result : dict [str , List [PreprocessResult ]] = (
588
- cls .preprocess_multimethod (
589
- method_to_partitioned_program , method_to_compile_specs
590
- )
591
- )
592
- backend_found = True
593
586
594
- if not backend_found :
587
+ backend_name_to_subclass = {
588
+ subclass .__name__ : subclass for subclass in BackendDetails .__subclasses__ ()
589
+ }
590
+ if backend_id not in backend_name_to_subclass :
595
591
raise NotImplementedError (f"Backend { backend_id } was not found." )
596
592
593
+ method_to_preprocess_result : dict [str , List [PreprocessResult ]] = (
594
+ backend_name_to_subclass [backend_id ].preprocess_multimethod (
595
+ method_to_partitioned_program , method_to_compile_specs
596
+ )
597
+ )
598
+
597
599
for method_name in method_to_preprocess_result .keys ():
598
600
owning_program = method_to_tagged_edge_program [method_name ]
599
601
list_of_preprocess_results = method_to_preprocess_result [method_name ]
@@ -612,6 +614,9 @@ def lower_all_submodules_to_backend(
612
614
compile_specs = compile_spec ,
613
615
named_data_store_output = preprocess_result .data_store_output ,
614
616
)
617
+ lowered_module .meta = {
618
+ "debug_handle_map" : preprocess_result .debug_handle_map ,
619
+ }
615
620
is_submodule = call_submodule_node .meta ["is_submodule" ]
616
621
toplevel_input_specs_to_delete = call_submodule_node .meta [
617
622
"toplevel_input_specs_to_delete"
@@ -633,6 +638,20 @@ def lower_all_submodules_to_backend(
633
638
)
634
639
635
640
641
+ def remove_used_metadata (graph : torch .fx .Graph ) -> None :
642
+ """
643
+ Remove the used metadata from the graph.
644
+ """
645
+ for node in graph .nodes :
646
+ node .meta .pop ("delegation_tag" , None )
647
+ node .meta .pop ("backend_id" , None )
648
+ node .meta .pop ("submodule_program" , None )
649
+ node .meta .pop ("toplevel_input_specs_to_delete" , None )
650
+ node .meta .pop ("toplevel_output_specs_to_delete" , None )
651
+ node .meta .pop ("is_submodule" , None )
652
+ node .meta .pop ("submodule_output_node" , None )
653
+
654
+
636
655
@dataclass
637
656
class MethodProgramsPartitionerSpec :
638
657
"""
@@ -748,6 +767,7 @@ def to_backend(
748
767
if method_name in method_to_tagged_exported_program :
749
768
tagged_exported_program = method_to_tagged_exported_program [method_name ]
750
769
tagged_exported_program ._validate ()
770
+ remove_used_metadata (tagged_exported_program .graph_module .graph )
751
771
partitioned_and_lowered_exported_programs [method_name ] = ExportedProgram (
752
772
root = tagged_exported_program .graph_module ,
753
773
graph = tagged_exported_program .graph_module .graph ,
0 commit comments