Skip to content

Commit dcd849d

Browse files
authored
[GSoC] Gemm and MatMul block quantization support (#268)
* Gemm and MatMul block quantization support * refactoring * fix indentation * node name independent
1 parent ac83ef3 commit dcd849d

File tree

1 file changed

+134
-114
lines changed

1 file changed

+134
-114
lines changed

tools/quantize/block_quantize.py

+134-114
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,10 @@
1414
import onnx
1515
from onnx import helper
1616

17-
BITS_TO_NUMPY_TYPE = {8: np.uint8, 16: np.uint16}
17+
BITS_TO_NUMPY_TYPE = {8: np.int8, 16: np.int16}
1818

1919

20-
SUPPORTED_OPS = {
21-
"Conv"
22-
}
20+
SUPPORTED_OPS = {"Conv", "Gemm", "MatMul"}
2321

2422
ONNX_OPSET = 21
2523

@@ -43,12 +41,6 @@ class BlockQuantizeResult:
4341
quantization_error: np.ndarray = field(default_factory=lambda: np.array([]))
4442

4543

46-
@dataclass
47-
class LayerParams:
48-
weights: np.ndarray = field(default_factory=lambda: np.array([]))
49-
bias: Optional[np.ndarray] = None
50-
51-
5244
def closest_divisor(number: int, divisor: int) -> int:
5345
for d in range(divisor, 0, -1):
5446
if number % d == 0:
@@ -169,18 +161,6 @@ def get_initializer_tensor(self, name: str) -> Optional[np.ndarray]:
169161

170162
return None
171163

172-
def get_layer_params(self, node: onnx.NodeProto) -> LayerParams:
173-
params = LayerParams()
174-
175-
weights_name = node.input[1]
176-
params.weights = self.get_initializer_tensor(weights_name)
177-
178-
if len(node.input) > 2:
179-
bias_name = node.input[2]
180-
params.bias = self.get_initializer_tensor(bias_name)
181-
182-
return params
183-
184164
def compute_scale_zeropoint(
185165
self, b_min: np.ndarray, b_max: np.ndarray
186166
) -> Tuple[np.ndarray, np.ndarray]:
@@ -208,24 +188,28 @@ def compute_scale_zeropoint(
208188

209189
def block_quantize(self, weight: np.ndarray) -> BlockQuantizeResult:
210190
original_shape = weight.shape
211-
weight = weight.reshape((weight.shape[0], -1))
212191

213-
quantization_axis = 1
192+
if weight.ndim > 1:
193+
weight = weight.reshape((weight.shape[0], -1))
194+
quantization_axis = 1
195+
else:
196+
quantization_axis = 0
214197

215-
block_size = closest_divisor(weight.shape[1], self.conf.block_size)
198+
block_size = closest_divisor(
199+
weight.shape[quantization_axis], self.conf.block_size
200+
)
216201

217202
assert (
218-
weight.shape[1] % block_size == 0
219-
), f"weight shape ({weight.shape[1]}) must be divisible by block size ({block_size})"
203+
weight.shape[quantization_axis] % block_size == 0
204+
), f"weight shape ({weight.shape[quantization_axis]}) must be divisible by block size ({block_size})"
220205

221-
# Warning, axis = 1 specific instruction!
222-
blocked_weight = weight.reshape(
223-
(weight.shape[0], weight.shape[1] // block_size, -1)
224-
)
206+
# Flattening the tensor after the quantization axis
207+
new_shape = list(weight.shape[: quantization_axis + 1]) + [-1]
208+
new_shape[quantization_axis] = new_shape[quantization_axis] // block_size
209+
210+
blocked_weight = weight.reshape(new_shape)
225211

226-
# Warning, axis = 1 specific instruction!
227212
blocked_max = np.max(blocked_weight, -1)
228-
# Warning, axis = 1 specific instruction!
229213
blocked_min = np.min(blocked_weight, -1)
230214

231215
scales, zeropoints = self.compute_scale_zeropoint(blocked_min, blocked_max)
@@ -273,93 +257,129 @@ def display_summary(self, sqe: List):
273257
def run(self):
274258
print("Quantizing the model...")
275259

276-
visited_nodes = []
260+
quantized_inputs = []
277261
sqe = []
278262

279-
for node in self.model.graph.node:
280-
if node.name in visited_nodes:
281-
continue
263+
node_idx = 0
264+
265+
while node_idx < len(self.model.graph.node):
266+
node = self.model.graph.node[node_idx]
267+
282268
if node.op_type in SUPPORTED_OPS:
283-
conv_params = self.get_layer_params(node)
284-
block_quantize_res = self.block_quantize(conv_params.weights)
285-
286-
quantized_weights_name = f"{node.name}_quantized_weights"
287-
quantized_node_name = f"{node.name}_quantized_node"
288-
dequantized_weights_name = f"{node.name}_dequantized_weights"
289-
scales_name = f"{node.name}_scales"
290-
zero_point_name = f"{node.name}_zero_point"
291-
292-
shape_node_name = f"{node.name}_shape_node"
293-
shape_name = f"{node.name}_shape"
294-
reshaped_weights_name = f"{node.name}_reshaped_weights"
295-
296-
dequantize_node = create_dequantize_node(
297-
quantized_node_name,
298-
quantized_weights_name,
299-
scales_name,
300-
zero_point_name,
301-
dequantized_weights_name,
302-
block_quantize_res.block_size,
303-
block_quantize_res.axis,
304-
)
305-
reshape_node = create_reshape_node(
306-
shape_node_name,
307-
dequantized_weights_name,
308-
shape_name,
309-
reshaped_weights_name,
310-
)
311-
312-
shape_tensor = onnx.numpy_helper.from_array(
313-
np.array(block_quantize_res.original_shape), name=shape_name
314-
)
315-
scale_initializer = onnx.numpy_helper.from_array(
316-
block_quantize_res.scales, name=scales_name
317-
)
318-
zero_point_initializer = onnx.numpy_helper.from_array(
319-
block_quantize_res.zero_point, name=zero_point_name
320-
)
321-
quantized_weights_initializer = onnx.numpy_helper.from_array(
322-
block_quantize_res.quantized_weights, name=quantized_weights_name
323-
)
324-
325-
dequantized_weights_info = helper.make_tensor_value_info(
326-
dequantized_weights_name,
327-
onnx.TensorProto.FLOAT,
328-
block_quantize_res.quantized_weights.shape,
329-
)
330-
shape_info = helper.make_tensor_value_info(
331-
reshaped_weights_name,
332-
onnx.TensorProto.FLOAT,
333-
block_quantize_res.original_shape,
334-
)
335-
336-
self.graph.initializer.extend(
337-
[
338-
scale_initializer,
339-
zero_point_initializer,
340-
shape_tensor,
341-
quantized_weights_initializer,
342-
]
343-
)
344-
345-
# Removing fp32 weights
346-
self.graph.initializer.remove(
347-
next(
348-
init
349-
for init in self.graph.initializer
350-
if init.name == node.input[1]
269+
for input_idx, input_name in enumerate(node.input):
270+
weight = self.get_initializer_tensor(input_name)
271+
272+
quantized_weights_name = f"{input_name}_quantized"
273+
quantized_node_name = f"{input_name}_quantized_node"
274+
dequantized_weights_name = f"{input_name}_dequantized"
275+
scales_name = f"{input_name}_scales"
276+
zero_point_name = f"{input_name}_zero_point"
277+
278+
shape_node_name = f"{input_name}_shape_node"
279+
shape_name = f"{input_name}_shape"
280+
reshaped_weights_name = f"{input_name}_reshaped"
281+
282+
# Skip quantization if weights are taken as external input
283+
# or if they don't contain enough elements to create at least 1 block
284+
if weight is None or weight.size < self.conf.block_size:
285+
continue
286+
287+
reshape_needed = weight.ndim > 2
288+
289+
# In case of parameter sharing
290+
if input_name in quantized_inputs:
291+
node.input[input_idx] = (
292+
reshaped_weights_name
293+
if reshape_needed
294+
else dequantized_weights_name
295+
)
296+
continue
297+
298+
quantized_inputs.append(input_name)
299+
block_quantize_res = self.block_quantize(weight)
300+
301+
dequantize_node = create_dequantize_node(
302+
quantized_node_name,
303+
quantized_weights_name,
304+
scales_name,
305+
zero_point_name,
306+
dequantized_weights_name,
307+
block_quantize_res.block_size,
308+
block_quantize_res.axis,
351309
)
352-
)
353-
node.input[1] = reshaped_weights_name
354310

355-
# Preserving the topological order of graph nodes
356-
self.graph.node.insert(0, reshape_node)
357-
self.graph.node.insert(0, dequantize_node)
358-
self.graph.value_info.insert(0, shape_info)
359-
self.graph.value_info.insert(0, dequantized_weights_info)
311+
if reshape_needed:
312+
reshape_node = create_reshape_node(
313+
shape_node_name,
314+
dequantized_weights_name,
315+
shape_name,
316+
reshaped_weights_name,
317+
)
318+
319+
shape_tensor = onnx.numpy_helper.from_array(
320+
np.array(block_quantize_res.original_shape), name=shape_name
321+
)
322+
scale_initializer = onnx.numpy_helper.from_array(
323+
block_quantize_res.scales, name=scales_name
324+
)
325+
zero_point_initializer = onnx.numpy_helper.from_array(
326+
block_quantize_res.zero_point, name=zero_point_name
327+
)
328+
quantized_weights_initializer = onnx.numpy_helper.from_array(
329+
block_quantize_res.quantized_weights,
330+
name=quantized_weights_name,
331+
)
332+
333+
dequantized_weights_info = helper.make_tensor_value_info(
334+
dequantized_weights_name,
335+
onnx.TensorProto.FLOAT,
336+
block_quantize_res.quantized_weights.shape,
337+
)
338+
339+
if reshape_needed:
340+
shape_info = helper.make_tensor_value_info(
341+
reshaped_weights_name,
342+
onnx.TensorProto.FLOAT,
343+
block_quantize_res.original_shape,
344+
)
345+
346+
self.graph.initializer.extend(
347+
[
348+
scale_initializer,
349+
zero_point_initializer,
350+
shape_tensor,
351+
quantized_weights_initializer,
352+
]
353+
)
354+
355+
# Removing fp32 weights
356+
self.graph.initializer.remove(
357+
next(
358+
init
359+
for init in self.graph.initializer
360+
if init.name == input_name
361+
)
362+
)
363+
364+
node.input[input_idx] = (
365+
reshaped_weights_name
366+
if reshape_needed
367+
else dequantized_weights_name
368+
)
369+
370+
# Preserving graph nodes topological order
371+
if reshape_needed:
372+
self.graph.node.insert(0, reshape_node)
373+
node_idx += 1
374+
375+
self.graph.node.insert(0, dequantize_node)
376+
node_idx += 1
377+
self.graph.value_info.insert(0, shape_info)
378+
self.graph.value_info.insert(0, dequantized_weights_info)
360379

361-
sqe.append(block_quantize_res.quantization_error**2)
362-
visited_nodes.append(node.name)
380+
sqe.append(block_quantize_res.quantization_error**2)
381+
382+
node_idx += 1
363383

364384
onnx.checker.check_model(self.model, full_check=True)
365385
onnx.save(self.model, self.conf.output_model_path)

0 commit comments

Comments
 (0)