Skip to content

Commit e277190

Browse files
authored
[Enhancement] Update data reading process (PaddlePaddle#1719)
1 parent fb22907 commit e277190

File tree

12 files changed

+307
-578
lines changed

12 files changed

+307
-578
lines changed

contrib/PP-HumanSeg/README_cn.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ python data/download_data.py
127127
```bash
128128
# 通过电脑摄像头进行实时分割处理
129129
python bg_replace.py \
130-
--config export_model/ppseg_lite_portrait_398x224_with_softmax/deploy.yaml
130+
--config export_model/pphumanseg_lite_portrait_398x224_with_softmax/deploy.yaml
131131

132132
# 对人像视频进行分割处理
133133
python bg_replace.py \
@@ -143,7 +143,7 @@ python bg_replace.py \
143143
```bash
144144
# 增加光流后处理
145145
python bg_replace.py \
146-
--config export_model/ppseg_lite_portrait_398x224_with_softmax/deploy.yaml \
146+
--config export_model/pphumanseg_lite_portrait_398x224_with_softmax/deploy.yaml \
147147
--use_optic_flow
148148
```
149149

@@ -152,7 +152,7 @@ python bg_replace.py \
152152
```bash
153153
# 通过电脑摄像头进行实时背景替换处理。可通过'--background_video_path'传入背景视频
154154
python bg_replace.py \
155-
--config export_model/ppseg_lite_portrait_398x224_with_softmax/deploy.yaml \
155+
--config export_model/pphumanseg_lite_portrait_398x224_with_softmax/deploy.yaml \
156156
--input_shape 224 398 \
157157
--bg_img_path data/background.jpg
158158

@@ -164,7 +164,7 @@ python bg_replace.py \
164164

165165
# 对单张图像进行背景替换
166166
python bg_replace.py \
167-
--config export_model/ppseg_lite_portrait_398x224_with_softmax/deploy.yaml \
167+
--config export_model/pphumanseg_lite_portrait_398x224_with_softmax/deploy.yaml \
168168
--input_shape 224 398 \
169169
--img_path data/human_image.jpg \
170170
--bg_img_path data/background.jpg

contrib/PP-HumanSeg/scripts/train.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,25 @@ def check_logits_losses(logits_list, losses):
3333
.format(len_logits, len_losses))
3434

3535

36-
def loss_computation(logits_list, labels, losses, edges=None):
36+
def loss_computation(logits_list, label_dict, losses):
3737
check_logits_losses(logits_list, losses)
3838
loss_list = []
3939
for i in range(len(logits_list)):
4040
logits = logits_list[i]
4141
loss_i = losses['types'][i]
42-
# Whether to use edges as labels According to loss type.
42+
coef_i = losses['coef'][i]
4343
if loss_i.__class__.__name__ in ('BCELoss', ) and loss_i.edge_label:
44-
loss_list.append(losses['coef'][i] * loss_i(logits, edges))
44+
# Use edges as labels According to loss type.
45+
loss_list.append(coef_i * loss_i(logits, label_dict['edge']))
46+
elif loss_i.__class__.__name__ == 'MixedLoss':
47+
mixed_loss_list = loss_i(logits, label_dict['label'])
48+
for mixed_loss in mixed_loss_list:
49+
loss_list.append(coef_i * mixed_loss)
50+
elif loss_i.__class__.__name__ in ("KLLoss", ):
51+
loss_list.append(coef_i *
52+
loss_i(logits_list[0], logits_list[1].detach()))
4553
else:
46-
loss_list.append(losses['coef'][i] * loss_i(logits, labels))
54+
loss_list.append(coef_i * loss_i(logits, label_dict['label']))
4755
return loss_list
4856

4957

@@ -132,21 +140,18 @@ def train(model,
132140
if iter > iters:
133141
break
134142
reader_cost_averager.record(time.time() - batch_start)
135-
images = data[0]
136-
labels = data[1].astype('int64')
143+
images = data['img']
144+
labels = data['label'].astype('int64')
137145
edges = None
138-
if len(data) == 3:
139-
edges = data[2].astype('int64')
146+
if 'edge' in data.keys():
147+
edges = data['edge'].astype('int64')
140148

141149
if nranks > 1:
142150
logits_list = ddp_model(images)
143151
else:
144152
logits_list = model(images)
145153
loss_list = loss_computation(
146-
logits_list=logits_list,
147-
labels=labels,
148-
losses=losses,
149-
edges=edges)
154+
logits_list=logits_list, label_dict=data, losses=losses)
150155
loss = sum(loss_list)
151156
loss.backward()
152157

deploy/python/infer.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,9 @@ def run(self, imgs_path):
387387
logger.info("Finish")
388388

389389
def _preprocess(self, img):
390-
return self.cfg.transforms(img)[0]
390+
data = {}
391+
data['img'] = img
392+
return self.cfg.transforms(data)['img']
391393

392394
def _postprocess(self, results):
393395
if self.args.with_argmax:

deploy/python/infer_benchmark.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,9 @@ def run(self, img_path):
168168
logger.info("Average time: %.3f ms/img" % avg_time)
169169

170170
def _preprocess(self, img_path):
171+
data = {'img': img_path}
171172
if self.args.resize_width == 0 and self.args.resize_height == 0:
172-
return self.cfg.transforms(img_path)[0]
173+
return self.cfg.transforms(data)['img']
173174
else:
174175
assert args.resize_width > 0 and args.resize_height > 0
175176
with codecs.open(args.cfg, 'r', 'utf-8') as file:
@@ -180,7 +181,7 @@ def _preprocess(self, img_path):
180181
'target_size': [args.resize_width, args.resize_height]
181182
})
182183
transforms = DeployConfig.load_transforms(transforms_dic)
183-
return transforms(img_path)[0]
184+
return transforms(data)['img']
184185

185186
def _save_imgs(self, results):
186187
for i in range(results.shape[0]):

paddleseg/core/infer.py

+15-92
Original file line numberDiff line numberDiff line change
@@ -21,96 +21,24 @@
2121
import paddle.nn.functional as F
2222

2323

24-
def get_reverse_list(ori_shape, transforms):
25-
"""
26-
get reverse list of transform.
27-
28-
Args:
29-
ori_shape (list): Origin shape of image.
30-
transforms (list): List of transform.
31-
32-
Returns:
33-
list: List of tuple, there are two format:
34-
('resize', (h, w)) The image shape before resize,
35-
('padding', (h, w)) The image shape before padding.
36-
"""
37-
reverse_list = []
38-
h, w = ori_shape[0], ori_shape[1]
39-
for op in transforms:
40-
if op.__class__.__name__ in ['Resize']:
41-
reverse_list.append(('resize', (h, w)))
42-
h, w = op.target_size[0], op.target_size[1]
43-
if op.__class__.__name__ in ['ResizeByLong']:
44-
reverse_list.append(('resize', (h, w)))
45-
long_edge = max(h, w)
46-
short_edge = min(h, w)
47-
short_edge = int(round(short_edge * op.long_size / long_edge))
48-
long_edge = op.long_size
49-
if h > w:
50-
h = long_edge
51-
w = short_edge
52-
else:
53-
w = long_edge
54-
h = short_edge
55-
if op.__class__.__name__ in ['ResizeByShort']:
56-
reverse_list.append(('resize', (h, w)))
57-
long_edge = max(h, w)
58-
short_edge = min(h, w)
59-
long_edge = int(round(long_edge * op.short_size / short_edge))
60-
short_edge = op.short_size
61-
if h > w:
62-
h = long_edge
63-
w = short_edge
64-
else:
65-
w = long_edge
66-
h = short_edge
67-
if op.__class__.__name__ in ['Padding']:
68-
reverse_list.append(('padding', (h, w)))
69-
w, h = op.target_size[0], op.target_size[1]
70-
if op.__class__.__name__ in ['PaddingByAspectRatio']:
71-
reverse_list.append(('padding', (h, w)))
72-
ratio = w / h
73-
if ratio == op.aspect_ratio:
74-
pass
75-
elif ratio > op.aspect_ratio:
76-
h = int(w / op.aspect_ratio)
77-
else:
78-
w = int(h * op.aspect_ratio)
79-
if op.__class__.__name__ in ['LimitLong']:
80-
long_edge = max(h, w)
81-
short_edge = min(h, w)
82-
if ((op.max_long is not None) and (long_edge > op.max_long)):
83-
reverse_list.append(('resize', (h, w)))
84-
long_edge = op.max_long
85-
short_edge = int(round(short_edge * op.max_long / long_edge))
86-
elif ((op.min_long is not None) and (long_edge < op.min_long)):
87-
reverse_list.append(('resize', (h, w)))
88-
long_edge = op.min_long
89-
short_edge = int(round(short_edge * op.min_long / long_edge))
90-
if h > w:
91-
h = long_edge
92-
w = short_edge
93-
else:
94-
w = long_edge
95-
h = short_edge
96-
return reverse_list
97-
98-
99-
def reverse_transform(pred, ori_shape, transforms, mode='nearest'):
24+
def reverse_transform(pred, trans_info, mode='nearest'):
10025
"""recover pred to origin shape"""
101-
reverse_list = get_reverse_list(ori_shape, transforms)
10226
intTypeList = [paddle.int8, paddle.int16, paddle.int32, paddle.int64]
10327
dtype = pred.dtype
104-
for item in reverse_list[::-1]:
105-
if item[0] == 'resize':
28+
for item in trans_info[::-1]:
29+
if isinstance(item[0], list):
30+
trans_mode = item[0][0]
31+
else:
32+
trans_mode = item[0]
33+
if trans_mode == 'resize':
10634
h, w = item[1][0], item[1][1]
10735
if paddle.get_device() == 'cpu' and dtype in intTypeList:
10836
pred = paddle.cast(pred, 'float32')
10937
pred = F.interpolate(pred, (h, w), mode=mode)
11038
pred = paddle.cast(pred, dtype)
11139
else:
11240
pred = F.interpolate(pred, (h, w), mode=mode)
113-
elif item[0] == 'padding':
41+
elif trans_mode == 'padding':
11442
h, w = item[1][0], item[1][1]
11543
pred = pred[:, :, 0:h, 0:w]
11644
else:
@@ -205,8 +133,7 @@ def slide_inference(model, im, crop_size, stride):
205133

206134
def inference(model,
207135
im,
208-
ori_shape=None,
209-
transforms=None,
136+
trans_info=None,
210137
is_slide=False,
211138
stride=None,
212139
crop_size=None):
@@ -216,8 +143,7 @@ def inference(model,
216143
Args:
217144
model (paddle.nn.Layer): model to get logits of image.
218145
im (Tensor): the input image.
219-
ori_shape (list): Origin shape of image.
220-
transforms (list): Transforms for image.
146+
trans_info (list): Image shape informating changed process. Default: None.
221147
is_slide (bool): Whether to infer by sliding window. Default: False.
222148
crop_size (tuple|list). The size of sliding window, (w, h). It should be probided if is_slide is True.
223149
stride (tuple|list). The size of stride, (w, h). It should be probided if is_slide is True.
@@ -239,8 +165,8 @@ def inference(model,
239165
logit = slide_inference(model, im, crop_size=crop_size, stride=stride)
240166
if hasattr(model, 'data_format') and model.data_format == 'NHWC':
241167
logit = logit.transpose((0, 3, 1, 2))
242-
if ori_shape is not None:
243-
logit = reverse_transform(logit, ori_shape, transforms, mode='bilinear')
168+
if trans_info is not None:
169+
logit = reverse_transform(logit, trans_info, mode='bilinear')
244170
pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
245171
return pred, logit
246172
else:
@@ -249,8 +175,7 @@ def inference(model,
249175

250176
def aug_inference(model,
251177
im,
252-
ori_shape,
253-
transforms,
178+
trans_info,
254179
scales=1.0,
255180
flip_horizontal=False,
256181
flip_vertical=False,
@@ -263,8 +188,7 @@ def aug_inference(model,
263188
Args:
264189
model (paddle.nn.Layer): model to get logits of image.
265190
im (Tensor): the input image.
266-
ori_shape (list): Origin shape of image.
267-
transforms (list): Transforms for image.
191+
trans_info (list): Transforms for image.
268192
scales (float|tuple|list): Scales for resize. Default: 1.
269193
flip_horizontal (bool): Whether to flip horizontally. Default: False.
270194
flip_vertical (bool): Whether to flip vertically. Default: False.
@@ -302,8 +226,7 @@ def aug_inference(model,
302226
logit = F.softmax(logit, axis=1)
303227
final_logit = final_logit + logit
304228

305-
final_logit = reverse_transform(
306-
final_logit, ori_shape, transforms, mode='bilinear')
229+
final_logit = reverse_transform(final_logit, trans_info, mode='bilinear')
307230
pred = paddle.argmax(final_logit, axis=1, keepdim=True, dtype='int32')
308231

309232
return pred, final_logit

paddleseg/core/predict.py

+14-16
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@ def partition_list(arr, m):
3636
return [arr[i:i + n] for i in range(0, len(arr), n)]
3737

3838

39+
def preprocess(im_path, transforms):
40+
data = {}
41+
data['img'] = im_path
42+
data = transforms(data)
43+
data['img'] = data['img'][np.newaxis, ...]
44+
data['img'] = paddle.to_tensor(data['img'])
45+
return data
46+
47+
3948
def predict(model,
4049
model_path,
4150
transforms,
@@ -89,18 +98,13 @@ def predict(model,
8998
color_map = visualize.get_color_map_list(256, custom_color=custom_color)
9099
with paddle.no_grad():
91100
for i, im_path in enumerate(img_lists[local_rank]):
92-
im = cv2.imread(im_path)
93-
ori_shape = im.shape[:2]
94-
im, _ = transforms(im)
95-
im = im[np.newaxis, ...]
96-
im = paddle.to_tensor(im)
101+
data = preprocess(im_path, transforms)
97102

98103
if aug_pred:
99104
pred, _ = infer.aug_inference(
100105
model,
101-
im,
102-
ori_shape=ori_shape,
103-
transforms=transforms.transforms,
106+
data['img'],
107+
trans_info=data['trans_info'],
104108
scales=scales,
105109
flip_horizontal=flip_horizontal,
106110
flip_vertical=flip_vertical,
@@ -110,9 +114,8 @@ def predict(model,
110114
else:
111115
pred, _ = infer.inference(
112116
model,
113-
im,
114-
ori_shape=ori_shape,
115-
transforms=transforms.transforms,
117+
data['img'],
118+
trans_info=data['trans_info'],
116119
is_slide=is_slide,
117120
stride=stride,
118121
crop_size=crop_size)
@@ -141,9 +144,4 @@ def predict(model,
141144
mkdir(pred_saved_path)
142145
pred_mask.save(pred_saved_path)
143146

144-
# pred_im = utils.visualize(im_path, pred, weight=0.0)
145-
# pred_saved_path = os.path.join(pred_saved_dir, im_file)
146-
# mkdir(pred_saved_path)
147-
# cv2.imwrite(pred_saved_path, pred_im)
148-
149147
progbar_pred.update(i + 1)

0 commit comments

Comments
 (0)