diff --git a/example/auto_compression/nlp/README.md b/example/auto_compression/nlp/README.md index c98f1987e..11ffe7e08 100644 --- a/example/auto_compression/nlp/README.md +++ b/example/auto_compression/nlp/README.md @@ -56,16 +56,16 @@ #### 3.1 准备环境 - python >= 3.6 -- PaddlePaddle >= 2.4 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装) -- PaddleSlim >= 2.4 -- PaddleNLP >= 2.3 +- PaddlePaddle ==2.5 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装) +- PaddleSlim ==2.5 +- PaddleNLP ==2.6 安装paddlepaddle: ```shell # CPU -pip install paddlepaddle==2.4.1 +pip install paddlepaddle==2.5.0 # GPU 以Ubuntu、CUDA 11.2为例 -python -m pip install paddlepaddle-gpu==2.4.1.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html +python -m pip install paddlepaddle-gpu==2.5.0.post116 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html ``` 安装paddleslim: @@ -95,7 +95,6 @@ pip install paddlenlp |:------:|:------:|:------:|:------:|:------:|:-----------:|:------:|:------:| | PP-MiniLM | [afqmc](https://bj.bcebos.com/v1/paddle-slim-models/act/afqmc.tar) | [tnews](https://bj.bcebos.com/v1/paddle-slim-models/act/tnews.tar) | [iflytek](https://bj.bcebos.com/v1/paddle-slim-models/act/iflytek.tar) | [cmnli](https://bj.bcebos.com/v1/paddle-slim-models/act/cmnli.tar) | [ ocnli](https://bj.bcebos.com/v1/paddle-slim-models/act/ocnli.tar) | [cluewsc2020](https://bj.bcebos.com/v1/paddle-slim-models/act/cluewsc.tar) | [csl](https://bj.bcebos.com/v1/paddle-slim-models/act/csl.tar) | | ERNIE 3.0-Medium | [afqmc](https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/AFQMC.tar) | [tnews](https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/TNEWS.tar) | [iflytek](https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/IFLYTEK.tar) | [cmnli](https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/CMNLI.tar) | [ocnli](https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/OCNLI.tar) | [cluewsc2020](https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/CLUEWSC2020.tar) | [csl](https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/CSL.tar) | -| UIE-base | [报销工单](https://bj.bcebos.com/v1/paddle-slim-models/act/uie_base.tar) | 从上表获得模型超链接, 并用以下命令下载推理模型文件: @@ -119,11 +118,6 @@ export CUDA_VISIBLE_DEVICES=0 python run.py --config_path='./configs/pp-minilm/auto/afqmc.yaml' --save_dir='./save_afqmc_pruned/' ``` -自动压缩UIE系列模型需要使用 run_uie.py 脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行自动压缩。配置config文件中训练部分的参数,将任务名称、模型类型、数据集名称、压缩参数传入,配置完成后便可对模型进行蒸馏量化训练。 -```shell -export CUDA_VISIBLE_DEVICES=0 -python run_uie.py --config_path='./configs/uie/uie_base.yaml' --save_dir='./save_uie_qat/' -``` 如仅需验证模型精度,或验证压缩之后模型精度,在启动```run.py```脚本时,将配置文件中模型文件夹 ```model_dir``` 改为压缩之后保存的文件夹路径 ```./save_afqmc_pruned``` ,命令加上```--eval True```即可: ```shell @@ -212,12 +206,29 @@ QuantPost: ## 5. 预测部署 + 量化模型在GPU上可以使用TensorRT进行加速,在CPU上可以使用MKLDNN进行加速。 +以下字段用于配置预测参数: + +| 参数名 | 含义 | +|:------:|:------:| +| model_path | inference 模型文件所在目录,该目录下需要有文件 model.pdmodel 和 model.pdiparams 两个文件 | +| model_filename | 模型文件的名称,默认值为inference.pdmodel | +| params_filename | 参数文件的名称,默认值为inference.pdiparams | +| task_name | 要执行的任务名称,默认值为afqmc | +| dataset | 模型使用的数据集,默认值为clue | +| device | 用于推理的设备,默认为gpu,可选cpu或gpu | +| batch_size | 推理时的batch size,默认为32 | +| max_seq_len | 输入序列在分词后的最大长度,默认值为128,如果序列长于此值,将会被截断;如果短于此值,将会被填充| +| perf_warmup_steps | 性能测试的预热步数,默认值为20 | +| use_trt | 一个标志(flag),用于决定是否使用TensorRT推理 | +| precision | 推理精度,默认为fp32,可选fp16或int8 | +| use_mkldnn | 一个标志(flag),用于决定是否使用MKLDNN推理 | +| cpu_threads | CPU线程数,默认为1 | -- TensorRT预测: -环境配置:如果使用 TesorRT 预测引擎,需安装 ```WITH_TRT=ON``` 的Paddle,下载地址:[Python预测库](https://paddleinference.paddlepaddle.org.cn/master/user_guides/download_lib.html#python) +#### 5.1 TensorRT预测: 首先下载量化好的模型: ```shell @@ -227,10 +238,30 @@ tar -xf save_ppminilm_afqmc_new_calib.tar ```shell python paddle_inference_eval.py \ - --model_path=save_ernie3_afqmc_new_cablib \ + --model_path=save_ppminilm_afqmc_new_calib \ + --model_filename=inference.pdmodel \ + --params_filename=inference.pdiparams \ + --task_name='afqmc' \ + --use_trt \ + --precision=int8 +``` + +- ERNIE 3.0-Medium: +```shell +python paddle_inference_eval.py \ + --model_path=TNEWS \ --model_filename=infer.pdmodel \ --params_filename=infer.pdiparams \ - --task_name='afqmc' \ + --task_name='tnews' \ + --use_trt \ + --precision=fp32 +``` +```shell +python paddle_inference_eval.py \ + --model_path=save_tnews_pruned \ + --model_filename=infer.pdmodel \ + --params_filename=infer.pdiparams \ + --task_name='tnews' \ --use_trt \ --precision=int8 ``` @@ -239,9 +270,9 @@ python paddle_inference_eval.py \ ```shell python paddle_inference_eval.py \ - --model_path=save_ernie3_afqmc_new_cablib \ - --model_filename=infer.pdmodel \ - --params_filename=infer.pdiparams \ + --model_path=save_ppminilm_afqmc_new_calib \ + --model_filename=inference.pdmodel \ + --params_filename=inference.pdiparams \ --task_name='afqmc' \ --device=cpu \ --use_mkldnn=True \ diff --git a/example/auto_compression/nlp/configs/ernie3.0/tnews.yaml b/example/auto_compression/nlp/configs/ernie3.0/tnews.yaml index 49093ab87..b90da628a 100644 --- a/example/auto_compression/nlp/configs/ernie3.0/tnews.yaml +++ b/example/auto_compression/nlp/configs/ernie3.0/tnews.yaml @@ -6,12 +6,17 @@ Global: dataset: clue batch_size: 16 max_seq_length: 128 -TrainConfig: - epochs: 6 - eval_iter: 1110 - learning_rate: 2.0e-5 - optimizer_builder: - optimizer: - type: AdamW - weight_decay: 0.01 - origin_metric: 0.5700 + +# 剪枝 +Prune: + prune_algo: transformer_pruner + pruned_ratio: 0.25 + +# 离线量化 +QuantPost: + activation_bits: 8 + quantize_op_types: + - depthwise_conv2d + - conv2d + weight_bits: 8 + \ No newline at end of file diff --git a/example/auto_compression/nlp/configs/pp-minilm/auto/afqmc.yaml b/example/auto_compression/nlp/configs/pp-minilm/auto/afqmc.yaml index 9c9f58826..fdf65673b 100644 --- a/example/auto_compression/nlp/configs/pp-minilm/auto/afqmc.yaml +++ b/example/auto_compression/nlp/configs/pp-minilm/auto/afqmc.yaml @@ -6,17 +6,11 @@ Global: dataset: clue batch_size: 16 max_seq_length: 128 -TransformerPrune: - pruned_ratio: 0.25 -HyperParameterOptimization: -Distillation: + +#离线量化 QuantPost: -TrainConfig: - epochs: 6 - eval_iter: 1070 - learning_rate: 2.0e-5 - optimizer_builder: - optimizer: - type: AdamW - weight_decay: 0.01 - origin_metric: 0.7403 + activation_bits: 8 + quantize_op_types: + - conv2d + - depthwise_conv2d + weight_bits: 8 \ No newline at end of file diff --git a/example/auto_compression/nlp/paddle_inference_eval.py b/example/auto_compression/nlp/paddle_inference_eval.py index f48e20698..119a5ad8d 100644 --- a/example/auto_compression/nlp/paddle_inference_eval.py +++ b/example/auto_compression/nlp/paddle_inference_eval.py @@ -91,7 +91,8 @@ def parse_args(): "--max_seq_length", default=128, type=int, - help="The maximum total input sequence length after tokenization. Sequences longer " + help= + "The maximum total input sequence length after tokenization. Sequences longer " "than this will be truncated, sequences shorter will be padded.", ) parser.add_argument( "--perf_warmup_steps", @@ -107,7 +108,8 @@ def parse_args(): type=str, default="fp32", choices=["fp32", "fp16", "int8"], - help="The precision of inference. It can be 'fp32', 'fp16' or 'int8'. Default is 'fp16'.", + help= + "The precision of inference. It can be 'fp32', 'fp16' or 'int8'. Default is 'fp16'.", ) parser.add_argument( "--use_mkldnn", @@ -156,8 +158,7 @@ def _convert_example(example, } elif "target" in example: # wsc text, query, pronoun, query_idx, pronoun_idx = ( - example["text"], - example["target"]["span1_text"], + example["text"], example["target"]["span1_text"], example["target"]["span2_text"], example["target"]["span1_index"], example["target"]["span2_index"], ) @@ -209,6 +210,12 @@ def create_predictor(cls, args): config = paddle.inference.Config( os.path.join(args.model_path, args.model_filename), os.path.join(args.model_path, args.params_filename)) + # config.switch_ir_debug(True) + # 适用于ERNIE 3.0-Medium模型 + # config.exp_disable_tensorrt_ops(["elementwise_add"]) + # config.exp_disable_tensorrt_ops(["fused_embedding_eltwise_layernorm"]) + # config.exp_disable_tensorrt_ops(["tmp_3"]) + if args.device == "gpu": # set GPU configs accordingly config.enable_use_gpu(100, 0) @@ -239,8 +246,8 @@ def create_predictor(cls, args): dynamic_shape_file = os.path.join(args.model_path, "dynamic_shape.txt") if os.path.exists(dynamic_shape_file): - config.enable_tuned_tensorrt_dynamic_shape(dynamic_shape_file, - True) + config.enable_tuned_tensorrt_dynamic_shape( + dynamic_shape_file, True) print("trt set dynamic shape done!") else: config.collect_shape_range_info(dynamic_shape_file) @@ -365,4 +372,4 @@ def main(): if __name__ == "__main__": paddle.set_device("cpu") - main() + main() \ No newline at end of file diff --git a/paddleslim/quant/advanced/auto_clip.py b/paddleslim/quant/advanced/auto_clip.py index cf21ad30e..ac7166ed7 100644 --- a/paddleslim/quant/advanced/auto_clip.py +++ b/paddleslim/quant/advanced/auto_clip.py @@ -21,14 +21,15 @@ from .metrics import mse_loss from paddle.distributed.fleet.meta_parallel import ( ColumnParallelLinear, - RowParallelLinear, -) + RowParallelLinear, ) __all__ = ['AutoClip'] + class AutoClip(nn.Layer): """ AutoClip from AWQ[https://arxiv.org/abs/2306.00978] """ + def __init__( self, model, @@ -39,8 +40,7 @@ def __init__( n_grid=20, max_shrink=0.5, n_sample_token=512, - group_size=128, - ): + group_size=128, ): super(AutoClip, self).__init__() self.model = model self.weight_bits = weight_bits @@ -59,7 +59,9 @@ def __init__( def _apply_hook(self): self._forward_hook_list = [] for _, sub_layer in self.model.named_sublayers(): - if type(sub_layer) in [ColumnParallelLinear, RowParallelLinear, paddle.nn.Linear]: + if type(sub_layer) in [ + ColumnParallelLinear, RowParallelLinear, paddle.nn.Linear + ]: forward_pre_hook_handle = sub_layer.register_forward_pre_hook( self._forward_pre_hook) self._forward_hook_list.append(forward_pre_hook_handle) @@ -67,7 +69,7 @@ def _apply_hook(self): def _forward_pre_hook(self, layer, input): self._sample_scale(input, layer.full_name()) return input - + def _sample_scale(self, input, name): input = input[0] if type(input) == tuple else input input.stop_gradient = True @@ -80,7 +82,6 @@ def _sample_scale(self, input, name): else: self.sampled_inputs[name] = input - def auto_clip(self, group_size=128, oc_batch_size=256): """ search clip scale for each layer and update the layer's weight @@ -89,7 +90,7 @@ def auto_clip(self, group_size=128, oc_batch_size=256): name = sub_layer.full_name() if name not in self.sampled_inputs or 'out_linear' in sub_name: continue - + weight = sub_layer.weight.cast('float16') weight_t = paddle.transpose(weight, perm=[1, 0]) x = self.sampled_inputs[name].cast('float16') @@ -98,16 +99,18 @@ def auto_clip(self, group_size=128, oc_batch_size=256): x = x.reshape([1, x.shape[0], -1, group_size]) x = x[:, 0::x.shape[1] // self.n_sample_token] weight_t = weight_t.reshape([weight_t.shape[0], 1, -1, group_size]) - oc_batch_size = oc_batch_size if weight_t.shape[0] % oc_batch_size == 0 else 128 # prevent OOM + oc_batch_size = oc_batch_size if weight_t.shape[ + 0] % oc_batch_size == 0 else 128 # prevent OOM assert weight_t.shape[0] % oc_batch_size == 0 w_all = weight_t best_max_val_all = [] for i_b in range(weight_t.shape[0] // oc_batch_size): - w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size] + w = w_all[i_b * oc_batch_size:(i_b + 1) * oc_batch_size] - org_max_val = w.abs().max(axis=-1, keepdim=True) # co, 1, n_group, 1 + org_max_val = w.abs().max( + axis=-1, keepdim=True) # co, 1, n_group, 1 best_max_val = org_max_val.clone() min_errs = paddle.ones_like(org_max_val, dtype='float16') * 1e9 org_out = (x * w).sum(axis=-1) # co, n_token, n_group @@ -115,16 +118,22 @@ def auto_clip(self, group_size=128, oc_batch_size=256): max_val = org_max_val * (1 - i_s / self.n_grid) max_val_tmp = max_val cur_w = paddle.where(w > max_val_tmp, max_val_tmp, w) - cur_w = paddle.where(cur_w < - max_val_tmp, - max_val_tmp, cur_w) + cur_w = paddle.where(cur_w < -max_val_tmp, -max_val_tmp, + cur_w) org_w_shape = cur_w.shape - cur_w_r = cur_w.reshape([-1, self.group_size]).transpose([1, 0]) - quant_dequant_weight = fake_quant(cur_w_r, method='abs_max_channel_wise', weight_bits=4) - quant_dequant_weight = quant_dequant_weight.transpose([1, 0]).reshape(org_w_shape) + cur_w_r = cur_w.reshape([-1, + self.group_size]).transpose([1, 0]) + quant_dequant_weight = fake_quant( + cur_w_r, method='abs_max_channel_wise', weight_bits=4) + quant_dequant_weight = quant_dequant_weight.transpose( + [1, 0]).reshape(org_w_shape) cur_out = (x * quant_dequant_weight).sum(axis=-1) # co, 1, n_group, 1 tmp = (cur_out - org_out).detach().clone() - err = paddle.pow(tmp, 2).mean(axis=1).reshape(min_errs.shape) - print('block {} search s {} err {}'.format(i_b, i_s, err.mean().item())) + err = paddle.pow(tmp, + 2).mean(axis=1).reshape(min_errs.shape) + print('block {} search s {} err {}'.format( + i_b, i_s, err.mean().item())) del cur_w, cur_out, quant_dequant_weight, tmp, cur_w_r paddle.device.cuda.empty_cache() @@ -143,10 +152,14 @@ def auto_clip(self, group_size=128, oc_batch_size=256): if 'w_0' in param.name: param_tmp = param.transpose(perm=[1, 0]).cast('float16') tmp_shape = param_tmp.shape - param_tmp = param_tmp.reshape([best_max_val.shape[0], best_max_val.shape[1], -1]) - best_max_val = paddle.tile(best_max_val, repeat_times=(1, 1, param_tmp.shape[-1])) - param_tmp = paddle.where(param_tmp > best_max_val, best_max_val, param_tmp) - param_tmp = paddle.where(param_tmp < - best_max_val, - best_max_val, param_tmp) + param_tmp = param_tmp.reshape( + [best_max_val.shape[0], best_max_val.shape[1], -1]) + best_max_val = paddle.tile( + best_max_val, repeat_times=(1, 1, param_tmp.shape[-1])) + param_tmp = paddle.where(param_tmp > best_max_val, + best_max_val, param_tmp) + param_tmp = paddle.where(param_tmp < -best_max_val, + -best_max_val, param_tmp) param_tmp = param_tmp.reshape(tmp_shape).cast(param.dtype) param_tmp = param_tmp.transpose(perm=[1, 0]) paddle.assign(param_tmp, output=param) @@ -154,5 +167,6 @@ def auto_clip(self, group_size=128, oc_batch_size=256): paddle.device.cuda.empty_cache() break - del best_max_val, weight_t, x, weight, self.sampled_inputs[name], w_all, best_max_val_all + del best_max_val, weight_t, x, weight, self.sampled_inputs[ + name], w_all, best_max_val_all paddle.device.cuda.empty_cache() diff --git a/paddleslim/quant/advanced/gptq.py b/paddleslim/quant/advanced/gptq.py index 5194b71f5..5ae47205c 100644 --- a/paddleslim/quant/advanced/gptq.py +++ b/paddleslim/quant/advanced/gptq.py @@ -189,4 +189,4 @@ def fasterquant(self, self.quantized = True del H, Q, Hinv, W, Losses - paddle.device.cuda.empty_cache() + paddle.device.cuda.empty_cache() \ No newline at end of file diff --git a/paddleslim/quant/advanced/piecewise_search.py b/paddleslim/quant/advanced/piecewise_search.py index 4ec44e58b..8fd642ccc 100644 --- a/paddleslim/quant/advanced/piecewise_search.py +++ b/paddleslim/quant/advanced/piecewise_search.py @@ -71,7 +71,6 @@ def search(self, layer_name, sampled_input, act_abs_max, weight): origin_out = paddle.matmul(act, weight) w_abs_max = weight.abs().max(axis=-1, keepdim=True) rw_abs_max = w_abs_max.reshape(act_abs_max.shape) - smooth_scale_out = None global_loss = float('inf') best_scale = None @@ -184,5 +183,4 @@ def search(self, layer_name, sampled_input, act_abs_max, weight): print('Find Better K-Piece {}'.format(k_piece)) if not self.search_piece: break - return best_scale