Skip to content

Commit 9663bfb

Browse files
authored
Hook up PreprocessAll flow to EdgeManager
Differential Revision: D74629455 Pull Request resolved: #10842
1 parent 9aaea31 commit 9663bfb

File tree

2 files changed

+57
-26
lines changed

2 files changed

+57
-26
lines changed

exir/backend/backend_api.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,6 @@ def generate_debug_handle(ep: ExportedProgram) -> int:
238238
call_delegate_node.meta["val"] = submodule_output_node.meta["val"]
239239
call_submodule_node.replace_all_uses_with(call_delegate_node)
240240
owning_graph_module.graph.erase_node(call_submodule_node)
241-
242241
if is_submodule:
243242
assert len(toplevel_input_specs_to_delete) == 0
244243
assert len(toplevel_output_specs_to_delete) == 0
@@ -574,26 +573,29 @@ def lower_all_submodules_to_backend(
574573
# The created exported program for the submodules are in the call_module node's meta data
575574
# We just map the method_to_submodule_nodes directly to the method_to_partitioned_exported_programs
576575
method_to_partitioned_program = {
577-
method_name: [node.meta["submodule_program"] for node in call_submodule_nodes]
576+
method_name: [
577+
copy.deepcopy(node.meta["submodule_program"])
578+
for node in call_submodule_nodes
579+
]
578580
for method_name, call_submodule_nodes in method_to_submodules_nodes.items()
579581
}
580582
method_to_compile_specs = {
581583
method_name: [node.meta["compile_spec"] for node in call_submodule_nodes]
582584
for method_name, call_submodule_nodes in method_to_submodules_nodes.items()
583585
}
584-
backend_found = False
585-
for cls in BackendDetails.__subclasses__():
586-
if backend_id == cls.__name__:
587-
method_to_preprocess_result: dict[str, List[PreprocessResult]] = (
588-
cls.preprocess_multimethod(
589-
method_to_partitioned_program, method_to_compile_specs
590-
)
591-
)
592-
backend_found = True
593586

594-
if not backend_found:
587+
backend_name_to_subclass = {
588+
subclass.__name__: subclass for subclass in BackendDetails.__subclasses__()
589+
}
590+
if backend_id not in backend_name_to_subclass:
595591
raise NotImplementedError(f"Backend {backend_id} was not found.")
596592

593+
method_to_preprocess_result: dict[str, List[PreprocessResult]] = (
594+
backend_name_to_subclass[backend_id].preprocess_multimethod(
595+
method_to_partitioned_program, method_to_compile_specs
596+
)
597+
)
598+
597599
for method_name in method_to_preprocess_result.keys():
598600
owning_program = method_to_tagged_edge_program[method_name]
599601
list_of_preprocess_results = method_to_preprocess_result[method_name]
@@ -612,6 +614,9 @@ def lower_all_submodules_to_backend(
612614
compile_specs=compile_spec,
613615
named_data_store_output=preprocess_result.data_store_output,
614616
)
617+
lowered_module.meta = {
618+
"debug_handle_map": preprocess_result.debug_handle_map,
619+
}
615620
is_submodule = call_submodule_node.meta["is_submodule"]
616621
toplevel_input_specs_to_delete = call_submodule_node.meta[
617622
"toplevel_input_specs_to_delete"
@@ -633,6 +638,20 @@ def lower_all_submodules_to_backend(
633638
)
634639

635640

641+
def remove_used_metadata(graph: torch.fx.Graph) -> None:
642+
"""
643+
Remove the used metadata from the graph.
644+
"""
645+
for node in graph.nodes:
646+
node.meta.pop("delegation_tag", None)
647+
node.meta.pop("backend_id", None)
648+
node.meta.pop("submodule_program", None)
649+
node.meta.pop("toplevel_input_specs_to_delete", None)
650+
node.meta.pop("toplevel_output_specs_to_delete", None)
651+
node.meta.pop("is_submodule", None)
652+
node.meta.pop("submodule_output_node", None)
653+
654+
636655
@dataclass
637656
class MethodProgramsPartitionerSpec:
638657
"""
@@ -748,6 +767,7 @@ def to_backend(
748767
if method_name in method_to_tagged_exported_program:
749768
tagged_exported_program = method_to_tagged_exported_program[method_name]
750769
tagged_exported_program._validate()
770+
remove_used_metadata(tagged_exported_program.graph_module.graph)
751771
partitioned_and_lowered_exported_programs[method_name] = ExportedProgram(
752772
root=tagged_exported_program.graph_module,
753773
graph=tagged_exported_program.graph_module.graph,

exir/program/_program.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
from executorch.exir._serialize._serialize import serialize_for_executorch
2424
from executorch.exir._serialize.data_serializer import DataSerializer
2525
from executorch.exir._warnings import experimental
26-
from executorch.exir.backend.backend_api import to_backend
26+
from executorch.exir.backend.backend_api import (
27+
MethodProgramsPartitionerSpec,
28+
to_backend,
29+
)
2730
from executorch.exir.backend.partitioner import Partitioner
2831
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
2932
from executorch.exir.delegate import executorch_call_delegate, is_lowered_module
@@ -1239,10 +1242,16 @@ def to_edge_transform_and_lower(
12391242
if transform_passes is not None:
12401243
edge_manager = edge_manager.transform(transform_passes)
12411244

1242-
if partitioner is not None:
1245+
max_num_partitioners = 0
1246+
for partitioner_list in partitioner.values():
1247+
max_num_partitioners = max(max_num_partitioners, len(partitioner_list))
1248+
1249+
for i in range(max_num_partitioners):
1250+
method_to_partitioner = {}
12431251
for name, partitioner_list in partitioner.items():
1244-
for curr_partitioner in partitioner_list:
1245-
edge_manager = edge_manager.to_backend({name: curr_partitioner})
1252+
if i < len(partitioner_list):
1253+
method_to_partitioner[name] = partitioner_list[i]
1254+
edge_manager = edge_manager.to_backend(method_to_partitioner)
12461255

12471256
for name, program in edge_manager._edge_programs.items():
12481257
ops_set_to_not_decompose: Set[torch._ops.OpOverload] = set()
@@ -1475,7 +1484,8 @@ def transform(
14751484

14761485
@et_logger("to_backend")
14771486
def to_backend(
1478-
self, partitioner: Union[Partitioner, Dict[str, Partitioner]]
1487+
self,
1488+
partitioner: Union[Partitioner, Dict[str, Partitioner]],
14791489
) -> "EdgeProgramManager":
14801490
"""
14811491
Returns a semantically-equivalent program to the one given as input,
@@ -1501,17 +1511,18 @@ def to_backend(
15011511
specified subgraphs lowered.
15021512
"""
15031513
new_edge_programs: Dict[str, ExportedProgram] = {}
1504-
if isinstance(partitioner, dict):
1505-
for name, program in self._edge_programs.items():
1506-
if name in partitioner.keys():
1507-
new_edge_programs[name] = to_backend(program, partitioner[name])
1508-
else:
1509-
new_edge_programs[name] = program
1514+
method_to_partitioner: Dict[str, Partitioner] = {}
1515+
if not isinstance(partitioner, dict):
1516+
method_to_partitioner = {name: partitioner for name in self._edge_programs}
1517+
else:
1518+
method_to_partitioner = partitioner
15101519

1511-
else: # apply partitioner to every method
1512-
for name, program in self._edge_programs.items():
1513-
new_edge_programs[name] = to_backend(program, partitioner)
1520+
method_to_programs_and_partitioners = MethodProgramsPartitionerSpec(
1521+
self._edge_programs,
1522+
method_to_partitioner,
1523+
)
15141524

1525+
new_edge_programs = to_backend(method_to_programs_and_partitioners)
15151526
config = EdgeCompileConfig(_check_ir_validity=False)
15161527
return EdgeProgramManager(
15171528
new_edge_programs,

0 commit comments

Comments
 (0)