Skip to content

Commit 5471a17

Browse files
authored
Update pphuman (PaddlePaddle#5393)
* refine attr vis & refine model_dir in config * support model_dir in command line
1 parent 6e1fa92 commit 5471a17

File tree

5 files changed

+36
-16
lines changed

5 files changed

+36
-16
lines changed

deploy/pphuman/config/infer_cfg.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@ ATTR:
1111
batch_size: 8
1212

1313
MOT:
14-
model_dir: output_inference/pedestrian_yolov3_darknet/
15-
tracker_config: deploy/pphuman/tracker_config.yml
14+
model_dir: output_inference/mot_ppyolov3/
15+
tracker_config: deploy/pphuman/config/tracker_config.yml
1616
batch_size: 1
File renamed without changes.

deploy/pphuman/pipe_utils.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def argsparser():
4545
default=None,
4646
help="Path of video file, `video_file` or `camera_id` has a highest priority."
4747
)
48+
parser.add_argument(
49+
"--model_dir", nargs='*', help="set model dir in pipeline")
4850
parser.add_argument(
4951
"--camera_id",
5052
type=int,
@@ -182,6 +184,21 @@ def report(self, average=False):
182184
return dic
183185

184186

187+
def merge_model_dir(args, model_dir):
188+
# set --model_dir DET=ppyoloe/ to overwrite the model_dir in config file
189+
task_set = ['DET', 'ATTR', 'MOT', 'KPT', 'ACTION']
190+
if not model_dir:
191+
return args
192+
for md in model_dir:
193+
md = md.strip()
194+
k, v = md.split('=', 1)
195+
k_upper = k.upper()
196+
assert k_upper in task_set, 'Illegal type of task, expect task are: {}, but received {}'.format(
197+
task_set, k)
198+
args[k_upper].update({'model_dir': v})
199+
return args
200+
201+
185202
def merge_cfg(args):
186203
with open(args.config) as f:
187204
pred_config = yaml.safe_load(f)
@@ -196,14 +213,17 @@ def merge(cfg, arg):
196213
merge_cfg[k] = merge(v, arg)
197214
return merge_cfg
198215

199-
pred_config = merge(pred_config, vars(args))
216+
args_dict = vars(args)
217+
model_dir = args_dict.pop('model_dir')
218+
pred_config = merge_model_dir(pred_config, model_dir)
219+
pred_config = merge(pred_config, args_dict)
200220
return pred_config
201221

202222

203223
def print_arguments(cfg):
204224
print('----------- Running Arguments -----------')
205-
for arg, value in sorted(cfg.items()):
206-
print('%s: %s' % (arg, value))
225+
buffer = yaml.dump(cfg)
226+
print(buffer)
207227
print('------------------------------------------')
208228

209229

deploy/python/attr_infer.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,13 @@ def postprocess(self, inputs, result):
9696
age_list = ['AgeLess18', 'Age18-60', 'AgeOver60']
9797
direct_list = ['Front', 'Side', 'Back']
9898
bag_list = ['HandBag', 'ShoulderBag', 'Backpack']
99-
upper_list = [
100-
'UpperStride', 'UpperLogo', 'UpperPlaid', 'UpperSplice', 'LongCoat'
101-
]
99+
upper_list = ['UpperStride', 'UpperLogo', 'UpperPlaid', 'UpperSplice']
102100
lower_list = [
103-
'LowerStripe', 'LowerPattern', 'Trousers', 'Shorts', 'Skirt&Dress'
101+
'LowerStripe', 'LowerPattern', 'LongCoat', 'Trousers', 'Shorts',
102+
'Skirt&Dress'
104103
]
105-
104+
glasses_threshold = 0.3
105+
hold_threshold = 0.6
106106
batch_res = []
107107
for res in im_results:
108108
res = res.tolist()
@@ -118,7 +118,7 @@ def postprocess(self, inputs, result):
118118
label_res.append(direction)
119119
# glasses
120120
glasses = 'Glasses: '
121-
if res[1] > self.threshold:
121+
if res[1] > glasses_threshold:
122122
glasses += 'True'
123123
else:
124124
glasses += 'False'
@@ -132,7 +132,7 @@ def postprocess(self, inputs, result):
132132
label_res.append(hat)
133133
# hold obj
134134
hold_obj = 'HoldObjectsInFront: '
135-
if res[18] > self.threshold:
135+
if res[18] > hold_threshold:
136136
hold_obj += 'True'
137137
else:
138138
hold_obj += 'False'
@@ -143,7 +143,7 @@ def postprocess(self, inputs, result):
143143
bag_label = bag if bag_score > self.threshold else 'No bag'
144144
label_res.append(bag_label)
145145
# upper
146-
upper_res = res[4:8] + res[10:11]
146+
upper_res = res[4:8]
147147
upper_label = 'Upper:'
148148
sleeve = 'LongSleeve' if res[3] > res[2] else 'ShortSleeve'
149149
upper_label += ' {}'.format(sleeve)
@@ -152,7 +152,7 @@ def postprocess(self, inputs, result):
152152
upper_label += ' {}'.format(upper_list[i])
153153
label_res.append(upper_label)
154154
# lower
155-
lower_res = res[8:10] + res[11:14]
155+
lower_res = res[8:14]
156156
lower_label = 'Lower: '
157157
has_lower = False
158158
for i, l in enumerate(lower_res):

deploy/python/visualize.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,8 @@ def visualize_attr(im, results, boxes=None):
338338
im = np.ascontiguousarray(np.copy(im))
339339

340340
im_h, im_w = im.shape[:2]
341-
text_scale = max(1, int(im.shape[0] / 1600.))
342-
text_thickness = 2
341+
text_scale = max(1, int(im.shape[0] / 1200.))
342+
text_thickness = 3
343343

344344
line_inter = im.shape[0] / 50.
345345
for i, res in enumerate(results):

0 commit comments

Comments
 (0)