Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 51 additions & 10 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import logging
from collections import defaultdict
from collections.abc import Sequence

Expand Down Expand Up @@ -110,8 +110,13 @@
UnsqueezeBeforeRepeatPass,
UnsqueezeScalarPlaceholdersPass,
)

from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
from executorch.backends.arm.common.pipeline_config import (
ArmPassPipelineConfig,
FuseDuplicateUsersConfig,
SoftmaxDecompositionConfig,
)
from executorch.backends.arm.tosa.specification import (
tosa_spec_in_set,
TosaLoweringContext,
Expand All @@ -124,11 +129,45 @@
from torch.fx.passes.infra.pass_base import PassResult
from torch.nn.modules import Module

logger = logging.getLogger(__name__)


class ArmPassManager(PassManager):
def __init__(self, tosa_spec: TosaSpecification) -> None:
self.tosa_spec = tosa_spec
def __init__(self, compile_spec: ArmCompileSpec) -> None:
self.compile_spec = compile_spec
self.tosa_spec = compile_spec.tosa_spec
self._skip_pass_types: tuple[type, ...] = ()
super().__init__()
self.configure_skip_passes()

def configure_skip_passes(
self,
override_config: ArmPassPipelineConfig | None = None,
) -> tuple[type, ...]:
"""
Configures the pass manager to skip certain passes based on the ArmPassPipelineConfig class
found in the compile spec.
"""
skip_set: set[type] = set()

config = override_config or self.compile_spec.get_pass_pipeline_config()
logger.debug(f"Skip Config: {config}")

match config.softmax:
case SoftmaxDecompositionConfig.MASKED:
skip_set.add(DecomposeSoftmaxUnstablePass)
case SoftmaxDecompositionConfig.UNSTABLE:
skip_set.add(DecomposeSoftmaxPass)
skip_set.add(DecomposeMaskedFillPass)

if config.fuse_duplicate_users is FuseDuplicateUsersConfig.DISABLED:
skip_set.add(FuseDuplicateUsersPass)

self._skip_pass_types = tuple(skip_set)
skip_names = [skipped_pass.__name__ for skipped_pass in self._skip_pass_types]
logger.debug(f"Passes in skip list: {skip_names}")

return self._skip_pass_types

def validate_constraints_mandatory(self):
"""
Expand Down Expand Up @@ -165,6 +204,11 @@ def _transform(self, graph_module: GraphModule):
with TosaLoweringContext(self.tosa_spec):
return self(graph_module).graph_module

def add_pass(self, pipeline_pass):
if type(pipeline_pass) in self._skip_pass_types:
return
super().add_pass(pipeline_pass)

def _tosa_pipeline(
self, exported_program: ExportedProgram, graph_module: GraphModule
) -> GraphModule:
Expand Down Expand Up @@ -373,11 +417,8 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
DecomposeSqrtPass(),
DecomposeSiluPass(),
DecomposeAvgPool2dPass(),
(
DecomposeSoftmaxUnstablePass()
if self.tosa_spec.is_U55_subset
else DecomposeSoftmaxPass()
),
DecomposeSoftmaxUnstablePass(),
DecomposeSoftmaxPass(),
ConvertMinMaxPass(),
]
)
Expand All @@ -386,7 +427,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_passes(
[
ReplaceInfAndLimitValuesPass(),
DecomposeMaskedFillPass() if not self.tosa_spec.is_U55_subset else None,
DecomposeMaskedFillPass(),
]
)

Expand Down
38 changes: 38 additions & 0 deletions backends/arm/common/arm_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
# JIT compiler flows.
#

import json
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum

from executorch.backends.arm.common.pipeline_config import ArmPassPipelineConfig
from executorch.backends.arm.tosa import TosaSpecification

from executorch.exir.backend.compile_spec_schema import CompileSpec
Expand All @@ -36,6 +38,7 @@ class DebugMode(Enum):
_DEBUG_ARTIFACT_KEY = "debug_artifact_path"
_DEBUG_MODE_KEY = "dump_debug_info"
_OUTPUT_REORDER_KEY = "ouput_reorder_workaround"
_TRANSFORM_PIPELINE_CONFIG_KEY = "transform_pipeline_config"

def _set_compile_specs(
self,
Expand All @@ -44,13 +47,15 @@ def _set_compile_specs(
path_for_intermediates: str | None = None,
tosa_debug_mode: DebugMode | None = None,
output_order_workaround: bool = True,
pipeline_config: ArmPassPipelineConfig | None = None,
):
"""Set all values of dataclass directly."""
self.tosa_spec = tosa_spec
self.compiler_flags = compiler_flags
self.path_for_intermediates = path_for_intermediates
self.tosa_debug_mode = tosa_debug_mode
self.output_order_workaround = output_order_workaround
self._pipeline_config = pipeline_config

@classmethod
def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
Expand All @@ -60,6 +65,7 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
path_for_intermediates: str | None = None
tosa_debug_mode: ArmCompileSpec.DebugMode | None = None
output_order_workaround: bool = True
pipeline_config: ArmPassPipelineConfig | None = None
unknown_specs: dict[str, str] = {}
for spec in compile_specs:
key = spec.key
Expand Down Expand Up @@ -98,6 +104,12 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
tosa_debug_mode = ArmCompileSpec.DebugMode[val]
elif key == ArmCompileSpec._OUTPUT_REORDER_KEY:
output_order_workaround = val # type: ignore[assignment]
elif key == ArmCompileSpec._TRANSFORM_PIPELINE_CONFIG_KEY:
if pipeline_config is not None:
raise ValueError(
"More than one transform pipeline entry in compile spec."
)
pipeline_config = ArmPassPipelineConfig.from_dict(json.loads(val))
else:
unknown_specs[key] = val

Expand All @@ -120,6 +132,7 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
path_for_intermediates=path_for_intermediates,
tosa_debug_mode=tosa_debug_mode,
output_order_workaround=output_order_workaround,
pipeline_config=pipeline_config,
)
cls.from_list_hook(compile_spec, unknown_specs)
compile_spec.validate()
Expand Down Expand Up @@ -189,8 +202,33 @@ def to_list(self):
)
)

if self._pipeline_config is not None and not self._pipeline_config.is_default():
compile_spec.append(
CompileSpec(
ArmCompileSpec._TRANSFORM_PIPELINE_CONFIG_KEY,
self._pipeline_config.serialize(),
)
)
return compile_spec

def get_pass_pipeline_config(self) -> ArmPassPipelineConfig:
"""
Returns configuration that controls how the Arm pass pipeline should behave.
Subclasses may override to tweak defaults for specific targets.
"""
if self._pipeline_config is None:
self._pipeline_config = self._create_default_pipeline_config()
return self._pipeline_config

def set_pass_pipeline_config(self, config: ArmPassPipelineConfig) -> None:
self._pipeline_config = config

def _create_default_pipeline_config(self) -> ArmPassPipelineConfig:
config = ArmPassPipelineConfig()
if self.tosa_spec.is_U55_subset:
config.disable_masked_softmax()
return config

def get_intermediate_path(self) -> str | None:
"""
Gets the path used for dumping intermediate results such as tosa and pte.
Expand Down
59 changes: 59 additions & 0 deletions backends/arm/common/pipeline_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import json
from dataclasses import dataclass, fields
from enum import auto, Enum
from typing import Any


class SoftmaxDecompositionConfig(Enum):
MASKED = auto()
UNSTABLE = auto()


class FuseDuplicateUsersConfig(Enum):
ENABLED = auto()
DISABLED = auto()


@dataclass
class ArmPassPipelineConfig:
softmax: SoftmaxDecompositionConfig = SoftmaxDecompositionConfig.MASKED
fuse_duplicate_users: FuseDuplicateUsersConfig = FuseDuplicateUsersConfig.ENABLED

def disable_masked_softmax(self) -> None:
self.softmax = SoftmaxDecompositionConfig.UNSTABLE

def disable_fuse_duplicate_users(self) -> None:
self.fuse_duplicate_users = FuseDuplicateUsersConfig.DISABLED

def is_default(self) -> bool:
return (
self.softmax is SoftmaxDecompositionConfig.MASKED
and self.fuse_duplicate_users is FuseDuplicateUsersConfig.ENABLED
)

def to_dict(self) -> dict[str, str]:
return {f.name: getattr(self, f.name).name for f in fields(self)}

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ArmPassPipelineConfig":
config = cls()
for f in fields(cls):
raw_value = data.get(f.name)
if raw_value is None:
continue
enum_type = f.type
setattr(config, f.name, enum_type[raw_value])
return config

def serialize(self) -> bytes:
"""Return a serialized representation of this config."""
return json.dumps(self.to_dict()).encode()

def __repr__(self):
fields = ", ".join(f"{name}={value!r}" for name, value in self.__dict__.items())
return f"({fields})"
24 changes: 14 additions & 10 deletions backends/arm/ethosu/compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
# LICENSE file in the root directory of this source tree.

from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec

from executorch.backends.arm.common.pipeline_config import ( # noqa: unused
ArmPassPipelineConfig,
)
from executorch.backends.arm.tosa import ( # type: ignore[import-not-found]
TosaSpecification,
)

from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found]
CompileSpec,
)
from executorch.exir.backend.compile_spec_schema import CompileSpec


class EthosUCompileSpec(ArmCompileSpec):
Expand Down Expand Up @@ -43,7 +42,6 @@ def __init__(

"""
self.target = target

# Set vela compiler flags
if config_ini is None:
config_ini = "Arm/vela.ini"
Expand All @@ -57,25 +55,26 @@ def __init__(
]
)
# default system config and memory mode
if "ethos-u55" in self.target:
target_lower = self.target.lower()
if "ethos-u55" in target_lower:
if system_config is None:
system_config = "Ethos_U55_High_End_Embedded"
if memory_mode is None:
memory_mode = "Shared_Sram"
elif "ethos-u85" in self.target:
elif "ethos-u85" in target_lower:
if system_config is None:
system_config = "Ethos_U85_SYS_DRAM_Mid"
if memory_mode is None:
memory_mode = "Sram_Only"
else:
raise RuntimeError(f"Unknown ethos target: {self.target}")
raise RuntimeError(f"Unknown ethos target: {target}")

compiler_flags.append(f"--system-config={system_config}")
compiler_flags.append(f"--memory-mode={memory_mode}")

# Set TOSA version.
base_tosa_version = "TOSA-1.0+INT+int16"
if "u55" in self.target:
if "u55" in target_lower:
# Add the Ethos-U55 extension marker
base_tosa_version += "+u55"
tosa_spec = TosaSpecification.create_from_string(base_tosa_version)
Expand Down Expand Up @@ -109,3 +108,8 @@ def validate(self):
def get_output_format(cls) -> str:
"""Return the artifact format emitted by this compile spec."""
return "vela"

def _create_default_pipeline_config(self) -> ArmPassPipelineConfig:
# Any u55 subset passes are treated as tosa specification configs
# As such, they should be added to the base class default.
return super()._create_default_pipeline_config()
13 changes: 7 additions & 6 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,11 +339,13 @@ class TOSAQuantizer(Quantizer):
def __init__(
self, compile_spec_or_tosa_spec: TosaSpecification | ArmCompileSpec
) -> None:

super().__init__()
self.compile_spec: ArmCompileSpec
if isinstance(compile_spec_or_tosa_spec, TosaSpecification):
self.tosa_spec = compile_spec_or_tosa_spec
self.compile_spec = None
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec

self.compile_spec = TosaCompileSpec(compile_spec_or_tosa_spec)
self.tosa_spec = self.compile_spec.tosa_spec
elif isinstance(compile_spec_or_tosa_spec, ArmCompileSpec):
self.compile_spec = compile_spec_or_tosa_spec
self.tosa_spec = self.compile_spec.tosa_spec
Expand Down Expand Up @@ -432,9 +434,8 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
# TODO: Fix the need to lazily import this.
from executorch.backends.arm._passes import ArmPassManager

return ArmPassManager(self.tosa_spec).transform_for_annotation_pipeline(
graph_module=model
)
pass_manager = ArmPassManager(self.compile_spec)
return pass_manager.transform_for_annotation_pipeline(graph_module=model)

def annotate(self, model: GraphModule) -> GraphModule:
"""Annotate the graph with the configured quantization settings.
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/test/misc/test_call_operator_submodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager
from executorch.backends.arm.tosa.specification import TosaSpecification
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
from torch.fx import GraphModule
from torch.fx.passes.infra.pass_base import PassResult

Expand Down Expand Up @@ -58,7 +58,7 @@ def test_call_operator_runs_once_for_cond_submodules() -> None:
graph_module = exported.graph_module

recording_pass = _DepthRecordingPass(graph_module)
pass_manager = ArmPassManager(TosaSpecification.create_from_string("TOSA-1.00+FP"))
pass_manager = ArmPassManager(TosaCompileSpec("TOSA-1.00+FP"))
pass_manager.add_pass(recording_pass)
pass_manager._transform(graph_module)

Expand Down
Loading
Loading