30
30
from python .infer import Detector , DetectorPicoDet
31
31
from python .mot_sde_infer import SDE_Detector
32
32
from python .attr_infer import AttrDetector
33
+ from python .keypoint_infer import KeyPointDetector
34
+ from python .keypoint_postprocess import translate_to_ori_images
35
+ from python .action_infer import ActionRecognizer
36
+ from python .action_utils import KeyPointCollector , ActionVisualCollector
37
+
33
38
from pipe_utils import argsparser , print_arguments , merge_cfg , PipeTimer
34
- from pipe_utils import get_test_images , crop_image_with_det , crop_image_with_mot , parse_mot_res
39
+ from pipe_utils import get_test_images , crop_image_with_det , crop_image_with_mot , parse_mot_res , parse_mot_keypoint
35
40
from python .preprocess import decode_image
36
- from python .visualize import visualize_box_mask , visualize_attr
41
+ from python .visualize import visualize_box_mask , visualize_attr , visualize_pose , visualize_action
37
42
from pptracking .python .visualize import plot_tracking
38
43
39
44
@@ -299,9 +304,45 @@ def __init__(self,
299
304
trt_max_shape , trt_opt_shape , trt_calib_mode , cpu_threads ,
300
305
enable_mkldnn )
301
306
if self .with_action :
302
- self .kpt_predictor = KeyPointDetector ()
303
- self .kpt_collector = KeyPointCollector ()
304
- self .action_predictor = ActionDetector ()
307
+ kpt_cfg = self .cfg ['KPT' ]
308
+ kpt_model_dir = kpt_cfg ['model_dir' ]
309
+ kpt_batch_size = kpt_cfg ['batch_size' ]
310
+ action_cfg = self .cfg ['ACTION' ]
311
+ action_model_dir = action_cfg ['model_dir' ]
312
+ action_batch_size = action_cfg ['batch_size' ]
313
+ action_frames = action_cfg ['max_frames' ]
314
+ display_frames = action_cfg ['display_frames' ]
315
+ self .coord_size = action_cfg ['coord_size' ]
316
+
317
+ self .kpt_predictor = KeyPointDetector (
318
+ kpt_model_dir ,
319
+ device ,
320
+ run_mode ,
321
+ kpt_batch_size ,
322
+ trt_min_shape ,
323
+ trt_max_shape ,
324
+ trt_opt_shape ,
325
+ trt_calib_mode ,
326
+ cpu_threads ,
327
+ enable_mkldnn ,
328
+ use_dark = False )
329
+ self .kpt_collector = KeyPointCollector (action_frames )
330
+
331
+ self .action_predictor = ActionRecognizer (
332
+ action_model_dir ,
333
+ device ,
334
+ run_mode ,
335
+ action_batch_size ,
336
+ trt_min_shape ,
337
+ trt_max_shape ,
338
+ trt_opt_shape ,
339
+ trt_calib_mode ,
340
+ cpu_threads ,
341
+ enable_mkldnn ,
342
+ window_size = action_frames )
343
+
344
+ self .action_visual_collector = ActionVisualCollector (
345
+ display_frames )
305
346
306
347
def set_file_name (self , path ):
307
348
self .file_name = os .path .split (path )[- 1 ]
@@ -412,7 +453,8 @@ def predict_video(self, capture):
412
453
413
454
self .pipeline_res .update (mot_res , 'mot' )
414
455
if self .with_attr or self .with_action :
415
- crop_input = crop_image_with_mot (frame , mot_res )
456
+ crop_input , new_bboxes , ori_bboxes = crop_image_with_mot (
457
+ frame , mot_res )
416
458
417
459
if self .with_attr :
418
460
if frame_id > self .warmup_frame :
@@ -424,17 +466,34 @@ def predict_video(self, capture):
424
466
self .pipeline_res .update (attr_res , 'attr' )
425
467
426
468
if self .with_action :
427
- kpt_result = self .kpt_predictor .predict_image (crop_input )
428
- self .pipeline_res .update (kpt_result , 'kpt' )
429
-
430
- self .kpt_collector .update (kpt_result ) # collect kpt output
431
- state = self .kpt_collector .state () # whether frame num is enough
432
-
469
+ kpt_pred = self .kpt_predictor .predict_image (
470
+ crop_input , visual = False )
471
+ keypoint_vector , score_vector = translate_to_ori_images (
472
+ kpt_pred , np .array (new_bboxes ))
473
+ kpt_res = {}
474
+ kpt_res ['keypoint' ] = [
475
+ keypoint_vector .tolist (), score_vector .tolist ()
476
+ ] if len (keypoint_vector ) > 0 else [[], []]
477
+ kpt_res ['bbox' ] = ori_bboxes
478
+ self .pipeline_res .update (kpt_res , 'kpt' )
479
+
480
+ self .kpt_collector .update (kpt_res ,
481
+ mot_res ) # collect kpt output
482
+ state = self .kpt_collector .get_state (
483
+ ) # whether frame num is enough or lost tracker
484
+
485
+ action_res = {}
433
486
if state :
434
- action_input = self .kpt_collector .collate (
435
- ) # reorgnize kpt output in ID
436
- action_res = self .action_predictor .predict_kpt (action_input )
437
- self .pipeline_res .update (action , 'action' )
487
+ collected_keypoint = self .kpt_collector .get_collected_keypoint (
488
+ ) # reoragnize kpt output with ID
489
+ action_input = parse_mot_keypoint (collected_keypoint ,
490
+ self .coord_size )
491
+ action_res = self .action_predictor .predict_skeleton_with_mot (
492
+ action_input )
493
+ self .pipeline_res .update (action_res , 'action' )
494
+
495
+ if self .cfg ['visual' ]:
496
+ self .action_visual_collector .update (action_res )
438
497
439
498
if frame_id > self .warmup_frame :
440
499
self .pipe_timer .img_num += 1
@@ -474,6 +533,19 @@ def visualize_video(self, image, result, frame_id, fps):
474
533
image = visualize_attr (image , attr_res , boxes )
475
534
image = np .array (image )
476
535
536
+ kpt_res = result .get ('kpt' )
537
+ if kpt_res is not None :
538
+ image = visualize_pose (
539
+ image ,
540
+ kpt_res ,
541
+ visual_thresh = self .cfg ['kpt_thresh' ],
542
+ returnimg = True )
543
+
544
+ action_res = result .get ('action' )
545
+ if action_res is not None :
546
+ image = visualize_action (image , mot_res ['boxes' ],
547
+ self .action_visual_collector , "Falling" )
548
+
477
549
return image
478
550
479
551
def visualize_image (self , im_files , images , result ):
0 commit comments