21
21
from .metrics import mse_loss
22
22
from paddle .distributed .fleet .meta_parallel import (
23
23
ColumnParallelLinear ,
24
- RowParallelLinear ,
25
- )
24
+ RowParallelLinear , )
26
25
__all__ = ['AutoClip' ]
27
26
27
+
28
28
class AutoClip (nn .Layer ):
29
29
"""
30
30
AutoClip from AWQ[https://arxiv.org/abs/2306.00978]
31
31
"""
32
+
32
33
def __init__ (
33
34
self ,
34
35
model ,
@@ -39,8 +40,7 @@ def __init__(
39
40
n_grid = 20 ,
40
41
max_shrink = 0.5 ,
41
42
n_sample_token = 512 ,
42
- group_size = 128 ,
43
- ):
43
+ group_size = 128 , ):
44
44
super (AutoClip , self ).__init__ ()
45
45
self .model = model
46
46
self .weight_bits = weight_bits
@@ -59,15 +59,17 @@ def __init__(
59
59
def _apply_hook (self ):
60
60
self ._forward_hook_list = []
61
61
for _ , sub_layer in self .model .named_sublayers ():
62
- if type (sub_layer ) in [ColumnParallelLinear , RowParallelLinear , paddle .nn .Linear ]:
62
+ if type (sub_layer ) in [
63
+ ColumnParallelLinear , RowParallelLinear , paddle .nn .Linear
64
+ ]:
63
65
forward_pre_hook_handle = sub_layer .register_forward_pre_hook (
64
66
self ._forward_pre_hook )
65
67
self ._forward_hook_list .append (forward_pre_hook_handle )
66
68
67
69
def _forward_pre_hook (self , layer , input ):
68
70
self ._sample_scale (input , layer .full_name ())
69
71
return input
70
-
72
+
71
73
def _sample_scale (self , input , name ):
72
74
input = input [0 ] if type (input ) == tuple else input
73
75
input .stop_gradient = True
@@ -80,7 +82,6 @@ def _sample_scale(self, input, name):
80
82
else :
81
83
self .sampled_inputs [name ] = input
82
84
83
-
84
85
def auto_clip (self , group_size = 128 , oc_batch_size = 256 ):
85
86
"""
86
87
search clip scale for each layer and update the layer's weight
@@ -89,7 +90,7 @@ def auto_clip(self, group_size=128, oc_batch_size=256):
89
90
name = sub_layer .full_name ()
90
91
if name not in self .sampled_inputs or 'out_linear' in sub_name :
91
92
continue
92
-
93
+
93
94
weight = sub_layer .weight .cast ('float16' )
94
95
weight_t = paddle .transpose (weight , perm = [1 , 0 ])
95
96
x = self .sampled_inputs [name ].cast ('float16' )
@@ -98,33 +99,41 @@ def auto_clip(self, group_size=128, oc_batch_size=256):
98
99
x = x .reshape ([1 , x .shape [0 ], - 1 , group_size ])
99
100
x = x [:, 0 ::x .shape [1 ] // self .n_sample_token ]
100
101
weight_t = weight_t .reshape ([weight_t .shape [0 ], 1 , - 1 , group_size ])
101
- oc_batch_size = oc_batch_size if weight_t .shape [0 ] % oc_batch_size == 0 else 128 # prevent OOM
102
+ oc_batch_size = oc_batch_size if weight_t .shape [
103
+ 0 ] % oc_batch_size == 0 else 128 # prevent OOM
102
104
assert weight_t .shape [0 ] % oc_batch_size == 0
103
105
104
106
w_all = weight_t
105
107
best_max_val_all = []
106
108
107
109
for i_b in range (weight_t .shape [0 ] // oc_batch_size ):
108
- w = w_all [i_b * oc_batch_size : (i_b + 1 ) * oc_batch_size ]
110
+ w = w_all [i_b * oc_batch_size :(i_b + 1 ) * oc_batch_size ]
109
111
110
- org_max_val = w .abs ().max (axis = - 1 , keepdim = True ) # co, 1, n_group, 1
112
+ org_max_val = w .abs ().max (
113
+ axis = - 1 , keepdim = True ) # co, 1, n_group, 1
111
114
best_max_val = org_max_val .clone ()
112
115
min_errs = paddle .ones_like (org_max_val , dtype = 'float16' ) * 1e9
113
116
org_out = (x * w ).sum (axis = - 1 ) # co, n_token, n_group
114
117
for i_s in range (int (self .max_shrink * self .n_grid )):
115
118
max_val = org_max_val * (1 - i_s / self .n_grid )
116
119
max_val_tmp = max_val
117
120
cur_w = paddle .where (w > max_val_tmp , max_val_tmp , w )
118
- cur_w = paddle .where (cur_w < - max_val_tmp , - max_val_tmp , cur_w )
121
+ cur_w = paddle .where (cur_w < - max_val_tmp , - max_val_tmp ,
122
+ cur_w )
119
123
org_w_shape = cur_w .shape
120
- cur_w_r = cur_w .reshape ([- 1 , self .group_size ]).transpose ([1 , 0 ])
121
- quant_dequant_weight = fake_quant (cur_w_r , method = 'abs_max_channel_wise' , weight_bits = 4 )
122
- quant_dequant_weight = quant_dequant_weight .transpose ([1 , 0 ]).reshape (org_w_shape )
124
+ cur_w_r = cur_w .reshape ([- 1 ,
125
+ self .group_size ]).transpose ([1 , 0 ])
126
+ quant_dequant_weight = fake_quant (
127
+ cur_w_r , method = 'abs_max_channel_wise' , weight_bits = 4 )
128
+ quant_dequant_weight = quant_dequant_weight .transpose (
129
+ [1 , 0 ]).reshape (org_w_shape )
123
130
cur_out = (x * quant_dequant_weight ).sum (axis = - 1 )
124
131
# co, 1, n_group, 1
125
132
tmp = (cur_out - org_out ).detach ().clone ()
126
- err = paddle .pow (tmp , 2 ).mean (axis = 1 ).reshape (min_errs .shape )
127
- print ('block {} search s {} err {}' .format (i_b , i_s , err .mean ().item ()))
133
+ err = paddle .pow (tmp ,
134
+ 2 ).mean (axis = 1 ).reshape (min_errs .shape )
135
+ print ('block {} search s {} err {}' .format (
136
+ i_b , i_s , err .mean ().item ()))
128
137
del cur_w , cur_out , quant_dequant_weight , tmp , cur_w_r
129
138
paddle .device .cuda .empty_cache ()
130
139
@@ -143,16 +152,21 @@ def auto_clip(self, group_size=128, oc_batch_size=256):
143
152
if 'w_0' in param .name :
144
153
param_tmp = param .transpose (perm = [1 , 0 ]).cast ('float16' )
145
154
tmp_shape = param_tmp .shape
146
- param_tmp = param_tmp .reshape ([best_max_val .shape [0 ], best_max_val .shape [1 ], - 1 ])
147
- best_max_val = paddle .tile (best_max_val , repeat_times = (1 , 1 , param_tmp .shape [- 1 ]))
148
- param_tmp = paddle .where (param_tmp > best_max_val , best_max_val , param_tmp )
149
- param_tmp = paddle .where (param_tmp < - best_max_val , - best_max_val , param_tmp )
155
+ param_tmp = param_tmp .reshape (
156
+ [best_max_val .shape [0 ], best_max_val .shape [1 ], - 1 ])
157
+ best_max_val = paddle .tile (
158
+ best_max_val , repeat_times = (1 , 1 , param_tmp .shape [- 1 ]))
159
+ param_tmp = paddle .where (param_tmp > best_max_val ,
160
+ best_max_val , param_tmp )
161
+ param_tmp = paddle .where (param_tmp < - best_max_val ,
162
+ - best_max_val , param_tmp )
150
163
param_tmp = param_tmp .reshape (tmp_shape ).cast (param .dtype )
151
164
param_tmp = param_tmp .transpose (perm = [1 , 0 ])
152
165
paddle .assign (param_tmp , output = param )
153
166
del param_tmp
154
167
paddle .device .cuda .empty_cache ()
155
168
break
156
169
157
- del best_max_val , weight_t , x , weight , self .sampled_inputs [name ], w_all , best_max_val_all
170
+ del best_max_val , weight_t , x , weight , self .sampled_inputs [
171
+ name ], w_all , best_max_val_all
158
172
paddle .device .cuda .empty_cache ()
0 commit comments