|
| 1 | +""" |
| 2 | +some instructions |
| 3 | +1. Fill the models that needs to be checked in the modelzoo_dict |
| 4 | +2. Arange the structure of the directory as follows, the script will find the |
| 5 | + corresponding config itself: |
| 6 | + model_dir/model_family/checkpoints |
| 7 | + e.g.: models/faster_rcnn/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth |
| 8 | + models/faster_rcnn/faster_rcnn_r101_fpn_1x_coco_20200130-047c8118.pth |
| 9 | +3. Excute the batch_test.sh |
| 10 | +""" |
| 11 | + |
| 12 | +import argparse |
| 13 | +import json |
| 14 | +import os |
| 15 | +import subprocess |
| 16 | + |
| 17 | +import mmcv |
| 18 | +import torch |
| 19 | +from mmcv import Config, get_logger |
| 20 | +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel |
| 21 | +from mmcv.runner import get_dist_info, init_dist, load_checkpoint |
| 22 | + |
| 23 | +from mmdet.apis import multi_gpu_test, single_gpu_test |
| 24 | +from mmdet.core import wrap_fp16_model |
| 25 | +from mmdet.datasets import (build_dataloader, build_dataset, |
| 26 | + replace_ImageToTensor) |
| 27 | +from mmdet.models import build_detector |
| 28 | + |
| 29 | +modelzoo_dict = { |
| 30 | + 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py': { |
| 31 | + 'bbox': 0.374 |
| 32 | + }, |
| 33 | + 'configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py': { |
| 34 | + 'bbox': 0.382, |
| 35 | + 'segm': 0.347 |
| 36 | + }, |
| 37 | + 'configs/rpn/rpn_r50_fpn_1x_coco.py': { |
| 38 | + 'AR@1000': 0.582 |
| 39 | + } |
| 40 | +} |
| 41 | + |
| 42 | + |
| 43 | +def parse_args(): |
| 44 | + parser = argparse.ArgumentParser( |
| 45 | + description='The script used for checking the correctness \ |
| 46 | + of batch inference') |
| 47 | + parser.add_argument('model_dir', help='directory of models') |
| 48 | + parser.add_argument( |
| 49 | + 'json_out', help='the output json records test information like mAP') |
| 50 | + parser.add_argument( |
| 51 | + '--launcher', |
| 52 | + choices=['none', 'pytorch', 'slurm', 'mpi'], |
| 53 | + default='none', |
| 54 | + help='job launcher') |
| 55 | + parser.add_argument('--local_rank', type=int, default=0) |
| 56 | + args = parser.parse_args() |
| 57 | + if 'LOCAL_RANK' not in os.environ: |
| 58 | + os.environ['LOCAL_RANK'] = str(args.local_rank) |
| 59 | + return args |
| 60 | + |
| 61 | + |
| 62 | +def check_finish(all_model_dict, result_file): |
| 63 | + # check if all models are checked |
| 64 | + tested_cfgs = [] |
| 65 | + with open(result_file, 'r+') as f: |
| 66 | + for line in f: |
| 67 | + line = json.loads(line) |
| 68 | + tested_cfgs.append(line['cfg']) |
| 69 | + is_finish = True |
| 70 | + for cfg in sorted(all_model_dict.keys()): |
| 71 | + if cfg not in tested_cfgs: |
| 72 | + return cfg |
| 73 | + if is_finish: |
| 74 | + with open(result_file, 'a+') as f: |
| 75 | + f.write('finished\n') |
| 76 | + |
| 77 | + |
| 78 | +def dump_dict(record_dict, json_out): |
| 79 | + # dump result json dict |
| 80 | + with open(json_out, 'a+') as f: |
| 81 | + mmcv.dump(record_dict, f, file_format='json') |
| 82 | + f.write('\n') |
| 83 | + |
| 84 | + |
| 85 | +def main(): |
| 86 | + args = parse_args() |
| 87 | + # touch the output json if not exist |
| 88 | + with open(args.json_out, 'a+'): |
| 89 | + pass |
| 90 | + # init distributed env first, since logger depends on the dist |
| 91 | + # info. |
| 92 | + if args.launcher == 'none': |
| 93 | + distributed = False |
| 94 | + else: |
| 95 | + distributed = True |
| 96 | + init_dist(args.launcher, backend='nccl') |
| 97 | + rank, world_size = get_dist_info() |
| 98 | + |
| 99 | + logger = get_logger('root') |
| 100 | + |
| 101 | + # read info of checkpoints and config |
| 102 | + result_dict = dict() |
| 103 | + for model_family_dir in os.listdir(args.model_dir): |
| 104 | + for model in os.listdir( |
| 105 | + os.path.join(args.model_dir, model_family_dir)): |
| 106 | + # cpt: rpn_r50_fpn_1x_coco_20200218-5525fa2e.pth |
| 107 | + # cfg: rpn_r50_fpn_1x_coco.py |
| 108 | + cfg = model.split('.')[0][:-18] + '.py' |
| 109 | + cfg_path = os.path.join('configs', model_family_dir, cfg) |
| 110 | + assert os.path.isfile( |
| 111 | + cfg_path), f'{cfg_path} is not valid config path' |
| 112 | + cpt_path = os.path.join(args.model_dir, model_family_dir, model) |
| 113 | + result_dict[cfg_path] = cpt_path |
| 114 | + assert cfg_path in modelzoo_dict, f'please fill the ' \ |
| 115 | + f'performance of cfg: {cfg_path}' |
| 116 | + cfg = check_finish(result_dict, args.json_out) |
| 117 | + cpt = result_dict[cfg] |
| 118 | + try: |
| 119 | + cfg_name = cfg |
| 120 | + logger.info(f'evaluate {cfg}') |
| 121 | + record = dict(cfg=cfg, cpt=cpt) |
| 122 | + cfg = Config.fromfile(cfg) |
| 123 | + # cfg.data.test.ann_file = 'data/val_0_10.json' |
| 124 | + # set cudnn_benchmark |
| 125 | + if cfg.get('cudnn_benchmark', False): |
| 126 | + torch.backends.cudnn.benchmark = True |
| 127 | + cfg.model.pretrained = None |
| 128 | + if cfg.model.get('neck'): |
| 129 | + if isinstance(cfg.model.neck, list): |
| 130 | + for neck_cfg in cfg.model.neck: |
| 131 | + if neck_cfg.get('rfp_backbone'): |
| 132 | + if neck_cfg.rfp_backbone.get('pretrained'): |
| 133 | + neck_cfg.rfp_backbone.pretrained = None |
| 134 | + elif cfg.model.neck.get('rfp_backbone'): |
| 135 | + if cfg.model.neck.rfp_backbone.get('pretrained'): |
| 136 | + cfg.model.neck.rfp_backbone.pretrained = None |
| 137 | + |
| 138 | + # in case the test dataset is concatenated |
| 139 | + if isinstance(cfg.data.test, dict): |
| 140 | + cfg.data.test.test_mode = True |
| 141 | + elif isinstance(cfg.data.test, list): |
| 142 | + for ds_cfg in cfg.data.test: |
| 143 | + ds_cfg.test_mode = True |
| 144 | + |
| 145 | + # build the dataloader |
| 146 | + samples_per_gpu = 2 # hack test with 2 image per gpu |
| 147 | + if samples_per_gpu > 1: |
| 148 | + # Replace 'ImageToTensor' to 'DefaultFormatBundle' |
| 149 | + cfg.data.test.pipeline = replace_ImageToTensor( |
| 150 | + cfg.data.test.pipeline) |
| 151 | + dataset = build_dataset(cfg.data.test) |
| 152 | + data_loader = build_dataloader( |
| 153 | + dataset, |
| 154 | + samples_per_gpu=samples_per_gpu, |
| 155 | + workers_per_gpu=cfg.data.workers_per_gpu, |
| 156 | + dist=distributed, |
| 157 | + shuffle=False) |
| 158 | + |
| 159 | + # build the model and load checkpoint |
| 160 | + model = build_detector( |
| 161 | + cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) |
| 162 | + fp16_cfg = cfg.get('fp16', None) |
| 163 | + if fp16_cfg is not None: |
| 164 | + wrap_fp16_model(model) |
| 165 | + |
| 166 | + checkpoint = load_checkpoint(model, cpt, map_location='cpu') |
| 167 | + # old versions did not save class info in checkpoints, |
| 168 | + # this walkaround is for backward compatibility |
| 169 | + if 'CLASSES' in checkpoint['meta']: |
| 170 | + model.CLASSES = checkpoint['meta']['CLASSES'] |
| 171 | + else: |
| 172 | + model.CLASSES = dataset.CLASSES |
| 173 | + |
| 174 | + if not distributed: |
| 175 | + model = MMDataParallel(model, device_ids=[0]) |
| 176 | + outputs = single_gpu_test(model, data_loader) |
| 177 | + else: |
| 178 | + model = MMDistributedDataParallel( |
| 179 | + model.cuda(), |
| 180 | + device_ids=[torch.cuda.current_device()], |
| 181 | + broadcast_buffers=False) |
| 182 | + outputs = multi_gpu_test(model, data_loader, 'tmp') |
| 183 | + if rank == 0: |
| 184 | + ref_mAP_dict = modelzoo_dict[cfg_name] |
| 185 | + metrics = list(ref_mAP_dict.keys()) |
| 186 | + metrics = [ |
| 187 | + m if m != 'AR@1000' else 'proposal_fast' for m in metrics |
| 188 | + ] |
| 189 | + eval_results = dataset.evaluate(outputs, metrics) |
| 190 | + print(eval_results) |
| 191 | + for metric in metrics: |
| 192 | + if metric == 'proposal_fast': |
| 193 | + ref_metric = modelzoo_dict[cfg_name]['AR@1000'] |
| 194 | + eval_metric = eval_results['AR@1000'] |
| 195 | + else: |
| 196 | + ref_metric = modelzoo_dict[cfg_name][metric] |
| 197 | + eval_metric = eval_results[f'{metric}_mAP'] |
| 198 | + if abs(ref_metric - eval_metric) > 0.003: |
| 199 | + record['is_normal'] = False |
| 200 | + dump_dict(record, args.json_out) |
| 201 | + check_finish(result_dict, args.json_out) |
| 202 | + except Exception as e: |
| 203 | + logger.error(f'rank: {rank} test fail with error: {e}') |
| 204 | + record['terminate'] = True |
| 205 | + dump_dict(record, args.json_out) |
| 206 | + check_finish(result_dict, args.json_out) |
| 207 | + # hack there to throw some error to prevent hang out |
| 208 | + subprocess.call('xxx') |
| 209 | + |
| 210 | + |
| 211 | +if __name__ == '__main__': |
| 212 | + main() |
0 commit comments