1010from pprint import pprint
1111
1212import paddle
13+ import paddle .distributed .fleet as fleet
1314import paddle .distributed as dist
1415
1516from paddlenlp .transformers import TransformerModel , CrossEntropyCriterion
@@ -36,8 +37,14 @@ def parse_args():
3637
3738def do_train (args ):
3839 paddle .enable_static ()
39- places = paddle .static .cuda_places () if args .use_gpu else paddle .static .cpu_places ()
40- trainer_count = len (places )
40+ if args .is_distributed :
41+ fleet .init (is_collective = True )
42+ gpu_id = int (os .getenv ("FLAGS_selected_gpus" , "0" ))
43+ places = paddle .CUDAPlace (gpu_id ) if args .use_gpu else paddle .static .cpu_places ()
44+ trainer_count = 1 if args .use_gpu else len (places )
45+ else :
46+ places = paddle .static .cuda_places () if args .use_gpu else paddle .static .cpu_places ()
47+ trainer_count = len (places )
4148
4249 # Set seed for CE
4350 random_seed = eval (str (args .random_seed ))
@@ -88,19 +95,38 @@ def do_train(args):
8895 epsilon = float (args .eps ),
8996 parameters = transformer .parameters ())
9097
98+ if args .is_distributed :
99+ build_strategy = paddle .static .BuildStrategy ()
100+ exec_strategy = paddle .static .ExecutionStrategy ()
101+ dist_strategy = fleet .DistributedStrategy ()
102+ dist_strategy .build_strategy = build_strategy
103+ dist_strategy .execution_strategy = exec_strategy
104+ dist_strategy .fuse_grad_size_in_MB = 16
105+
106+ if args .use_amp :
107+ dist_strategy .amp = True
108+ dist_strategy .amp_configs = {
109+ 'custom_white_list' : ['softmax' , 'layer_norm' , 'gelu' ],
110+ 'init_loss_scaling' : args .scale_loss ,
111+ }
112+
113+ optimizer = fleet .distributed_optimizer (optimizer , strategy = dist_strategy )
91114 optimizer .minimize (avg_cost )
92115
93- exe = paddle .static .Executor ()
116+ if args .is_distributed :
117+ exe = paddle .static .Executor (places )
118+ else :
119+ exe = paddle .static .Executor ()
120+ build_strategy = paddle .static .BuildStrategy ()
121+ exec_strategy = paddle .static .ExecutionStrategy ()
122+
123+ compiled_train_program = paddle .static .CompiledProgram (
124+ train_program ).with_data_parallel (
125+ loss_name = avg_cost .name ,
126+ build_strategy = build_strategy ,
127+ exec_strategy = exec_strategy )
94128 exe .run (startup_program )
95129
96- build_strategy = paddle .static .BuildStrategy ()
97- exec_strategy = paddle .static .ExecutionStrategy ()
98-
99- compiled_train_program = paddle .static .CompiledProgram (
100- train_program ).with_data_parallel (
101- loss_name = avg_cost .name ,
102- build_strategy = build_strategy ,
103- exec_strategy = exec_strategy )
104130
105131 # the best cross-entropy value with label smoothing
106132 loss_normalizer = - (
@@ -127,13 +153,22 @@ def do_train(args):
127153 data = [data ]
128154 train_reader_cost = time .time () - batch_start
129155
130- outs = exe .run (compiled_train_program ,
131- feed = [{
132- 'src_word' : data [i ][0 ],
133- 'trg_word' : data [i ][1 ],
134- 'lbl_word' : data [i ][2 ],
135- } for i in range (trainer_count )],
136- fetch_list = [sum_cost .name , token_num .name ])
156+ if args .is_distributed :
157+ outs = exe .run (train_program ,
158+ feed = [{
159+ 'src_word' : data [i ][0 ],
160+ 'trg_word' : data [i ][1 ],
161+ 'lbl_word' : data [i ][2 ],
162+ } for i in range (trainer_count )],
163+ fetch_list = [sum_cost .name , token_num .name ])
164+ else :
165+ outs = exe .run (compiled_train_program ,
166+ feed = [{
167+ 'src_word' : data [i ][0 ],
168+ 'trg_word' : data [i ][1 ],
169+ 'lbl_word' : data [i ][2 ],
170+ } for i in range (trainer_count )],
171+ fetch_list = [sum_cost .name , token_num .name ])
137172 scheduler .step ()
138173
139174 train_batch_cost = time .time () - batch_start
@@ -176,7 +211,7 @@ def do_train(args):
176211 batch_ips_avg .reset ()
177212
178213 if step_idx % args .save_step == 0 and step_idx != 0 :
179- if args .save_model :
214+ if args .save_model and dist . get_rank () == 0 :
180215 model_path = os .path .join (
181216 args .save_model , "step_" + str (step_idx ), "transformer" )
182217 paddle .static .save (train_program , model_path )
0 commit comments