@@ -19,6 +19,7 @@ def parse_args():
19
19
parser .add_argument ('--model' , help = 'model file path' )
20
20
parser .add_argument ('--path' , default = './demo' , help = 'path to images or video' )
21
21
parser .add_argument ('--camid' , type = int , default = 0 , help = 'webcam demo camera id' )
22
+ parser .add_argument ('--save_result' , action = 'store_true' , help = 'whether to save the inference result of image/video' )
22
23
args = parser .parse_args ()
23
24
return args
24
25
@@ -61,8 +62,9 @@ def inference(self, img):
61
62
62
63
def visualize (self , dets , meta , class_names , score_thres , wait = 0 ):
63
64
time1 = time .time ()
64
- self .model .head .show_result (meta ['raw_img' ], dets , class_names , score_thres = score_thres , show = True )
65
+ result_img = self .model .head .show_result (meta ['raw_img' ], dets , class_names , score_thres = score_thres , show = True )
65
66
print ('viz time: {:.3f}s' .format (time .time ()- time1 ))
67
+ return result_img
66
68
67
69
68
70
def get_image_list (path ):
@@ -85,6 +87,7 @@ def main():
85
87
logger = Logger (- 1 , use_tensorboard = False )
86
88
predictor = Predictor (cfg , args .model , logger , device = 'cuda:0' )
87
89
logger .log ('Press "Esc", "q" or "Q" to exit.' )
90
+ current_time = time .localtime ()
88
91
if args .demo == 'image' :
89
92
if os .path .isdir (args .path ):
90
93
files = get_image_list (args .path )
@@ -93,18 +96,38 @@ def main():
93
96
files .sort ()
94
97
for image_name in files :
95
98
meta , res = predictor .inference (image_name )
96
- predictor .visualize (res , meta , cfg .class_names , 0.35 )
99
+ result_image = predictor .visualize (res , meta , cfg .class_names , 0.35 )
100
+ if args .save_result :
101
+ save_folder = os .path .join (cfg .save_dir , time .strftime ("%Y_%m_%d_%H_%M_%S" , current_time ))
102
+ if not os .path .exists (save_folder ):
103
+ os .mkdir (save_folder )
104
+ save_file_name = os .path .join (save_folder , os .path .basename (image_name ))
105
+ cv2 .imwrite (save_file_name , result_image )
97
106
ch = cv2 .waitKey (0 )
98
107
if ch == 27 or ch == ord ('q' ) or ch == ord ('Q' ):
99
108
break
100
109
elif args .demo == 'video' or args .demo == 'webcam' :
101
110
cap = cv2 .VideoCapture (args .path if args .demo == 'video' else args .camid )
111
+ width = cap .get (cv2 .CAP_PROP_FRAME_WIDTH ) # float
112
+ height = cap .get (cv2 .CAP_PROP_FRAME_HEIGHT ) # float
113
+ fps = cap .get (cv2 .CAP_PROP_FPS )
114
+ save_folder = os .path .join (cfg .save_dir , time .strftime ("%Y_%m_%d_%H_%M_%S" , current_time ))
115
+ if not os .path .exists (save_folder ):
116
+ os .mkdir (save_folder )
117
+ save_path = os .path .join (save_folder , args .path .split ('/' )[- 1 ]) if args .demo == 'video' else os .path .join (save_folder , 'camera.mp4' )
118
+ print (f'save_path is { save_path } ' )
119
+ vid_writer = cv2 .VideoWriter (save_path , cv2 .VideoWriter_fourcc (* 'mp4v' ), fps , (int (width ), int (height )))
102
120
while True :
103
121
ret_val , frame = cap .read ()
104
- meta , res = predictor .inference (frame )
105
- predictor .visualize (res , meta , cfg .class_names , 0.35 )
106
- ch = cv2 .waitKey (1 )
107
- if ch == 27 or ch == ord ('q' ) or ch == ord ('Q' ):
122
+ if ret_val :
123
+ meta , res = predictor .inference (frame )
124
+ result_frame = predictor .visualize (res , meta , cfg .class_names , 0.35 )
125
+ if args .save_result :
126
+ vid_writer .write (result_frame )
127
+ ch = cv2 .waitKey (1 )
128
+ if ch == 27 or ch == ord ('q' ) or ch == ord ('Q' ):
129
+ break
130
+ else :
108
131
break
109
132
110
133
0 commit comments