14
14
import onnx
15
15
from onnx import helper
16
16
17
- BITS_TO_NUMPY_TYPE = {8 : np .uint8 , 16 : np .uint16 }
17
+ BITS_TO_NUMPY_TYPE = {8 : np .int8 , 16 : np .int16 }
18
18
19
19
20
- SUPPORTED_OPS = {
21
- "Conv"
22
- }
20
+ SUPPORTED_OPS = {"Conv" , "Gemm" , "MatMul" }
23
21
24
22
ONNX_OPSET = 21
25
23
@@ -43,12 +41,6 @@ class BlockQuantizeResult:
43
41
quantization_error : np .ndarray = field (default_factory = lambda : np .array ([]))
44
42
45
43
46
- @dataclass
47
- class LayerParams :
48
- weights : np .ndarray = field (default_factory = lambda : np .array ([]))
49
- bias : Optional [np .ndarray ] = None
50
-
51
-
52
44
def closest_divisor (number : int , divisor : int ) -> int :
53
45
for d in range (divisor , 0 , - 1 ):
54
46
if number % d == 0 :
@@ -169,18 +161,6 @@ def get_initializer_tensor(self, name: str) -> Optional[np.ndarray]:
169
161
170
162
return None
171
163
172
- def get_layer_params (self , node : onnx .NodeProto ) -> LayerParams :
173
- params = LayerParams ()
174
-
175
- weights_name = node .input [1 ]
176
- params .weights = self .get_initializer_tensor (weights_name )
177
-
178
- if len (node .input ) > 2 :
179
- bias_name = node .input [2 ]
180
- params .bias = self .get_initializer_tensor (bias_name )
181
-
182
- return params
183
-
184
164
def compute_scale_zeropoint (
185
165
self , b_min : np .ndarray , b_max : np .ndarray
186
166
) -> Tuple [np .ndarray , np .ndarray ]:
@@ -208,24 +188,28 @@ def compute_scale_zeropoint(
208
188
209
189
def block_quantize (self , weight : np .ndarray ) -> BlockQuantizeResult :
210
190
original_shape = weight .shape
211
- weight = weight .reshape ((weight .shape [0 ], - 1 ))
212
191
213
- quantization_axis = 1
192
+ if weight .ndim > 1 :
193
+ weight = weight .reshape ((weight .shape [0 ], - 1 ))
194
+ quantization_axis = 1
195
+ else :
196
+ quantization_axis = 0
214
197
215
- block_size = closest_divisor (weight .shape [1 ], self .conf .block_size )
198
+ block_size = closest_divisor (
199
+ weight .shape [quantization_axis ], self .conf .block_size
200
+ )
216
201
217
202
assert (
218
- weight .shape [1 ] % block_size == 0
219
- ), f"weight shape ({ weight .shape [1 ]} ) must be divisible by block size ({ block_size } )"
203
+ weight .shape [quantization_axis ] % block_size == 0
204
+ ), f"weight shape ({ weight .shape [quantization_axis ]} ) must be divisible by block size ({ block_size } )"
220
205
221
- # Warning, axis = 1 specific instruction!
222
- blocked_weight = weight .reshape (
223
- (weight .shape [0 ], weight .shape [1 ] // block_size , - 1 )
224
- )
206
+ # Flattening the tensor after the quantization axis
207
+ new_shape = list (weight .shape [: quantization_axis + 1 ]) + [- 1 ]
208
+ new_shape [quantization_axis ] = new_shape [quantization_axis ] // block_size
209
+
210
+ blocked_weight = weight .reshape (new_shape )
225
211
226
- # Warning, axis = 1 specific instruction!
227
212
blocked_max = np .max (blocked_weight , - 1 )
228
- # Warning, axis = 1 specific instruction!
229
213
blocked_min = np .min (blocked_weight , - 1 )
230
214
231
215
scales , zeropoints = self .compute_scale_zeropoint (blocked_min , blocked_max )
@@ -273,93 +257,129 @@ def display_summary(self, sqe: List):
273
257
def run (self ):
274
258
print ("Quantizing the model..." )
275
259
276
- visited_nodes = []
260
+ quantized_inputs = []
277
261
sqe = []
278
262
279
- for node in self .model .graph .node :
280
- if node .name in visited_nodes :
281
- continue
263
+ node_idx = 0
264
+
265
+ while node_idx < len (self .model .graph .node ):
266
+ node = self .model .graph .node [node_idx ]
267
+
282
268
if node .op_type in SUPPORTED_OPS :
283
- conv_params = self .get_layer_params (node )
284
- block_quantize_res = self .block_quantize (conv_params .weights )
285
-
286
- quantized_weights_name = f"{ node .name } _quantized_weights"
287
- quantized_node_name = f"{ node .name } _quantized_node"
288
- dequantized_weights_name = f"{ node .name } _dequantized_weights"
289
- scales_name = f"{ node .name } _scales"
290
- zero_point_name = f"{ node .name } _zero_point"
291
-
292
- shape_node_name = f"{ node .name } _shape_node"
293
- shape_name = f"{ node .name } _shape"
294
- reshaped_weights_name = f"{ node .name } _reshaped_weights"
295
-
296
- dequantize_node = create_dequantize_node (
297
- quantized_node_name ,
298
- quantized_weights_name ,
299
- scales_name ,
300
- zero_point_name ,
301
- dequantized_weights_name ,
302
- block_quantize_res .block_size ,
303
- block_quantize_res .axis ,
304
- )
305
- reshape_node = create_reshape_node (
306
- shape_node_name ,
307
- dequantized_weights_name ,
308
- shape_name ,
309
- reshaped_weights_name ,
310
- )
311
-
312
- shape_tensor = onnx .numpy_helper .from_array (
313
- np .array (block_quantize_res .original_shape ), name = shape_name
314
- )
315
- scale_initializer = onnx .numpy_helper .from_array (
316
- block_quantize_res .scales , name = scales_name
317
- )
318
- zero_point_initializer = onnx .numpy_helper .from_array (
319
- block_quantize_res .zero_point , name = zero_point_name
320
- )
321
- quantized_weights_initializer = onnx .numpy_helper .from_array (
322
- block_quantize_res .quantized_weights , name = quantized_weights_name
323
- )
324
-
325
- dequantized_weights_info = helper .make_tensor_value_info (
326
- dequantized_weights_name ,
327
- onnx .TensorProto .FLOAT ,
328
- block_quantize_res .quantized_weights .shape ,
329
- )
330
- shape_info = helper .make_tensor_value_info (
331
- reshaped_weights_name ,
332
- onnx .TensorProto .FLOAT ,
333
- block_quantize_res .original_shape ,
334
- )
335
-
336
- self .graph .initializer .extend (
337
- [
338
- scale_initializer ,
339
- zero_point_initializer ,
340
- shape_tensor ,
341
- quantized_weights_initializer ,
342
- ]
343
- )
344
-
345
- # Removing fp32 weights
346
- self .graph .initializer .remove (
347
- next (
348
- init
349
- for init in self .graph .initializer
350
- if init .name == node .input [1 ]
269
+ for input_idx , input_name in enumerate (node .input ):
270
+ weight = self .get_initializer_tensor (input_name )
271
+
272
+ quantized_weights_name = f"{ input_name } _quantized"
273
+ quantized_node_name = f"{ input_name } _quantized_node"
274
+ dequantized_weights_name = f"{ input_name } _dequantized"
275
+ scales_name = f"{ input_name } _scales"
276
+ zero_point_name = f"{ input_name } _zero_point"
277
+
278
+ shape_node_name = f"{ input_name } _shape_node"
279
+ shape_name = f"{ input_name } _shape"
280
+ reshaped_weights_name = f"{ input_name } _reshaped"
281
+
282
+ # Skip quantization if weights are taken as external input
283
+ # or if they don't contain enough elements to create at least 1 block
284
+ if weight is None or weight .size < self .conf .block_size :
285
+ continue
286
+
287
+ reshape_needed = weight .ndim > 2
288
+
289
+ # In case of parameter sharing
290
+ if input_name in quantized_inputs :
291
+ node .input [input_idx ] = (
292
+ reshaped_weights_name
293
+ if reshape_needed
294
+ else dequantized_weights_name
295
+ )
296
+ continue
297
+
298
+ quantized_inputs .append (input_name )
299
+ block_quantize_res = self .block_quantize (weight )
300
+
301
+ dequantize_node = create_dequantize_node (
302
+ quantized_node_name ,
303
+ quantized_weights_name ,
304
+ scales_name ,
305
+ zero_point_name ,
306
+ dequantized_weights_name ,
307
+ block_quantize_res .block_size ,
308
+ block_quantize_res .axis ,
351
309
)
352
- )
353
- node .input [1 ] = reshaped_weights_name
354
310
355
- # Preserving the topological order of graph nodes
356
- self .graph .node .insert (0 , reshape_node )
357
- self .graph .node .insert (0 , dequantize_node )
358
- self .graph .value_info .insert (0 , shape_info )
359
- self .graph .value_info .insert (0 , dequantized_weights_info )
311
+ if reshape_needed :
312
+ reshape_node = create_reshape_node (
313
+ shape_node_name ,
314
+ dequantized_weights_name ,
315
+ shape_name ,
316
+ reshaped_weights_name ,
317
+ )
318
+
319
+ shape_tensor = onnx .numpy_helper .from_array (
320
+ np .array (block_quantize_res .original_shape ), name = shape_name
321
+ )
322
+ scale_initializer = onnx .numpy_helper .from_array (
323
+ block_quantize_res .scales , name = scales_name
324
+ )
325
+ zero_point_initializer = onnx .numpy_helper .from_array (
326
+ block_quantize_res .zero_point , name = zero_point_name
327
+ )
328
+ quantized_weights_initializer = onnx .numpy_helper .from_array (
329
+ block_quantize_res .quantized_weights ,
330
+ name = quantized_weights_name ,
331
+ )
332
+
333
+ dequantized_weights_info = helper .make_tensor_value_info (
334
+ dequantized_weights_name ,
335
+ onnx .TensorProto .FLOAT ,
336
+ block_quantize_res .quantized_weights .shape ,
337
+ )
338
+
339
+ if reshape_needed :
340
+ shape_info = helper .make_tensor_value_info (
341
+ reshaped_weights_name ,
342
+ onnx .TensorProto .FLOAT ,
343
+ block_quantize_res .original_shape ,
344
+ )
345
+
346
+ self .graph .initializer .extend (
347
+ [
348
+ scale_initializer ,
349
+ zero_point_initializer ,
350
+ shape_tensor ,
351
+ quantized_weights_initializer ,
352
+ ]
353
+ )
354
+
355
+ # Removing fp32 weights
356
+ self .graph .initializer .remove (
357
+ next (
358
+ init
359
+ for init in self .graph .initializer
360
+ if init .name == input_name
361
+ )
362
+ )
363
+
364
+ node .input [input_idx ] = (
365
+ reshaped_weights_name
366
+ if reshape_needed
367
+ else dequantized_weights_name
368
+ )
369
+
370
+ # Preserving graph nodes topological order
371
+ if reshape_needed :
372
+ self .graph .node .insert (0 , reshape_node )
373
+ node_idx += 1
374
+
375
+ self .graph .node .insert (0 , dequantize_node )
376
+ node_idx += 1
377
+ self .graph .value_info .insert (0 , shape_info )
378
+ self .graph .value_info .insert (0 , dequantized_weights_info )
360
379
361
- sqe .append (block_quantize_res .quantization_error ** 2 )
362
- visited_nodes .append (node .name )
380
+ sqe .append (block_quantize_res .quantization_error ** 2 )
381
+
382
+ node_idx += 1
363
383
364
384
onnx .checker .check_model (self .model , full_check = True )
365
385
onnx .save (self .model , self .conf .output_model_path )
0 commit comments