Skip to content

Commit dcf79e9

Browse files
authored
Add GroupWiseQuant & AWQ & AutoClip (#1821)
1 parent d4ac0ef commit dcf79e9

File tree

8 files changed

+428
-42
lines changed

8 files changed

+428
-42
lines changed

paddleslim/quant/advanced/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from . import sample
2020
from . import layerwise_quant_error
2121
from . import utils_layers
22+
from . import awq_search
23+
from . import auto_clip
2224

2325
from .gptq import *
2426
from .smooth import *
@@ -27,6 +29,8 @@
2729
from .sample import *
2830
from .layerwise_quant_error import *
2931
from .utils_layers import *
32+
from .awq_search import *
33+
from .auto_clip import *
3034

3135
__all__ = []
3236
__all__ += gptq.__all__
@@ -35,4 +39,6 @@
3539
__all__ += piecewise_search.__all__
3640
__all__ += sample.__all__
3741
__all__ += layerwise_quant_error.__all__
38-
__all__ += utils_layers.__all__
42+
__all__ += utils_layers.__all__
43+
__all__ += awq_search.__all__
44+
__all__ += auto_clip.__all__
+155
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import paddle.nn as nn
17+
import numpy as np
18+
from .utils import fake_quant
19+
from .metrics import mse_loss
20+
from paddle.distributed.fleet.meta_parallel import (
21+
ColumnParallelLinear,
22+
RowParallelLinear,
23+
)
24+
__all__ = ['AutoClip']
25+
26+
class AutoClip(nn.Layer):
27+
"""
28+
AutoClip from AWQ[https://arxiv.org/abs/2306.00978]
29+
"""
30+
def __init__(
31+
self,
32+
model,
33+
weight_bits=8,
34+
weight_quant_method='groupwise',
35+
loss_function=mse_loss,
36+
sample_function=None,
37+
n_grid=20,
38+
max_shrink=0.5,
39+
n_sample_token=128,
40+
group_size=-1,
41+
):
42+
super(AutoClip, self).__init__()
43+
self.model = model
44+
self.weight_bits = weight_bits
45+
self.weight_method = weight_quant_method
46+
self.loss_function = loss_function
47+
self.n_grid = n_grid
48+
self.max_shrink = max_shrink
49+
self.n_sample_token = n_sample_token
50+
self.bnt = (1 << (self.weight_bits - 1)) - 1
51+
self.sampled_inputs = {}
52+
self.sample_function = sample_function
53+
self.group_size = group_size
54+
55+
self._apply_hook()
56+
57+
def _apply_hook(self):
58+
self._forward_hook_list = []
59+
for _, sub_layer in self.model.named_sublayers():
60+
if type(sub_layer) in [ColumnParallelLinear, RowParallelLinear, paddle.nn.Linear]:
61+
forward_pre_hook_handle = sub_layer.register_forward_pre_hook(
62+
self._forward_pre_hook)
63+
self._forward_hook_list.append(forward_pre_hook_handle)
64+
65+
def _forward_pre_hook(self, layer, input):
66+
self._sample_scale(input, layer.full_name())
67+
return input
68+
69+
def _sample_scale(self, input, name):
70+
input = input[0] if type(input) == tuple else input
71+
input.stop_gradient = True
72+
if name not in self.sampled_inputs:
73+
self.sampled_inputs[name] = input
74+
else:
75+
if self.sample_function is not None:
76+
self.sampled_inputs[name] = self.sample_function.sample(
77+
input, self.sampled_inputs[name], name)
78+
else:
79+
self.sampled_inputs[name] = input
80+
81+
82+
def auto_clip(self, group_size=128, oc_batch_size=1024):
83+
"""
84+
search clip scale for each layer and update the layer's weight
85+
"""
86+
for sub_name, sub_layer in self.model.named_sublayers():
87+
name = sub_layer.full_name()
88+
if name not in self.sampled_inputs:
89+
continue
90+
print('AutoClipping', sub_name, name)
91+
weight = sub_layer.weight.cast('float16')
92+
weight_t = paddle.transpose(weight, perm=[1, 0])
93+
x = self.sampled_inputs[name].cast('float16')
94+
x = x.reshape([-1, x.shape[-1]])
95+
x = x.reshape([1, x.shape[0], -1, group_size])
96+
x = x[:, 0::x.shape[1] // self.n_sample_token]
97+
weight_t = weight_t.reshape([weight_t.shape[0], 1, -1, group_size])
98+
# fast test
99+
# oc_batch_size = weight_t.shape[0] // 4
100+
oc_batch_size = oc_batch_size if weight_t.shape[0] % oc_batch_size == 0 else 128 # prevent OOM
101+
assert weight_t.shape[0] % oc_batch_size == 0
102+
103+
w_all = weight_t
104+
best_max_val_all = []
105+
106+
for i_b in range(weight_t.shape[0] // oc_batch_size):
107+
w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size]
108+
109+
org_max_val = w.abs().max(axis=-1, keepdim=True) # co, 1, n_group, 1
110+
best_max_val = org_max_val.clone()
111+
min_errs = paddle.ones_like(org_max_val, dtype='float16') * 1e9
112+
org_out = (x * w).sum(axis=-1) # co, n_token, n_group
113+
for i_s in range(int(self.max_shrink * self.n_grid)):
114+
max_val = org_max_val * (1 - i_s / self.n_grid)
115+
max_val_tmp = max_val
116+
cur_w = paddle.where(w > max_val_tmp, max_val_tmp, w)
117+
cur_w = paddle.where(cur_w < - max_val_tmp, - max_val_tmp, cur_w)
118+
quant_dequant_weight = fake_quant(cur_w, method='abs_max', weight_bits=4)
119+
cur_out = (x * quant_dequant_weight).sum(axis=-1)
120+
# co, 1, n_group, 1
121+
tmp = (cur_out - org_out).detach().clone()
122+
err = paddle.pow(tmp, 2).mean(axis=1).reshape(min_errs.shape)
123+
print('block {} search s {} err {}'.format(i_b, i_s, err.mean().item()))
124+
del cur_w, cur_out, quant_dequant_weight, tmp
125+
paddle.device.cuda.empty_cache()
126+
127+
cur_best_idx = paddle.where(err < min_errs)
128+
if cur_best_idx[0].shape[0] != 0:
129+
min_errs[cur_best_idx] = err[cur_best_idx]
130+
best_max_val[cur_best_idx] = max_val[cur_best_idx]
131+
best_max_val_all.append(best_max_val)
132+
133+
del org_out, org_max_val, min_errs, best_max_val, err, cur_best_idx, max_val_tmp, max_val, w
134+
paddle.device.cuda.empty_cache()
135+
136+
best_max_val = paddle.concat(best_max_val_all, axis=0)
137+
best_max_val = paddle.squeeze(best_max_val, axis=1)
138+
for param in sub_layer.parameters(include_sublayers=False):
139+
if 'w_0' in param.name:
140+
param_tmp = param.transpose(perm=[1, 0]).cast('float16')
141+
tmp_shape = param_tmp.shape
142+
param_tmp = param_tmp.reshape([best_max_val.shape[0], best_max_val.shape[1], -1])
143+
best_max_val = paddle.tile(best_max_val, repeat_times=(1, 1, param_tmp.shape[-1]))
144+
param_tmp = paddle.where(param_tmp > best_max_val, best_max_val, param_tmp)
145+
param_tmp = paddle.where(param_tmp < - best_max_val, - best_max_val, param_tmp)
146+
param_tmp = param_tmp.reshape(tmp_shape).cast(param.dtype)
147+
param_tmp = param_tmp.transpose(perm=[1, 0])
148+
paddle.assign(param_tmp, output=param)
149+
del param_tmp
150+
paddle.device.cuda.empty_cache()
151+
break
152+
153+
del best_max_val, weight_t, x, weight, self.sampled_inputs[name], w_all, best_max_val_all
154+
paddle.device.cuda.empty_cache()
155+
+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import paddle
15+
import numpy as np
16+
from .utils import compute_scales
17+
from .metrics import mse_loss
18+
__all__ = ['AWQSearch']
19+
20+
class AWQSearch():
21+
def __init__(self,
22+
n_grid=20,
23+
bits_length=4,
24+
weight_quant_method='groupwise',
25+
group_size=128,
26+
loss_function=mse_loss):
27+
'''
28+
The implementation of AutoScale from AWQ(https://arxiv.org/pdf/2306.00978.pdf).
29+
'''
30+
self.n_grid = n_grid
31+
self.bits_length = bits_length
32+
self.weight_quant_method = weight_quant_method
33+
self.bnt = (1 << (bits_length - 1)) - 1
34+
self.group_size = group_size
35+
self.loss_function = loss_function
36+
37+
def search(self, layer_name, sampled_input, act_abs_max, weight):
38+
act = sampled_input
39+
act.stop_gradient = True
40+
print('[awq search] search input of %s' % layer_name)
41+
dtype = weight.dtype
42+
origin_out = paddle.matmul(act, weight)
43+
best_error = float('inf')
44+
best_ratio = -1
45+
best_scales = None
46+
47+
for ratio in range(self.n_grid):
48+
ratio = ratio * 1 / self.n_grid
49+
act_abs_max_tmp = act_abs_max.detach().clone().cast('float32')
50+
scales = paddle.clip(paddle.pow(act_abs_max_tmp, ratio), min=1e-4)
51+
scales = scales / (scales.max() * scales.min()).sqrt()
52+
scales = scales.cast(dtype)
53+
new_weight = weight * scales.reshape([-1, 1])
54+
new_act = act / scales
55+
quant_scale = compute_scales(
56+
new_weight, method=self.weight_quant_method, group_size=self.group_size)
57+
if self.weight_quant_method == 'groupwise':
58+
quant_scale = paddle.repeat_interleave(quant_scale.cast('float32'), self.group_size, 0).cast(dtype)
59+
quant_weight = paddle.clip(
60+
paddle.round(new_weight / quant_scale * self.bnt),
61+
-self.bnt - 1, self.bnt)
62+
quant_dequant_weight = quant_weight / self.bnt * quant_scale
63+
new_out = paddle.matmul(new_act,
64+
quant_dequant_weight)
65+
loss = self.loss_function(origin_out, new_out).numpy()
66+
is_best = loss < best_error
67+
if is_best:
68+
print('find better ratio: {}, loss: {}'.format(ratio, loss))
69+
best_error = loss
70+
best_ratio = ratio
71+
best_scales = scales
72+
73+
if best_scales is None:
74+
best_scales = paddle.ones(scales.shape, dtype=dtype)
75+
print('Cannot find better ratio.')
76+
else:
77+
print('Best ratio :{}, minimal loss : {}.'.format(best_ratio, best_error))
78+
return best_scales

paddleslim/quant/advanced/piecewise_search.py

+26-13
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ def __init__(self,
3131
search_scale_max=5.,
3232
weight_quant_method='abs_max_channel_wise',
3333
act_quant_method='abs_max',
34+
use_clip=False,
35+
search_clip=False,
3436
loss_function=mse_loss):
3537
'''
3638
PieceWiseSearch provides to search k_piece, alpha and scale.
@@ -58,31 +60,36 @@ def __init__(self,
5860
self.act_quant_method = act_quant_method
5961
self.bnt = (1 << (bits_length - 1)) - 1
6062
self.loss_function = loss_function
63+
self.use_clip = use_clip
64+
self.search_clip = search_clip
6165

6266
def search(self, layer_name, sampled_input, act_abs_max, weight):
6367
act = sampled_input
6468
act.stop_gradient = True
6569
print('[smooth search] search input of %s' % layer_name)
66-
70+
dtype = weight.dtype
6771
origin_out = paddle.matmul(act, weight)
6872
w_abs_max = weight.abs().max(axis=-1, keepdim=True)
6973
rw_abs_max = w_abs_max.reshape(act_abs_max.shape)
70-
np_act_abs_max = np.array(act_abs_max)
71-
np_rw_abs_max = np.array(rw_abs_max)
72-
74+
7375
smooth_scale_out = None
7476
global_loss = float('inf')
7577
best_scale = None
7678

77-
for k_piece in range(1, self.k_piece + 1):
79+
if self.search_clip:
80+
piece_range = [1] + list(range(1, self.k_piece + 1))
81+
else:
82+
piece_range = list(range(1, self.k_piece + 1))
83+
84+
for k_idx, k_piece in enumerate(piece_range):
7885
if not self.search_piece:
7986
k_piece = self.k_piece
8087
print('Search {} Piece'.format(k_piece))
8188
centroids, labels = k_means(act_abs_max, k_piece)
8289
piece = ['piece_{}'.format(a) for a in range(len(centroids))]
8390
for i in range(len(centroids)):
84-
# print('search for piece {}; centroids value is {}'.format(
85-
# piece[i], centroids[centroids.argsort()[i]].numpy()))
91+
print('search for piece {}; centroids value is {}'.format(
92+
piece[i], float(centroids[centroids.argsort()[i: i + 1]].cast('float32'))))
8693
alpha = self.search_alpha_min
8794
alpha_max = self.search_scale_max if self.search_scale_max is not None else self.search_alpha_max
8895
calibration_loss = float('inf')
@@ -104,12 +111,16 @@ def search(self, layer_name, sampled_input, act_abs_max, weight):
104111
alpha = round(alpha, 2)
105112

106113
if alpha < 1:
107-
s = (np.power(np_act_abs_max, alpha) / np.power(
108-
np_rw_abs_max, 1. - alpha)).clip(min=1e-5)
109-
s = paddle.to_tensor(s, dtype='float32')
114+
act_abs_max_tmp = act_abs_max.detach().clone()
115+
s = paddle.clip(paddle.pow(act_abs_max_tmp, alpha) / paddle.pow(
116+
rw_abs_max, 1 - alpha), min=1e-5)
117+
118+
if self.use_clip or (k_piece == 1 and k_idx == 1 and self.search_clip):
119+
s = paddle.clip(act_abs_max_tmp / paddle.max(act_abs_max / s), min=1)
120+
del act_abs_max_tmp
110121
smooth_scale = s * mask_for_search
111122
else:
112-
smooth_scale = alpha * mask_for_search
123+
smooth_scale = paddle.to_tensor(alpha, dtype=dtype) * mask_for_search
113124

114125
if smooth_scale_out is not None:
115126
mask_for_ones_new = paddle.where(
@@ -145,9 +156,10 @@ def search(self, layer_name, sampled_input, act_abs_max, weight):
145156
calibration_loss = cur_loss
146157
final_smooth_scale = smooth_scale
147158
final_alpha = alpha
159+
# print('Better alpha: {} loss: {}'.format(alpha, calibration_loss.cast('float32')))
148160

149-
# print("Layer {} Piece {}, loss: {}, alpha : {}".format(
150-
# layer_name, piece[i], float(calibration_loss), final_alpha))
161+
print("Layer {} Piece {}, loss: {}, alpha : {}".format(
162+
layer_name, piece[i], float(calibration_loss.cast('float32')), final_alpha))
151163
if smooth_scale_out is None:
152164
smooth_scale_out = final_smooth_scale
153165
else:
@@ -160,4 +172,5 @@ def search(self, layer_name, sampled_input, act_abs_max, weight):
160172
print('Find Better K-Piece {}'.format(k_piece))
161173
if not self.search_piece:
162174
break
175+
163176
return best_scale

0 commit comments

Comments
 (0)