Skip to content

Commit bb02b10

Browse files
authored
Fix Bugs (#1792)
1 parent 73e1529 commit bb02b10

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

paddleslim/quant/advanced/layerwise_quant_error.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(self,
5454
if type(cur_layer) == LayerWiseQuantError:
5555
print(cur_name, cur_layer.losses.mean())
5656
'''
57+
super(LayerWiseQuantError, self).__init__()
5758
self.layer = layer
5859
self.weight = layer.weight
5960
self.weight_bits = weight_bits
@@ -62,14 +63,13 @@ def __init__(self,
6263
self.act_method = act_quant_method
6364
self.loss_function = loss_function
6465
self.losses = []
66+
self.loss = None
6567

6668
def forward(self, input):
6769
act = input[0] if type(input) == tuple else input
6870
origin_out = paddle.matmul(act, self.weight)
6971
bnt = (1 << (self.weight_bits - 1)) - 1
70-
quant_scale = compute_scales(
71-
self.weight.cast('float32'),
72-
method=self.weight_method).cast(self.weight.dtype)
72+
quant_scale = compute_scales(self.weight, method=self.weight_method)
7373
quant_weight = paddle.clip(
7474
paddle.round(self.weight / quant_scale * bnt), -bnt - 1, bnt)
7575
quant_dequant_weight = quant_weight / bnt * quant_scale
@@ -80,6 +80,7 @@ def forward(self, input):
8080
paddle.round(act / quant_scale * bnt), -bnt - 1, bnt)
8181
quant_dequant_act = quant_act / bnt * quant_scale
8282
quant_out = paddle.matmul(quant_dequant_act, quant_dequant_weight)
83-
loss = self.loss_function(origin_out, quant_out)
83+
loss = self.loss_function(origin_out, quant_out).cast('float32')
8484
self.losses.append(loss)
85+
self.loss = paddle.to_tensor(self.losses, dtype='float32').mean()
8586
return self.layer(input)

paddleslim/quant/advanced/utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,10 @@ def compute_scales(x, method='abs_max'):
4848
elif method == 'abs_max_channel_wise':
4949
reduce_axis = tuple([i for i in range(len(x.shape)) if i != 1])
5050
quant_scale = paddle.max(paddle.abs(x), axis=reduce_axis)
51-
quant_scale = paddle.where(quant_scale == np.float32(0.0),
52-
np.float32(1e-8), quant_scale)
51+
quant_scale = paddle.where(quant_scale == paddle.to_tensor(
52+
0, dtype=x.dtype),
53+
paddle.to_tensor(1e-8, dtype=x.dtype),
54+
quant_scale)
5355
return quant_scale
5456

5557

paddleslim/quant/advanced/utils_layers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, layer):
2626
super(ShiftSmoothHelpLayer, self).__init__()
2727
self.weight = layer.weight
2828
shift_shape = self.weight.shape[0]
29-
if hasattr(layer, "bias") or layer.bias is None:
29+
if not hasattr(layer, "bias") or layer.bias is None:
3030
self.bias = paddle.create_parameter(
3131
shape=[self.weight.shape[1]],
3232
dtype=self.weight.dtype,

0 commit comments

Comments
 (0)