Skip to content

Commit 66187ae

Browse files
authored
Merge pull request #191 from wwdok/main
add feature of saving demo inference result
2 parents 847dae1 + ff68de3 commit 66187ae

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

demo/demo.py

+29-6
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def parse_args():
1919
parser.add_argument('--model', help='model file path')
2020
parser.add_argument('--path', default='./demo', help='path to images or video')
2121
parser.add_argument('--camid', type=int, default=0, help='webcam demo camera id')
22+
parser.add_argument('--save_result', action='store_true', help='whether to save the inference result of image/video')
2223
args = parser.parse_args()
2324
return args
2425

@@ -61,8 +62,9 @@ def inference(self, img):
6162

6263
def visualize(self, dets, meta, class_names, score_thres, wait=0):
6364
time1 = time.time()
64-
self.model.head.show_result(meta['raw_img'], dets, class_names, score_thres=score_thres, show=True)
65+
result_img = self.model.head.show_result(meta['raw_img'], dets, class_names, score_thres=score_thres, show=True)
6566
print('viz time: {:.3f}s'.format(time.time()-time1))
67+
return result_img
6668

6769

6870
def get_image_list(path):
@@ -85,6 +87,7 @@ def main():
8587
logger = Logger(-1, use_tensorboard=False)
8688
predictor = Predictor(cfg, args.model, logger, device='cuda:0')
8789
logger.log('Press "Esc", "q" or "Q" to exit.')
90+
current_time = time.localtime()
8891
if args.demo == 'image':
8992
if os.path.isdir(args.path):
9093
files = get_image_list(args.path)
@@ -93,18 +96,38 @@ def main():
9396
files.sort()
9497
for image_name in files:
9598
meta, res = predictor.inference(image_name)
96-
predictor.visualize(res, meta, cfg.class_names, 0.35)
99+
result_image = predictor.visualize(res, meta, cfg.class_names, 0.35)
100+
if args.save_result:
101+
save_folder = os.path.join(cfg.save_dir, time.strftime("%Y_%m_%d_%H_%M_%S", current_time))
102+
if not os.path.exists(save_folder):
103+
os.mkdir(save_folder)
104+
save_file_name = os.path.join(save_folder, os.path.basename(image_name))
105+
cv2.imwrite(save_file_name, result_image)
97106
ch = cv2.waitKey(0)
98107
if ch == 27 or ch == ord('q') or ch == ord('Q'):
99108
break
100109
elif args.demo == 'video' or args.demo == 'webcam':
101110
cap = cv2.VideoCapture(args.path if args.demo == 'video' else args.camid)
111+
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
112+
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
113+
fps = cap.get(cv2.CAP_PROP_FPS)
114+
save_folder = os.path.join(cfg.save_dir, time.strftime("%Y_%m_%d_%H_%M_%S", current_time))
115+
if not os.path.exists(save_folder):
116+
os.mkdir(save_folder)
117+
save_path = os.path.join(save_folder, args.path.split('/')[-1]) if args.demo == 'video' else os.path.join(save_folder, 'camera.mp4')
118+
print(f'save_path is {save_path}')
119+
vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (int(width), int(height)))
102120
while True:
103121
ret_val, frame = cap.read()
104-
meta, res = predictor.inference(frame)
105-
predictor.visualize(res, meta, cfg.class_names, 0.35)
106-
ch = cv2.waitKey(1)
107-
if ch == 27 or ch == ord('q') or ch == ord('Q'):
122+
if ret_val:
123+
meta, res = predictor.inference(frame)
124+
result_frame = predictor.visualize(res, meta, cfg.class_names, 0.35)
125+
if args.save_result:
126+
vid_writer.write(result_frame)
127+
ch = cv2.waitKey(1)
128+
if ch == 27 or ch == ord('q') or ch == ord('Q'):
129+
break
130+
else:
108131
break
109132

110133

nanodet/model/head/gfl_head.py

+1
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ def show_result(self, img, dets, class_names, score_thres=0.3, show=True, save_p
457457
result = overlay_bbox_cv(img, dets, class_names, score_thresh=score_thres)
458458
if show:
459459
cv2.imshow('det', result)
460+
return result
460461

461462
def get_bboxes(self,
462463
cls_scores,

0 commit comments

Comments
 (0)