Skip to content

Commit 21855f5

Browse files
committed
add train,test
1 parent 671b463 commit 21855f5

7 files changed

+646
-0
lines changed

_init_paths.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import os.path as osp
2+
import sys
3+
4+
def add_path(path):
5+
if path not in sys.path:
6+
sys.path.insert(0, path)
7+
8+
this_dir = osp.dirname(__file__)
9+
10+
# Add lib to PYTHONPATH
11+
lib_path = osp.join(this_dir, 'lib')
12+
add_path(lib_path)
13+
14+
coco_path = osp.join(this_dir, 'data', 'coco', 'PythonAPI')
15+
add_path(coco_path)

cfgs/res101-lg.yml

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
EXP_DIR: res101-lg
2+
TRAIN:
3+
HAS_RPN: True
4+
# IMS_PER_BATCH: 1
5+
BBOX_NORMALIZE_TARGETS_PRECOMPUTED: True
6+
RPN_POSITIVE_OVERLAP: 0.7
7+
RPN_BATCHSIZE: 256
8+
PROPOSAL_METHOD: gt
9+
BG_THRESH_LO: 0.0
10+
DISPLAY: 20
11+
BATCH_SIZE: 256
12+
WEIGHT_DECAY: 0.0001
13+
DOUBLE_BIAS: False
14+
SNAPSHOT_PREFIX: res101_faster_rcnn
15+
SCALES: [800]
16+
MAX_SIZE: 1333
17+
TEST:
18+
HAS_RPN: True
19+
SCALES: [800]
20+
MAX_SIZE: 1333
21+
RPN_POST_NMS_TOP_N: 1000
22+
POOLING_MODE: crop
23+
ANCHOR_SCALES: [2,4,8,16,32]

cfgs/res101.yml

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
EXP_DIR: res101
2+
TRAIN:
3+
HAS_RPN: True
4+
# IMS_PER_BATCH: 1
5+
BBOX_NORMALIZE_TARGETS_PRECOMPUTED: True
6+
RPN_POSITIVE_OVERLAP: 0.7
7+
RPN_BATCHSIZE: 256
8+
PROPOSAL_METHOD: gt
9+
BG_THRESH_LO: 0.0
10+
DISPLAY: 20
11+
BATCH_SIZE: 256
12+
WEIGHT_DECAY: 0.0001
13+
DOUBLE_BIAS: False
14+
SNAPSHOT_PREFIX: res101_faster_rcnn
15+
TEST:
16+
HAS_RPN: True
17+
POOLING_MODE: crop

cfgs/res50.yml

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
EXP_DIR: res50
2+
TRAIN:
3+
HAS_RPN: True
4+
# IMS_PER_BATCH: 1
5+
BBOX_NORMALIZE_TARGETS_PRECOMPUTED: True
6+
RPN_POSITIVE_OVERLAP: 0.7
7+
RPN_BATCHSIZE: 256
8+
PROPOSAL_METHOD: gt
9+
BG_THRESH_LO: 0.0
10+
DISPLAY: 20
11+
BATCH_SIZE: 256
12+
WEIGHT_DECAY: 0.0001
13+
DOUBLE_BIAS: False
14+
SNAPSHOT_PREFIX: res50_faster_rcnn
15+
TEST:
16+
HAS_RPN: True
17+
POOLING_MODE: crop

cfgs/vgg16.yml

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
EXP_DIR: vgg16
2+
TRAIN:
3+
HAS_RPN: True
4+
# IMS_PER_BATCH: 1
5+
BBOX_NORMALIZE_TARGETS_PRECOMPUTED: False
6+
RPN_POSITIVE_OVERLAP: 0.7
7+
RPN_BATCHSIZE: 256
8+
PROPOSAL_METHOD: gt
9+
BG_THRESH_LO: 0.0
10+
BATCH_SIZE: 256
11+
TEST:
12+
HAS_RPN: True

test_net.py

+248
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# --------------------------------------------------------
2+
# Tensorflow Faster R-CNN
3+
# Licensed under The MIT License [see LICENSE for details]
4+
# Written by Jiasen Lu, Jianwei Yang, based on code from Ross Girshick
5+
# --------------------------------------------------------
6+
from __future__ import absolute_import
7+
from __future__ import division
8+
from __future__ import print_function
9+
10+
import _init_paths
11+
import os
12+
import sys
13+
import numpy as np
14+
import argparse
15+
import pprint
16+
import pdb
17+
import time
18+
import cv2
19+
import cPickle
20+
import torch
21+
from torch.autograd import Variable
22+
import torch.nn as nn
23+
import torch.optim as optim
24+
25+
import torchvision.transforms as transforms
26+
27+
from roi_data_layer.roidb import combined_roidb
28+
from roi_data_layer.roibatchLoader import roibatchLoader
29+
from model.utils.config import cfg, cfg_from_file, cfg_from_list, get_output_dir
30+
from model.faster_rcnn.faster_rcnn import _fasterRCNN
31+
from model.rpn.bbox_transform import clip_boxes
32+
from model.nms.nms_wrapper import nms
33+
from model.fast_rcnn.nms_wrapper import nms
34+
from model.rpn.bbox_transform import bbox_transform_inv
35+
from model.utils.network import save_net, load_net, vis_detections
36+
import pdb
37+
38+
def parse_args():
39+
"""
40+
Parse input arguments
41+
"""
42+
parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
43+
parser.add_argument('--cfg', dest='cfg_file',
44+
help='optional config file',
45+
default='cfgs/vgg16.yml', type=str)
46+
parser.add_argument('--imdb', dest='imdb_name',
47+
help='dataset to train on',
48+
default='voc_2007_trainval', type=str)
49+
parser.add_argument('--imdbval', dest='imdbval_name',
50+
help='dataset to validate on',
51+
default='voc_2007_test', type=str)
52+
parser.add_argument('--net', dest='net',
53+
help='vgg16, res50, res101, res152',
54+
default='vgg16', type=str)
55+
parser.add_argument('--set', dest='set_cfgs',
56+
help='set config keys', default=None,
57+
nargs=argparse.REMAINDER)
58+
parser.add_argument('--load_dir', dest='load_dir',
59+
help='directory to load models', default="models",
60+
nargs=argparse.REMAINDER)
61+
parser.add_argument('--ngpu', dest='ngpu',
62+
help='number of gpu',
63+
default=1, type=int)
64+
parser.add_argument('--checksession', dest='checksession',
65+
help='checksession to load model',
66+
default=1, type=int)
67+
parser.add_argument('--checkepoch', dest='checkepoch',
68+
help='checkepoch to load network',
69+
default=1, type=int)
70+
parser.add_argument('--checkpoint', dest='checkpoint',
71+
help='checkpoint to load network',
72+
default=10000, type=int)
73+
74+
args = parser.parse_args()
75+
return args
76+
77+
lr = cfg.TRAIN.LEARNING_RATE
78+
momentum = cfg.TRAIN.MOMENTUM
79+
weight_decay = cfg.TRAIN.WEIGHT_DECAY
80+
81+
if __name__ == '__main__':
82+
83+
args = parse_args()
84+
85+
print('Called with args:')
86+
print(args)
87+
88+
if args.cfg_file is not None:
89+
cfg_from_file(args.cfg_file)
90+
if args.set_cfgs is not None:
91+
cfg_from_list(args.set_cfgs)
92+
93+
print('Using config:')
94+
pprint.pprint(cfg)
95+
np.random.seed(cfg.RNG_SEED)
96+
97+
# train set
98+
# -- Note: Use validation set and disable the flipped to enable faster loading.
99+
cfg.TRAIN.USE_FLIPPED = False
100+
imdb, roidb = combined_roidb(args.imdbval_name)
101+
imdb.competition_mode(on=True)
102+
103+
print('{:d} roidb entries'.format(len(roidb)))
104+
105+
input_dir = args.load_dir + "/" + args.net
106+
if not os.path.exists(input_dir):
107+
raise Exception('There is no input directory for loading network')
108+
load_name = os.path.join(input_dir,
109+
'faster_rcnn_{}_{}_{}.pth'.format(args.checksession, args.checkepoch, args.checkpoint))
110+
111+
pdb.set_trace()
112+
checkpoint = torch.load(load_name)
113+
fasterRCNN = checkpoint['model']
114+
115+
print("load checkpoint %s" % (load_name))
116+
117+
# initilize the tensor holder here.
118+
im_data = torch.FloatTensor(1)
119+
im_info = torch.FloatTensor(1)
120+
num_boxes = torch.LongTensor(1)
121+
gt_boxes = torch.FloatTensor(1)
122+
123+
# ship to cuda
124+
if args.ngpu > 0:
125+
im_data = im_data.cuda()
126+
im_info = im_info.cuda()
127+
num_boxes = num_boxes.cuda()
128+
gt_boxes = gt_boxes.cuda()
129+
130+
# make variable
131+
im_data = Variable(im_data, volatile=True)
132+
im_info = Variable(im_info, volatile=True)
133+
num_boxes = Variable(num_boxes, volatile=True)
134+
gt_boxes = Variable(gt_boxes, volatile=True)
135+
136+
if args.ngpu > 0:
137+
cfg.CUDA = True
138+
139+
fasterRCNN = torch.load(load_name)
140+
print('load model successfully!')
141+
142+
if args.ngpu > 0:
143+
fasterRCNN.cuda()
144+
145+
fasterRCNN.eval()
146+
147+
start = time.time()
148+
max_per_image = 100
149+
thresh = 0.05
150+
vis = False
151+
152+
save_name = 'faster_rcnn_10'
153+
num_images = len(imdb.image_index)
154+
all_boxes = [[[] for _ in xrange(num_images)]
155+
for _ in xrange(imdb.num_classes)]
156+
157+
output_dir = get_output_dir(imdb, save_name)
158+
159+
160+
dataset = roibatchLoader(roidb, imdb.num_classes, training=False,
161+
normalize = transforms.Normalize(
162+
mean=[0.485, 0.456, 0.406],
163+
std=[0.229, 0.224, 0.225]))
164+
165+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1,
166+
shuffle=False, num_workers=0,
167+
pin_memory=True)
168+
169+
data_iter = iter(dataloader)
170+
171+
_t = {'im_detect': time.time(), 'misc': time.time()}
172+
det_file = os.path.join(output_dir, 'detections.pkl')
173+
174+
for i in range(num_images):
175+
176+
data = data_iter.next()
177+
im_data.data.resize_(data[0].size()).copy_(data[0])
178+
im_info.data.resize_(data[1].size()).copy_(data[1])
179+
gt_boxes.data.resize_(data[2].size()).copy_(data[2])
180+
num_boxes.data.resize_(data[3].size()).copy_(data[3])
181+
182+
det_tic = time.time()
183+
rois, cls_prob, bbox_pred, rpn_loss, rcnn_loss = fasterRCNN(im_data, im_info, gt_boxes, num_boxes)
184+
scores = cls_prob.data
185+
boxes = rois[:, :, 1:5] / data[1][0][2]
186+
187+
if cfg.TEST.BBOX_REG:
188+
# Apply bounding-box regression deltas
189+
box_deltas = bbox_pred.data
190+
pred_boxes = bbox_transform_inv(boxes, box_deltas, 1)
191+
pred_boxes = clip_boxes(pred_boxes, im_info.data, 1)
192+
else:
193+
# Simply repeat the boxes, once for each class
194+
pred_boxes = np.tile(boxes, (1, scores.shape[1]))
195+
196+
scores = scores.squeeze().cpu().numpy()
197+
pred_boxes = pred_boxes.squeeze().cpu().numpy()
198+
# _t['im_detect'].tic()
199+
det_toc = time.time()
200+
detect_time = det_toc - det_tic
201+
202+
misc_tic = time.time()
203+
204+
if vis:
205+
im = cv2.imread(imdb.image_path_at(i))
206+
im2show = np.copy(im)
207+
208+
for j in xrange(1, imdb.num_classes):
209+
inds = np.where(scores[:, j] > thresh)[0]
210+
cls_scores = scores[inds, j]
211+
cls_boxes = pred_boxes[inds, j * 4:(j + 1) * 4]
212+
cls_dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])) \
213+
.astype(np.float32, copy=False)
214+
keep = nms(cls_dets, cfg.TEST.NMS)
215+
cls_dets = cls_dets[keep, :]
216+
if vis:
217+
im2show = vis_detections(im2show, imdb.classes[j], cls_dets)
218+
all_boxes[j][i] = cls_dets
219+
220+
# Limit to max_per_image detections *over all classes*
221+
if max_per_image > 0:
222+
image_scores = np.hstack([all_boxes[j][i][:, -1]
223+
for j in xrange(1, imdb.num_classes)])
224+
if len(image_scores) > max_per_image:
225+
image_thresh = np.sort(image_scores)[-max_per_image]
226+
for j in xrange(1, imdb.num_classes):
227+
keep = np.where(all_boxes[j][i][:, -1] >= image_thresh)[0]
228+
all_boxes[j][i] = all_boxes[j][i][keep, :]
229+
230+
misc_toc = time.time()
231+
nms_time = misc_toc - misc_tic
232+
233+
sys.stdout.write('im_detect: {:d}/{:d} {:.3f}s {:.3f}s \r' \
234+
.format(i + 1, num_images, detect_time, nms_time))
235+
sys.stdout.flush()
236+
237+
if vis:
238+
cv2.imshow('test', im2show)
239+
cv2.waitKey(0)
240+
241+
with open(det_file, 'wb') as f:
242+
cPickle.dump(all_boxes, f, cPickle.HIGHEST_PROTOCOL)
243+
244+
print('Evaluating detections')
245+
imdb.evaluate_detections(all_boxes, output_dir)
246+
247+
end = time.time()
248+
print("test time: %0.4fs" % (end - start))

0 commit comments

Comments
 (0)