@@ -65,7 +65,7 @@ class CenterTrack(Detector):
65
65
"""
66
66
Args:
67
67
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
68
- device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
68
+ device (str): Choose the device you want to run, it can be: CPU/GPU/XPU/NPU , default is CPU
69
69
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
70
70
batch_size (int): size of pre batch in inference
71
71
trt_min_shape (int): min shape for dynamic shape in trt
@@ -130,7 +130,7 @@ def __init__(
130
130
vertical_ratio = vertical_ratio ,
131
131
track_thresh = track_thresh ,
132
132
pre_thresh = pre_thresh )
133
-
133
+
134
134
self .pre_image = None
135
135
136
136
def get_additional_inputs (self , dets , meta , with_hm = True ):
@@ -173,19 +173,18 @@ def preprocess(self, image_list):
173
173
#inputs = create_inputs(im, im_info)
174
174
inputs = {}
175
175
inputs ['image' ] = np .array ((im , )).astype ('float32' )
176
- inputs ['im_shape' ] = np .array (
177
- (im_info ['im_shape' ], )).astype ('float32' )
176
+ inputs ['im_shape' ] = np .array ((im_info ['im_shape' ], )).astype ('float32' )
178
177
inputs ['scale_factor' ] = np .array (
179
178
(im_info ['scale_factor' ], )).astype ('float32' )
180
-
179
+
181
180
inputs ['trans_input' ] = im_info ['trans_input' ]
182
181
inputs ['inp_width' ] = im_info ['inp_width' ]
183
182
inputs ['inp_height' ] = im_info ['inp_height' ]
184
183
inputs ['center' ] = im_info ['center' ]
185
184
inputs ['scale' ] = im_info ['scale' ]
186
185
inputs ['out_height' ] = im_info ['out_height' ]
187
186
inputs ['out_width' ] = im_info ['out_width' ]
188
-
187
+
189
188
if self .pre_image is None :
190
189
self .pre_image = inputs ['image' ]
191
190
# initializing tracker for the first frame
@@ -196,7 +195,7 @@ def preprocess(self, image_list):
196
195
# render input heatmap from tracker status
197
196
pre_hm = self .get_additional_inputs (
198
197
self .tracker .tracks , inputs , with_hm = True )
199
- inputs ['pre_hm' ] = pre_hm #.to_tensor(pre_hm)
198
+ inputs ['pre_hm' ] = pre_hm #.to_tensor(pre_hm)
200
199
201
200
input_names = self .predictor .get_input_names ()
202
201
for i in range (len (input_names )):
@@ -256,8 +255,8 @@ def centertrack_post_process(self, dets, meta, out_thresh):
256
255
return preds
257
256
258
257
def tracking (self , inputs , det_results ):
259
- result = self .centertrack_post_process (
260
- det_results , inputs , self .tracker .out_thresh )
258
+ result = self .centertrack_post_process (det_results , inputs ,
259
+ self .tracker .out_thresh )
261
260
online_targets = self .tracker .update (result )
262
261
263
262
online_tlwhs , online_scores , online_ids = [], [], []
@@ -292,10 +291,7 @@ def predict(self, repeats=1):
292
291
tracking_tensor = self .predictor .get_output_handle (output_names [2 ])
293
292
np_tracking = tracking_tensor .copy_to_cpu ()
294
293
295
- result = dict (
296
- bboxes = np_bboxes ,
297
- cts = np_cts ,
298
- tracking = np_tracking )
294
+ result = dict (bboxes = np_bboxes , cts = np_cts , tracking = np_tracking )
299
295
return result
300
296
301
297
def predict_image (self ,
@@ -333,8 +329,8 @@ def predict_image(self,
333
329
# tracking
334
330
result_warmup = self .tracking (inputs , det_result )
335
331
self .det_times .tracking_time_s .start ()
336
- online_tlwhs , online_scores , online_ids = self .tracking (inputs ,
337
- det_result )
332
+ online_tlwhs , online_scores , online_ids = self .tracking (
333
+ inputs , det_result )
338
334
self .det_times .tracking_time_s .end ()
339
335
self .det_times .img_num += 1
340
336
@@ -358,8 +354,8 @@ def predict_image(self,
358
354
359
355
# tracking process
360
356
self .det_times .tracking_time_s .start ()
361
- online_tlwhs , online_scores , online_ids = self .tracking (inputs ,
362
- det_result )
357
+ online_tlwhs , online_scores , online_ids = self .tracking (
358
+ inputs , det_result )
363
359
self .det_times .tracking_time_s .end ()
364
360
self .det_times .img_num += 1
365
361
@@ -499,7 +495,7 @@ def main():
499
495
FLAGS = parser .parse_args ()
500
496
print_arguments (FLAGS )
501
497
FLAGS .device = FLAGS .device .upper ()
502
- assert FLAGS .device in ['CPU' , 'GPU' , 'XPU'
503
- ], "device should be CPU, GPU or XPU"
498
+ assert FLAGS .device in ['CPU' , 'GPU' , 'XPU' , 'NPU'
499
+ ], "device should be CPU, GPU, NPU or XPU"
504
500
505
501
main ()
0 commit comments