Skip to content

Commit 2205e48

Browse files
marialyucopybara-github
authored andcommitted
Add int8 and int16 support for SELECT op to AEQ
PiperOrigin-RevId: 795111628
1 parent 89d38f0 commit 2205e48

File tree

10 files changed

+198
-1
lines changed

10 files changed

+198
-1
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ The table below outlines the allowed configurations for available recipes.
139139
|SPLIT | | |<div align="center"> &check; </div>| |<div align="center"> &check; </div>| | | |
140140
|LOGISTIC | | |<div align="center"> &check; </div>| |<div align="center"> &check; </div>| | | |
141141
|SLICE | | |<div align="center"> &check; </div>| |<div align="center"> &check; </div>| | | |
142+
|SELECT | | |<div align="center"> &check; </div>| |<div align="center"> &check; </div>| | | |
142143
|SELECT_V2 | | |<div align="center"> &check; </div>| |<div align="center"> &check; </div>| | | |
143144
|SUM | | |<div align="center"> &check; </div>| |<div align="center"> &check; </div>| | | |
144145
|PAD | | |<div align="center"> &check; </div>| |<div align="center"> &check; </div>| | | |

ai_edge_quantizer/algorithm_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class AlgorithmName(str, enum.Enum):
102102
_TFLOpName.LOGISTIC: common_quantize.materialize_softmax_and_logistic,
103103
_TFLOpName.SLICE: common_quantize.materialize_slice,
104104
_TFLOpName.SUM: common_quantize.materialize_sum,
105+
_TFLOpName.SELECT: common_quantize.materialize_select,
105106
_TFLOpName.SELECT_V2: common_quantize.materialize_select_v2,
106107
_TFLOpName.DYNAMIC_UPDATE_SLICE: (
107108
common_quantize.materialize_dynamic_update_slice
@@ -250,6 +251,7 @@ class AlgorithmName(str, enum.Enum):
250251
_TFLOpName.LOGISTIC: common_quantize.materialize_softmax_and_logistic,
251252
_TFLOpName.SLICE: common_quantize.materialize_slice,
252253
_TFLOpName.SUM: common_quantize.materialize_sum,
254+
_TFLOpName.SELECT: common_quantize.materialize_select,
253255
_TFLOpName.SELECT_V2: common_quantize.materialize_select_v2,
254256
_TFLOpName.DYNAMIC_UPDATE_SLICE: (
255257
common_quantize.materialize_dynamic_update_slice

ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,25 @@ def materialize_slice(
371371
)
372372

373373

374+
def materialize_select(
375+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
376+
op_info: qtyping.OpInfo,
377+
graph_info: qtyping.GraphInfo,
378+
tensor_name_to_qsv: dict[str, Any],
379+
) -> list[qtyping.TensorTransformationParams]:
380+
"""Materialize tensors in tfl.select."""
381+
return common_utils.materialize_standard_op(
382+
op_info,
383+
graph_info,
384+
tensor_name_to_qsv,
385+
get_tensor_quant_params_fn,
386+
constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
387+
inputs_to_ignore=[
388+
0,
389+
], # Condition tensor does not need to be quantized.
390+
)
391+
392+
374393
def materialize_select_v2(
375394
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
376395
op_info: qtyping.OpInfo,
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2024 The AI Edge Quantizer Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
import os
17+
18+
from absl.testing import parameterized
19+
import numpy as np
20+
21+
from tensorflow.python.platform import googletest
22+
from ai_edge_quantizer import qtyping
23+
from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
24+
from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
25+
from ai_edge_quantizer.algorithms.uniform_quantize import octav
26+
from ai_edge_quantizer.algorithms.uniform_quantize.op_architecture_tests import test_utils as op_test_utils
27+
from ai_edge_quantizer.utils import test_utils
28+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
29+
30+
31+
_TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile(
32+
"../../../tests/models"
33+
)
34+
35+
36+
class SelectTest(op_test_utils.BaseQuantizeTest):
37+
38+
def setUp(self):
39+
super().setUp()
40+
np.random.seed(666)
41+
self._test_model_path = os.path.join(
42+
_TEST_DATA_PREFIX_PATH, "single_select.tflite"
43+
)
44+
self._op_test_info = op_test_utils.OpTestInfo(
45+
test_model=tfl_flatbuffer_utils.read_model(self._test_model_path),
46+
op_tensor_names={},
47+
input_range=(np.array([[-10]]), np.array([[10]])),
48+
output_range=(np.array([[-10]]), np.array([[10]])),
49+
)
50+
# The test model has one subgraph for now.
51+
self._graph_info = qtyping.GraphInfo(
52+
subgraph_tensors=self._op_test_info.test_model.subgraphs[0].tensors,
53+
buffers=self._op_test_info.test_model.buffers,
54+
)
55+
56+
@parameterized.parameters(
57+
# get_tensor_quant_params_func, activations_num_bits, symmetric
58+
(naive_min_max_quantize.get_tensor_quant_params, 8, True),
59+
(naive_min_max_quantize.get_tensor_quant_params, 8, False),
60+
(naive_min_max_quantize.get_tensor_quant_params, 16, True),
61+
(octav.get_tensor_quant_params, 8, True),
62+
(octav.get_tensor_quant_params, 16, True),
63+
)
64+
def test_materialize_select_succeeds(
65+
self, get_tensor_quant_params_func, activations_num_bits, symmetric
66+
):
67+
activation_config = test_utils.get_static_activation_quant_setting(
68+
activations_num_bits, symmetric
69+
)
70+
op_quant_config = test_utils.get_static_op_quant_config(activation_config)
71+
72+
# Read from Model Explorer.
73+
subgraph0 = self._op_test_info.test_model.subgraphs[0]
74+
subgraph_op_id = 0
75+
op = subgraph0.operators[subgraph_op_id]
76+
op_info = qtyping.OpInfo(
77+
op=op,
78+
op_name=qtyping.TFLOperationName.SELECT,
79+
subgraph_op_index=subgraph_op_id,
80+
op_quant_config=op_quant_config,
81+
)
82+
83+
# Test settings.
84+
op_tensor_names = {}
85+
op_tensor_names["input"] = "serving_default_condition:0"
86+
op_tensor_names["input2"] = "serving_default_x:0"
87+
op_tensor_names["input3"] = "serving_default_y:0"
88+
op_tensor_names["output"] = "PartitionedCall:0"
89+
self._op_test_info.op_tensor_names = op_tensor_names
90+
self._test_no_weights_op(
91+
op_info,
92+
self._graph_info,
93+
self._op_test_info,
94+
common_quantize.materialize_select,
95+
get_tensor_quant_params_func,
96+
same_input_output_params=True,
97+
inputs_to_ignore=[0], # Condition tensor does not need to be quantized.
98+
)
99+
100+
101+
if __name__ == "__main__":
102+
googletest.main()

ai_edge_quantizer/calibrator_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def test_toy_gemma2_calibration_success(self):
302302
self._toy_gemma2_calibration_dataset,
303303
model_recipe_manager=recipe_mngr,
304304
)
305-
self.assertLen(calib.get_model_qsvs(), 288)
305+
self.assertLen(calib.get_model_qsvs(), 290)
306306

307307

308308
if __name__ == "__main__":

ai_edge_quantizer/default_policy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@
180180
"SLICE",
181181
"EMBEDDING_LOOKUP",
182182
"SUM",
183+
"SELECT",
183184
"SELECT_V2",
184185
"DYNAMIC_UPDATE_SLICE",
185186
"SELECT_V2",
@@ -222,6 +223,7 @@
222223
"SLICE",
223224
"EMBEDDING_LOOKUP",
224225
"SUM",
226+
"SELECT",
225227
"SELECT_V2",
226228
"DYNAMIC_UPDATE_SLICE",
227229
"SELECT_V2",

ai_edge_quantizer/qtyping.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class TFLOperationName(str, enum.Enum):
5959
LOGISTIC = 'LOGISTIC'
6060
SLICE = 'SLICE'
6161
SUM = 'SUM'
62+
SELECT = 'SELECT'
6263
SELECT_V2 = 'SELECT_V2'
6364
DYNAMIC_UPDATE_SLICE = 'DYNAMIC_UPDATE_SLICE'
6465
STABLEHLO_COMPOSITE = 'STABLEHLO_COMPOSITE'
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2024 The AI Edge Quantizer Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""E2E tests for the quantizer for model with select op."""
17+
18+
import os
19+
20+
from absl.testing import parameterized
21+
22+
from tensorflow.python.platform import googletest
23+
from ai_edge_quantizer import qtyping
24+
from ai_edge_quantizer import quantizer
25+
from ai_edge_quantizer.utils import test_utils
26+
27+
28+
_TEST_MODEL_FOLDER = test_utils.get_path_to_datafile('../models/')
29+
_QuantAlgo = quantizer.AlgorithmName
30+
31+
32+
class SelectTest(test_utils.BaseOpTestCase):
33+
34+
def setUp(self):
35+
super().setUp()
36+
self._op_name = qtyping.TFLOperationName.SELECT
37+
38+
@parameterized.parameters(
39+
# algorithm_key, activations_num_bits, symmetric
40+
(_QuantAlgo.MIN_MAX_UNIFORM_QUANT, 8, True),
41+
(_QuantAlgo.MIN_MAX_UNIFORM_QUANT, 8, False),
42+
(_QuantAlgo.MIN_MAX_UNIFORM_QUANT, 16, True),
43+
(_QuantAlgo.OCTAV, 8, True),
44+
(_QuantAlgo.OCTAV, 16, True),
45+
)
46+
def test_select_static_quantization_accuracy_and_size_within_tolerance(
47+
self, algorithm_key, activations_num_bits, symmetric
48+
):
49+
output_tolerance = 5e-4
50+
model_filename = 'single_select.tflite'
51+
model_path = os.path.join(_TEST_MODEL_FOLDER, model_filename)
52+
53+
activation_config = test_utils.get_static_activation_quant_setting(
54+
activations_num_bits, symmetric
55+
)
56+
op_config = test_utils.get_static_op_quant_config(activation_config)
57+
self.assert_quantization_accuracy(
58+
algorithm_key=algorithm_key,
59+
model_path=model_path,
60+
op_name=self._op_name,
61+
op_config=op_config,
62+
output_tolerance=output_tolerance,
63+
num_calibration_samples=1,
64+
num_validation_samples=1,
65+
)
66+
67+
68+
if __name__ == '__main__':
69+
googletest.main()
1.04 KB
Binary file not shown.

ai_edge_quantizer/utils/tfl_flatbuffer_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
_TFLOpName.LOGISTIC: schema.BuiltinOperator.LOGISTIC,
5252
_TFLOpName.SLICE: schema.BuiltinOperator.SLICE,
5353
_TFLOpName.SUM: schema.BuiltinOperator.SUM,
54+
_TFLOpName.SELECT: schema.BuiltinOperator.SELECT,
5455
_TFLOpName.SELECT_V2: schema.BuiltinOperator.SELECT_V2,
5556
_TFLOpName.STABLEHLO_COMPOSITE: schema.BuiltinOperator.STABLEHLO_COMPOSITE,
5657
_TFLOpName.DYNAMIC_UPDATE_SLICE: (

0 commit comments

Comments
 (0)