Skip to content

Commit 764fbd6

Browse files
authored
Merge branch 'main' into android-config-api-2
2 parents 5004fdd + a68cdae commit 764fbd6

File tree

68 files changed

+1561
-527
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+1561
-527
lines changed

.github/workflows/doc-build.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ jobs:
2121
- name: Check URLs
2222
run: bash ./scripts/check_urls.sh
2323

24-
check-links:
24+
check-xrefs:
2525
runs-on: ubuntu-latest
2626
steps:
2727
- uses: actions/checkout@v3
2828
- name: Check Links
29-
run: bash ./scripts/check_links.sh
29+
run: bash ./scripts/check_xrefs.sh
3030

3131
build:
3232
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main

backends/arm/__init__.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from .arm_backend import ArmCompileSpecBuilder # noqa # usort: skip
7+
from .tosa_backend import TOSABackend # noqa # usort: skip
8+
from .tosa_partitioner import TOSAPartitioner # noqa # usort: skip
9+
from .ethosu_backend import EthosUBackend # noqa # usort: skip
10+
from .ethosu_partitioner import EthosUPartitioner # noqa # usort: skip

backends/arm/_passes/convert_expand_copy_to_repeat.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
2-
# All rights reserved.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
65

76
# pyre-unsafe
87

8+
import logging
99
from typing import cast
1010

1111
from executorch.exir.dialects._ops import ops as exir_ops
1212
from executorch.exir.pass_base import ExportPass
1313

14+
logger = logging.getLogger(__name__)
15+
1416

1517
class ConvertExpandCopyToRepeatPass(ExportPass):
1618
"""
@@ -41,6 +43,14 @@ def call_operator(self, op, args, kwargs, meta):
4143
multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1
4244
for i in range(expanded_rank)
4345
]
46+
47+
if all((x == 1 for x in multiples)):
48+
# All dimensions/repetitions occur only once. Remove node
49+
# altogether since it's in practice just a copy.
50+
logger.warning("Found redundant expand node (no-op). Removing it.")
51+
52+
return args[0]
53+
4454
return super().call_operator(
4555
op=self.repeat, args=(args[0], multiples), kwargs=kwargs, meta=meta
4656
)

backends/arm/ethosu_backend.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
import logging
1515
from typing import final, List
1616

17-
from executorch.backends.arm.arm_vela import vela_compile
17+
from executorch.backends.arm import TOSABackend
1818

19-
from executorch.backends.arm.tosa_backend import TOSABackend
19+
from executorch.backends.arm.arm_vela import vela_compile
2020
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
2121
from executorch.exir.backend.compile_spec_schema import CompileSpec
2222
from torch.export.exported_program import ExportedProgram

backends/arm/ethosu_partitioner.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
from executorch.backends.arm.arm_backend import (
1111
is_ethosu,
1212
) # usort: skip
13-
from executorch.backends.arm.ethosu_backend import EthosUBackend
14-
from executorch.backends.arm.tosa_partitioner import TOSAPartitioner
13+
from executorch.backends.arm import EthosUBackend, TOSAPartitioner
1514
from executorch.exir.backend.compile_spec_schema import CompileSpec
1615
from executorch.exir.backend.partitioner import DelegationSpec
1716
from torch.fx.passes.operator_support import OperatorSupportBase

backends/arm/operator_support/tosa_supported_operators.py

+2
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def is_node_supported(
207207
exir_ops.edge.aten._log_softmax.default,
208208
exir_ops.edge.aten.sub.Tensor,
209209
exir_ops.edge.aten.tanh.default,
210+
exir_ops.edge.aten.upsample_bilinear2d.vec,
210211
exir_ops.edge.aten.upsample_nearest2d.vec,
211212
exir_ops.edge.aten.var.correction,
212213
exir_ops.edge.aten.var.dim,
@@ -365,6 +366,7 @@ def is_node_supported(
365366
exir_ops.edge.aten.sigmoid.default,
366367
exir_ops.edge.aten.sub.Tensor,
367368
exir_ops.edge.aten.tanh.default,
369+
exir_ops.edge.aten.upsample_bilinear2d.vec,
368370
exir_ops.edge.aten.upsample_nearest2d.vec,
369371
exir_ops.edge.aten.gelu.default,
370372
):

backends/arm/operators/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
op_to_copy,
4747
op_to_dim_order_copy,
4848
op_transpose,
49+
op_upsample_bilinear2d,
4950
op_upsample_nearest2d,
5051
op_view,
5152
op_where,

backends/arm/operators/op_mul.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,15 @@ def define_node(
4141
inputs: List[TosaArg],
4242
output: TosaArg,
4343
) -> None:
44-
assert inputs[0].dtype == inputs[1].dtype == output.dtype == ts.DType.INT8
44+
if (
45+
inputs[0].dtype != ts.DType.INT8
46+
or inputs[1].dtype != ts.DType.INT8
47+
or output.dtype != ts.DType.INT8
48+
):
49+
raise ValueError(
50+
f"Inputs and output for {self.target} need to be INT8, got "
51+
f"{inputs[0].dtype=}, {inputs[1].dtype=} and {output.dtype=}"
52+
)
4553

4654
dim_order = (
4755
inputs[0].dim_order
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import List
8+
9+
import torch
10+
11+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
12+
13+
from executorch.backends.arm.operators.node_visitor import (
14+
NodeVisitor,
15+
register_node_visitor,
16+
)
17+
from executorch.backends.arm.tosa_mapping import TosaArg
18+
from executorch.backends.arm.tosa_quant_utils import build_rescale
19+
from executorch.backends.arm.tosa_utils import get_resize_parameters, tosa_shape
20+
from tosa_tools.v0_80.tosa.ResizeMode import ResizeMode # type: ignore
21+
22+
23+
@register_node_visitor
24+
class UpsampleBilinear2dVisitor_0_80(NodeVisitor):
25+
target = "aten.upsample_bilinear2d.vec"
26+
27+
def __init__(self, *args):
28+
super().__init__(*args)
29+
30+
def define_node(
31+
self,
32+
node: torch.fx.Node,
33+
tosa_graph: ts.TosaSerializer,
34+
inputs: List[TosaArg],
35+
output: TosaArg,
36+
) -> None:
37+
assert (
38+
inputs[0].shape is not None and output.shape is not None
39+
), "Only static shapes are supported"
40+
41+
input_dtype = inputs[0].dtype
42+
43+
# tosa_shape output is NHWC, take HW
44+
input_size_yx = torch.tensor(
45+
tosa_shape(inputs[0].shape, inputs[0].dim_order)[1:3]
46+
)
47+
# Ignore scale and size parameters, directly use the output size as
48+
# we only support static shapes currently
49+
output_size_yx = torch.tensor(tosa_shape(output.shape, output.dim_order)[1:3])
50+
51+
scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters(
52+
input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=True
53+
)
54+
55+
def in_int16_range(x):
56+
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)
57+
58+
assert in_int16_range(scale_n_yx)
59+
assert in_int16_range(scale_d_yx)
60+
assert in_int16_range(border_yx)
61+
62+
attr = ts.TosaSerializerAttribute()
63+
attr.ResizeAttribute(
64+
scale=[scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]],
65+
offset=offset_yx.tolist(),
66+
border=border_yx.tolist(),
67+
mode=ResizeMode.BILINEAR,
68+
)
69+
70+
if input_dtype == output.dtype == ts.DType.FP32:
71+
tosa_graph.addOperator(
72+
ts.TosaOp.Op().RESIZE, [inputs[0].name], [output.name], attr
73+
)
74+
return
75+
elif input_dtype == output.dtype == ts.DType.INT8:
76+
intermediate = tosa_graph.addIntermediate(
77+
tosa_shape(output.shape, output.dim_order), ts.DType.INT32
78+
)
79+
80+
tosa_graph.addOperator(
81+
ts.TosaOp.Op().RESIZE, [inputs[0].name], [intermediate.name], attr
82+
)
83+
84+
final_output_scale = float(1 / (scale_n_yx[0] * scale_n_yx[1]))
85+
86+
build_rescale(
87+
tosa_fb=tosa_graph,
88+
scale=[final_output_scale],
89+
input_node=intermediate,
90+
output_name=output.name,
91+
output_type=ts.DType.INT8,
92+
output_shape=output.shape,
93+
input_zp=0,
94+
output_zp=0,
95+
is_double_round=False,
96+
)
97+
else:
98+
raise ValueError(
99+
"Input/output dtype not in {float32, int8}: {input_dtype=} {output.dtype=}"
100+
)

backends/arm/quantizer/__init__.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from .quantization_config import QuantizationConfig # noqa # usort: skip
8+
from .arm_quantizer import ( # noqa
9+
EthosUQuantizer,
10+
get_symmetric_quantization_config,
11+
TOSAQuantizer,
12+
)

backends/arm/quantizer/arm_quantizer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,14 @@
1919
import torch
2020
from executorch.backends.arm._passes import ArmPassManager
2121

22-
from executorch.backends.arm.quantizer import arm_quantizer_utils
22+
from executorch.backends.arm.quantizer import arm_quantizer_utils, QuantizationConfig
2323
from executorch.backends.arm.quantizer.arm_quantizer_utils import ( # type: ignore[attr-defined]
2424
mark_node_as_annotated,
2525
)
2626
from executorch.backends.arm.quantizer.quantization_annotator import ( # type: ignore[import-not-found]
2727
annotate_graph,
2828
)
2929

30-
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
3130
from executorch.backends.arm.tosa_specification import TosaSpecification
3231
from executorch.backends.arm.arm_backend import (
3332
get_tosa_spec,

backends/arm/quantizer/quantization_annotator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010

1111
import torch
1212
import torch.fx
13-
from executorch.backends.arm.quantizer import arm_quantizer_utils
14-
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
13+
from executorch.backends.arm.quantizer import arm_quantizer_utils, QuantizationConfig
1514
from executorch.backends.arm.tosa_utils import get_node_debug_info
1615
from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec
1716
from torch.ao.quantization.quantizer.utils import (
@@ -215,6 +214,7 @@ def _match_pattern(
215214
torch.ops.aten.flip.default,
216215
torch.ops.aten.chunk.default,
217216
torch.ops.aten.contiguous.default,
217+
torch.ops.aten.upsample_bilinear2d.vec,
218218
torch.ops.aten.upsample_nearest2d.vec,
219219
torch.ops.aten.pad.default,
220220
torch.ops.aten.amax.default,

backends/arm/test/ops/test_expand.py

+26-9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -16,7 +15,7 @@
1615

1716
import torch
1817

19-
from executorch.backends.arm.quantizer.arm_quantizer import (
18+
from executorch.backends.arm.quantizer import (
2019
EthosUQuantizer,
2120
get_symmetric_quantization_config,
2221
TOSAQuantizer,
@@ -37,14 +36,14 @@ class Expand(torch.nn.Module):
3736
# (input tensor, multiples)
3837
test_parameters = [
3938
(torch.rand(1), (2,)),
40-
(torch.randn(1, 4), (1, -1)),
4139
(torch.randn(1), (2, 2, 4)),
4240
(torch.randn(1, 1, 1, 5), (1, 4, -1, -1)),
43-
(torch.randn(1, 1, 192), (1, -1, -1)),
4441
(torch.randn(1, 1), (1, 2, 2, 4)),
4542
(torch.randn(1, 1), (2, 2, 2, 4)),
4643
(torch.randn(10, 1, 1, 97), (-1, 4, -1, -1)),
4744
(torch.rand(1, 1, 2, 2), (4, 3, -1, 2)),
45+
(torch.randn(1, 4), (1, -1)),
46+
(torch.randn(1, 1, 192), (1, -1, -1)),
4847
]
4948

5049
def forward(self, x: torch.Tensor, m: Sequence):
@@ -117,34 +116,52 @@ def test_expand_tosa_MI(self, test_input, multiples):
117116
def test_expand_tosa_BI(self, test_input, multiples):
118117
self._test_expand_tosa_BI_pipeline(self.Expand(), (test_input, multiples))
119118

120-
@parameterized.expand(Expand.test_parameters[:-3])
119+
@parameterized.expand(Expand.test_parameters[:-5])
121120
@pytest.mark.corstone_fvp
122121
def test_expand_u55_BI(self, test_input, multiples):
123122
self._test_expand_ethosu_BI_pipeline(
124123
common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
125124
)
126125

127126
# MLETORCH-629: Expand does not work on FVP with batch>1
128-
@parameterized.expand(Expand.test_parameters[-3:])
127+
@parameterized.expand(Expand.test_parameters[-5:-2])
129128
@pytest.mark.corstone_fvp
130129
@conftest.expectedFailureOnFVP
130+
def test_expand_u55_BI_xfails_on_fvp(self, test_input, multiples):
131+
self._test_expand_ethosu_BI_pipeline(
132+
common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
133+
)
134+
135+
@parameterized.expand(Expand.test_parameters[-2:])
136+
@pytest.mark.xfail(
137+
reason="MLETORCH-716: Node will be optimized away and Vela can't handle empty graphs"
138+
)
131139
def test_expand_u55_BI_xfails(self, test_input, multiples):
132140
self._test_expand_ethosu_BI_pipeline(
133141
common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
134142
)
135143

136-
@parameterized.expand(Expand.test_parameters[:-3])
144+
@parameterized.expand(Expand.test_parameters[:-5])
137145
@pytest.mark.corstone_fvp
138146
def test_expand_u85_BI(self, test_input, multiples):
139147
self._test_expand_ethosu_BI_pipeline(
140148
common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
141149
)
142150

143151
# MLETORCH-629: Expand does not work on FVP with batch>1
144-
@parameterized.expand(Expand.test_parameters[-3:])
152+
@parameterized.expand(Expand.test_parameters[-5:-2])
145153
@pytest.mark.corstone_fvp
146154
@conftest.expectedFailureOnFVP
147-
def test_expand_u85_BI_xfails(self, test_input, multiples):
155+
def test_expand_u85_BI_xfails_on_fvp(self, test_input, multiples):
156+
self._test_expand_ethosu_BI_pipeline(
157+
common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
158+
)
159+
160+
@parameterized.expand(Expand.test_parameters[-2:])
161+
@pytest.mark.xfail(
162+
reason="MLETORCH-716: Node will be optimized away and Vela can't handle empty graphs"
163+
)
164+
def test_expand_u85_xfails(self, test_input, multiples):
148165
self._test_expand_ethosu_BI_pipeline(
149166
common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
150167
)

backends/arm/test/ops/test_hardtanh.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import torch
1515

16-
from executorch.backends.arm.quantizer.arm_quantizer import (
16+
from executorch.backends.arm.quantizer import (
1717
EthosUQuantizer,
1818
get_symmetric_quantization_config,
1919
TOSAQuantizer,

backends/arm/test/ops/test_max_pool.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pytest
1313

1414
import torch
15-
from executorch.backends.arm.quantizer.arm_quantizer import (
15+
from executorch.backends.arm.quantizer import (
1616
EthosUQuantizer,
1717
get_symmetric_quantization_config,
1818
TOSAQuantizer,

backends/arm/test/ops/test_permute.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import torch
1515

16-
from executorch.backends.arm.quantizer.arm_quantizer import (
16+
from executorch.backends.arm.quantizer import (
1717
EthosUQuantizer,
1818
get_symmetric_quantization_config,
1919
TOSAQuantizer,

0 commit comments

Comments
 (0)