|
| 1 | +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License" |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import paddle |
| 16 | +import paddle.nn as nn |
| 17 | +import numpy as np |
| 18 | +from .utils import fake_quant |
| 19 | +from .metrics import mse_loss |
| 20 | +from paddle.distributed.fleet.meta_parallel import ( |
| 21 | + ColumnParallelLinear, |
| 22 | + RowParallelLinear, |
| 23 | +) |
| 24 | +__all__ = ['AutoClip'] |
| 25 | + |
| 26 | +class AutoClip(nn.Layer): |
| 27 | + """ |
| 28 | + AutoClip from AWQ[https://arxiv.org/abs/2306.00978] |
| 29 | + """ |
| 30 | + def __init__( |
| 31 | + self, |
| 32 | + model, |
| 33 | + weight_bits=8, |
| 34 | + weight_quant_method='groupwise', |
| 35 | + loss_function=mse_loss, |
| 36 | + sample_function=None, |
| 37 | + n_grid=20, |
| 38 | + max_shrink=0.5, |
| 39 | + n_sample_token=128, |
| 40 | + group_size=-1, |
| 41 | + ): |
| 42 | + super(AutoClip, self).__init__() |
| 43 | + self.model = model |
| 44 | + self.weight_bits = weight_bits |
| 45 | + self.weight_method = weight_quant_method |
| 46 | + self.loss_function = loss_function |
| 47 | + self.n_grid = n_grid |
| 48 | + self.max_shrink = max_shrink |
| 49 | + self.n_sample_token = n_sample_token |
| 50 | + self.bnt = (1 << (self.weight_bits - 1)) - 1 |
| 51 | + self.sampled_inputs = {} |
| 52 | + self.sample_function = sample_function |
| 53 | + self.group_size = group_size |
| 54 | + |
| 55 | + self._apply_hook() |
| 56 | + |
| 57 | + def _apply_hook(self): |
| 58 | + self._forward_hook_list = [] |
| 59 | + for _, sub_layer in self.model.named_sublayers(): |
| 60 | + if type(sub_layer) in [ColumnParallelLinear, RowParallelLinear, paddle.nn.Linear]: |
| 61 | + forward_pre_hook_handle = sub_layer.register_forward_pre_hook( |
| 62 | + self._forward_pre_hook) |
| 63 | + self._forward_hook_list.append(forward_pre_hook_handle) |
| 64 | + |
| 65 | + def _forward_pre_hook(self, layer, input): |
| 66 | + self._sample_scale(input, layer.full_name()) |
| 67 | + return input |
| 68 | + |
| 69 | + def _sample_scale(self, input, name): |
| 70 | + input = input[0] if type(input) == tuple else input |
| 71 | + input.stop_gradient = True |
| 72 | + if name not in self.sampled_inputs: |
| 73 | + self.sampled_inputs[name] = input |
| 74 | + else: |
| 75 | + if self.sample_function is not None: |
| 76 | + self.sampled_inputs[name] = self.sample_function.sample( |
| 77 | + input, self.sampled_inputs[name], name) |
| 78 | + else: |
| 79 | + self.sampled_inputs[name] = input |
| 80 | + |
| 81 | + |
| 82 | + def auto_clip(self, group_size=128, oc_batch_size=1024): |
| 83 | + """ |
| 84 | + search clip scale for each layer and update the layer's weight |
| 85 | + """ |
| 86 | + for sub_name, sub_layer in self.model.named_sublayers(): |
| 87 | + name = sub_layer.full_name() |
| 88 | + if name not in self.sampled_inputs: |
| 89 | + continue |
| 90 | + print('AutoClipping', sub_name, name) |
| 91 | + weight = sub_layer.weight.cast('float16') |
| 92 | + weight_t = paddle.transpose(weight, perm=[1, 0]) |
| 93 | + x = self.sampled_inputs[name].cast('float16') |
| 94 | + x = x.reshape([-1, x.shape[-1]]) |
| 95 | + x = x.reshape([1, x.shape[0], -1, group_size]) |
| 96 | + x = x[:, 0::x.shape[1] // self.n_sample_token] |
| 97 | + weight_t = weight_t.reshape([weight_t.shape[0], 1, -1, group_size]) |
| 98 | + # fast test |
| 99 | + # oc_batch_size = weight_t.shape[0] // 4 |
| 100 | + oc_batch_size = oc_batch_size if weight_t.shape[0] % oc_batch_size == 0 else 128 # prevent OOM |
| 101 | + assert weight_t.shape[0] % oc_batch_size == 0 |
| 102 | + |
| 103 | + w_all = weight_t |
| 104 | + best_max_val_all = [] |
| 105 | + |
| 106 | + for i_b in range(weight_t.shape[0] // oc_batch_size): |
| 107 | + w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size] |
| 108 | + |
| 109 | + org_max_val = w.abs().max(axis=-1, keepdim=True) # co, 1, n_group, 1 |
| 110 | + best_max_val = org_max_val.clone() |
| 111 | + min_errs = paddle.ones_like(org_max_val, dtype='float16') * 1e9 |
| 112 | + org_out = (x * w).sum(axis=-1) # co, n_token, n_group |
| 113 | + for i_s in range(int(self.max_shrink * self.n_grid)): |
| 114 | + max_val = org_max_val * (1 - i_s / self.n_grid) |
| 115 | + max_val_tmp = max_val |
| 116 | + cur_w = paddle.where(w > max_val_tmp, max_val_tmp, w) |
| 117 | + cur_w = paddle.where(cur_w < - max_val_tmp, - max_val_tmp, cur_w) |
| 118 | + quant_dequant_weight = fake_quant(cur_w, method='abs_max', weight_bits=4) |
| 119 | + cur_out = (x * quant_dequant_weight).sum(axis=-1) |
| 120 | + # co, 1, n_group, 1 |
| 121 | + tmp = (cur_out - org_out).detach().clone() |
| 122 | + err = paddle.pow(tmp, 2).mean(axis=1).reshape(min_errs.shape) |
| 123 | + print('block {} search s {} err {}'.format(i_b, i_s, err.mean().item())) |
| 124 | + del cur_w, cur_out, quant_dequant_weight, tmp |
| 125 | + paddle.device.cuda.empty_cache() |
| 126 | + |
| 127 | + cur_best_idx = paddle.where(err < min_errs) |
| 128 | + if cur_best_idx[0].shape[0] != 0: |
| 129 | + min_errs[cur_best_idx] = err[cur_best_idx] |
| 130 | + best_max_val[cur_best_idx] = max_val[cur_best_idx] |
| 131 | + best_max_val_all.append(best_max_val) |
| 132 | + |
| 133 | + del org_out, org_max_val, min_errs, best_max_val, err, cur_best_idx, max_val_tmp, max_val, w |
| 134 | + paddle.device.cuda.empty_cache() |
| 135 | + |
| 136 | + best_max_val = paddle.concat(best_max_val_all, axis=0) |
| 137 | + best_max_val = paddle.squeeze(best_max_val, axis=1) |
| 138 | + for param in sub_layer.parameters(include_sublayers=False): |
| 139 | + if 'w_0' in param.name: |
| 140 | + param_tmp = param.transpose(perm=[1, 0]).cast('float16') |
| 141 | + tmp_shape = param_tmp.shape |
| 142 | + param_tmp = param_tmp.reshape([best_max_val.shape[0], best_max_val.shape[1], -1]) |
| 143 | + best_max_val = paddle.tile(best_max_val, repeat_times=(1, 1, param_tmp.shape[-1])) |
| 144 | + param_tmp = paddle.where(param_tmp > best_max_val, best_max_val, param_tmp) |
| 145 | + param_tmp = paddle.where(param_tmp < - best_max_val, - best_max_val, param_tmp) |
| 146 | + param_tmp = param_tmp.reshape(tmp_shape).cast(param.dtype) |
| 147 | + param_tmp = param_tmp.transpose(perm=[1, 0]) |
| 148 | + paddle.assign(param_tmp, output=param) |
| 149 | + del param_tmp |
| 150 | + paddle.device.cuda.empty_cache() |
| 151 | + break |
| 152 | + |
| 153 | + del best_max_val, weight_t, x, weight, self.sampled_inputs[name], w_all, best_max_val_all |
| 154 | + paddle.device.cuda.empty_cache() |
| 155 | + |
0 commit comments