@@ -54,6 +54,7 @@ def __init__(self,
54
54
if type(cur_layer) == LayerWiseQuantError:
55
55
print(cur_name, cur_layer.losses.mean())
56
56
'''
57
+ super (LayerWiseQuantError , self ).__init__ ()
57
58
self .layer = layer
58
59
self .weight = layer .weight
59
60
self .weight_bits = weight_bits
@@ -62,14 +63,13 @@ def __init__(self,
62
63
self .act_method = act_quant_method
63
64
self .loss_function = loss_function
64
65
self .losses = []
66
+ self .loss = None
65
67
66
68
def forward (self , input ):
67
69
act = input [0 ] if type (input ) == tuple else input
68
70
origin_out = paddle .matmul (act , self .weight )
69
71
bnt = (1 << (self .weight_bits - 1 )) - 1
70
- quant_scale = compute_scales (
71
- self .weight .cast ('float32' ),
72
- method = self .weight_method ).cast (self .weight .dtype )
72
+ quant_scale = compute_scales (self .weight , method = self .weight_method )
73
73
quant_weight = paddle .clip (
74
74
paddle .round (self .weight / quant_scale * bnt ), - bnt - 1 , bnt )
75
75
quant_dequant_weight = quant_weight / bnt * quant_scale
@@ -80,6 +80,7 @@ def forward(self, input):
80
80
paddle .round (act / quant_scale * bnt ), - bnt - 1 , bnt )
81
81
quant_dequant_act = quant_act / bnt * quant_scale
82
82
quant_out = paddle .matmul (quant_dequant_act , quant_dequant_weight )
83
- loss = self .loss_function (origin_out , quant_out )
83
+ loss = self .loss_function (origin_out , quant_out ). cast ( 'float32' )
84
84
self .losses .append (loss )
85
+ self .loss = paddle .to_tensor (self .losses , dtype = 'float32' ).mean ()
85
86
return self .layer (input )
0 commit comments