Skip to content

Commit 451ac5c

Browse files
rewu93copybara-github
authored andcommitted
Add recipe building utils to AEQ.
PiperOrigin-RevId: 795129649
1 parent 2205e48 commit 451ac5c

File tree

2 files changed

+240
-58
lines changed

2 files changed

+240
-58
lines changed

ai_edge_quantizer/recipe.py

Lines changed: 154 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,51 +15,163 @@
1515

1616
"""Quantization recipe module."""
1717

18+
from ai_edge_quantizer import algorithm_manager
19+
from ai_edge_quantizer import qtyping
20+
from ai_edge_quantizer import recipe_manager
1821

19-
def dynamic_wi8_afp32():
20-
"""Returns a dynamic quantization recipe with int8 weights and float32 activation."""
21-
return [
22-
dict({
23-
'regex': '.*',
24-
'operation': '*',
25-
'algorithm_key': 'min_max_uniform_quantize',
26-
'op_config': {
27-
'weight_tensor_config': {
28-
'num_bits': 8,
29-
'symmetric': True,
30-
'granularity': 'CHANNELWISE',
31-
'dtype': 'INT',
32-
'block_size': 0,
33-
},
34-
'compute_precision': 'INTEGER',
35-
'explicit_dequantize': False,
36-
'skip_checks': False,
37-
},
38-
})
39-
]
22+
AlgorithmName = algorithm_manager.AlgorithmName
4023

4124

42-
def dynamic_wi4_afp32():
43-
"""Returns a dynamic quantization recipe with int4 weights and float32 activation."""
44-
return [
45-
dict({
46-
'regex': '.*',
47-
'operation': '*',
48-
'algorithm_key': 'min_max_uniform_quantize',
49-
'op_config': {
50-
'weight_tensor_config': {
51-
'num_bits': 4,
52-
'symmetric': True,
53-
'granularity': 'CHANNELWISE',
54-
'dtype': 'INT',
55-
'block_size': 0,
56-
},
57-
'compute_precision': 'INTEGER',
58-
'explicit_dequantize': False,
59-
'skip_checks': False,
60-
},
61-
})
62-
]
25+
def dynamic_wi8_afp32(
26+
algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
27+
):
28+
"""Returns a dynamic quantization recipe with int8 weights and float32 activation.
29+
30+
All supported ops will be quantized with int8 weights and float32 activations,
31+
which will be dynamically quantized to int8 during inference to enable int8
32+
compute. The model quality may suffer due to the on-the-fly quantization. If
33+
quality is a concern, consider using weight-only quantization.
34+
35+
Args:
36+
algorithm_key: The algorithm to use for quantization.
37+
38+
Returns:
39+
A dynamic quantization recipe.
40+
"""
41+
rp_manager = recipe_manager.RecipeManager()
42+
rp_manager.add_dynamic_config(
43+
regex='.*',
44+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
45+
num_bits=8,
46+
algorithm_key=algorithm_key,
47+
)
48+
return rp_manager.get_quantization_recipe()
49+
50+
51+
def dynamic_wi4_afp32(
52+
algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
53+
):
54+
"""Returns a dynamic quantization recipe with int4 weights and float32 activation.
55+
56+
All supported ops will be quantized with int4 weights and float32 activations,
57+
which will be dynamically quantized to int4 during inference to enable int4
58+
compute.
59+
60+
Args:
61+
algorithm_key: The algorithm to use for quantization.
62+
63+
Returns:
64+
A dynamic quantization recipe.
65+
"""
66+
rp_manager = recipe_manager.RecipeManager()
67+
rp_manager.add_dynamic_config(
68+
regex='.*',
69+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
70+
num_bits=4,
71+
algorithm_key=algorithm_key,
72+
)
73+
return rp_manager.get_quantization_recipe()
74+
75+
76+
def weight_only_wi8_afp32(
77+
algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
78+
):
79+
"""Returns a weight-only quantization recipe with int8 weights and float32 activation.
80+
81+
All supported ops will be quantized with int8 weights and float32 activations.
82+
The weights will be explicitly dequantized before being fed into the op to
83+
enable float compute thus retain model quality. If latency is a concern,
84+
consider using dynamic range quantization.
85+
86+
Args:
87+
algorithm_key: The algorithm to use for quantization.
88+
89+
Returns:
90+
A weight-only quantization recipe.
91+
"""
92+
rp_manager = recipe_manager.RecipeManager()
93+
rp_manager.add_weight_only_config(
94+
regex='.*',
95+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
96+
num_bits=8,
97+
algorithm_key=algorithm_key,
98+
)
99+
return rp_manager.get_quantization_recipe()
100+
101+
102+
def weight_only_wi4_afp32(
103+
algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
104+
):
105+
"""Returns a weight-only quantization recipe with int4 weights and float32 activation.
106+
107+
All supported ops will be quantized with int4 weights and float32 activations.
108+
The weights will be explicitly dequantized before being fed into the op to
109+
enable float compute thus retain model quality.
110+
111+
Args:
112+
algorithm_key: The algorithm to use for quantization.
113+
114+
Returns:
115+
A weight-only quantization recipe.
116+
"""
117+
rp_manager = recipe_manager.RecipeManager()
118+
rp_manager.add_weight_only_config(
119+
regex='.*',
120+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
121+
num_bits=4,
122+
algorithm_key=algorithm_key,
123+
)
124+
return rp_manager.get_quantization_recipe()
125+
126+
127+
def static_wi8_ai8(
128+
algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
129+
):
130+
"""Returns a static quantization recipe with int8 weights and int8 activations.
131+
132+
All supported ops will be quantized with int8 weights and int8 activations.
133+
Calibration is needed to use this recipe.
134+
135+
Args:
136+
algorithm_key: The algorithm to use for quantization.
137+
138+
Returns:
139+
A static quantization recipe.
140+
"""
141+
rp_manager = recipe_manager.RecipeManager()
142+
rp_manager.add_static_config(
143+
regex='.*',
144+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
145+
activation_num_bits=8,
146+
weight_num_bits=8,
147+
algorithm_key=algorithm_key,
148+
)
149+
return rp_manager.get_quantization_recipe()
150+
151+
152+
def static_wi8_ai16(
153+
algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
154+
):
155+
"""Returns a static quantization recipe with int8 weights and int16 activations.
156+
157+
All supported ops will be quantized with int8 weights and int16 activations.
158+
Calibration is needed to use this recipe.
159+
160+
Args:
161+
algorithm_key: The algorithm to use for quantization.
162+
163+
Returns:
164+
A static quantization recipe.
165+
"""
166+
rp_manager = recipe_manager.RecipeManager()
167+
rp_manager.add_static_config(
168+
regex='.*',
169+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
170+
activation_num_bits=16,
171+
weight_num_bits=8,
172+
algorithm_key=algorithm_key,
173+
)
174+
return rp_manager.get_quantization_recipe()
63175

64176

65177
def dynamic_legacy_wi8_afp32():

ai_edge_quantizer/recipe_test.py

Lines changed: 86 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ai_edge_quantizer import quantizer
2222
from ai_edge_quantizer import recipe
2323
from ai_edge_quantizer.utils import test_utils
24+
from ai_edge_quantizer.utils import tfl_interpreter_utils
2425

2526

2627
_TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('')
@@ -30,33 +31,76 @@ class RecipeTest(parameterized.TestCase):
3031

3132
def setUp(self):
3233
super().setUp()
33-
self._test_model_path = os.path.join(
34+
# Weights has < 1024 elements so legacy recipe will not quantize it.
35+
self._small_model_path = os.path.join(
3436
_TEST_DATA_PREFIX_PATH,
3537
'tests/models/single_conv2d_transpose_bias.tflite',
3638
)
39+
self._test_model_path = os.path.join(
40+
_TEST_DATA_PREFIX_PATH,
41+
'tests/models/conv_fc_mnist.tflite',
42+
)
3743

38-
def _quantize_with_recipe_func(self, recipe_func):
39-
qt = quantizer.Quantizer(self._test_model_path)
44+
def _quantize_with_recipe_func(self, recipe_func, test_model_path):
45+
qt = quantizer.Quantizer(test_model_path)
4046
qt.load_quantization_recipe(recipe_func())
4147
self.assertIsNone(qt._result.quantized_model)
42-
quant_result = qt.quantize()
43-
self.assertIsNotNone(quant_result.quantized_model)
44-
return quant_result
48+
if qt.need_calibration:
49+
calibration_data = tfl_interpreter_utils.create_random_normal_input_data(
50+
qt.float_model,
51+
num_samples=1,
52+
)
53+
calibration_result = qt.calibrate(calibration_data)
54+
quantization_result = qt.quantize(calibration_result)
55+
else:
56+
quantization_result = qt.quantize()
57+
self.assertIsNotNone(quantization_result.quantized_model)
58+
return quantization_result
4559

4660
def test_quantization_from_dynamic_wi8_afp32_func_succeeds(self):
47-
quant_result = self._quantize_with_recipe_func(recipe.dynamic_wi8_afp32)
61+
quant_result = self._quantize_with_recipe_func(
62+
recipe.dynamic_wi8_afp32, self._test_model_path
63+
)
64+
self.assertLess(
65+
len(quant_result.quantized_model),
66+
os.path.getsize(self._test_model_path),
67+
)
68+
69+
def test_quantization_from_dynamic_wi4_afp32_func_succeeds(self):
70+
quant_result = self._quantize_with_recipe_func(
71+
recipe.dynamic_wi4_afp32, self._test_model_path
72+
)
73+
self.assertLess(
74+
len(quant_result.quantized_model),
75+
os.path.getsize(self._test_model_path),
76+
)
77+
78+
def test_quantization_from_weight_only_wi8_afp32_func_succeeds(self):
79+
quant_result = self._quantize_with_recipe_func(
80+
recipe.weight_only_wi8_afp32, self._test_model_path
81+
)
82+
self.assertLess(
83+
len(quant_result.quantized_model),
84+
os.path.getsize(self._test_model_path),
85+
)
86+
87+
def test_quantization_from_weight_only_wi4_afp32_func_succeeds(self):
88+
quant_result = self._quantize_with_recipe_func(
89+
recipe.weight_only_wi4_afp32, self._test_model_path
90+
)
4891
self.assertLess(
4992
len(quant_result.quantized_model),
5093
os.path.getsize(self._test_model_path),
5194
)
5295

5396
def test_quantization_from_dynamic_legacy_wi8_afp32_func_succeeds(self):
5497
quant_result = self._quantize_with_recipe_func(
55-
recipe.dynamic_legacy_wi8_afp32
98+
recipe.dynamic_legacy_wi8_afp32,
99+
self._small_model_path,
56100
)
57101
self.assertLen(
58102
quant_result.quantized_model,
59-
os.path.getsize(self._test_model_path),
103+
os.path.getsize(self._small_model_path),
60104
)
61105

62106
@parameterized.named_parameters(
@@ -65,28 +109,54 @@ def test_quantization_from_dynamic_legacy_wi8_afp32_func_succeeds(self):
65109
recipe_json_path='recipes/dynamic_wi8_afp32_recipe.json',
66110
recipe_func=recipe.dynamic_wi8_afp32,
67111
),
112+
dict(
113+
testcase_name='weight_only_wi8_afp32',
114+
recipe_json_path='recipes/default_af32w8float_recipe.json',
115+
recipe_func=recipe.weight_only_wi8_afp32,
116+
),
117+
dict(
118+
testcase_name='weight_only_wi4_afp32',
119+
recipe_json_path='recipes/default_af32w4float_recipe.json',
120+
recipe_func=recipe.weight_only_wi4_afp32,
121+
),
68122
dict(
69123
testcase_name='dynamic_legacy_wi8_afp32',
70124
recipe_json_path='recipes/dynamic_legacy_wi8_afp32_recipe.json',
71125
recipe_func=recipe.dynamic_legacy_wi8_afp32,
72126
),
127+
dict(
128+
testcase_name='a8w8',
129+
recipe_json_path='recipes/default_a8w8_recipe.json',
130+
recipe_func=recipe.static_wi8_ai8,
131+
),
132+
dict(
133+
testcase_name='a16w8',
134+
recipe_json_path='recipes/default_a16w8_recipe.json',
135+
recipe_func=recipe.static_wi8_ai16,
136+
),
73137
)
74138
def test_recipe_func_and_json_matches(self, recipe_json_path, recipe_func):
75139
# Quantize with recipe from function in recipe module.
76-
quant_result_from_func = self._quantize_with_recipe_func(recipe_func)
140+
quant_result_from_func = self._quantize_with_recipe_func(
141+
recipe_func, self._test_model_path
142+
)
77143

78144
# Quantize with recipe from json file.
79145
qt_json = quantizer.Quantizer(self._test_model_path)
80146
json_recipe_path = os.path.join(_TEST_DATA_PREFIX_PATH, recipe_json_path)
81147
qt_json.load_quantization_recipe(json_recipe_path)
82-
quant_result_from_json = qt_json.quantize()
148+
if qt_json.need_calibration:
149+
calibration_data = tfl_interpreter_utils.create_random_normal_input_data(
150+
qt_json.float_model,
151+
num_samples=1,
152+
)
153+
calibration_result = qt_json.calibrate(calibration_data)
154+
quant_result_from_json = qt_json.quantize(calibration_result)
155+
else:
156+
quant_result_from_json = qt_json.quantize()
83157
self.assertIsNotNone(quant_result_from_json.quantized_model)
84158

85-
# Check if the recipes and quantized models match.
86-
self.assertEqual(
87-
quant_result_from_func.recipe,
88-
quant_result_from_json.recipe,
89-
)
159+
# Check if the quantized models match.
90160
self.assertEqual(
91161
len(quant_result_from_func.quantized_model),
92162
len(quant_result_from_json.quantized_model),

0 commit comments

Comments
 (0)