|
| 1 | +# 结构化剪枝敏感度分析 |
| 2 | + |
| 3 | +本示例将以自动压缩示例中MobileNetV1为例,介绍如何快速修改示例代码,进行结构化剪枝敏感度分析工具分析模型参数敏感度,从而设置合适的剪枝比例和要剪枝的参数,在保证剪枝后模型精度的前提下进行最大比例的模型剪枝。 |
| 4 | +图像分类除MobileNetV1模型外其他模型的结构化剪枝敏感度分析可以直接使用 [run.py](./run.py) 脚本,替换传入的 config_path 文件为其他模型的任一压缩yaml文件,即可对其他图像分类模型进行敏感度分析。 |
| 5 | + |
| 6 | +## 计算通道剪枝敏感度 |
| 7 | + |
| 8 | +以下为示例代码每一步的含义,如果您是ACT(自动压缩工具)的用户,加粗文字表示如何把一个自动压缩示例改为一个敏感度分析示例。 |
| 9 | + |
| 10 | +### 1. 引入依赖 |
| 11 | + |
| 12 | +引入一些需要的依赖,可以直接复用以下代码,如果您需要对其他场景下模型进行敏感度分析,需要把其他场景文件下中 ``run.py`` 文件中独有的依赖也导入进来。**或者把最后一个依赖放入自动压缩示例代码中。** |
| 13 | + |
| 14 | +```python |
| 15 | +import os |
| 16 | +import sys |
| 17 | +import argparse |
| 18 | +import pickle |
| 19 | +import functools |
| 20 | +from functools import partial |
| 21 | +import math |
| 22 | +from tqdm import tqdm |
| 23 | + |
| 24 | +import numpy as np |
| 25 | +import paddle |
| 26 | +import paddle.nn as nn |
| 27 | +from paddle.io import DataLoader |
| 28 | +import paddleslim |
| 29 | +from imagenet_reader import ImageNetDataset |
| 30 | +from paddleslim.common import load_config as load_slim_config |
| 31 | +from paddleslim.auto_compression.analysis import analysis_prune |
| 32 | +``` |
| 33 | + |
| 34 | +### 2. 定义可传入参数 |
| 35 | + |
| 36 | +定义一些可以通过指令传入的参数。**此段代码无论您想对任何场景的模型进行分析都无需修改,复制过去替换原本的指令即可** |
| 37 | + |
| 38 | +```python |
| 39 | +def argsparser(): |
| 40 | + parser = argparse.ArgumentParser(description=__doc__) |
| 41 | + parser.add_argument( |
| 42 | + '--config_path', |
| 43 | + type=str, |
| 44 | + default=None, |
| 45 | + help="path of compression strategy config.", |
| 46 | + required=True) |
| 47 | + parser.add_argument( |
| 48 | + '--analysis_file', |
| 49 | + type=str, |
| 50 | + default='sensitivity_0.data', |
| 51 | + help="directory to save compressed model.") |
| 52 | + parser.add_argument( |
| 53 | + '--pruned_ratios', |
| 54 | + nargs='+', |
| 55 | + type=float, |
| 56 | + default=[0.1, 0.2, 0.3, 0.4], |
| 57 | + help="The ratios to be pruned when compute sensitivity.") |
| 58 | + parser.add_argument( |
| 59 | + '--target_loss', |
| 60 | + type=float, |
| 61 | + default=0.2, |
| 62 | + help="use the target loss to get prune ratio of each parameter") |
| 63 | + |
| 64 | + return parser |
| 65 | + |
| 66 | + |
| 67 | +``` |
| 68 | + |
| 69 | +### 3. 定义eval_function |
| 70 | + |
| 71 | +需要定义完整的测试流程,可以直接使用对应场景文件夹下 ``run.py`` 文件中的测试流程即可,**把自动压缩示例代码中测试回调函数中下面这一行代码:** |
| 72 | + |
| 73 | +```python |
| 74 | +def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): |
| 75 | +``` |
| 76 | +**修改成:** |
| 77 | +```python |
| 78 | +def eval_function(compiled_test_program, exe, test_feed_names, test_fetch_list): |
| 79 | +``` |
| 80 | + |
| 81 | +最终的测试过程代码如下: |
| 82 | +```python |
| 83 | +def eval_reader(data_dir, batch_size, crop_size, resize_size, place=None): |
| 84 | + val_reader = ImageNetDataset( |
| 85 | + mode='val', |
| 86 | + data_dir=data_dir, |
| 87 | + crop_size=crop_size, |
| 88 | + resize_size=resize_size) |
| 89 | + val_loader = DataLoader( |
| 90 | + val_reader, |
| 91 | + places=[place] if place is not None else None, |
| 92 | + batch_size=global_config['batch_size'], |
| 93 | + shuffle=False, |
| 94 | + drop_last=False, |
| 95 | + num_workers=0) |
| 96 | + return val_loader |
| 97 | + |
| 98 | + |
| 99 | +def eval_function(compiled_test_program, exe, test_feed_names, test_fetch_list): |
| 100 | + val_loader = eval_reader( |
| 101 | + global_config['data_dir'], |
| 102 | + batch_size=global_config['batch_size'], |
| 103 | + crop_size=img_size, |
| 104 | + resize_size=resize_size) |
| 105 | + |
| 106 | + results = [] |
| 107 | + with tqdm( |
| 108 | + total=len(val_loader), |
| 109 | + bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', |
| 110 | + ncols=80) as t: |
| 111 | + for batch_id, (image, label) in enumerate(val_loader): |
| 112 | + # top1_acc, top5_acc |
| 113 | + if len(test_feed_names) == 1: |
| 114 | + image = np.array(image) |
| 115 | + label = np.array(label).astype('int64') |
| 116 | + pred = exe.run(compiled_test_program, |
| 117 | + feed={test_feed_names[0]: image}, |
| 118 | + fetch_list=test_fetch_list) |
| 119 | + pred = np.array(pred[0]) |
| 120 | + label = np.array(label) |
| 121 | + sort_array = pred.argsort(axis=1) |
| 122 | + top_1_pred = sort_array[:, -1:][:, ::-1] |
| 123 | + top_1 = np.mean(label == top_1_pred) |
| 124 | + top_5_pred = sort_array[:, -5:][:, ::-1] |
| 125 | + acc_num = 0 |
| 126 | + for i in range(len(label)): |
| 127 | + if label[i][0] in top_5_pred[i]: |
| 128 | + acc_num += 1 |
| 129 | + top_5 = float(acc_num) / len(label) |
| 130 | + results.append([top_1, top_5]) |
| 131 | + else: |
| 132 | + # eval "eval model", which inputs are image and label, output is top1 and top5 accuracy |
| 133 | + image = np.array(image) |
| 134 | + label = np.array(label).astype('int64') |
| 135 | + result = exe.run(compiled_test_program, |
| 136 | + feed={ |
| 137 | + test_feed_names[0]: image, |
| 138 | + test_feed_names[1]: label |
| 139 | + }, |
| 140 | + fetch_list=test_fetch_list) |
| 141 | + result = [np.mean(r) for r in result] |
| 142 | + results.append(result) |
| 143 | + t.update() |
| 144 | + result = np.mean(np.array(results), axis=0) |
| 145 | + return result[0] |
| 146 | +``` |
| 147 | + |
| 148 | +### 4. 加载配置文件 |
| 149 | +加载配置文件,获得文件中数据读取部分的相关配置。**使用原始的自动压缩示例代码中的即可** |
| 150 | +```python |
| 151 | +global global_config |
| 152 | +all_config = load_slim_config(args.config_path) |
| 153 | + |
| 154 | +assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}" |
| 155 | +global_config = all_config["Global"] |
| 156 | + |
| 157 | +global img_size, resize_size |
| 158 | +img_size = global_config['img_size'] if 'img_size' in global_config else 224 |
| 159 | +resize_size = global_config[ |
| 160 | + 'resize_size'] if 'resize_size' in global_config else 256 |
| 161 | +``` |
| 162 | + |
| 163 | +### 4. 进行敏感度分析 |
| 164 | + |
| 165 | +传入测试回调函数,配置(主要包括模型位置和模型名称等信息),分析文件保存的位置,要分析的裁剪比例和可以接受的精度目标损失。如果不传入可以接受的精度目标损失,则只返回敏感度分析情况。**把自动压缩代码中调用AutoCompression 和 ac.compress 的代码替换成以下代码即可** |
| 166 | + |
| 167 | +```python |
| 168 | +analysis_prune(eval_function, global_config['model_dir'], global_config['model_filename'], global_config['params_filename'], args.analysis_file, |
| 169 | + args.pruned_ratios, args.target_loss) |
| 170 | +``` |
0 commit comments