Skip to content

Commit 2842d51

Browse files
paulineshocopybara-github
authored andcommitted
Add plumbing for blockwise for TFLite interpreter
PiperOrigin-RevId: 803249284
1 parent 9597cfa commit 2842d51

File tree

4 files changed

+35
-8
lines changed

4 files changed

+35
-8
lines changed

ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def fix_quantization_params_rank(
119119
symmetric=quantization_params.symmetric,
120120
quantized_dimension=quantization_params.quantized_dimension,
121121
quantized_data=quantization_params.quantized_data,
122+
block_size=quantization_params.block_size,
122123
)
123124

124125

@@ -204,13 +205,16 @@ def _broadcast_scale_zp_for_blockwise(
204205
),
205206
tensor_content.shape,
206207
)
207-
expanded_zp = np.reshape(
208-
np.broadcast_to(
209-
np.expand_dims(quant_params.zero_point, quantized_dim + 1),
210-
expanded_tensor_shape,
211-
),
212-
tensor_content.shape,
213-
)
208+
if quant_params.zero_point is None or quant_params.zero_point.size == 0:
209+
expanded_zp = np.zeros_like(tensor_content, dtype=np.int32)
210+
else:
211+
expanded_zp = np.reshape(
212+
np.broadcast_to(
213+
np.expand_dims(quant_params.zero_point, quantized_dim + 1),
214+
expanded_tensor_shape,
215+
),
216+
tensor_content.shape,
217+
)
214218
return qtyping.UniformQuantParams(
215219
scale=expanded_scale,
216220
zero_point=expanded_zp,
@@ -290,6 +294,26 @@ def uniform_dequantize(
290294
Returns:
291295
The dequantized tensor.
292296
"""
297+
if quantization_params.block_size != 0:
298+
# b/443830202: The quantized dimension is currently increased by 1 because
299+
# AEQ expects 1 and XNNPack expects 0.
300+
quantization_params = dataclasses.replace(
301+
quantization_params,
302+
quantized_dimension=quantization_params.quantized_dimension + 1,
303+
)
304+
scale_shape = list(tensor_data.shape)
305+
scale_shape[quantization_params.quantized_dimension] = (
306+
scale_shape[quantization_params.quantized_dimension]
307+
// quantization_params.block_size
308+
)
309+
quantization_params = dataclasses.replace(
310+
quantization_params,
311+
scale=quantization_params.scale.reshape(scale_shape),
312+
)
313+
quantization_params = _broadcast_scale_zp_for_blockwise(
314+
tensor_data, quantization_params
315+
)
316+
293317
# quant params in flatbuffer is flattened, expand the rank to be the same
294318
# as the tensor rank to avoid ambiguous broadcasting.
295319
quantization_params = fix_quantization_params_rank(

ai_edge_quantizer/qtyping.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def from_tfl_tensor_details(cls, tensor_detail) -> 'UniformQuantParams':
212212
scale=quant_params['scales'],
213213
zero_point=quant_params['zero_points'],
214214
symmetric=symmetric,
215+
block_size=quant_params['block_size'],
215216
)
216217

217218
def __eq__(self, other):

ai_edge_quantizer/transformations/quantize_tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,11 @@ def _perform_blockwise_quantization(
154154
transformation_input.buffers,
155155
)
156156
blockwise_details.scales = scale_tensor_id
157+
blockwise_details.zeroPoints = -1
157158
blockwise_details.blockSize = transformation_input.quant_params.block_size
158159
# TODO: b/404909258 - Add optional zero point to blockwise quantization.
159160
flatbuffer_quantization.details = blockwise_details
161+
flatbuffer_quantization.quantizedDimension = 0
160162
return flatbuffer_quantization
161163

162164

ai_edge_quantizer/transformations/quantize_tensor_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def test_blockwise_quantization_with_zero_point(self):
170170
# Check if the scale and zero point tensors are inserted correctly.
171171
self.assertEqual(quant_param.details.scales, 9)
172172
# So far we don't have zero point in blockwise quantization.
173-
self.assertEqual(quant_param.details.zeroPoints, 0)
173+
self.assertEqual(quant_param.details.zeroPoints, -1)
174174

175175
def test_int4_constant_packed_correctly(self):
176176
subgraph = self._model.subgraphs[0]

0 commit comments

Comments
 (0)