Skip to content

Commit e13498b

Browse files
authored
[HumanSeg] Add PP-HumanSeg14K dataset (PaddlePaddle#1708)
1 parent 39403c6 commit e13498b

15 files changed

+183
-23
lines changed

README.md

+4-3
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ PaddleSeg is an end-to-end high-efficent development toolkit for image segmentat
129129
<ul>
130130
<li>Cross Entropy</li>
131131
<li>Binary CE</li>
132-
<li>Bootstrapped CE</li>
132+
<li>Bootstrapped CE</li>
133133
<li>Point CE</li>
134134
<li>OHEM CE</li>
135135
<li>Pixel Contrast CE</li>
@@ -167,7 +167,8 @@ PaddleSeg is an end-to-end high-efficent development toolkit for image segmentat
167167
<li>HRF</li>
168168
<li>DRIVE</li>
169169
<li>STARE</li>
170-
</ul>
170+
<li>PP-HumanSeg14K</li>
171+
</ul>
171172
<b>Data Augmentation</b><br>
172173
<ul>
173174
<li>Flipping</li>
@@ -182,7 +183,7 @@ PaddleSeg is an end-to-end high-efficent development toolkit for image segmentat
182183
<li>PaddingByAspectRatio</li>
183184
<li>RandomPaddingCrop</li>
184185
<li>RandomCenterCrop</li>
185-
<li>ScalePadding</li>
186+
<li>ScalePadding</li>
186187
<li>RandomNoise</li>
187188
<li>RandomBlur</li>
188189
<li>RandomRotation</li>

README_CN.md

+4-3
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ PaddleSeg是基于飞桨PaddlePaddle开发的端到端图像分割开发套件
137137
<ul>
138138
<li>Cross Entropy</li>
139139
<li>Binary CE</li>
140-
<li>Bootstrapped CE</li>
140+
<li>Bootstrapped CE</li>
141141
<li>Point CE</li>
142142
<li>OHEM CE</li>
143143
<li>Pixel Contrast CE</li>
@@ -175,7 +175,8 @@ PaddleSeg是基于飞桨PaddlePaddle开发的端到端图像分割开发套件
175175
<li>HRF</li>
176176
<li>DRIVE</li>
177177
<li>STARE</li>
178-
</ul>
178+
<li>PP-HumanSeg14K</li>
179+
</ul>
179180
<b>数据增强</b><br>
180181
<ul>
181182
<li>Flipping</li>
@@ -190,7 +191,7 @@ PaddleSeg是基于飞桨PaddlePaddle开发的端到端图像分割开发套件
190191
<li>PaddingByAspectRatio</li>
191192
<li>RandomPaddingCrop</li>
192193
<li>RandomCenterCrop</li>
193-
<li>ScalePadding</li>
194+
<li>ScalePadding</li>
194195
<li>RandomNoise</li>
195196
<li>RandomBlur</li>
196197
<li>RandomRotation</li>

configs/fastscnn/fastscnn_cityscapes_1024x1024_40k_SCL.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ loss:
88
- type: MixedLoss
99
losses:
1010
- type: CrossEntropyLoss
11-
- type: SemanticConnectivityLearning
11+
- type: SemanticConnectivityLoss
1212
coef: [1, 0.01]
1313
- type: CrossEntropyLoss
1414
coef: [1.0, 0.4]

configs/fcn/fcn_hrnetw18_cityscapes_1024x512_80k_bs4_SCL.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,6 @@ loss:
2121
- type: MixedLoss
2222
losses:
2323
- type: CrossEntropyLoss
24-
- type: SemanticConnectivityLearning
24+
- type: SemanticConnectivityLoss
2525
coef: [1, 0.05]
2626
coef: [1]
+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
train_dataset:
2+
type: PPHumanSeg14K
3+
dataset_root: data/PP-HumanSeg14K
4+
transforms:
5+
- type: ResizeStepScaling
6+
min_scale_factor: 0.5
7+
max_scale_factor: 2.0
8+
scale_step_size: 0.25
9+
- type: RandomPaddingCrop
10+
crop_size: [398, 224]
11+
- type: RandomHorizontalFlip
12+
- type: RandomDistort
13+
brightness_range: 0.4
14+
contrast_range: 0.4
15+
saturation_range: 0.4
16+
- type: Normalize
17+
mode: train
18+
19+
val_dataset:
20+
type: PPHumanSeg14K
21+
dataset_root: data/PP-HumanSeg14K
22+
transforms:
23+
- type: Normalize
24+
mode: val
25+
26+
model:
27+
type: FCN
28+
backbone:
29+
type: HRNet_W18
30+
align_corners: False
31+
pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz
32+
num_classes: 2
33+
pretrained: Null
34+
backbone_indices: [-1]
35+
36+
optimizer:
37+
type: sgd
38+
momentum: 0.9
39+
weight_decay: 0.0005
40+
41+
lr_scheduler:
42+
type: PolynomialDecay
43+
learning_rate: 0.05
44+
end_lr: 0
45+
power: 0.9
46+
47+
loss:
48+
types:
49+
- type: CrossEntropyLoss
50+
coef: [1]
51+
52+
iters: 10000
53+
batch_size: 64

configs/ocrnet/ocrnet_hrnetw48_cityscapes_1024x512_40k_SCL.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ loss:
2626
- type: MixedLoss
2727
losses:
2828
- type: CrossEntropyLoss
29-
- type: SemanticConnectivityLearning
29+
- type: SemanticConnectivityLoss
3030
coef: [1, 0.1]
3131
- type: CrossEntropyLoss
3232
coef: [1, 0.4]

contrib/PP-HumanSeg/paper.md

+26-4
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,28 @@
22
Official resource for the paper *PP-HumanSeg: Connectivity-Aware Portrait Segmentation With a Large-Scale Teleconferencing Video Dataset*. [[Paper](https://arxiv.org/abs/2112.07146) | [Poster](https://paddleseg.bj.bcebos.com/dygraph/humanseg/paper/12-HAD-poster.pdf) | [YouTube](https://www.youtube.com/watch?v=FlK8R5cdD7E)]
33

44
## Semantic Connectivity-aware Learning
5-
SCL (Semantic Connectivity-aware Learning) framework, which introduces a SC Loss (Semantic Connectivity-aware Loss) to improve the quality of segmentation results from the perspective of connectivity. Support multi-class segmentation. [[Source code](../../paddleseg/models/losses/semantic_connectivity_learning.py)]
5+
SCL (Semantic Connectivity-aware Learning) framework, which introduces a SC Loss (Semantic Connectivity-aware Loss) to improve the quality of segmentation results from the perspective of connectivity. SCL can improve the integrity of segmentation objects and increase segmentation accuracy. Support multi-class segmentation. [[Source code](../../paddleseg/models/losses/semantic_connectivity_loss.py)]
6+
7+
<p align="center">
8+
<img src="https://user-images.githubusercontent.com/30695251/148921096-29a4f90f-2113-4f97-87b5-19364e83b454.png" width="40%" height="40%">
9+
</p>
10+
11+
### Connected Components Calculation and Matching
12+
<p align="center">
13+
<img src="https://user-images.githubusercontent.com/30695251/148931627-bfaeeecb-c260-4d00-9393-a7e52a56ce18.png" width="40%" height="40%">
14+
</p>
15+
(a) It indicates prediction and ground truth, i.e. P and G. (b) Connected components are generated through the CCL algorithm, respectively. (c) Connected components are matched using the IoU value.
16+
17+
### Segmentation Results
18+
19+
<p align="center">
20+
<img src="https://user-images.githubusercontent.com/30695251/148931612-bfc5a7f2-f6b7-4666-b2dd-86926ea7bfd7.png" width="60%" height="60%">
21+
</p>
622

7-
SCL can improve the integrity of segmentation objects and increase segmentation accuracy. The experimental results on our Teleconferencing Video Dataset are shown in paper, and the experimental results on Cityscapes are as follows:
823

924
### Perfermance on Cityscapes
25+
The experimental results on our Teleconferencing Video Dataset are shown in paper, and the experimental results on Cityscapes are as follows:
26+
1027
| Model | Backbone | Learning Strategy | GPUs * Batch Size(Per Card)| Training Iters | mIoU (%) | Config |
1128
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
1229
|OCRNet|HRNet-W48|-|2*2|40000|76.23| [config](../../configs/ocrnet/ocrnet_hrnetw48_cityscapes_1024x512_40k.yml) |
@@ -16,10 +33,15 @@ SCL can improve the integrity of segmentation objects and increase segmentation
1633
|Fast SCNN|-|-|2*4|40000|56.41|[config](../../configs/fastscnn/fastscnn_cityscapes_1024x1024_40k.yml)|
1734
|Fast SCNN|-|SCL|2*4|40000|57.37(**+0.96**)|[config](../../configs/fastscnn/fastscnn_cityscapes_1024x1024_40k_SCL.yml)|
1835

19-
## Large-Scale Teleconferencing Video Dataset
20-
A large-scale video portrait dataset that contains 291 videos from 23 conference scenes with 14K fine-labeled frames. The data can be obtained by sending an application email to [email protected].
2136

2237

38+
## PP-HumanSeg14K: A Large-Scale Teleconferencing Video Dataset
39+
A large-scale video portrait dataset that contains 291 videos from 23 conference scenes with 14K fine-labeled frames. This dataset contains various teleconferencing scenes, various actions of the participants, interference of passers-by and illumination change. The data can be obtained by sending an application email to [email protected].
40+
41+
<p align="center">
42+
<img src="https://user-images.githubusercontent.com/30695251/148931684-cc10c994-3bd4-4d0c-9bcc-283f9bbc6ac9.png" width="80%" height="80%">
43+
</p>
44+
2345
## Citation
2446
If our project is useful in your research, please citing:
2547

docs/module/loss/SemanticConnectivityLearning_cn.md docs/module/loss/SemanticConnectivityLoss_cn.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
简体中文 | [English](SemanticConnectivityLearning_en.md)
2-
## [SemanticConnectivityLearning](../../../paddleseg/models/losses/semantic_connectivity_learning.py)
1+
简体中文 | [English](SemanticConnectivityLoss_en.md)
2+
## [SemanticConnectivityLoss](../../../paddleseg/models/losses/semantic_connectivity_loss.py)
33
SCL(Semantic Connectivity-aware Learning)框架,它引入了SC Loss (Semantic Connectivity-aware Loss),从连通性的角度提升分割结果的质量。支持多类别分割。
44

55
论文信息:
@@ -12,7 +12,7 @@ SCL(Semantic Connectivity-aware Learning)框架,它引入了SC Loss (Seman
1212
步骤1,连通域计算
1313
步骤2,连通域匹配与SC Loss计算
1414
```python
15-
class paddleseg.models.losses.SemanticConnectivityLearning(
15+
class paddleseg.models.losses.SemanticConnectivityLoss(
1616
ignore_index = 255,
1717
max_pred_num_conn = 10,
1818
use_argmax = True

docs/module/loss/SemanticConnectivityLearning_en.md docs/module/loss/SemanticConnectivityLoss_en.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
English | [简体中文](SemanticConnectivityLearning_cn.md)
2-
## [SemanticConnectivityLearning](../../../paddleseg/models/losses/semantic_connectivity_learning.py)
1+
English | [简体中文](SemanticConnectivityLoss_cn.md)
2+
## [SemanticConnectivityLoss](../../../paddleseg/models/losses/semantic_connectivity_loss.py)
33
SCL (Semantic Connectivity-aware Learning) framework, which introduces a SC Loss (Semantic Connectivity-aware Loss)
44
to improve the quality of segmentation results from the perspective of connectivity. Support multi-class segmentation.
55

@@ -14,7 +14,7 @@ Step 1. Connected Components Calculation
1414
Step 2. Connected Components Matching and SC Loss Calculation
1515

1616
```python
17-
class paddleseg.models.losses.SemanticConnectivityLearning(
17+
class paddleseg.models.losses.SemanticConnectivityLoss(
1818
ignore_index = 255,
1919
max_pred_num_conn = 10,
2020
use_argmax = True

docs/module/loss/losses_cn.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@
2525

2626
* ## [paddleseg.models.losses.ohem_edge_attention_loss](./OhemEdgeAttentionLoss_cn.md)
2727

28-
* ## [paddleseg.models.losses.semantic_connectivity_learning](./SemanticConnectivityLearning_cn.md)
28+
* ## [paddleseg.models.losses.semantic_connectivity_loss](./SemanticConnectivityLoss_cn.md)

docs/module/loss/losses_en.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ English | [简体中文](losses_cn.md)
2525

2626
* ## [paddleseg.models.losses.ohem_edge_attention_loss](./OhemEdgeAttentionLoss_en.md)
2727

28-
* ## [paddleseg.models.losses.semantic_connectivity_learning](./SemanticConnectivityLearning_en.md)
28+
* ## [paddleseg.models.losses.semantic_connectivity_loss](./SemanticConnectivityLoss_en.md)

paddleseg/datasets/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@
2626
from .drive import DRIVE
2727
from .hrf import HRF
2828
from .chase_db1 import CHASEDB1
29+
from .pp_humanseg14k import PPHumanSeg14K

paddleseg/datasets/pp_humanseg14k.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
from .dataset import Dataset
18+
from paddleseg.cvlibs import manager
19+
from paddleseg.transforms import Compose
20+
21+
22+
@manager.DATASETS.add_component
23+
class PPHumanSeg14K(Dataset):
24+
"""
25+
This is the PP-HumanSeg14K Dataset.
26+
27+
This dataset was introduced in the work:
28+
Chu, Lutao, et al. "PP-HumanSeg: Connectivity-Aware Portrait Segmentation with a Large-Scale Teleconferencing Video Dataset." Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision. 2022.
29+
30+
This dataset is divided into training set, validation set and test set. The training set includes 8770 pictures, the validation set includes 2431 pictures, and the test set includes 2482 pictures.
31+
32+
Args:
33+
dataset_root (str, optional): The dataset directory. Default: None.
34+
transforms (list, optional): Transforms for image. Default: None.
35+
mode (str, optional): Which part of dataset to use. It is one of ('train', 'val'). Default: 'train'.
36+
edge (bool, optional): Whether to compute edge while training. Default: False.
37+
"""
38+
NUM_CLASSES = 2
39+
40+
def __init__(self,
41+
dataset_root=None,
42+
transforms=None,
43+
mode='train',
44+
edge=False):
45+
self.dataset_root = dataset_root
46+
self.transforms = Compose(transforms)
47+
mode = mode.lower()
48+
self.mode = mode
49+
self.file_list = list()
50+
self.num_classes = self.NUM_CLASSES
51+
self.ignore_index = 255
52+
self.edge = edge
53+
54+
if mode not in ['train', 'val', 'test']:
55+
raise ValueError(
56+
"`mode` should be 'train', 'val' or 'test', but got {}.".format(
57+
mode))
58+
59+
if self.transforms is None:
60+
raise ValueError("`transforms` is necessary, but it is None.")
61+
62+
if mode == 'train':
63+
file_path = os.path.join(self.dataset_root, 'train.txt')
64+
elif mode == 'val':
65+
file_path = os.path.join(self.dataset_root, 'val.txt')
66+
else:
67+
file_path = os.path.join(self.dataset_root, 'test.txt')
68+
69+
with open(file_path, 'r') as f:
70+
for line in f:
71+
items = line.strip().split(' ')
72+
if len(items) != 2:
73+
if mode == 'train' or mode == 'val':
74+
raise Exception(
75+
"File list format incorrect! It should be"
76+
" image_name label_name\\n")
77+
image_path = os.path.join(self.dataset_root, items[0])
78+
grt_path = None
79+
else:
80+
image_path = os.path.join(self.dataset_root, items[0])
81+
grt_path = os.path.join(self.dataset_root, items[1])
82+
self.file_list.append([image_path, grt_path])

paddleseg/models/losses/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@
3333
from .point_cross_entropy_loss import PointCrossEntropyLoss
3434
from .pixel_contrast_cross_entropy_loss import PixelContrastCrossEntropyLoss
3535
from .semantic_encode_cross_entropy_loss import SECrossEntropyLoss
36-
from .semantic_connectivity_learning import SemanticConnectivityLearning
36+
from .semantic_connectivity_loss import SemanticConnectivityLoss

paddleseg/models/losses/semantic_connectivity_learning.py paddleseg/models/losses/semantic_connectivity_loss.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
@manager.LOSSES.add_component
25-
class SemanticConnectivityLearning(nn.Layer):
25+
class SemanticConnectivityLoss(nn.Layer):
2626
'''
2727
SCL (Semantic Connectivity-aware Learning) framework, which introduces a SC Loss (Semantic Connectivity-aware Loss)
2828
to improve the quality of segmentation results from the perspective of connectivity. Support multi-class segmentation.

0 commit comments

Comments
 (0)