diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py b/ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py index 70e400c6..34ae6b44 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py @@ -119,6 +119,7 @@ def fix_quantization_params_rank( symmetric=quantization_params.symmetric, quantized_dimension=quantization_params.quantized_dimension, quantized_data=quantization_params.quantized_data, + block_size=quantization_params.block_size, ) @@ -204,13 +205,16 @@ def _broadcast_scale_zp_for_blockwise( ), tensor_content.shape, ) - expanded_zp = np.reshape( - np.broadcast_to( - np.expand_dims(quant_params.zero_point, quantized_dim + 1), - expanded_tensor_shape, - ), - tensor_content.shape, - ) + if quant_params.zero_point is None or quant_params.zero_point.size == 0: + expanded_zp = np.zeros_like(tensor_content, dtype=np.int32) + else: + expanded_zp = np.reshape( + np.broadcast_to( + np.expand_dims(quant_params.zero_point, quantized_dim + 1), + expanded_tensor_shape, + ), + tensor_content.shape, + ) return qtyping.UniformQuantParams( scale=expanded_scale, zero_point=expanded_zp, @@ -290,6 +294,26 @@ def uniform_dequantize( Returns: The dequantized tensor. """ + if quantization_params.block_size != 0: + # b/443830202: The quantized dimension is currently increased by 1 because + # AEQ expects 1 and XNNPack expects 0. + quantization_params = dataclasses.replace( + quantization_params, + quantized_dimension=quantization_params.quantized_dimension + 1, + ) + scale_shape = list(tensor_data.shape) + scale_shape[quantization_params.quantized_dimension] = ( + scale_shape[quantization_params.quantized_dimension] + // quantization_params.block_size + ) + quantization_params = dataclasses.replace( + quantization_params, + scale=quantization_params.scale.reshape(scale_shape), + ) + quantization_params = _broadcast_scale_zp_for_blockwise( + tensor_data, quantization_params + ) + # quant params in flatbuffer is flattened, expand the rank to be the same # as the tensor rank to avoid ambiguous broadcasting. quantization_params = fix_quantization_params_rank( diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py index 811c621f..8e186073 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py @@ -278,6 +278,30 @@ def test_uniform_dequantize_wrong_shape(self): ), ) + def test_uniform_dequantize_blockwise(self): + quantized_tensor = np.array([[-8, -5, -4, 7], [-4, 7, -8, -5]]) + expected_output_tensor = np.array([ + [-10.1333336, -6.3333335, -5.0666668, 8.8666669], + [-5.0666668, 8.8666669, -10.1333336, -6.3333335], + ]) + quant_params = qtyping.UniformQuantParams( + # b/443830202: + quantized_dimension=0, + num_bits=4, + scale=np.array([[[1.2666667, 1.2666667], [1.2666667, 1.2666667]]]), + zero_point=np.array([[0]]), + symmetric=True, + block_size=2, + ) + + dequantized_tensor = uniform_quantize_tensor.uniform_dequantize( + np.array(quantized_tensor), quant_params + ) + + self.assertSequenceAlmostEqual( + expected_output_tensor.flatten(), dequantized_tensor.flatten(), places=4 + ) + @parameterized.parameters( (8, 8, True, True), (8, 4, False, True), diff --git a/ai_edge_quantizer/qtyping.py b/ai_edge_quantizer/qtyping.py index 56e0ca6c..28eb5d9d 100644 --- a/ai_edge_quantizer/qtyping.py +++ b/ai_edge_quantizer/qtyping.py @@ -212,6 +212,7 @@ def from_tfl_tensor_details(cls, tensor_detail) -> 'UniformQuantParams': scale=quant_params['scales'], zero_point=quant_params['zero_points'], symmetric=symmetric, + block_size=quant_params['block_size'], ) def __eq__(self, other): diff --git a/ai_edge_quantizer/transformations/quantize_tensor.py b/ai_edge_quantizer/transformations/quantize_tensor.py index 26a7fe6c..8a2996b9 100644 --- a/ai_edge_quantizer/transformations/quantize_tensor.py +++ b/ai_edge_quantizer/transformations/quantize_tensor.py @@ -154,9 +154,11 @@ def _perform_blockwise_quantization( transformation_input.buffers, ) blockwise_details.scales = scale_tensor_id + blockwise_details.zeroPoints = -1 blockwise_details.blockSize = transformation_input.quant_params.block_size # TODO: b/404909258 - Add optional zero point to blockwise quantization. flatbuffer_quantization.details = blockwise_details + flatbuffer_quantization.quantizedDimension = 0 return flatbuffer_quantization diff --git a/ai_edge_quantizer/transformations/quantize_tensor_test.py b/ai_edge_quantizer/transformations/quantize_tensor_test.py index 5460907f..ba90ccdc 100644 --- a/ai_edge_quantizer/transformations/quantize_tensor_test.py +++ b/ai_edge_quantizer/transformations/quantize_tensor_test.py @@ -170,7 +170,7 @@ def test_blockwise_quantization_with_zero_point(self): # Check if the scale and zero point tensors are inserted correctly. self.assertEqual(quant_param.details.scales, 9) # So far we don't have zero point in blockwise quantization. - self.assertEqual(quant_param.details.zeroPoints, 0) + self.assertEqual(quant_param.details.zeroPoints, -1) def test_int4_constant_packed_correctly(self): subgraph = self._model.subgraphs[0]