diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py b/ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py index 2b3248ba..bf1b13a3 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py @@ -1165,39 +1165,36 @@ def init_tensor_min_max( A dictionary containing the min/max values for the tensor, or an empty dictionary if the tensor data is None. """ - if tensor_data is None: + weight_tensor_config = op_info.op_quant_config.weight_tensor_config + if tensor_data is None or weight_tensor_config is None: return {} else: - weight_tensor_config = op_info.op_quant_config.weight_tensor_config - quantized_dim = None - if weight_tensor_config is not None and ( - weight_tensor_config.granularity == qtyping.QuantGranularity.CHANNELWISE - ): + # Get reduce dimension for min/max calculation based on quantization + # granularity. + granularity = weight_tensor_config.granularity + if granularity == qtyping.QuantGranularity.TENSORWISE: + reduce_dims = None + keep_dims = True + elif granularity == qtyping.QuantGranularity.CHANNELWISE: quantized_dim = common_utils.get_weight_quantized_dim( op_info, tensor_data, weight_tensor_config.granularity ) - if ( - weight_tensor_config is not None - and weight_tensor_config.granularity - == qtyping.QuantGranularity.BLOCKWISE - ): - reshaped_data, reduce_dims = ( + reduce_dims = common_utils.get_reduce_dims( + quantized_dim, tensor_data.shape + ) + keep_dims = True + elif uniform_quantize_tensor.is_blockwise(granularity): + tensor_data, reduce_dims = ( uniform_quantize_tensor.reshape_data_for_blockwise( tensor_data, op_info.op_name, - weight_tensor_config.block_size, + granularity, ) ) - return { - "min": np.min(reshaped_data, axis=reduce_dims, keepdims=False), - "max": np.max(reshaped_data, axis=reduce_dims, keepdims=False), - } - + keep_dims = False else: - reduce_dims = common_utils.get_reduce_dims( - quantized_dim, tensor_data.shape - ) - return { - "min": np.min(tensor_data, axis=reduce_dims, keepdims=True), - "max": np.max(tensor_data, axis=reduce_dims, keepdims=True), - } + raise ValueError(f"Unsupported granularity: {granularity}") + return { + "min": np.min(tensor_data, axis=reduce_dims, keepdims=keep_dims), + "max": np.max(tensor_data, axis=reduce_dims, keepdims=keep_dims), + } diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py b/ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py index 91cb57bc..f532d81c 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py @@ -158,7 +158,7 @@ def get_tensor_quant_params( op_info, tensor_quant_config, tensor_content, tensor_qsv ) - if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE: + if uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity): raise ValueError( "Blockwise quantization is not supported for dequantized weight" " recovery." diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py index 38cba960..18e278c4 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py @@ -147,8 +147,7 @@ def test_fully_connected_blockwise_supported(self): weight_tensor_config=_TensorQuantConfig( num_bits=8, symmetric=True, - granularity=qtyping.QuantGranularity.BLOCKWISE, - block_size=32, + granularity=qtyping.QuantGranularity.BLOCKWISE_32, ), ), ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/mse.py b/ai_edge_quantizer/algorithms/uniform_quantize/mse.py index 578bed80..714c7b2b 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/mse.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/mse.py @@ -55,7 +55,7 @@ def get_tensor_quant_params( ValueError: `tensor_qsv` must contain min/max values, or `tensor_content` must be provided so that they can be inferred. """ - if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE: + if uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity): raise ValueError( "Blockwise quantization is not supported for MSE quantization." ) @@ -113,13 +113,15 @@ def get_tensor_quant_params( num_bits=tensor_quant_config.num_bits, symmetric=tensor_quant_config.symmetric, quantized_dimension=quantized_dim, - block_size=tensor_quant_config.block_size, + block_size=uniform_quantize_tensor.extract_block_size_from_granularity( + tensor_quant_config.granularity + ), ) quantized_vars = uniform_quantize_tensor.uniform_quantize( tensor_content, quant_params, - tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE, + uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity), ) return dataclasses.replace(quant_params, quantized_data=quantized_vars) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py index 84870dc0..c853d337 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py @@ -84,7 +84,7 @@ def test_get_tensor_quant_params_raises_error_with_unsupported_granularity( tensor_quant_config=qtyping.TensorQuantizationConfig( num_bits=4, symmetric=True, - granularity=qtyping.QuantGranularity.BLOCKWISE, + granularity=qtyping.QuantGranularity.BLOCKWISE_32, ), tensor_content=test_data, ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py index 478d2062..ca7dde9f 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py @@ -15,6 +15,7 @@ """Performs naive min/max uniform quantization.""" +import dataclasses from typing import Any, Optional import numpy as np from ai_edge_quantizer import qtyping @@ -91,7 +92,9 @@ def get_tensor_quant_params( num_bits=tensor_quant_config.num_bits, symmetric=tensor_quant_config.symmetric, quantized_dimension=quantized_dim, - block_size=tensor_quant_config.block_size, + block_size=uniform_quantize_tensor.extract_block_size_from_granularity( + tensor_quant_config.granularity + ), ) if tensor_content is None: return quant_params @@ -99,18 +102,10 @@ def get_tensor_quant_params( quantized_vars = uniform_quantize_tensor.uniform_quantize( tensor_content, quant_params, - tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE, + uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity), ) # Update with quantized values. - return qtyping.UniformQuantParams( - scale=scale, - zero_point=zp, - num_bits=tensor_quant_config.num_bits, - symmetric=tensor_quant_config.symmetric, - quantized_dimension=quantized_dim, - quantized_data=quantized_vars, - block_size=tensor_quant_config.block_size, - ) + return dataclasses.replace(quant_params, quantized_data=quantized_vars) # TODO: b/333731147 - Use named tuple to store min/max. diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py index fe96cedf..2a7497b3 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py @@ -17,6 +17,7 @@ from typing import cast from absl.testing import parameterized +import ml_dtypes import numpy as np from tensorflow.python.platform import googletest @@ -165,8 +166,7 @@ def test_get_tensor_quant_params_for_blockwise_weight(self): weight_tensor_config = _TensorQuantConfig( num_bits=4, symmetric=True, - granularity=qtyping.QuantGranularity.BLOCKWISE, - block_size=2, + granularity=qtyping.QuantGranularity.BLOCKWISE_32, ) op_info = qtyping.OpInfo( op=fc_op, @@ -176,28 +176,32 @@ def test_get_tensor_quant_params_for_blockwise_weight(self): weight_tensor_config=weight_tensor_config, ), ) - test_data = np.array([[-7, 7], [4, -4], [4, -4], [7, 7]]) + test_data = np.random.uniform(low=-10, high=10, size=(4, 32)).astype( + np.float32 + ) quant_params = naive_min_max_quantize.get_tensor_quant_params( op_info=op_info, tensor_quant_config=weight_tensor_config, tensor_content=test_data, ) - scale = quant_params.scale zp = quant_params.zero_point - expected_scale = np.array([ - [1], - [0.5703125], - [0.5703125], - [1], - ]) - expected_zp = np.zeros([4, 1]) - self.assertTrue(np.array_equal(zp, expected_zp)) - self.assertTrue(np.array_equal(scale, expected_scale)) + self.assertEqual(zp.shape, (4, 1)) + self.assertTrue(np.array_equal(zp, np.zeros([4, 1]))) + + self.assertEqual(quant_params.scale.shape, (4, 1)) + expected_scales = np.max(np.abs(test_data), axis=1, keepdims=True) / 7.0 + expected_scales = ( + expected_scales.astype(ml_dtypes.bfloat16) + .astype(np.float16) + .astype(np.float32) + ) + self.assertTrue(np.allclose(quant_params.scale, expected_scales, atol=1e-5)) + self.assertIsNotNone(quant_params.quantized_data) self.assertTupleEqual( cast(np.ndarray, quant_params.quantized_data).shape, test_data.shape ) - self.assertEqual(quant_params.block_size, 2) + self.assertEqual(quant_params.block_size, 32) self.assertEqual(quant_params.quantized_dimension, 1) def test_calibrate_ignores_inf_min_max(self): diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/octav.py b/ai_edge_quantizer/algorithms/uniform_quantize/octav.py index 84eda640..e494485a 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/octav.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/octav.py @@ -131,12 +131,12 @@ def get_tensor_quant_params( quantized_dim = common_utils.get_weight_quantized_dim( op_info, tensor_content, tensor_quant_config.granularity ) - if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE: + if uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity): reshaped_data, reduce_dims = ( uniform_quantize_tensor.reshape_data_for_blockwise( tensor_content, op_info.op_name, - tensor_quant_config.block_size, + tensor_quant_config.granularity, ) ) else: @@ -154,7 +154,7 @@ def get_tensor_quant_params( # We created a new dimension in order to reduce properly for blockwise # quantization, so we need to reshape the clipping constants back to the # min/max shape for the next step. - if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE: + if uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity): clipping_constants = clipping_constants.reshape(tensor_min_max["min"].shape) zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max( @@ -172,13 +172,17 @@ def get_tensor_quant_params( num_bits=tensor_quant_config.num_bits, symmetric=tensor_quant_config.symmetric, quantized_dimension=quantized_dim, - block_size=tensor_quant_config.block_size, + block_size=uniform_quantize_tensor.extract_block_size_from_granularity( + tensor_quant_config.granularity + ), ) quantized_vars = uniform_quantize_tensor.uniform_quantize( tensor_content, quant_params, - tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE, + is_blockwise_quant=uniform_quantize_tensor.is_blockwise( + tensor_quant_config.granularity + ), ) return dataclasses.replace(quant_params, quantized_data=quantized_vars) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py index ef554b19..3d847386 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py @@ -196,8 +196,7 @@ def test_get_tensor_quant_params_sanity_blockwise(self): tensor_config = qtyping.TensorQuantizationConfig( num_bits=4, symmetric=True, - granularity=qtyping.QuantGranularity.BLOCKWISE, - block_size=32, + granularity=qtyping.QuantGranularity.BLOCKWISE_32, ) fc_op_info = qtyping.OpInfo( op=self._fc_op, diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py b/ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py index bda64748..900ded71 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py @@ -29,6 +29,11 @@ class IntType: signed: bool +def is_blockwise(granularity: qtyping.QuantGranularity) -> bool: + """Checks if the quantization granularity is blockwise.""" + return "BLOCKWISE" in str(granularity) + + def get_quantized_range(qtype: IntType) -> tuple[float, float]: """Calculates range of the quantized type.""" if qtype.signed: @@ -40,6 +45,22 @@ def get_quantized_range(qtype: IntType) -> tuple[float, float]: return float(qmin), float(qmax) +def extract_block_size_from_granularity( + granularity: qtyping.QuantGranularity, +) -> int: + """Get the block size for blockwise quantization.""" + if granularity == qtyping.QuantGranularity.BLOCKWISE_32: + return 32 + elif granularity == qtyping.QuantGranularity.BLOCKWISE_64: + return 64 + elif granularity == qtyping.QuantGranularity.BLOCKWISE_128: + return 128 + elif granularity == qtyping.QuantGranularity.BLOCKWISE_256: + return 256 + else: + return 0 + + def _round_and_clip( tensor: np.ndarray, qtype: IntType, narrow: bool ) -> np.ndarray: @@ -157,14 +178,16 @@ def _get_tensor_shape_for_blockwise( def reshape_data_for_blockwise( - tensor_data: np.ndarray, op_name: qtyping.TFLOperationName, block_size: int + tensor_data: np.ndarray, + op_name: qtyping.TFLOperationName, + granularity: qtyping.QuantGranularity, ) -> tuple[np.ndarray, int]: """Reshapes data for blockwise quantization. Args: tensor_data: The original tensor data. op_name: The name of the TFL op. - block_size: The size of the block. + granularity: The quantization granularity for the tensor. Returns: A tuple containing the reshaped tensor data and the new reduce dimension. @@ -172,11 +195,11 @@ def reshape_data_for_blockwise( quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[ op_name ] + block_size = extract_block_size_from_granularity(granularity) new_shape = _get_tensor_shape_for_blockwise( tensor_data.shape, quantized_dim, block_size ) - reshaped_data = tensor_data.reshape(new_shape) - return reshaped_data, quantized_dim + 1 + return tensor_data.reshape(new_shape), quantized_dim + 1 def _broadcast_scale_zp_for_blockwise( @@ -233,21 +256,21 @@ def _broadcast_scale_zp_for_blockwise( def uniform_quantize( tensor_data: np.ndarray, quantization_params: qtyping.UniformQuantParams, - is_blockwise: bool = False, + is_blockwise_quant: bool = False, ): """Uniform quantize a tensor. Args: tensor_data: The tensor to be quantized. quantization_params: The quantization parameters. - is_blockwise: Whether the tensor is blockwise quantized. + is_blockwise_quant: Whether the tensor is blockwise quantized. Returns: The quantized tensor. """ # The reshaping for blockwise quantization is unique hence we do this here # to avoid unexpected broadcast behavior downstream. - if is_blockwise: + if is_blockwise_quant: quantization_params = _broadcast_scale_zp_for_blockwise( tensor_data, quantization_params ) @@ -435,6 +458,7 @@ def tensor_zp_scale_from_min_max( Returns: The zero point and scale of the tensor. """ + # TODO: b/332574603 - support unsigned data type. qtype = IntType( num_bits, @@ -445,7 +469,7 @@ def tensor_zp_scale_from_min_max( pos_clipping_values = None if clipping_values is None else clipping_values neg_clipping_values = None if clipping_values is None else -clipping_values - if granularity == qtyping.QuantGranularity.BLOCKWISE: + if is_blockwise(granularity): # Blockwise quantization uses float16 scale, # with 7 bit mantissa, so the maximum scale value is 65280 and maximum # representable range is [-65280 * (2 ** num_bits), @@ -493,7 +517,7 @@ def tensor_zp_scale_from_min_max( zp = qmin - bound_min / scale zp = np.rint(zp) - if granularity == qtyping.QuantGranularity.BLOCKWISE: + if is_blockwise(granularity): # Round the scale values to 7 bit mantissa. scale = ( scale.astype(ml_dtypes.bfloat16).astype(np.float16).astype(np.float32) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py index 1c0654ba..1fde8ce7 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py @@ -222,7 +222,7 @@ def test_uniform_quantize_quant_dim_not_divisible_by_block_size_raise(self): zero_point=np.array([-6]), symmetric=True, ), - is_blockwise=True, + is_blockwise_quant=True, ) @parameterized.parameters( diff --git a/ai_edge_quantizer/algorithms/utils/common_utils.py b/ai_edge_quantizer/algorithms/utils/common_utils.py index 1dca3a90..9be178b6 100644 --- a/ai_edge_quantizer/algorithms/utils/common_utils.py +++ b/ai_edge_quantizer/algorithms/utils/common_utils.py @@ -51,8 +51,9 @@ def check_subchannel_config( """Checks the op quantization config for subchannel quantization.""" if ( op_quant_config.weight_tensor_config is not None - and op_quant_config.weight_tensor_config.granularity - == qtyping.QuantGranularity.BLOCKWISE + and uniform_quantize_tensor.is_blockwise( + op_quant_config.weight_tensor_config.granularity + ) ): if op_name not in _SUPPORTED_SUBCHANNEL_OPS: raise ValueError(f"Unsupported op for blockwise quantization: {op_name}.") @@ -66,10 +67,6 @@ def check_subchannel_config( "Blockwise quantization does not support for asymmetric weight" " quantization." ) - if op_quant_config.weight_tensor_config.block_size <= 0: - raise ValueError( - "Blockwise quantization must have a non-zero block size." - ) def check_if_valid_op_config( @@ -993,7 +990,7 @@ def get_weight_quantized_dim( quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM.get( op_info.op_name, None ) - elif granularity == qtyping.QuantGranularity.BLOCKWISE: + elif uniform_quantize_tensor.is_blockwise(granularity): quantized_dim = ( tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[ op_info.op_name diff --git a/ai_edge_quantizer/default_policy.py b/ai_edge_quantizer/default_policy.py index aa4281d5..cbbc78fd 100644 --- a/ai_edge_quantizer/default_policy.py +++ b/ai_edge_quantizer/default_policy.py @@ -61,9 +61,8 @@ "weight_tensor_config": { "num_bits": 4, "symmetric": [true], - "granularity": ["BLOCKWISE"], - "dtype": "INT", - "block_size": [32, 64, 96, 128, 256] + "granularity": ["BLOCKWISE_32", "BLOCKWISE_64", "BLOCKWISE_128", "BLOCKWISE_256"], + "dtype": "INT" }, "explicit_dequantize": false, "compute_precision": "INTEGER" @@ -320,16 +319,9 @@ def _unroll_json_config( "granularity": granularity, "dtype": json_config["weight_tensor_config"]["dtype"], } - if "block_size" in json_config["weight_tensor_config"]: - for block_size in json_config["weight_tensor_config"]["block_size"]: - tensor_config["block_size"] = block_size - weight_configs.append( - qtyping.TensorQuantizationConfig.from_dict(tensor_config) - ) - else: - weight_configs.append( - qtyping.TensorQuantizationConfig.from_dict(tensor_config) - ) + weight_configs.append( + qtyping.TensorQuantizationConfig.from_dict(tensor_config) + ) if activation_configs: for activation_config in activation_configs: diff --git a/ai_edge_quantizer/qtyping.py b/ai_edge_quantizer/qtyping.py index 19884d6b..2cedbd09 100644 --- a/ai_edge_quantizer/qtyping.py +++ b/ai_edge_quantizer/qtyping.py @@ -112,7 +112,11 @@ class TensorDataType(str, enum.Enum): class QuantGranularity(str, enum.Enum): TENSORWISE = 'TENSORWISE' CHANNELWISE = 'CHANNELWISE' - BLOCKWISE = 'BLOCKWISE' + # Blockwise quantization with various block sizes. + BLOCKWISE_32 = 'BLOCKWISE_32' + BLOCKWISE_64 = 'BLOCKWISE_64' + BLOCKWISE_128 = 'BLOCKWISE_128' + BLOCKWISE_256 = 'BLOCKWISE_256' class QuantTransformation(enum.Enum): @@ -310,7 +314,6 @@ class TensorQuantizationConfig: granularity: Whether to perform per-tensor, per-channel or per-block quantization. dtype: The data type of the tensor. - block_size: The block size for blockwise quantization, ignored otherwise. algorithm_key: The algorithm key to use for quantization. """ @@ -318,7 +321,6 @@ class TensorQuantizationConfig: symmetric: bool = True granularity: QuantGranularity = QuantGranularity.TENSORWISE dtype: TensorDataType = TensorDataType.INT - block_size: int = 0 def to_dict(self) -> dict[str, Any]: """Converts ActivationQuantizationConfig to dict.""" @@ -336,9 +338,28 @@ def to_dict(self) -> dict[str, Any]: def from_dict(cls, params: dict[str, Any]) -> 'TensorQuantizationConfig': """Converts a given dict to TensorQuantizationConfig.""" params_copy = copy.deepcopy(params) + # Process block_size config from legacy recipe. + params_copy = _process_block_size(params_copy) return cls(**params_copy) +def _process_block_size(params: dict[str, Any]) -> dict[str, Any]: + """Processes block size in the params.""" + block_size = params.pop('block_size', 0) + if block_size > 0: + if block_size == 32: + params['granularity'] = QuantGranularity.BLOCKWISE_32 + elif block_size == 64: + params['granularity'] = QuantGranularity.BLOCKWISE_64 + elif block_size == 128: + params['granularity'] = QuantGranularity.BLOCKWISE_128 + elif block_size == 256: + params['granularity'] = QuantGranularity.BLOCKWISE_256 + else: + raise ValueError(f'Unsupported block size: {block_size}') + return params + + @dataclasses.dataclass(frozen=True) class OpQuantizationConfig: """Configuration class to control the quantization process behavior. diff --git a/ai_edge_quantizer/quantizer_test.py b/ai_edge_quantizer/quantizer_test.py index 41dff66e..8c8f5f3a 100644 --- a/ai_edge_quantizer/quantizer_test.py +++ b/ai_edge_quantizer/quantizer_test.py @@ -309,6 +309,44 @@ def test_save_succeeds(self): saved_recipe = json.load(json_file) self.assertEqual(saved_recipe, self._test_recipe) + def test_saved_legacy_recipe_lacks_block_size(self): + model_name = 'test_model' + legacy_recipe_path = os.path.join( + TEST_DATA_PREFIX_PATH, + 'recipes/dynamic_legacy_wi8_afp32_recipe.json', + ) + self._quantizer.load_quantization_recipe(legacy_recipe_path) + result = self._quantizer.quantize() + result.save(self._tmp_save_path, model_name) + saved_recipe_path = os.path.join( + self._tmp_save_path, model_name + '_recipe.json' + ) + with open(saved_recipe_path) as json_file: + saved_recipe = json.load(json_file) + with open(legacy_recipe_path) as json_file: + legacy_recipe = json.load(json_file) + + self.assertNotEqual(saved_recipe, legacy_recipe) + + # Verify that the default test recipe contains 'block_size'. + has_block_size = False + for config in legacy_recipe: + op_config = config.get('op_config') + if op_config: + weight_config = op_config.get('weight_tensor_config') + if weight_config and 'block_size' in weight_config: + has_block_size = True + break + self.assertTrue(has_block_size) + + # Verify that the saved recipe does not have 'block_size'. + for config in saved_recipe: + op_config = config.get('op_config') + if op_config: + weight_config = op_config.get('weight_tensor_config') + if weight_config: + self.assertNotIn('block_size', weight_config) + def test_save_no_quantize_raise_error(self): error_message = 'No quantized model to save.' with self.assertRaisesWithPredicateMatch( @@ -535,14 +573,12 @@ def test_constant_buffer_shared_by_tensors_with_different_quantization_params_su 'symmetric': False, 'granularity': 'TENSORWISE', 'dtype': 'INT', - 'block_size': 0, }, 'weight_tensor_config': { 'num_bits': 8, 'symmetric': True, 'granularity': 'CHANNELWISE', 'dtype': 'INT', - 'block_size': 0, }, 'compute_precision': 'INTEGER', 'explicit_dequantize': False, diff --git a/ai_edge_quantizer/recipe_manager_test.py b/ai_edge_quantizer/recipe_manager_test.py index 16c50310..87c3be50 100644 --- a/ai_edge_quantizer/recipe_manager_test.py +++ b/ai_edge_quantizer/recipe_manager_test.py @@ -569,14 +569,12 @@ def test_get_full_quantization_config(self): 'symmetric': False, 'granularity': _QuantGranularity.TENSORWISE, 'dtype': 'INT', - 'block_size': 0, }, 'weight_tensor_config': { 'num_bits': 8, 'symmetric': True, 'granularity': _QuantGranularity.TENSORWISE, 'dtype': 'INT', - 'block_size': 0, }, # WEIGHT_ONLY. 'compute_precision': _ComputePrecision.INTEGER, @@ -595,7 +593,6 @@ def test_get_full_quantization_config(self): 'num_bits': 8, 'symmetric': True, 'granularity': _QuantGranularity.TENSORWISE, - 'block_size': 0, }, # WEIGHT_ONLY. 'compute_precision': _ComputePrecision.FLOAT, @@ -614,7 +611,6 @@ def test_get_full_quantization_config(self): 'num_bits': 4, 'symmetric': True, 'granularity': _QuantGranularity.TENSORWISE, - 'block_size': 0, }, # WEIGHT_ONLY. 'compute_precision': _ComputePrecision.FLOAT, @@ -633,7 +629,6 @@ def test_get_full_quantization_config(self): 'num_bits': 6, 'symmetric': True, 'granularity': _QuantGranularity.TENSORWISE, - 'block_size': 0, }, # WEIGHT_ONLY. 'compute_precision': _ComputePrecision.FLOAT, @@ -652,7 +647,6 @@ def test_get_full_quantization_config(self): 'num_bits': 3, 'symmetric': True, 'granularity': _QuantGranularity.TENSORWISE, - 'block_size': 0, }, # WEIGHT_ONLY. 'compute_precision': _ComputePrecision.FLOAT, diff --git a/ai_edge_quantizer/recipes/default_a16w8_recipe.json b/ai_edge_quantizer/recipes/default_a16w8_recipe.json index 5135470f..bf31ad36 100644 --- a/ai_edge_quantizer/recipes/default_a16w8_recipe.json +++ b/ai_edge_quantizer/recipes/default_a16w8_recipe.json @@ -8,15 +8,13 @@ "num_bits": 16, "symmetric": true, "granularity": "TENSORWISE", - "dtype": "INT", - "block_size": 0 + "dtype": "INT" }, "weight_tensor_config": { "num_bits": 8, "symmetric": true, "granularity": "CHANNELWISE", - "dtype": "INT", - "block_size": 0 + "dtype": "INT" }, "compute_precision": "INTEGER", "explicit_dequantize": false, diff --git a/ai_edge_quantizer/recipes/default_a8w8_recipe.json b/ai_edge_quantizer/recipes/default_a8w8_recipe.json index 84395996..bbc02cf8 100644 --- a/ai_edge_quantizer/recipes/default_a8w8_recipe.json +++ b/ai_edge_quantizer/recipes/default_a8w8_recipe.json @@ -8,15 +8,13 @@ "num_bits": 8, "symmetric": false, "granularity": "TENSORWISE", - "dtype": "INT", - "block_size": 0 + "dtype": "INT" }, "weight_tensor_config": { "num_bits": 8, "symmetric": true, "granularity": "CHANNELWISE", - "dtype": "INT", - "block_size": 0 + "dtype": "INT" }, "compute_precision": "INTEGER", "explicit_dequantize": false, diff --git a/ai_edge_quantizer/recipes/default_af32w4float_recipe.json b/ai_edge_quantizer/recipes/default_af32w4float_recipe.json index deb750e9..81c18099 100644 --- a/ai_edge_quantizer/recipes/default_af32w4float_recipe.json +++ b/ai_edge_quantizer/recipes/default_af32w4float_recipe.json @@ -8,8 +8,7 @@ "num_bits": 4, "symmetric": false, "granularity": "CHANNELWISE", - "dtype": "INT", - "block_size": 0 + "dtype": "INT" }, "compute_precision": "FLOAT", "explicit_dequantize": true, diff --git a/ai_edge_quantizer/recipes/default_af32w8float_recipe.json b/ai_edge_quantizer/recipes/default_af32w8float_recipe.json index 2b169403..0de068a1 100644 --- a/ai_edge_quantizer/recipes/default_af32w8float_recipe.json +++ b/ai_edge_quantizer/recipes/default_af32w8float_recipe.json @@ -8,8 +8,7 @@ "num_bits": 8, "symmetric": false, "granularity": "CHANNELWISE", - "dtype": "INT", - "block_size": 0 + "dtype": "INT" }, "compute_precision": "FLOAT", "explicit_dequantize": true, diff --git a/ai_edge_quantizer/recipes/dynamic_wi8_afp32_recipe.json b/ai_edge_quantizer/recipes/dynamic_wi8_afp32_recipe.json index bdb29361..2332a00d 100644 --- a/ai_edge_quantizer/recipes/dynamic_wi8_afp32_recipe.json +++ b/ai_edge_quantizer/recipes/dynamic_wi8_afp32_recipe.json @@ -8,8 +8,7 @@ "num_bits": 8, "symmetric": true, "granularity": "CHANNELWISE", - "dtype": "INT", - "block_size": 0 + "dtype": "INT" }, "compute_precision": "INTEGER", "explicit_dequantize": false,