Skip to content

Commit 078fb23

Browse files
authored
Qualcomm AI Engine Direct - GA Albert, Bert, Distilbert, Eurobert (#11546)
### Summary This PR consists of 4 Encoder-Only models. Following stats are based on SM8750. 1. Albert (16a16w) - Accuracy: ~22% (NOTE: nn.Module accuracy is around 24%, so the similarity between QNN and nn.Module is around 92%) - Speed: 11ms/inf - Script: `python examples/qualcomm/oss_scripts/albert.py -b build-android -s $DEVICE -m SM8750 --dataset ../wikipedia-sentences/wikisent2.txt` 2. Bert (16a8w) - Accuracy: ~60% - Speed: 9ms/inf - Script: `python examples/qualcomm/oss_scripts/bert.py -b build-android -s $DEVICE -m SM8750 --dataset ../wikipedia-sentences/wikisent2.txt` 3. Distilbert (16a8w) - Accuracy: ~59% - Speed: 8ms/inf - Script: `python examples/qualcomm/oss_scripts/distilbert.py -b build-android -s $DEVICE -m SM8750 --dataset ../wikipedia-sentences/wikisent2.txt` 4. Eurobert (16a16w) - Accuracy: ~54% - Speed: 40ms/inf - Script: `python examples/qualcomm/oss_scripts/eurobert.py -b build-android -s $DEVICE -m SM8750 --dataset ../wikipedia-sentences/wikisent2.txt` ### Test plan - E2E Scripts under `test_qnn_delegate.py` - Example script: `python backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleOssScript.test_{BERT_MODEL} --model SM8750 -s $DEVICE --build_folder build-android/ -r ./ -a ./test --sentence_dataset ../wikipedia-sentences/wikisent2.txt` - Mainline CI Author: @haowhsu-quic, @chunit-quic, @winskuo-quic
1 parent 0286927 commit 078fb23

File tree

14 files changed

+1007
-39
lines changed

14 files changed

+1007
-39
lines changed

.ci/scripts/test_model.sh

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,14 @@ test_model_with_qnn() {
188188
EXPORT_SCRIPT=edsr
189189
# Additional deps for edsr
190190
pip install piq
191+
elif [[ "${MODEL_NAME}" == "albert" ]]; then
192+
EXPORT_SCRIPT=albert
193+
elif [[ "${MODEL_NAME}" == "bert" ]]; then
194+
EXPORT_SCRIPT=bert
195+
elif [[ "${MODEL_NAME}" == "distilbert" ]]; then
196+
EXPORT_SCRIPT=distilbert
197+
elif [[ "${MODEL_NAME}" == "eurobert" ]]; then
198+
EXPORT_SCRIPT=eurobert
191199
else
192200
echo "Unsupported model $MODEL_NAME"
193201
exit 1
@@ -197,7 +205,25 @@ test_model_with_qnn() {
197205
# TODO(guangyang): Make QNN chipset matches the target device
198206
QNN_CHIPSET=SM8450
199207

200-
"${PYTHON_EXECUTABLE}" -m examples.qualcomm.scripts.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m ${QNN_CHIPSET} --ci --compile_only $EXTRA_FLAGS
208+
SCRIPT_FOLDER=""
209+
case "${MODEL_NAME}" in
210+
"dl3"|"mv3"|"mv2"|"ic4"|"ic3"|"vit"|"mb"|"w2l")
211+
SCRIPT_FOLDER=scripts
212+
;;
213+
"albert"|"bert"|"distilbert")
214+
pip install evaluate
215+
SCRIPT_FOLDER=oss_scripts
216+
# Bert models running in 16bit will encounter op validation fail on some operations,
217+
# which requires CHIPSET >= SM8550.
218+
QNN_CHIPSET=SM8550
219+
;;
220+
*)
221+
echo "Unsupported model $MODEL_NAME"
222+
exit 1
223+
;;
224+
esac
225+
226+
"${PYTHON_EXECUTABLE}" -m examples.qualcomm.${SCRIPT_FOLDER}.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m ${QNN_CHIPSET} --ci --compile_only $EXTRA_FLAGS
201227
EXPORTED_MODEL=$(find "./${EXPORT_SCRIPT}" -type f -name "${MODEL_NAME}*.pte" -print -quit)
202228
}
203229

.github/workflows/trunk.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,32 @@ jobs:
480480
PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh
481481
PYTHON_EXECUTABLE=python bash .ci/scripts/test_model.sh ${{ matrix.model }} "cmake" "qnn"
482482
483+
test-qnn-optimum-model:
484+
name: test-qnn-optimum-model
485+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
486+
permissions:
487+
id-token: write
488+
contents: read
489+
strategy:
490+
matrix:
491+
dtype: [fp32]
492+
model: [albert, bert, distilbert] # eurobert requires transfomer >= 4.48.0, skip for now
493+
fail-fast: false
494+
with:
495+
runner: linux.2xlarge
496+
docker-image: executorch-ubuntu-22.04-qnn-sdk
497+
submodules: 'recursive'
498+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
499+
timeout: 900
500+
script: |
501+
# The generic Linux job chooses to use base env, not the one setup by the image
502+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
503+
conda activate "${CONDA_ENV}"
504+
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool cmake
505+
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-qnn-deps.sh
506+
PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh
507+
PYTHON_EXECUTABLE=python bash .ci/scripts/test_model.sh ${{ matrix.model }} "cmake" "qnn"
508+
483509
test-apple-model:
484510
name: test-apple-model
485511
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
2020
from .decompose_roll import DecomposeRoll
2121
from .decompose_silu import DecomposeSilu
22+
from .decompose_wrap_with_autocast import DecomposeWrapWithAutocast
2223
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
2324
from .fixed_linear_keep_dim import FixedLinearKeepDim
2425
from .fold_qdq import FoldQDQ
@@ -56,6 +57,7 @@
5657
DecomposeLinalgVectorNorm,
5758
DecomposeRoll,
5859
DecomposeSilu,
60+
DecomposeWrapWithAutocast,
5961
ExpandBroadcastTensorShape,
6062
FixedLinearKeepDim,
6163
FoldQDQ,
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import _operator
8+
from typing import Dict, Tuple
9+
10+
import torch
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
13+
from .utils import copy_nn_module_stack
14+
15+
16+
class DecomposeWrapWithAutocast(ExportPass):
17+
"""
18+
Decompose the _higher_order_ops WrapWithAutocast
19+
"""
20+
21+
def __init__(self) -> None:
22+
super().__init__()
23+
24+
def _get_submod(
25+
self, gm: torch.fx.GraphModule, node: torch.fx.Node
26+
) -> Tuple[torch.fx.GraphModule, str]:
27+
for a in node.args:
28+
if isinstance(a, torch.fx.Node) and "submod" in a.target:
29+
return getattr(gm, a.target), a.target
30+
31+
def _replace_output(
32+
self, wwac_node: torch.fx.Node, output_node: torch.fx.Node, remap: Dict
33+
):
34+
for user in wwac_node.users.copy():
35+
arg_idx = 0
36+
is_user_getitem = False
37+
38+
if user.target == _operator.getitem:
39+
arg_idx = user.args[1]
40+
is_user_getitem = True
41+
42+
user.replace_input_with(
43+
wwac_node,
44+
remap[output_node.args[0][arg_idx]],
45+
)
46+
47+
if is_user_getitem:
48+
for user_user in user.users.copy():
49+
user_user.replace_input_with(user, user.args[0])
50+
51+
def _replace(self, gm: torch.fx.GraphModule) -> None:
52+
graph = gm.graph
53+
for node in graph.nodes:
54+
if isinstance(node.target, torch._higher_order_ops.wrap.WrapWithAutocast):
55+
submod, submod_name = self._get_submod(gm, node)
56+
n_args = node.args
57+
input_submod = n_args[4]
58+
decomposed_module = submod
59+
with graph.inserting_before(node):
60+
# remap is used to map original node values to new node values,
61+
# which ensures that reference to nodes are correctly updated in the new graph
62+
# remap = {"expand_1": node.args[5], "to_4": node.args[6]}
63+
remap = {n_args[i].name: n_args[i] for i in range(5, len(n_args))}
64+
65+
for decomposed_node in decomposed_module.graph.nodes:
66+
copy_nn_module_stack(node, decomposed_node)
67+
# no need to copy existent 'output'
68+
if decomposed_node.op == "output":
69+
self._replace_output(node, decomposed_node, remap)
70+
# no need to copy existent placeholders
71+
elif decomposed_node.op == "placeholder":
72+
# replace node map from string to graph node
73+
remap[decomposed_node] = remap.pop(decomposed_node.name)
74+
else:
75+
remap[decomposed_node] = graph.node_copy(
76+
decomposed_node,
77+
arg_transform=lambda x, remap=remap: remap[x],
78+
)
79+
80+
graph.erase_node(node)
81+
82+
graph.erase_node(input_submod)
83+
84+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
85+
self._replace(graph_module)
86+
graph_module.graph.eliminate_dead_code()
87+
graph_module.recompile()
88+
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
DecomposeLinalgVectorNorm,
2525
DecomposeRoll,
2626
DecomposeSilu,
27+
DecomposeWrapWithAutocast,
2728
ExpandBroadcastTensorShape,
2829
FixedLinearKeepDim,
2930
FoldQDQ,
@@ -194,6 +195,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
194195
self.add_pass(DecomposeScaledDotProductAttention())
195196
self.add_pass(DecomposeRoll())
196197
self.add_pass(DecomposeSilu())
198+
self.add_pass(DecomposeWrapWithAutocast())
197199
self.add_pass(DecomposeEinsum())
198200
self.add_pass(DecomposeExpM1())
199201
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
@@ -207,6 +209,7 @@ def transform_for_export_pipeline(self, exported_program: ExportedProgram):
207209
self.add_pass(DecomposeRoll())
208210
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
209211
self.add_pass(DecomposeExpM1())
212+
self.add_pass(DecomposeWrapWithAutocast())
210213
# this pass will rewrite state_dict, it needs to be accomplished before
211214
# to_edge_transform_and_lower
212215
self.add_pass(ConvertConv1dToConv2d(exported_program))

backends/qualcomm/_passes/remove_redundancy.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def _dim_order_op_condition(self, node):
4343
dim_order = node.kwargs.get("dim_order")
4444
# skip if there contains layout hint
4545
# e.g. (0, 2, 3, 1) != (0, 1, 2, 3)
46+
if node.meta["val"].dtype != node.args[0].meta["val"].dtype:
47+
return False
4648
return dim_order != list(range(len(dim_order)))
4749

4850
def _to_copy_op_condition(self, node):
@@ -53,19 +55,15 @@ def _default_condition(self, ndoe):
5355

5456
def _remove(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
5557
for n in graph_module.graph.nodes:
56-
if n.target not in self.redundant_ops or not self.redundant_ops[n.target](
57-
n
58-
):
59-
continue
60-
61-
to_be_remove = n
62-
# assert_tensor_metadata op has no user
63-
if len(n.users.keys()) == 0:
64-
n.args = ()
65-
# normal case
66-
for user_n in list(n.users.keys()):
67-
user_n.replace_input_with(n, n.args[0])
68-
graph_module.graph.erase_node(to_be_remove)
58+
if n.target in self.redundant_ops and self.redundant_ops[n.target](n):
59+
to_be_remove = n
60+
# assert_tensor_metadata op has no user
61+
if len(n.users.keys()) == 0:
62+
n.args = ()
63+
# normal case
64+
for user_n in list(n.users.keys()):
65+
user_n.replace_input_with(n, n.args[0])
66+
graph_module.graph.erase_node(to_be_remove)
6967

7068
def call(self, graph_module: torch.fx.GraphModule):
7169
self._remove(graph_module)

backends/qualcomm/_passes/replace_inf_values.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99

1010
class ReplaceInfValues(ExportPass):
1111
"""
12-
Due to limitation in Qnn, we need to change inf or -inf to arbitrary value in quantization.
12+
Due to limitation in QNN, change inf or -inf to arbitrary value in quantization.
1313
"""
1414

1515
def __init__(self):
1616
super(ReplaceInfValues, self).__init__()
1717

18-
def call(self, graph_module: torch.fx.GraphModule):
18+
def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
1919
for buf_name, tensor in graph_module.named_buffers():
2020
if tensor.is_floating_point():
2121
# 255 here is mainly for attention_mask in Llama for reasonable quant scale
@@ -38,5 +38,23 @@ def call(self, graph_module: torch.fx.GraphModule):
3838
arg_list[2] = -255
3939
node.args = tuple(arg_list)
4040

41+
if node.target in [
42+
torch.ops.aten.masked_fill.Tensor,
43+
torch.ops.aten.masked_fill.Scalar,
44+
]:
45+
assert (
46+
len(node.args) == 3
47+
), f"Expecting {node.name} to have 3 arguments."
48+
val = node.args[2]
49+
if node.args[2] > torch.finfo(torch.float16).max:
50+
val = 255
51+
elif node.args[2] < torch.finfo(torch.float16).min:
52+
val = -255
53+
node.args = (
54+
node.args[0],
55+
node.args[1],
56+
val,
57+
)
58+
4159
graph_module.recompile()
4260
return PassResult(graph_module, True)

backends/qualcomm/quantizer/annotators.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def annotate_hardtanh(node: Node, quantization_config: QuantizationConfig) -> No
462462
annotate_single_in_single_out(node, quantization_config)
463463

464464

465-
@register_annotator([torch.ops.aten.mean.default])
465+
@register_annotator([torch.ops.aten.mean.default, torch.ops.aten.mean.dim])
466466
def annotate_mean(node: Node, quantization_config: QuantizationConfig) -> None:
467467
annotate_single_in_single_out(node, quantization_config)
468468

@@ -604,11 +604,6 @@ def annotate_select(node: Node, quantization_config: QuantizationConfig) -> None
604604
annotate_single_in_single_out(node, quantization_config)
605605

606606

607-
@register_annotator([torch.ops.aten.mean.dim])
608-
def annotate_mean_dim(node: Node, quantization_config: QuantizationConfig) -> None:
609-
annotate_single_in_single_out(node, quantization_config)
610-
611-
612607
@register_annotator([torch.ops.aten.slice.Tensor])
613608
def annotate_slice(node: Node, quantization_config: QuantizationConfig) -> None:
614609
annotate_single_in_single_out(node, quantization_config)

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,35 @@
2626
)
2727

2828

29+
def annotate_eurobert(gm: torch.fx.GraphModule):
30+
"""
31+
QNN does not support int32 -> signed 16bit quant
32+
We need to first annotate this to_fp node as 8bit quant, so it will perform requantize
33+
Final graph should look like: int32 -> convert -> cast -> matmul.args[1]
34+
35+
"""
36+
quantization_config_8a8w = get_8a8w_qnn_ptq_config()
37+
for node in gm.graph.nodes:
38+
# A little tricky here. This matmul node is wrapped inside a submodule after 1st torch.export.
39+
# There are actually 2 'to' op that is redundant.
40+
# It will look like: int64 -> to_fp -> to_fp -> matmul.args[1]
41+
# Draw out the graph after the 1st export will help visualize the submodule.
42+
43+
if node.target == torch.ops.aten.matmul.default and node.args[1].args[0].args[
44+
0
45+
].meta["val"].dtype in [torch.int64, torch.int32]:
46+
to_node = node.args[1]
47+
input_qspec_map = {}
48+
assert isinstance(to_node, Node)
49+
input_spec = quantization_config_8a8w.input_activation
50+
input_qspec_map[to_node] = input_spec
51+
to_node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
52+
input_qspec_map=input_qspec_map,
53+
output_qspec=quantization_config_8a8w.output_activation,
54+
_annotated=True,
55+
)
56+
57+
2958
def annotate_mimi_decoder(gm: torch.fx.GraphModule):
3059
"""
3160
The 1st transpose conv in mimi decoder is really sensitive to scale/offset in 16a8w, which causes execution failure.

0 commit comments

Comments
 (0)