Skip to content

Commit 521157e

Browse files
authored
[Cherry-Pick]Cp fit paddle26 (#1823)
1 parent dcf79e9 commit 521157e

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

paddleslim/quant/advanced/gptq.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,9 @@ def fasterquant(self,
106106
H = self.hessian
107107
del self.hessian
108108
dead = paddle.where(paddle.diag(H) == 0)
109-
H[dead, dead] = 1
110-
W[:, dead] = 0
109+
if dead[0].shape[0] != 0:
110+
H[dead, dead] = 1
111+
W[:, dead] = 0
111112
del dead
112113
if actorder:
113114
perm = paddle.argsort(paddle.diag(H), descending=True)
@@ -122,9 +123,15 @@ def fasterquant(self,
122123
damp = percdamp * paddle.mean(paddle.diag(H))
123124
diag = paddle.arange(self.columns)
124125
H[diag, diag] += damp
125-
126-
H = paddle.inverse(H)
127-
H = paddle.linalg.cholesky(H, upper=True)
126+
try:
127+
H = paddle.inverse(H)
128+
H = paddle.linalg.cholesky(H, upper=True)
129+
except:
130+
print('We skip GPTQ this layer now.')
131+
print(
132+
'If you want GPTQ this layer, please try setting damp_percent larger or increasing the number of samples.'
133+
)
134+
return
128135
Hinv = H
129136

130137
for i1 in range(0, self.columns, blocksize):
@@ -182,4 +189,4 @@ def fasterquant(self,
182189

183190
self.quantized = True
184191
del H, Q, Hinv, W, Losses
185-
paddle.device.cuda.empty_cache()
192+
paddle.device.cuda.empty_cache()

paddleslim/quant/advanced/piecewise_search.py

+3
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ def search(self, layer_name, sampled_input, act_abs_max, weight):
9797
mask_for_search = paddle.where(labels == centroids.argsort()[i],
9898
1., 0.)
9999
mask_for_ones = paddle.where(mask_for_search == 0., 1., 0.)
100+
mask_for_search = mask_for_search.cast(dtype)
101+
mask_for_ones = mask_for_ones.cast(dtype)
100102

101103
while alpha <= alpha_max:
102104
if alpha < 1:
@@ -125,6 +127,7 @@ def search(self, layer_name, sampled_input, act_abs_max, weight):
125127
if smooth_scale_out is not None:
126128
mask_for_ones_new = paddle.where(
127129
smooth_scale_out == 0., 1., 0.)
130+
mask_for_ones_new = mask_for_ones_new.cast(dtype)
128131
mask_for_ones *= mask_for_ones_new
129132
smooth_scale_ = smooth_scale_out + smooth_scale
130133
smooth_scale_tmp = smooth_scale_ + mask_for_ones

0 commit comments

Comments
 (0)