Skip to content

Commit 60c1617

Browse files
authored
Merge benchmark
1 parent dbf88d0 commit 60c1617

32 files changed

+1092
-580
lines changed

benchmark/README.md

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# PaddleSeg Benchmark with AMP
2+
3+
## 动态图
4+
5+
数据集cityscapes 放置于data目录下
6+
7+
通过 **--fp16** 开启amp训练。
8+
9+
单机单卡使用如下命令进行训练:
10+
```
11+
export CUDA_VISIBLE_DEVICES=0
12+
python train.py --config benchmark/hrnet.yml --iters 2000 --log_iters 10 --fp16
13+
```
14+
15+
单机多卡使用如下命令进行训练:
16+
```
17+
export CUDA_VISIBLE_DEVICES=0,1
18+
python -m paddle.distributed.launch train.py --config benchmark/hrnet.yml --iters 2000 --log_iters 10 --fp16
19+
# fleet开启多卡训练
20+
fleetrun train.py --config benchmark/hrnet.yml --iters 2000 --log_iters 10 --fp16
21+
```
22+
23+
DeepLabv3+ 模型的配置文件为:
24+
benchmark/deeplabv3p.yml
25+
26+
**注意**
27+
28+
* 动态图中batch_size设置为每卡的batch_size
29+
* DeepLabv3+ 支持通过传入 **--data_format NHWC**进行‘NHWC’数据格式的训练。
30+
31+
32+
33+
## 静态图
34+
数据集cityscapes 放置于legacy/dataset目录下
35+
36+
通过 **MODEL.FP16 True** 开启amp训练
37+
单机单卡使用如下命令进行训练:
38+
```
39+
cd legacy
40+
export CUDA_VISIBLE_DEVICES=0
41+
python pdseg/train.py --cfg configs/benchmark/hrnetw18_cityscapes_1024x512_215.yaml --use_gpu --use_mpio --log_steps 10 BATCH_SIZE 2 SOLVER.NUM_EPOCHS 3 MODEL.FP16 True
42+
```
43+
44+
单机多卡使用如下命令进行训练:
45+
```
46+
export CUDA_VISIBLE_DEVICES=0,1
47+
fleetrun pdseg/train.py --cfg configs/benchmark/hrnetw18_cityscapes_1024x512_215.yaml --use_gpu --use_mpio --log_steps 10 BATCH_SIZE 4 SOLVER.NUM_EPOCHS 3 MODEL.FP16 True
48+
```
49+
50+
deeplabv3p模型的配置文件为:
51+
configs/benchmark/deeplabv3p_resnet50_vd_cityscapes.yaml
52+
53+
**注意**
54+
静态图中的BATCH_SIZE为总的batch size。
55+
56+
## 竞品
57+
竞品为[mmsegmentation](https://github.com/open-mmlab/mmsegmentation)
58+
59+
对应竞品配置文件为:
60+
61+
configs/hrnet/fcn_hr18_512x1024_80k_cityscapes.py
62+
63+
configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_80k_cityscapes.py
64+
65+
相关执行方式请参考其官方仓库。

benchmark/deeplabv3p.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
batch_size: 2
2-
iters: 500
2+
iters: 80000
33

44
train_dataset:
55
type: Cityscapes
@@ -29,6 +29,7 @@ model:
2929
type: ResNet50_vd
3030
output_stride: 8
3131
multi_grid: [1, 2, 4]
32+
pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld_v2.tar.gz
3233
num_classes: 19
3334
backbone_indices: [0, 3]
3435
aspp_ratios: [1, 12, 24, 36]

benchmark/hrnet.yml

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
batch_size: 2
2-
iters: 500
2+
iters: 80000
33

44
train_dataset:
55
type: Cityscapes
@@ -27,8 +27,11 @@ model:
2727
type: FCN
2828
backbone:
2929
type: HRNet_W18
30+
pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz
31+
padding_same: False
3032
num_classes: 19
3133
backbone_indices: [-1]
34+
bias: False
3235

3336
optimizer:
3437
type: sgd

benchmark/hrnet48.yml

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
batch_size: 2
2+
iters: 80000
3+
4+
train_dataset:
5+
type: Cityscapes
6+
dataset_root: data/cityscapes
7+
transforms:
8+
- type: ResizeStepScaling
9+
min_scale_factor: 0.5
10+
max_scale_factor: 2.0
11+
scale_step_size: 0.25
12+
- type: RandomPaddingCrop
13+
crop_size: [1024, 512]
14+
- type: RandomHorizontalFlip
15+
- type: RandomDistort
16+
- type: Normalize
17+
mode: train
18+
19+
val_dataset:
20+
type: Cityscapes
21+
dataset_root: data/cityscapes
22+
transforms:
23+
- type: Normalize
24+
mode: val
25+
26+
model:
27+
type: FCN
28+
backbone:
29+
type: HRNet_W48
30+
pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w48_ssld.tar.gz
31+
padding_same: False
32+
num_classes: 19
33+
backbone_indices: [-1]
34+
bias: False
35+
36+
optimizer:
37+
type: sgd
38+
weight_decay: 0.0005
39+
40+
learning_rate:
41+
value: 0.01
42+
decay:
43+
type: poly
44+
power: 0.9
45+
end_lr: 0.0
46+
47+
loss:
48+
types:
49+
- type: CrossEntropyLoss
50+
ignore_index: 255
51+
coef: [1]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
EVAL_CROP_SIZE: (2048, 1024) # (width, height), for unpadding rangescaling and stepscaling
2+
TRAIN_CROP_SIZE: (1024, 512) # (width, height), for unpadding rangescaling and stepscaling
3+
AUG:
4+
AUG_METHOD: "stepscaling" # choice unpadding rangescaling and stepscaling
5+
FIX_RESIZE_SIZE: (2048, 1024) # (width, height), for unpadding
6+
INF_RESIZE_VALUE: 500 # for rangescaling
7+
MAX_RESIZE_VALUE: 600 # for rangescaling
8+
MIN_RESIZE_VALUE: 400 # for rangescaling
9+
MAX_SCALE_FACTOR: 2.0 # for stepscaling
10+
MIN_SCALE_FACTOR: 0.5 # for stepscaling
11+
SCALE_STEP_SIZE: 0.25 # for stepscaling
12+
MIRROR: True
13+
TO_RGB: True
14+
BATCH_SIZE: 8
15+
DATASET:
16+
DATA_DIR: "./dataset/cityscapes/"
17+
IMAGE_TYPE: "rgb" # choice rgb or rgba
18+
NUM_CLASSES: 19
19+
TEST_FILE_LIST: "dataset/cityscapes/val.list"
20+
TRAIN_FILE_LIST: "dataset/cityscapes/train.list"
21+
VAL_FILE_LIST: "dataset/cityscapes/val.list"
22+
IGNORE_INDEX: 255
23+
SEPARATOR: " "
24+
FREEZE:
25+
MODEL_FILENAME: "__model__"
26+
PARAMS_FILENAME: "__params__"
27+
MODEL:
28+
DEFAULT_NORM_TYPE: "bn"
29+
MODEL_NAME: "deeplabv3p"
30+
DEEPLAB:
31+
ASPP_WITH_SEP_CONV: True
32+
DECODER_USE_SEP_CONV: True
33+
BACKBONE: "resnet_vd_50"
34+
OUTPUT_STRIDE: 8
35+
BIAS: null
36+
ALIGN_CORNERS: False
37+
BENCHMARK: True
38+
DECODER:
39+
ACT: False
40+
TRAIN:
41+
PRETRAINED_MODEL_DIR: u"pretrained_model/resnet50_vd_imagenet"
42+
MODEL_SAVE_DIR: "output/deeplabv3p_resnet50_vd_bn_cityscapes"
43+
SNAPSHOT_EPOCH: 10
44+
SYNC_BATCH_NORM: True
45+
TEST:
46+
TEST_MODEL: "output/deeplabv3p_resnet50_vd_bn_cityscapes/final"
47+
SOLVER:
48+
LR: 0.01
49+
LR_POLICY: "poly"
50+
OPTIMIZER: "sgd"
51+
NUM_EPOCHS: 215
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
EVAL_CROP_SIZE: (2048, 1024) # (width, height), for unpadding rangescaling and stepscaling
2+
TRAIN_CROP_SIZE: (1024, 512) # (width, height), for unpadding rangescaling and stepscaling
3+
AUG:
4+
# AUG_METHOD: "unpadding" # choice unpadding rangescaling and stepscaling
5+
AUG_METHOD: "stepscaling" # choice unpadding rangescaling and stepscaling
6+
FIX_RESIZE_SIZE: (1024, 512) # (width, height), for unpadding
7+
INF_RESIZE_VALUE: 500 # for rangescaling
8+
MAX_RESIZE_VALUE: 600 # for rangescaling
9+
MIN_RESIZE_VALUE: 400 # for rangescaling
10+
MAX_SCALE_FACTOR: 2.0 # for stepscaling
11+
MIN_SCALE_FACTOR: 0.5 # for stepscaling
12+
SCALE_STEP_SIZE: 0.25 # for stepscaling
13+
MIRROR: True
14+
BATCH_SIZE: 8
15+
16+
DATASET:
17+
DATA_DIR: "./dataset/cityscapes/"
18+
IMAGE_TYPE: "rgb" # choice rgb or rgba
19+
NUM_CLASSES: 19
20+
TEST_FILE_LIST: "./dataset/cityscapes/val.list"
21+
TRAIN_FILE_LIST: "./dataset/cityscapes/train.list"
22+
VAL_FILE_LIST: "./dataset/cityscapes/val.list"
23+
IGNORE_INDEX: 255
24+
SEPARATOR: " "
25+
26+
MODEL:
27+
MODEL_NAME: "hrnet"
28+
DEFAULT_NORM_TYPE: "bn"
29+
HRNET:
30+
STAGE2:
31+
NUM_CHANNELS: [18, 36]
32+
STAGE3:
33+
NUM_CHANNELS: [18, 36, 72]
34+
STAGE4:
35+
NUM_CHANNELS: [18, 36, 72, 144]
36+
BIAS: False
37+
ALIGN_CORNERS: False
38+
39+
TRAIN:
40+
PRETRAINED_MODEL_DIR: u"./pretrained_model/hrnet_w18_ssld"
41+
MODEL_SAVE_DIR: "output/hrnetw18_bn_cityscapes"
42+
SNAPSHOT_EPOCH: 10
43+
SYNC_BATCH_NORM: True
44+
45+
TEST:
46+
TEST_MODEL: "output/hrnetw18_bn_cityscapes/best_model"
47+
48+
SOLVER:
49+
LR: 0.01
50+
LR_POLICY: "poly"
51+
WEIGHT_DECAY: 5.0e-4
52+
OPTIMIZER: "sgd"
53+
NUM_EPOCHS: 215

legacy/pdseg/eval.py

+25-19
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@
2424
import sys
2525
import argparse
2626
import pprint
27+
import time
28+
2729
import numpy as np
2830
import paddle
29-
import paddle.fluid as fluid
31+
import paddle.static as static
3032

3133
from utils import paddle_utils
3234
from utils.config import cfg
33-
from utils.timer import Timer, calculate_eta
35+
from utils.timer import TimeAverager, calculate_eta
3436
from models.model_builder import build_model
3537
from models.model_builder import ModelPhase
3638
from reader import SegDataset
@@ -82,8 +84,8 @@ def evaluate(cfg,
8284
**kwargs):
8385
np.set_printoptions(precision=5, suppress=True)
8486

85-
startup_prog = fluid.Program()
86-
test_prog = fluid.Program()
87+
startup_prog = static.Program()
88+
test_prog = static.Program()
8789
dataset = SegDataset(
8890
file_list=cfg.DATASET.VAL_FILE_LIST,
8991
mode=ModelPhase.EVAL,
@@ -109,17 +111,17 @@ def data_generator():
109111

110112
# Get device environment
111113
if use_gpu:
112-
places = fluid.cuda_places()
114+
places = static.cuda_places()
113115
elif use_xpu:
114116
xpu_id = int(os.environ.get('FLAGS_selected_xpus', 0))
115-
places = [fluid.XPUPlace(xpu_id)]
117+
places = [paddle.XPUPlace(xpu_id)]
116118
else:
117-
places = fluid.cpu_places()
119+
places = static.cpu_places()
118120
place = places[0]
119121
dev_count = len(places)
120122
print("#Device count: {}".format(dev_count))
121123

122-
exe = fluid.Executor(place)
124+
exe = static.Executor(place)
123125
exe.run(startup_prog)
124126

125127
test_prog = test_prog.clone(for_test=True)
@@ -132,9 +134,9 @@ def data_generator():
132134
if ckpt_dir is not None:
133135
print('load test model:', ckpt_dir)
134136
try:
135-
fluid.load(test_prog, os.path.join(ckpt_dir, 'model'), exe)
137+
static.load(test_prog, os.path.join(ckpt_dir, 'model'), exe)
136138
except:
137-
fluid.io.load_params(exe, ckpt_dir, main_program=test_prog)
139+
paddle.fluid.io.load_params(exe, ckpt_dir, main_program=test_prog)
138140

139141
# Use streaming confusion matrix to calculate mean_iou
140142
np.set_printoptions(
@@ -144,11 +146,13 @@ def data_generator():
144146
num_images = 0
145147
step = 0
146148
all_step = cfg.DATASET.TEST_TOTAL_IMAGES // cfg.BATCH_SIZE + 1
147-
timer = Timer()
148-
timer.start()
149+
reader_cost_averager = TimeAverager()
150+
batch_cost_averager = TimeAverager()
151+
batch_start = time.time()
149152
data_loader.start()
150153
while True:
151154
try:
155+
reader_cost_averager.record(time.time() - batch_start)
152156
step += 1
153157
loss, pred, grts, masks = exe.run(
154158
test_prog, fetch_list=fetch_list, return_numpy=True)
@@ -160,15 +164,17 @@ def data_generator():
160164
_, iou = conf_mat.mean_iou()
161165
_, acc = conf_mat.accuracy()
162166

163-
speed = 1.0 / timer.elapsed_time()
164-
167+
batch_cost_averager.record(
168+
time.time() - batch_start, num_samples=cfg.BATCH_SIZE)
169+
batch_cost = batch_cost_averager.get_average()
170+
reader_cost = reader_cost_averager.get_average()
171+
eta = calculate_eta(all_step - step, batch_cost)
165172
print(
166-
"[EVAL]step: {} loss: {:.5f} acc: {:.4f} IoU: {:.4f} step/sec: {:.2f} | ETA {}"
167-
.format(step, loss, acc, iou, speed,
168-
calculate_eta(all_step - step, speed)))
169-
timer.restart()
173+
"[EVAL]step: {} loss: {:.5f} acc: {:.4f} IoU: {:.4f} batch_cost: {:.4f}, reader_cost: {:.5f} | ETA {}"
174+
.format(step, loss, acc, iou, batch_cost, reader_cost, eta))
175+
batch_start = time.time()
170176
sys.stdout.flush()
171-
except fluid.core.EOFException:
177+
except paddle.fluid.core.EOFException:
172178
break
173179

174180
category_iou, avg_iou = conf_mat.mean_iou()

0 commit comments

Comments
 (0)