Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def fix_quantization_params_rank(
symmetric=quantization_params.symmetric,
quantized_dimension=quantization_params.quantized_dimension,
quantized_data=quantization_params.quantized_data,
block_size=quantization_params.block_size,
)


Expand Down Expand Up @@ -204,13 +205,16 @@ def _broadcast_scale_zp_for_blockwise(
),
tensor_content.shape,
)
expanded_zp = np.reshape(
np.broadcast_to(
np.expand_dims(quant_params.zero_point, quantized_dim + 1),
expanded_tensor_shape,
),
tensor_content.shape,
)
if quant_params.zero_point is None or quant_params.zero_point.size == 0:
expanded_zp = np.zeros_like(tensor_content, dtype=np.int32)
else:
expanded_zp = np.reshape(
np.broadcast_to(
np.expand_dims(quant_params.zero_point, quantized_dim + 1),
expanded_tensor_shape,
),
tensor_content.shape,
)
return qtyping.UniformQuantParams(
scale=expanded_scale,
zero_point=expanded_zp,
Expand Down Expand Up @@ -290,6 +294,26 @@ def uniform_dequantize(
Returns:
The dequantized tensor.
"""
if quantization_params.block_size != 0:
# b/443830202: The quantized dimension is currently increased by 1 because
# AEQ expects 1 and XNNPack expects 0.
quantization_params = dataclasses.replace(
quantization_params,
quantized_dimension=quantization_params.quantized_dimension + 1,
)
scale_shape = list(tensor_data.shape)
scale_shape[quantization_params.quantized_dimension] = (
scale_shape[quantization_params.quantized_dimension]
// quantization_params.block_size
)
quantization_params = dataclasses.replace(
quantization_params,
scale=quantization_params.scale.reshape(scale_shape),
)
quantization_params = _broadcast_scale_zp_for_blockwise(
tensor_data, quantization_params
)

# quant params in flatbuffer is flattened, expand the rank to be the same
# as the tensor rank to avoid ambiguous broadcasting.
quantization_params = fix_quantization_params_rank(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,30 @@ def test_uniform_dequantize_wrong_shape(self):
),
)

def test_uniform_dequantize_blockwise(self):
quantized_tensor = np.array([[-8, -5, -4, 7], [-4, 7, -8, -5]])
expected_output_tensor = np.array([
[-10.1333336, -6.3333335, -5.0666668, 8.8666669],
[-5.0666668, 8.8666669, -10.1333336, -6.3333335],
])
quant_params = qtyping.UniformQuantParams(
# b/443830202:
quantized_dimension=0,
num_bits=4,
scale=np.array([[[1.2666667, 1.2666667], [1.2666667, 1.2666667]]]),
zero_point=np.array([[0]]),
symmetric=True,
block_size=2,
)

dequantized_tensor = uniform_quantize_tensor.uniform_dequantize(
np.array(quantized_tensor), quant_params
)

self.assertSequenceAlmostEqual(
expected_output_tensor.flatten(), dequantized_tensor.flatten(), places=4
)

@parameterized.parameters(
(8, 8, True, True),
(8, 4, False, True),
Expand Down
1 change: 1 addition & 0 deletions ai_edge_quantizer/qtyping.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def from_tfl_tensor_details(cls, tensor_detail) -> 'UniformQuantParams':
scale=quant_params['scales'],
zero_point=quant_params['zero_points'],
symmetric=symmetric,
block_size=quant_params['block_size'],
)

def __eq__(self, other):
Expand Down
2 changes: 2 additions & 0 deletions ai_edge_quantizer/transformations/quantize_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,11 @@ def _perform_blockwise_quantization(
transformation_input.buffers,
)
blockwise_details.scales = scale_tensor_id
blockwise_details.zeroPoints = -1
blockwise_details.blockSize = transformation_input.quant_params.block_size
# TODO: b/404909258 - Add optional zero point to blockwise quantization.
flatbuffer_quantization.details = blockwise_details
flatbuffer_quantization.quantizedDimension = 0
return flatbuffer_quantization


Expand Down
2 changes: 1 addition & 1 deletion ai_edge_quantizer/transformations/quantize_tensor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_blockwise_quantization_with_zero_point(self):
# Check if the scale and zero point tensors are inserted correctly.
self.assertEqual(quant_param.details.scales, 9)
# So far we don't have zero point in blockwise quantization.
self.assertEqual(quant_param.details.zeroPoints, 0)
self.assertEqual(quant_param.details.zeroPoints, -1)

def test_int4_constant_packed_correctly(self):
subgraph = self._model.subgraphs[0]
Expand Down
Loading