|
| 1 | +# -------------------------------------------------------- |
| 2 | +# Tensorflow Faster R-CNN |
| 3 | +# Licensed under The MIT License [see LICENSE for details] |
| 4 | +# Written by Jiasen Lu, Jianwei Yang, based on code from Ross Girshick |
| 5 | +# -------------------------------------------------------- |
| 6 | +from __future__ import absolute_import |
| 7 | +from __future__ import division |
| 8 | +from __future__ import print_function |
| 9 | + |
| 10 | +import _init_paths |
| 11 | +import os |
| 12 | +import sys |
| 13 | +import numpy as np |
| 14 | +import argparse |
| 15 | +import pprint |
| 16 | +import pdb |
| 17 | +import time |
| 18 | +import cv2 |
| 19 | +import cPickle |
| 20 | +import torch |
| 21 | +from torch.autograd import Variable |
| 22 | +import torch.nn as nn |
| 23 | +import torch.optim as optim |
| 24 | + |
| 25 | +import torchvision.transforms as transforms |
| 26 | + |
| 27 | +from roi_data_layer.roidb import combined_roidb |
| 28 | +from roi_data_layer.roibatchLoader import roibatchLoader |
| 29 | +from model.utils.config import cfg, cfg_from_file, cfg_from_list, get_output_dir |
| 30 | +from model.faster_rcnn.faster_rcnn import _fasterRCNN |
| 31 | +from model.rpn.bbox_transform import clip_boxes |
| 32 | +from model.nms.nms_wrapper import nms |
| 33 | +from model.fast_rcnn.nms_wrapper import nms |
| 34 | +from model.rpn.bbox_transform import bbox_transform_inv |
| 35 | +from model.utils.network import save_net, load_net, vis_detections |
| 36 | +import pdb |
| 37 | + |
| 38 | +def parse_args(): |
| 39 | + """ |
| 40 | + Parse input arguments |
| 41 | + """ |
| 42 | + parser = argparse.ArgumentParser(description='Train a Fast R-CNN network') |
| 43 | + parser.add_argument('--cfg', dest='cfg_file', |
| 44 | + help='optional config file', |
| 45 | + default='cfgs/vgg16.yml', type=str) |
| 46 | + parser.add_argument('--imdb', dest='imdb_name', |
| 47 | + help='dataset to train on', |
| 48 | + default='voc_2007_trainval', type=str) |
| 49 | + parser.add_argument('--imdbval', dest='imdbval_name', |
| 50 | + help='dataset to validate on', |
| 51 | + default='voc_2007_test', type=str) |
| 52 | + parser.add_argument('--net', dest='net', |
| 53 | + help='vgg16, res50, res101, res152', |
| 54 | + default='vgg16', type=str) |
| 55 | + parser.add_argument('--set', dest='set_cfgs', |
| 56 | + help='set config keys', default=None, |
| 57 | + nargs=argparse.REMAINDER) |
| 58 | + parser.add_argument('--load_dir', dest='load_dir', |
| 59 | + help='directory to load models', default="models", |
| 60 | + nargs=argparse.REMAINDER) |
| 61 | + parser.add_argument('--ngpu', dest='ngpu', |
| 62 | + help='number of gpu', |
| 63 | + default=1, type=int) |
| 64 | + parser.add_argument('--checksession', dest='checksession', |
| 65 | + help='checksession to load model', |
| 66 | + default=1, type=int) |
| 67 | + parser.add_argument('--checkepoch', dest='checkepoch', |
| 68 | + help='checkepoch to load network', |
| 69 | + default=1, type=int) |
| 70 | + parser.add_argument('--checkpoint', dest='checkpoint', |
| 71 | + help='checkpoint to load network', |
| 72 | + default=10000, type=int) |
| 73 | + |
| 74 | + args = parser.parse_args() |
| 75 | + return args |
| 76 | + |
| 77 | +lr = cfg.TRAIN.LEARNING_RATE |
| 78 | +momentum = cfg.TRAIN.MOMENTUM |
| 79 | +weight_decay = cfg.TRAIN.WEIGHT_DECAY |
| 80 | + |
| 81 | +if __name__ == '__main__': |
| 82 | + |
| 83 | + args = parse_args() |
| 84 | + |
| 85 | + print('Called with args:') |
| 86 | + print(args) |
| 87 | + |
| 88 | + if args.cfg_file is not None: |
| 89 | + cfg_from_file(args.cfg_file) |
| 90 | + if args.set_cfgs is not None: |
| 91 | + cfg_from_list(args.set_cfgs) |
| 92 | + |
| 93 | + print('Using config:') |
| 94 | + pprint.pprint(cfg) |
| 95 | + np.random.seed(cfg.RNG_SEED) |
| 96 | + |
| 97 | + # train set |
| 98 | + # -- Note: Use validation set and disable the flipped to enable faster loading. |
| 99 | + cfg.TRAIN.USE_FLIPPED = False |
| 100 | + imdb, roidb = combined_roidb(args.imdbval_name) |
| 101 | + imdb.competition_mode(on=True) |
| 102 | + |
| 103 | + print('{:d} roidb entries'.format(len(roidb))) |
| 104 | + |
| 105 | + input_dir = args.load_dir + "/" + args.net |
| 106 | + if not os.path.exists(input_dir): |
| 107 | + raise Exception('There is no input directory for loading network') |
| 108 | + load_name = os.path.join(input_dir, |
| 109 | + 'faster_rcnn_{}_{}_{}.pth'.format(args.checksession, args.checkepoch, args.checkpoint)) |
| 110 | + |
| 111 | + pdb.set_trace() |
| 112 | + checkpoint = torch.load(load_name) |
| 113 | + fasterRCNN = checkpoint['model'] |
| 114 | + |
| 115 | + print("load checkpoint %s" % (load_name)) |
| 116 | + |
| 117 | + # initilize the tensor holder here. |
| 118 | + im_data = torch.FloatTensor(1) |
| 119 | + im_info = torch.FloatTensor(1) |
| 120 | + num_boxes = torch.LongTensor(1) |
| 121 | + gt_boxes = torch.FloatTensor(1) |
| 122 | + |
| 123 | + # ship to cuda |
| 124 | + if args.ngpu > 0: |
| 125 | + im_data = im_data.cuda() |
| 126 | + im_info = im_info.cuda() |
| 127 | + num_boxes = num_boxes.cuda() |
| 128 | + gt_boxes = gt_boxes.cuda() |
| 129 | + |
| 130 | + # make variable |
| 131 | + im_data = Variable(im_data, volatile=True) |
| 132 | + im_info = Variable(im_info, volatile=True) |
| 133 | + num_boxes = Variable(num_boxes, volatile=True) |
| 134 | + gt_boxes = Variable(gt_boxes, volatile=True) |
| 135 | + |
| 136 | + if args.ngpu > 0: |
| 137 | + cfg.CUDA = True |
| 138 | + |
| 139 | + fasterRCNN = torch.load(load_name) |
| 140 | + print('load model successfully!') |
| 141 | + |
| 142 | + if args.ngpu > 0: |
| 143 | + fasterRCNN.cuda() |
| 144 | + |
| 145 | + fasterRCNN.eval() |
| 146 | + |
| 147 | + start = time.time() |
| 148 | + max_per_image = 100 |
| 149 | + thresh = 0.05 |
| 150 | + vis = False |
| 151 | + |
| 152 | + save_name = 'faster_rcnn_10' |
| 153 | + num_images = len(imdb.image_index) |
| 154 | + all_boxes = [[[] for _ in xrange(num_images)] |
| 155 | + for _ in xrange(imdb.num_classes)] |
| 156 | + |
| 157 | + output_dir = get_output_dir(imdb, save_name) |
| 158 | + |
| 159 | + |
| 160 | + dataset = roibatchLoader(roidb, imdb.num_classes, training=False, |
| 161 | + normalize = transforms.Normalize( |
| 162 | + mean=[0.485, 0.456, 0.406], |
| 163 | + std=[0.229, 0.224, 0.225])) |
| 164 | + |
| 165 | + dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, |
| 166 | + shuffle=False, num_workers=0, |
| 167 | + pin_memory=True) |
| 168 | + |
| 169 | + data_iter = iter(dataloader) |
| 170 | + |
| 171 | + _t = {'im_detect': time.time(), 'misc': time.time()} |
| 172 | + det_file = os.path.join(output_dir, 'detections.pkl') |
| 173 | + |
| 174 | + for i in range(num_images): |
| 175 | + |
| 176 | + data = data_iter.next() |
| 177 | + im_data.data.resize_(data[0].size()).copy_(data[0]) |
| 178 | + im_info.data.resize_(data[1].size()).copy_(data[1]) |
| 179 | + gt_boxes.data.resize_(data[2].size()).copy_(data[2]) |
| 180 | + num_boxes.data.resize_(data[3].size()).copy_(data[3]) |
| 181 | + |
| 182 | + det_tic = time.time() |
| 183 | + rois, cls_prob, bbox_pred, rpn_loss, rcnn_loss = fasterRCNN(im_data, im_info, gt_boxes, num_boxes) |
| 184 | + scores = cls_prob.data |
| 185 | + boxes = rois[:, :, 1:5] / data[1][0][2] |
| 186 | + |
| 187 | + if cfg.TEST.BBOX_REG: |
| 188 | + # Apply bounding-box regression deltas |
| 189 | + box_deltas = bbox_pred.data |
| 190 | + pred_boxes = bbox_transform_inv(boxes, box_deltas, 1) |
| 191 | + pred_boxes = clip_boxes(pred_boxes, im_info.data, 1) |
| 192 | + else: |
| 193 | + # Simply repeat the boxes, once for each class |
| 194 | + pred_boxes = np.tile(boxes, (1, scores.shape[1])) |
| 195 | + |
| 196 | + scores = scores.squeeze().cpu().numpy() |
| 197 | + pred_boxes = pred_boxes.squeeze().cpu().numpy() |
| 198 | + # _t['im_detect'].tic() |
| 199 | + det_toc = time.time() |
| 200 | + detect_time = det_toc - det_tic |
| 201 | + |
| 202 | + misc_tic = time.time() |
| 203 | + |
| 204 | + if vis: |
| 205 | + im = cv2.imread(imdb.image_path_at(i)) |
| 206 | + im2show = np.copy(im) |
| 207 | + |
| 208 | + for j in xrange(1, imdb.num_classes): |
| 209 | + inds = np.where(scores[:, j] > thresh)[0] |
| 210 | + cls_scores = scores[inds, j] |
| 211 | + cls_boxes = pred_boxes[inds, j * 4:(j + 1) * 4] |
| 212 | + cls_dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])) \ |
| 213 | + .astype(np.float32, copy=False) |
| 214 | + keep = nms(cls_dets, cfg.TEST.NMS) |
| 215 | + cls_dets = cls_dets[keep, :] |
| 216 | + if vis: |
| 217 | + im2show = vis_detections(im2show, imdb.classes[j], cls_dets) |
| 218 | + all_boxes[j][i] = cls_dets |
| 219 | + |
| 220 | + # Limit to max_per_image detections *over all classes* |
| 221 | + if max_per_image > 0: |
| 222 | + image_scores = np.hstack([all_boxes[j][i][:, -1] |
| 223 | + for j in xrange(1, imdb.num_classes)]) |
| 224 | + if len(image_scores) > max_per_image: |
| 225 | + image_thresh = np.sort(image_scores)[-max_per_image] |
| 226 | + for j in xrange(1, imdb.num_classes): |
| 227 | + keep = np.where(all_boxes[j][i][:, -1] >= image_thresh)[0] |
| 228 | + all_boxes[j][i] = all_boxes[j][i][keep, :] |
| 229 | + |
| 230 | + misc_toc = time.time() |
| 231 | + nms_time = misc_toc - misc_tic |
| 232 | + |
| 233 | + sys.stdout.write('im_detect: {:d}/{:d} {:.3f}s {:.3f}s \r' \ |
| 234 | + .format(i + 1, num_images, detect_time, nms_time)) |
| 235 | + sys.stdout.flush() |
| 236 | + |
| 237 | + if vis: |
| 238 | + cv2.imshow('test', im2show) |
| 239 | + cv2.waitKey(0) |
| 240 | + |
| 241 | + with open(det_file, 'wb') as f: |
| 242 | + cPickle.dump(all_boxes, f, cPickle.HIGHEST_PROTOCOL) |
| 243 | + |
| 244 | + print('Evaluating detections') |
| 245 | + imdb.evaluate_detections(all_boxes, output_dir) |
| 246 | + |
| 247 | + end = time.time() |
| 248 | + print("test time: %0.4fs" % (end - start)) |
0 commit comments