Skip to content

Commit ca74d46

Browse files
committed
Merge branch 'release/2.6' of https://github.com/lizexu123/PaddleSlim into v2.6
2 parents 40d48cc + 521157e commit ca74d46

File tree

9 files changed

+444
-48
lines changed

9 files changed

+444
-48
lines changed

paddleslim/quant/advanced/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from . import sample
2020
from . import layerwise_quant_error
2121
from . import utils_layers
22+
from . import awq_search
23+
from . import auto_clip
2224

2325
from .gptq import *
2426
from .smooth import *
@@ -27,6 +29,8 @@
2729
from .sample import *
2830
from .layerwise_quant_error import *
2931
from .utils_layers import *
32+
from .awq_search import *
33+
from .auto_clip import *
3034

3135
__all__ = []
3236
__all__ += gptq.__all__
@@ -35,4 +39,6 @@
3539
__all__ += piecewise_search.__all__
3640
__all__ += sample.__all__
3741
__all__ += layerwise_quant_error.__all__
38-
__all__ += utils_layers.__all__
42+
__all__ += utils_layers.__all__
43+
__all__ += awq_search.__all__
44+
__all__ += auto_clip.__all__
+155
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import paddle.nn as nn
17+
import numpy as np
18+
from .utils import fake_quant
19+
from .metrics import mse_loss
20+
from paddle.distributed.fleet.meta_parallel import (
21+
ColumnParallelLinear,
22+
RowParallelLinear,
23+
)
24+
__all__ = ['AutoClip']
25+
26+
class AutoClip(nn.Layer):
27+
"""
28+
AutoClip from AWQ[https://arxiv.org/abs/2306.00978]
29+
"""
30+
def __init__(
31+
self,
32+
model,
33+
weight_bits=8,
34+
weight_quant_method='groupwise',
35+
loss_function=mse_loss,
36+
sample_function=None,
37+
n_grid=20,
38+
max_shrink=0.5,
39+
n_sample_token=128,
40+
group_size=-1,
41+
):
42+
super(AutoClip, self).__init__()
43+
self.model = model
44+
self.weight_bits = weight_bits
45+
self.weight_method = weight_quant_method
46+
self.loss_function = loss_function
47+
self.n_grid = n_grid
48+
self.max_shrink = max_shrink
49+
self.n_sample_token = n_sample_token
50+
self.bnt = (1 << (self.weight_bits - 1)) - 1
51+
self.sampled_inputs = {}
52+
self.sample_function = sample_function
53+
self.group_size = group_size
54+
55+
self._apply_hook()
56+
57+
def _apply_hook(self):
58+
self._forward_hook_list = []
59+
for _, sub_layer in self.model.named_sublayers():
60+
if type(sub_layer) in [ColumnParallelLinear, RowParallelLinear, paddle.nn.Linear]:
61+
forward_pre_hook_handle = sub_layer.register_forward_pre_hook(
62+
self._forward_pre_hook)
63+
self._forward_hook_list.append(forward_pre_hook_handle)
64+
65+
def _forward_pre_hook(self, layer, input):
66+
self._sample_scale(input, layer.full_name())
67+
return input
68+
69+
def _sample_scale(self, input, name):
70+
input = input[0] if type(input) == tuple else input
71+
input.stop_gradient = True
72+
if name not in self.sampled_inputs:
73+
self.sampled_inputs[name] = input
74+
else:
75+
if self.sample_function is not None:
76+
self.sampled_inputs[name] = self.sample_function.sample(
77+
input, self.sampled_inputs[name], name)
78+
else:
79+
self.sampled_inputs[name] = input
80+
81+
82+
def auto_clip(self, group_size=128, oc_batch_size=1024):
83+
"""
84+
search clip scale for each layer and update the layer's weight
85+
"""
86+
for sub_name, sub_layer in self.model.named_sublayers():
87+
name = sub_layer.full_name()
88+
if name not in self.sampled_inputs:
89+
continue
90+
print('AutoClipping', sub_name, name)
91+
weight = sub_layer.weight.cast('float16')
92+
weight_t = paddle.transpose(weight, perm=[1, 0])
93+
x = self.sampled_inputs[name].cast('float16')
94+
x = x.reshape([-1, x.shape[-1]])
95+
x = x.reshape([1, x.shape[0], -1, group_size])
96+
x = x[:, 0::x.shape[1] // self.n_sample_token]
97+
weight_t = weight_t.reshape([weight_t.shape[0], 1, -1, group_size])
98+
# fast test
99+
# oc_batch_size = weight_t.shape[0] // 4
100+
oc_batch_size = oc_batch_size if weight_t.shape[0] % oc_batch_size == 0 else 128 # prevent OOM
101+
assert weight_t.shape[0] % oc_batch_size == 0
102+
103+
w_all = weight_t
104+
best_max_val_all = []
105+
106+
for i_b in range(weight_t.shape[0] // oc_batch_size):
107+
w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size]
108+
109+
org_max_val = w.abs().max(axis=-1, keepdim=True) # co, 1, n_group, 1
110+
best_max_val = org_max_val.clone()
111+
min_errs = paddle.ones_like(org_max_val, dtype='float16') * 1e9
112+
org_out = (x * w).sum(axis=-1) # co, n_token, n_group
113+
for i_s in range(int(self.max_shrink * self.n_grid)):
114+
max_val = org_max_val * (1 - i_s / self.n_grid)
115+
max_val_tmp = max_val
116+
cur_w = paddle.where(w > max_val_tmp, max_val_tmp, w)
117+
cur_w = paddle.where(cur_w < - max_val_tmp, - max_val_tmp, cur_w)
118+
quant_dequant_weight = fake_quant(cur_w, method='abs_max', weight_bits=4)
119+
cur_out = (x * quant_dequant_weight).sum(axis=-1)
120+
# co, 1, n_group, 1
121+
tmp = (cur_out - org_out).detach().clone()
122+
err = paddle.pow(tmp, 2).mean(axis=1).reshape(min_errs.shape)
123+
print('block {} search s {} err {}'.format(i_b, i_s, err.mean().item()))
124+
del cur_w, cur_out, quant_dequant_weight, tmp
125+
paddle.device.cuda.empty_cache()
126+
127+
cur_best_idx = paddle.where(err < min_errs)
128+
if cur_best_idx[0].shape[0] != 0:
129+
min_errs[cur_best_idx] = err[cur_best_idx]
130+
best_max_val[cur_best_idx] = max_val[cur_best_idx]
131+
best_max_val_all.append(best_max_val)
132+
133+
del org_out, org_max_val, min_errs, best_max_val, err, cur_best_idx, max_val_tmp, max_val, w
134+
paddle.device.cuda.empty_cache()
135+
136+
best_max_val = paddle.concat(best_max_val_all, axis=0)
137+
best_max_val = paddle.squeeze(best_max_val, axis=1)
138+
for param in sub_layer.parameters(include_sublayers=False):
139+
if 'w_0' in param.name:
140+
param_tmp = param.transpose(perm=[1, 0]).cast('float16')
141+
tmp_shape = param_tmp.shape
142+
param_tmp = param_tmp.reshape([best_max_val.shape[0], best_max_val.shape[1], -1])
143+
best_max_val = paddle.tile(best_max_val, repeat_times=(1, 1, param_tmp.shape[-1]))
144+
param_tmp = paddle.where(param_tmp > best_max_val, best_max_val, param_tmp)
145+
param_tmp = paddle.where(param_tmp < - best_max_val, - best_max_val, param_tmp)
146+
param_tmp = param_tmp.reshape(tmp_shape).cast(param.dtype)
147+
param_tmp = param_tmp.transpose(perm=[1, 0])
148+
paddle.assign(param_tmp, output=param)
149+
del param_tmp
150+
paddle.device.cuda.empty_cache()
151+
break
152+
153+
del best_max_val, weight_t, x, weight, self.sampled_inputs[name], w_all, best_max_val_all
154+
paddle.device.cuda.empty_cache()
155+
+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import paddle
15+
import numpy as np
16+
from .utils import compute_scales
17+
from .metrics import mse_loss
18+
__all__ = ['AWQSearch']
19+
20+
class AWQSearch():
21+
def __init__(self,
22+
n_grid=20,
23+
bits_length=4,
24+
weight_quant_method='groupwise',
25+
group_size=128,
26+
loss_function=mse_loss):
27+
'''
28+
The implementation of AutoScale from AWQ(https://arxiv.org/pdf/2306.00978.pdf).
29+
'''
30+
self.n_grid = n_grid
31+
self.bits_length = bits_length
32+
self.weight_quant_method = weight_quant_method
33+
self.bnt = (1 << (bits_length - 1)) - 1
34+
self.group_size = group_size
35+
self.loss_function = loss_function
36+
37+
def search(self, layer_name, sampled_input, act_abs_max, weight):
38+
act = sampled_input
39+
act.stop_gradient = True
40+
print('[awq search] search input of %s' % layer_name)
41+
dtype = weight.dtype
42+
origin_out = paddle.matmul(act, weight)
43+
best_error = float('inf')
44+
best_ratio = -1
45+
best_scales = None
46+
47+
for ratio in range(self.n_grid):
48+
ratio = ratio * 1 / self.n_grid
49+
act_abs_max_tmp = act_abs_max.detach().clone().cast('float32')
50+
scales = paddle.clip(paddle.pow(act_abs_max_tmp, ratio), min=1e-4)
51+
scales = scales / (scales.max() * scales.min()).sqrt()
52+
scales = scales.cast(dtype)
53+
new_weight = weight * scales.reshape([-1, 1])
54+
new_act = act / scales
55+
quant_scale = compute_scales(
56+
new_weight, method=self.weight_quant_method, group_size=self.group_size)
57+
if self.weight_quant_method == 'groupwise':
58+
quant_scale = paddle.repeat_interleave(quant_scale.cast('float32'), self.group_size, 0).cast(dtype)
59+
quant_weight = paddle.clip(
60+
paddle.round(new_weight / quant_scale * self.bnt),
61+
-self.bnt - 1, self.bnt)
62+
quant_dequant_weight = quant_weight / self.bnt * quant_scale
63+
new_out = paddle.matmul(new_act,
64+
quant_dequant_weight)
65+
loss = self.loss_function(origin_out, new_out).numpy()
66+
is_best = loss < best_error
67+
if is_best:
68+
print('find better ratio: {}, loss: {}'.format(ratio, loss))
69+
best_error = loss
70+
best_ratio = ratio
71+
best_scales = scales
72+
73+
if best_scales is None:
74+
best_scales = paddle.ones(scales.shape, dtype=dtype)
75+
print('Cannot find better ratio.')
76+
else:
77+
print('Best ratio :{}, minimal loss : {}.'.format(best_ratio, best_error))
78+
return best_scales

paddleslim/quant/advanced/gptq.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,9 @@ def fasterquant(self,
106106
H = self.hessian
107107
del self.hessian
108108
dead = paddle.where(paddle.diag(H) == 0)
109-
H[dead, dead] = 1
110-
W[:, dead] = 0
109+
if dead[0].shape[0] != 0:
110+
H[dead, dead] = 1
111+
W[:, dead] = 0
111112
del dead
112113
if actorder:
113114
perm = paddle.argsort(paddle.diag(H), descending=True)
@@ -122,9 +123,15 @@ def fasterquant(self,
122123
damp = percdamp * paddle.mean(paddle.diag(H))
123124
diag = paddle.arange(self.columns)
124125
H[diag, diag] += damp
125-
126-
H = paddle.inverse(H)
127-
H = paddle.linalg.cholesky(H, upper=True)
126+
try:
127+
H = paddle.inverse(H)
128+
H = paddle.linalg.cholesky(H, upper=True)
129+
except:
130+
print('We skip GPTQ this layer now.')
131+
print(
132+
'If you want GPTQ this layer, please try setting damp_percent larger or increasing the number of samples.'
133+
)
134+
return
128135
Hinv = H
129136

130137
for i1 in range(0, self.columns, blocksize):
@@ -182,4 +189,4 @@ def fasterquant(self,
182189

183190
self.quantized = True
184191
del H, Q, Hinv, W, Losses
185-
paddle.device.cuda.empty_cache()
192+
paddle.device.cuda.empty_cache()

0 commit comments

Comments
 (0)