Skip to content

Commit 969939e

Browse files
author
lilong12
authored
Add fleet for transformer benchmark (#5164)
* add fleet, test=develop
1 parent 8510560 commit 969939e

File tree

3 files changed

+65
-19
lines changed

3 files changed

+65
-19
lines changed

PaddleNLP/benchmark/transformer/configs/transformer.big.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,11 @@ dropout: 0.1
9696
# Vocabularies in source and target should be same for weight sharing.
9797
weight_sharing: True
9898

99+
# Use amp or not
100+
use_amp: False
101+
scale_loss: 1.0
102+
103+
# Whether to use multi-card/multi-node distributed training.
104+
is_distributed: True
105+
99106
max_iter: None
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
python -m paddle.distributed.launch \
3+
--gpus="0,1" \
4+
train.py

PaddleNLP/benchmark/transformer/static/train.py

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pprint import pprint
1111

1212
import paddle
13+
import paddle.distributed.fleet as fleet
1314
import paddle.distributed as dist
1415

1516
from paddlenlp.transformers import TransformerModel, CrossEntropyCriterion
@@ -36,8 +37,14 @@ def parse_args():
3637

3738
def do_train(args):
3839
paddle.enable_static()
39-
places = paddle.static.cuda_places() if args.use_gpu else paddle.static.cpu_places()
40-
trainer_count = len(places)
40+
if args.is_distributed:
41+
fleet.init(is_collective=True)
42+
gpu_id = int(os.getenv("FLAGS_selected_gpus", "0"))
43+
places = paddle.CUDAPlace(gpu_id) if args.use_gpu else paddle.static.cpu_places()
44+
trainer_count = 1 if args.use_gpu else len(places)
45+
else:
46+
places = paddle.static.cuda_places() if args.use_gpu else paddle.static.cpu_places()
47+
trainer_count = len(places)
4148

4249
# Set seed for CE
4350
random_seed = eval(str(args.random_seed))
@@ -88,19 +95,38 @@ def do_train(args):
8895
epsilon=float(args.eps),
8996
parameters=transformer.parameters())
9097

98+
if args.is_distributed:
99+
build_strategy = paddle.static.BuildStrategy()
100+
exec_strategy = paddle.static.ExecutionStrategy()
101+
dist_strategy = fleet.DistributedStrategy()
102+
dist_strategy.build_strategy = build_strategy
103+
dist_strategy.execution_strategy = exec_strategy
104+
dist_strategy.fuse_grad_size_in_MB = 16
105+
106+
if args.use_amp:
107+
dist_strategy.amp = True
108+
dist_strategy.amp_configs = {
109+
'custom_white_list': ['softmax', 'layer_norm', 'gelu'],
110+
'init_loss_scaling': args.scale_loss,
111+
}
112+
113+
optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy)
91114
optimizer.minimize(avg_cost)
92115

93-
exe = paddle.static.Executor()
116+
if args.is_distributed:
117+
exe = paddle.static.Executor(places)
118+
else:
119+
exe = paddle.static.Executor()
120+
build_strategy = paddle.static.BuildStrategy()
121+
exec_strategy = paddle.static.ExecutionStrategy()
122+
123+
compiled_train_program = paddle.static.CompiledProgram(
124+
train_program).with_data_parallel(
125+
loss_name=avg_cost.name,
126+
build_strategy=build_strategy,
127+
exec_strategy=exec_strategy)
94128
exe.run(startup_program)
95129

96-
build_strategy = paddle.static.BuildStrategy()
97-
exec_strategy = paddle.static.ExecutionStrategy()
98-
99-
compiled_train_program = paddle.static.CompiledProgram(
100-
train_program).with_data_parallel(
101-
loss_name=avg_cost.name,
102-
build_strategy=build_strategy,
103-
exec_strategy=exec_strategy)
104130

105131
# the best cross-entropy value with label smoothing
106132
loss_normalizer = -(
@@ -127,13 +153,22 @@ def do_train(args):
127153
data = [data]
128154
train_reader_cost = time.time() - batch_start
129155

130-
outs = exe.run(compiled_train_program,
131-
feed=[{
132-
'src_word': data[i][0],
133-
'trg_word': data[i][1],
134-
'lbl_word': data[i][2],
135-
} for i in range(trainer_count)],
136-
fetch_list=[sum_cost.name, token_num.name])
156+
if args.is_distributed:
157+
outs = exe.run(train_program,
158+
feed=[{
159+
'src_word': data[i][0],
160+
'trg_word': data[i][1],
161+
'lbl_word': data[i][2],
162+
} for i in range(trainer_count)],
163+
fetch_list=[sum_cost.name, token_num.name])
164+
else:
165+
outs = exe.run(compiled_train_program,
166+
feed=[{
167+
'src_word': data[i][0],
168+
'trg_word': data[i][1],
169+
'lbl_word': data[i][2],
170+
} for i in range(trainer_count)],
171+
fetch_list=[sum_cost.name, token_num.name])
137172
scheduler.step()
138173

139174
train_batch_cost = time.time() - batch_start
@@ -176,7 +211,7 @@ def do_train(args):
176211
batch_ips_avg.reset()
177212

178213
if step_idx % args.save_step == 0 and step_idx != 0:
179-
if args.save_model:
214+
if args.save_model and dist.get_rank() == 0:
180215
model_path = os.path.join(
181216
args.save_model, "step_" + str(step_idx), "transformer")
182217
paddle.static.save(train_program, model_path)

0 commit comments

Comments
 (0)