Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix]修复自然语言处理模型自动压缩示例 #1839

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
67 changes: 49 additions & 18 deletions example/auto_compression/nlp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,16 @@

#### 3.1 准备环境
- python >= 3.6
- PaddlePaddle >= 2.4 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装)
- PaddleSlim >= 2.4
- PaddleNLP >= 2.3
- PaddlePaddle ==2.5 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装)
- PaddleSlim ==2.5
- PaddleNLP ==2.6

安装paddlepaddle:
```shell
# CPU
pip install paddlepaddle==2.4.1
pip install paddlepaddle==2.5.0
# GPU 以Ubuntu、CUDA 11.2为例
python -m pip install paddlepaddle-gpu==2.4.1.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
python -m pip install paddlepaddle-gpu==2.5.0.post116 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
```

安装paddleslim:
Expand Down Expand Up @@ -95,7 +95,6 @@ pip install paddlenlp
|:------:|:------:|:------:|:------:|:------:|:-----------:|:------:|:------:|
| PP-MiniLM | [afqmc](https://bj.bcebos.com/v1/paddle-slim-models/act/afqmc.tar) | [tnews](https://bj.bcebos.com/v1/paddle-slim-models/act/tnews.tar) | [iflytek](https://bj.bcebos.com/v1/paddle-slim-models/act/iflytek.tar) | [cmnli](https://bj.bcebos.com/v1/paddle-slim-models/act/cmnli.tar) | [ ocnli](https://bj.bcebos.com/v1/paddle-slim-models/act/ocnli.tar) | [cluewsc2020](https://bj.bcebos.com/v1/paddle-slim-models/act/cluewsc.tar) | [csl](https://bj.bcebos.com/v1/paddle-slim-models/act/csl.tar) |
| ERNIE 3.0-Medium | [afqmc](https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/AFQMC.tar) | [tnews](https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/TNEWS.tar) | [iflytek](https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/IFLYTEK.tar) | [cmnli](https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/CMNLI.tar) | [ocnli](https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/OCNLI.tar) | [cluewsc2020](https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/CLUEWSC2020.tar) | [csl](https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/CSL.tar) |
| UIE-base | [报销工单](https://bj.bcebos.com/v1/paddle-slim-models/act/uie_base.tar) |

从上表获得模型超链接, 并用以下命令下载推理模型文件:

Expand All @@ -119,11 +118,6 @@ export CUDA_VISIBLE_DEVICES=0
python run.py --config_path='./configs/pp-minilm/auto/afqmc.yaml' --save_dir='./save_afqmc_pruned/'
```

自动压缩UIE系列模型需要使用 run_uie.py 脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行自动压缩。配置config文件中训练部分的参数,将任务名称、模型类型、数据集名称、压缩参数传入,配置完成后便可对模型进行蒸馏量化训练。
```shell
export CUDA_VISIBLE_DEVICES=0
python run_uie.py --config_path='./configs/uie/uie_base.yaml' --save_dir='./save_uie_qat/'
```

如仅需验证模型精度,或验证压缩之后模型精度,在启动```run.py```脚本时,将配置文件中模型文件夹 ```model_dir``` 改为压缩之后保存的文件夹路径 ```./save_afqmc_pruned``` ,命令加上```--eval True```即可:
```shell
Expand Down Expand Up @@ -212,12 +206,29 @@ QuantPost:

## 5. 预测部署


量化模型在GPU上可以使用TensorRT进行加速,在CPU上可以使用MKLDNN进行加速。

以下字段用于配置预测参数:

| 参数名 | 含义 |
|:------:|:------:|
| model_path | inference 模型文件所在目录,该目录下需要有文件 model.pdmodel 和 model.pdiparams 两个文件 |
| model_filename | 模型文件的名称,默认值为inference.pdmodel |
| params_filename | 参数文件的名称,默认值为inference.pdiparams |
| task_name | 要执行的任务名称,默认值为afqmc |
| dataset | 模型使用的数据集,默认值为clue |
| device | 用于推理的设备,默认为gpu,可选cpu或gpu |
| batch_size | 推理时的batch size,默认为32 |
| max_seq_len | 输入序列在分词后的最大长度,默认值为128,如果序列长于此值,将会被截断;如果短于此值,将会被填充|
| perf_warmup_steps | 性能测试的预热步数,默认值为20 |
| use_trt | 一个标志(flag),用于决定是否使用TensorRT推理 |
| precision | 推理精度,默认为fp32,可选fp16或int8 |
| use_mkldnn | 一个标志(flag),用于决定是否使用MKLDNN推理 |
| cpu_threads | CPU线程数,默认为1 |

- TensorRT预测:

环境配置:如果使用 TesorRT 预测引擎,需安装 ```WITH_TRT=ON``` 的Paddle,下载地址:[Python预测库](https://paddleinference.paddlepaddle.org.cn/master/user_guides/download_lib.html#python)
#### 5.1 TensorRT预测:

首先下载量化好的模型:
```shell
Expand All @@ -227,10 +238,30 @@ tar -xf save_ppminilm_afqmc_new_calib.tar

```shell
python paddle_inference_eval.py \
--model_path=save_ernie3_afqmc_new_cablib \
--model_path=save_ppminilm_afqmc_new_calib \
--model_filename=inference.pdmodel \
--params_filename=inference.pdiparams \
--task_name='afqmc' \
--use_trt \
--precision=int8
```

- ERNIE 3.0-Medium:
```shell
python paddle_inference_eval.py \
--model_path=TNEWS \
--model_filename=infer.pdmodel \
--params_filename=infer.pdiparams \
--task_name='afqmc' \
--task_name='tnews' \
--use_trt \
--precision=fp32
```
```shell
python paddle_inference_eval.py \
--model_path=save_tnews_pruned \
--model_filename=infer.pdmodel \
--params_filename=infer.pdiparams \
--task_name='tnews' \
--use_trt \
--precision=int8
```
Expand All @@ -239,9 +270,9 @@ python paddle_inference_eval.py \

```shell
python paddle_inference_eval.py \
--model_path=save_ernie3_afqmc_new_cablib \
--model_filename=infer.pdmodel \
--params_filename=infer.pdiparams \
--model_path=save_ppminilm_afqmc_new_calib \
--model_filename=inference.pdmodel \
--params_filename=inference.pdiparams \
--task_name='afqmc' \
--device=cpu \
--use_mkldnn=True \
Expand Down
23 changes: 14 additions & 9 deletions example/auto_compression/nlp/configs/ernie3.0/tnews.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@ Global:
dataset: clue
batch_size: 16
max_seq_length: 128
TrainConfig:
epochs: 6
eval_iter: 1110
learning_rate: 2.0e-5
optimizer_builder:
optimizer:
type: AdamW
weight_decay: 0.01
origin_metric: 0.5700

# 剪枝
Prune:
prune_algo: transformer_pruner
pruned_ratio: 0.25

# 离线量化
QuantPost:
activation_bits: 8
quantize_op_types:
- depthwise_conv2d
- conv2d
weight_bits: 8

20 changes: 7 additions & 13 deletions example/auto_compression/nlp/configs/pp-minilm/auto/afqmc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,11 @@ Global:
dataset: clue
batch_size: 16
max_seq_length: 128
TransformerPrune:
pruned_ratio: 0.25
HyperParameterOptimization:
Distillation:

#离线量化
QuantPost:
TrainConfig:
epochs: 6
eval_iter: 1070
learning_rate: 2.0e-5
optimizer_builder:
optimizer:
type: AdamW
weight_decay: 0.01
origin_metric: 0.7403
activation_bits: 8
quantize_op_types:
- conv2d
- depthwise_conv2d
weight_bits: 8
21 changes: 14 additions & 7 deletions example/auto_compression/nlp/paddle_inference_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def parse_args():
"--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
help=
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.", )
parser.add_argument(
"--perf_warmup_steps",
Expand All @@ -107,7 +108,8 @@ def parse_args():
type=str,
default="fp32",
choices=["fp32", "fp16", "int8"],
help="The precision of inference. It can be 'fp32', 'fp16' or 'int8'. Default is 'fp16'.",
help=
"The precision of inference. It can be 'fp32', 'fp16' or 'int8'. Default is 'fp16'.",
)
parser.add_argument(
"--use_mkldnn",
Expand Down Expand Up @@ -156,8 +158,7 @@ def _convert_example(example,
}
elif "target" in example: # wsc
text, query, pronoun, query_idx, pronoun_idx = (
example["text"],
example["target"]["span1_text"],
example["text"], example["target"]["span1_text"],
example["target"]["span2_text"],
example["target"]["span1_index"],
example["target"]["span2_index"], )
Expand Down Expand Up @@ -209,6 +210,12 @@ def create_predictor(cls, args):
config = paddle.inference.Config(
os.path.join(args.model_path, args.model_filename),
os.path.join(args.model_path, args.params_filename))
# config.switch_ir_debug(True)
# 适用于ERNIE 3.0-Medium模型
# config.exp_disable_tensorrt_ops(["elementwise_add"])
# config.exp_disable_tensorrt_ops(["fused_embedding_eltwise_layernorm"])
# config.exp_disable_tensorrt_ops(["tmp_3"])

if args.device == "gpu":
# set GPU configs accordingly
config.enable_use_gpu(100, 0)
Expand Down Expand Up @@ -239,8 +246,8 @@ def create_predictor(cls, args):
dynamic_shape_file = os.path.join(args.model_path,
"dynamic_shape.txt")
if os.path.exists(dynamic_shape_file):
config.enable_tuned_tensorrt_dynamic_shape(dynamic_shape_file,
True)
config.enable_tuned_tensorrt_dynamic_shape(
dynamic_shape_file, True)
print("trt set dynamic shape done!")
else:
config.collect_shape_range_info(dynamic_shape_file)
Expand Down Expand Up @@ -365,4 +372,4 @@ def main():

if __name__ == "__main__":
paddle.set_device("cpu")
main()
main()
58 changes: 36 additions & 22 deletions paddleslim/quant/advanced/auto_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
from .metrics import mse_loss
from paddle.distributed.fleet.meta_parallel import (
ColumnParallelLinear,
RowParallelLinear,
)
RowParallelLinear, )
__all__ = ['AutoClip']


class AutoClip(nn.Layer):
"""
AutoClip from AWQ[https://arxiv.org/abs/2306.00978]
"""

def __init__(
self,
model,
Expand All @@ -39,8 +40,7 @@ def __init__(
n_grid=20,
max_shrink=0.5,
n_sample_token=512,
group_size=128,
):
group_size=128, ):
super(AutoClip, self).__init__()
self.model = model
self.weight_bits = weight_bits
Expand All @@ -59,15 +59,17 @@ def __init__(
def _apply_hook(self):
self._forward_hook_list = []
for _, sub_layer in self.model.named_sublayers():
if type(sub_layer) in [ColumnParallelLinear, RowParallelLinear, paddle.nn.Linear]:
if type(sub_layer) in [
ColumnParallelLinear, RowParallelLinear, paddle.nn.Linear
]:
forward_pre_hook_handle = sub_layer.register_forward_pre_hook(
self._forward_pre_hook)
self._forward_hook_list.append(forward_pre_hook_handle)

def _forward_pre_hook(self, layer, input):
self._sample_scale(input, layer.full_name())
return input

def _sample_scale(self, input, name):
input = input[0] if type(input) == tuple else input
input.stop_gradient = True
Expand All @@ -80,7 +82,6 @@ def _sample_scale(self, input, name):
else:
self.sampled_inputs[name] = input


def auto_clip(self, group_size=128, oc_batch_size=256):
"""
search clip scale for each layer and update the layer's weight
Expand All @@ -89,7 +90,7 @@ def auto_clip(self, group_size=128, oc_batch_size=256):
name = sub_layer.full_name()
if name not in self.sampled_inputs or 'out_linear' in sub_name:
continue

weight = sub_layer.weight.cast('float16')
weight_t = paddle.transpose(weight, perm=[1, 0])
x = self.sampled_inputs[name].cast('float16')
Expand All @@ -98,33 +99,41 @@ def auto_clip(self, group_size=128, oc_batch_size=256):
x = x.reshape([1, x.shape[0], -1, group_size])
x = x[:, 0::x.shape[1] // self.n_sample_token]
weight_t = weight_t.reshape([weight_t.shape[0], 1, -1, group_size])
oc_batch_size = oc_batch_size if weight_t.shape[0] % oc_batch_size == 0 else 128 # prevent OOM
oc_batch_size = oc_batch_size if weight_t.shape[
0] % oc_batch_size == 0 else 128 # prevent OOM
assert weight_t.shape[0] % oc_batch_size == 0

w_all = weight_t
best_max_val_all = []

for i_b in range(weight_t.shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size]
w = w_all[i_b * oc_batch_size:(i_b + 1) * oc_batch_size]

org_max_val = w.abs().max(axis=-1, keepdim=True) # co, 1, n_group, 1
org_max_val = w.abs().max(
axis=-1, keepdim=True) # co, 1, n_group, 1
best_max_val = org_max_val.clone()
min_errs = paddle.ones_like(org_max_val, dtype='float16') * 1e9
org_out = (x * w).sum(axis=-1) # co, n_token, n_group
for i_s in range(int(self.max_shrink * self.n_grid)):
max_val = org_max_val * (1 - i_s / self.n_grid)
max_val_tmp = max_val
cur_w = paddle.where(w > max_val_tmp, max_val_tmp, w)
cur_w = paddle.where(cur_w < - max_val_tmp, - max_val_tmp, cur_w)
cur_w = paddle.where(cur_w < -max_val_tmp, -max_val_tmp,
cur_w)
org_w_shape = cur_w.shape
cur_w_r = cur_w.reshape([-1, self.group_size]).transpose([1, 0])
quant_dequant_weight = fake_quant(cur_w_r, method='abs_max_channel_wise', weight_bits=4)
quant_dequant_weight = quant_dequant_weight.transpose([1, 0]).reshape(org_w_shape)
cur_w_r = cur_w.reshape([-1,
self.group_size]).transpose([1, 0])
quant_dequant_weight = fake_quant(
cur_w_r, method='abs_max_channel_wise', weight_bits=4)
quant_dequant_weight = quant_dequant_weight.transpose(
[1, 0]).reshape(org_w_shape)
cur_out = (x * quant_dequant_weight).sum(axis=-1)
# co, 1, n_group, 1
tmp = (cur_out - org_out).detach().clone()
err = paddle.pow(tmp, 2).mean(axis=1).reshape(min_errs.shape)
print('block {} search s {} err {}'.format(i_b, i_s, err.mean().item()))
err = paddle.pow(tmp,
2).mean(axis=1).reshape(min_errs.shape)
print('block {} search s {} err {}'.format(
i_b, i_s, err.mean().item()))
del cur_w, cur_out, quant_dequant_weight, tmp, cur_w_r
paddle.device.cuda.empty_cache()

Expand All @@ -143,16 +152,21 @@ def auto_clip(self, group_size=128, oc_batch_size=256):
if 'w_0' in param.name:
param_tmp = param.transpose(perm=[1, 0]).cast('float16')
tmp_shape = param_tmp.shape
param_tmp = param_tmp.reshape([best_max_val.shape[0], best_max_val.shape[1], -1])
best_max_val = paddle.tile(best_max_val, repeat_times=(1, 1, param_tmp.shape[-1]))
param_tmp = paddle.where(param_tmp > best_max_val, best_max_val, param_tmp)
param_tmp = paddle.where(param_tmp < - best_max_val, - best_max_val, param_tmp)
param_tmp = param_tmp.reshape(
[best_max_val.shape[0], best_max_val.shape[1], -1])
best_max_val = paddle.tile(
best_max_val, repeat_times=(1, 1, param_tmp.shape[-1]))
param_tmp = paddle.where(param_tmp > best_max_val,
best_max_val, param_tmp)
param_tmp = paddle.where(param_tmp < -best_max_val,
-best_max_val, param_tmp)
param_tmp = param_tmp.reshape(tmp_shape).cast(param.dtype)
param_tmp = param_tmp.transpose(perm=[1, 0])
paddle.assign(param_tmp, output=param)
del param_tmp
paddle.device.cuda.empty_cache()
break

del best_max_val, weight_t, x, weight, self.sampled_inputs[name], w_all, best_max_val_all
del best_max_val, weight_t, x, weight, self.sampled_inputs[
name], w_all, best_max_val_all
paddle.device.cuda.empty_cache()
2 changes: 1 addition & 1 deletion paddleslim/quant/advanced/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,4 +189,4 @@ def fasterquant(self,

self.quantized = True
del H, Q, Hinv, W, Losses
paddle.device.cuda.empty_cache()
paddle.device.cuda.empty_cache()
2 changes: 0 additions & 2 deletions paddleslim/quant/advanced/piecewise_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def search(self, layer_name, sampled_input, act_abs_max, weight):
origin_out = paddle.matmul(act, weight)
w_abs_max = weight.abs().max(axis=-1, keepdim=True)
rw_abs_max = w_abs_max.reshape(act_abs_max.shape)

smooth_scale_out = None
global_loss = float('inf')
best_scale = None
Expand Down Expand Up @@ -184,5 +183,4 @@ def search(self, layer_name, sampled_input, act_abs_max, weight):
print('Find Better K-Piece {}'.format(k_piece))
if not self.search_piece:
break

return best_scale