2121from ai_edge_quantizer import quantizer
2222from ai_edge_quantizer import recipe
2323from ai_edge_quantizer .utils import test_utils
24+ from ai_edge_quantizer .utils import tfl_interpreter_utils
2425
2526
2627_TEST_DATA_PREFIX_PATH = test_utils .get_path_to_datafile ('' )
@@ -30,33 +31,76 @@ class RecipeTest(parameterized.TestCase):
3031
3132 def setUp (self ):
3233 super ().setUp ()
33- self ._test_model_path = os .path .join (
34+ # Weights has < 1024 elements so legacy recipe will not quantize it.
35+ self ._small_model_path = os .path .join (
3436 _TEST_DATA_PREFIX_PATH ,
3537 'tests/models/single_conv2d_transpose_bias.tflite' ,
3638 )
39+ self ._test_model_path = os .path .join (
40+ _TEST_DATA_PREFIX_PATH ,
41+ 'tests/models/conv_fc_mnist.tflite' ,
42+ )
3743
38- def _quantize_with_recipe_func (self , recipe_func ):
39- qt = quantizer .Quantizer (self . _test_model_path )
44+ def _quantize_with_recipe_func (self , recipe_func , test_model_path ):
45+ qt = quantizer .Quantizer (test_model_path )
4046 qt .load_quantization_recipe (recipe_func ())
4147 self .assertIsNone (qt ._result .quantized_model )
42- quant_result = qt .quantize ()
43- self .assertIsNotNone (quant_result .quantized_model )
44- return quant_result
48+ if qt .need_calibration :
49+ calibration_data = tfl_interpreter_utils .create_random_normal_input_data (
50+ qt .float_model ,
51+ num_samples = 1 ,
52+ )
53+ calibration_result = qt .calibrate (calibration_data )
54+ quantization_result = qt .quantize (calibration_result )
55+ else :
56+ quantization_result = qt .quantize ()
57+ self .assertIsNotNone (quantization_result .quantized_model )
58+ return quantization_result
4559
4660 def test_quantization_from_dynamic_wi8_afp32_func_succeeds (self ):
47- quant_result = self ._quantize_with_recipe_func (recipe .dynamic_wi8_afp32 )
61+ quant_result = self ._quantize_with_recipe_func (
62+ recipe .dynamic_wi8_afp32 , self ._test_model_path
63+ )
64+ self .assertLess (
65+ len (quant_result .quantized_model ),
66+ os .path .getsize (self ._test_model_path ),
67+ )
68+
69+ def test_quantization_from_dynamic_wi4_afp32_func_succeeds (self ):
70+ quant_result = self ._quantize_with_recipe_func (
71+ recipe .dynamic_wi4_afp32 , self ._test_model_path
72+ )
73+ self .assertLess (
74+ len (quant_result .quantized_model ),
75+ os .path .getsize (self ._test_model_path ),
76+ )
77+
78+ def test_quantization_from_weight_only_wi8_afp32_func_succeeds (self ):
79+ quant_result = self ._quantize_with_recipe_func (
80+ recipe .weight_only_wi8_afp32 , self ._test_model_path
81+ )
82+ self .assertLess (
83+ len (quant_result .quantized_model ),
84+ os .path .getsize (self ._test_model_path ),
85+ )
86+
87+ def test_quantization_from_weight_only_wi4_afp32_func_succeeds (self ):
88+ quant_result = self ._quantize_with_recipe_func (
89+ recipe .weight_only_wi4_afp32 , self ._test_model_path
90+ )
4891 self .assertLess (
4992 len (quant_result .quantized_model ),
5093 os .path .getsize (self ._test_model_path ),
5194 )
5295
5396 def test_quantization_from_dynamic_legacy_wi8_afp32_func_succeeds (self ):
5497 quant_result = self ._quantize_with_recipe_func (
55- recipe .dynamic_legacy_wi8_afp32
98+ recipe .dynamic_legacy_wi8_afp32 ,
99+ self ._small_model_path ,
56100 )
57101 self .assertLen (
58102 quant_result .quantized_model ,
59- os .path .getsize (self ._test_model_path ),
103+ os .path .getsize (self ._small_model_path ),
60104 )
61105
62106 @parameterized .named_parameters (
@@ -65,28 +109,54 @@ def test_quantization_from_dynamic_legacy_wi8_afp32_func_succeeds(self):
65109 recipe_json_path = 'recipes/dynamic_wi8_afp32_recipe.json' ,
66110 recipe_func = recipe .dynamic_wi8_afp32 ,
67111 ),
112+ dict (
113+ testcase_name = 'weight_only_wi8_afp32' ,
114+ recipe_json_path = 'recipes/default_af32w8float_recipe.json' ,
115+ recipe_func = recipe .weight_only_wi8_afp32 ,
116+ ),
117+ dict (
118+ testcase_name = 'weight_only_wi4_afp32' ,
119+ recipe_json_path = 'recipes/default_af32w4float_recipe.json' ,
120+ recipe_func = recipe .weight_only_wi4_afp32 ,
121+ ),
68122 dict (
69123 testcase_name = 'dynamic_legacy_wi8_afp32' ,
70124 recipe_json_path = 'recipes/dynamic_legacy_wi8_afp32_recipe.json' ,
71125 recipe_func = recipe .dynamic_legacy_wi8_afp32 ,
72126 ),
127+ dict (
128+ testcase_name = 'a8w8' ,
129+ recipe_json_path = 'recipes/default_a8w8_recipe.json' ,
130+ recipe_func = recipe .static_wi8_ai8 ,
131+ ),
132+ dict (
133+ testcase_name = 'a16w8' ,
134+ recipe_json_path = 'recipes/default_a16w8_recipe.json' ,
135+ recipe_func = recipe .static_wi8_ai16 ,
136+ ),
73137 )
74138 def test_recipe_func_and_json_matches (self , recipe_json_path , recipe_func ):
75139 # Quantize with recipe from function in recipe module.
76- quant_result_from_func = self ._quantize_with_recipe_func (recipe_func )
140+ quant_result_from_func = self ._quantize_with_recipe_func (
141+ recipe_func , self ._test_model_path
142+ )
77143
78144 # Quantize with recipe from json file.
79145 qt_json = quantizer .Quantizer (self ._test_model_path )
80146 json_recipe_path = os .path .join (_TEST_DATA_PREFIX_PATH , recipe_json_path )
81147 qt_json .load_quantization_recipe (json_recipe_path )
82- quant_result_from_json = qt_json .quantize ()
148+ if qt_json .need_calibration :
149+ calibration_data = tfl_interpreter_utils .create_random_normal_input_data (
150+ qt_json .float_model ,
151+ num_samples = 1 ,
152+ )
153+ calibration_result = qt_json .calibrate (calibration_data )
154+ quant_result_from_json = qt_json .quantize (calibration_result )
155+ else :
156+ quant_result_from_json = qt_json .quantize ()
83157 self .assertIsNotNone (quant_result_from_json .quantized_model )
84158
85- # Check if the recipes and quantized models match.
86- self .assertEqual (
87- quant_result_from_func .recipe ,
88- quant_result_from_json .recipe ,
89- )
159+ # Check if the quantized models match.
90160 self .assertEqual (
91161 len (quant_result_from_func .quantized_model ),
92162 len (quant_result_from_json .quantized_model ),
0 commit comments