Skip to content

Commit e33dc48

Browse files
authored
remove with_data_parallel (PaddlePaddle#1658)
* remove with_data_parallel * ACT adapts fleet * ACT'demo adapts fleet * fix bugs
1 parent 65c776d commit e33dc48

File tree

30 files changed

+276
-303
lines changed

30 files changed

+276
-303
lines changed

demo/nas/block_sa_nas_mobilenetv2.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from paddleslim.analysis import flops
1414
from paddleslim.nas import SANAS
1515
from paddleslim.common import get_logger
16-
from optimizer import create_optimizer
1716
import imagenet_reader
1817

1918
_logger = get_logger(__name__, level=logging.INFO)
@@ -157,15 +156,13 @@ def search_mobilenetv2_block(config, args, image_size):
157156

158157
build_strategy = static.BuildStrategy()
159158
train_compiled_program = static.CompiledProgram(
160-
train_program).with_data_parallel(
161-
loss_name=avg_cost.name, build_strategy=build_strategy)
159+
train_program, build_strategy=build_strategy)
162160
for epoch_id in range(args.retain_epoch):
163161
for batch_id, data in enumerate(train_loader()):
164162
fetches = [avg_cost.name]
165163
s_time = time.time()
166-
outs = exe.run(train_compiled_program,
167-
feed=data,
168-
fetch_list=fetches)[0]
164+
outs = exe.run(
165+
train_compiled_program, feed=data, fetch_list=fetches)[0]
169166
batch_time = time.time() - s_time
170167
if batch_id % 10 == 0:
171168
_logger.info(
@@ -175,9 +172,8 @@ def search_mobilenetv2_block(config, args, image_size):
175172
reward = []
176173
for batch_id, data in enumerate(val_loader()):
177174
test_fetches = [avg_cost.name, acc_top1.name, acc_top5.name]
178-
batch_reward = exe.run(test_program,
179-
feed=data,
180-
fetch_list=test_fetches)
175+
batch_reward = exe.run(
176+
test_program, feed=data, fetch_list=test_fetches)
181177
reward_avg = np.mean(np.array(batch_reward), axis=1)
182178
reward.append(reward_avg)
183179

demo/nas/rl_nas_mobilenetv2.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -141,15 +141,13 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
141141

142142
build_strategy = static.BuildStrategy()
143143
train_compiled_program = static.CompiledProgram(
144-
train_program).with_data_parallel(
145-
loss_name=avg_cost.name, build_strategy=build_strategy)
144+
train_program, build_strategy=build_strategy)
146145
for epoch_id in range(args.retain_epoch):
147146
for batch_id, data in enumerate(train_loader()):
148147
fetches = [avg_cost.name]
149148
s_time = time.time()
150-
outs = exe.run(train_compiled_program,
151-
feed=data,
152-
fetch_list=fetches)[0]
149+
outs = exe.run(
150+
train_compiled_program, feed=data, fetch_list=fetches)[0]
153151
batch_time = time.time() - s_time
154152
if batch_id % 10 == 0:
155153
_logger.info(
@@ -161,9 +159,8 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
161159
test_fetches = [
162160
test_avg_cost.name, test_acc_top1.name, test_acc_top5.name
163161
]
164-
batch_reward = exe.run(test_program,
165-
feed=data,
166-
fetch_list=test_fetches)
162+
batch_reward = exe.run(
163+
test_program, feed=data, fetch_list=test_fetches)
167164
reward_avg = np.mean(np.array(batch_reward), axis=1)
168165
reward.append(reward_avg)
169166

demo/nas/sa_nas_mobilenetv2.py

+10-16
Original file line numberDiff line numberDiff line change
@@ -134,15 +134,13 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
134134

135135
build_strategy = static.BuildStrategy()
136136
train_compiled_program = static.CompiledProgram(
137-
train_program).with_data_parallel(
138-
loss_name=avg_cost.name, build_strategy=build_strategy)
137+
train_program, build_strategy=build_strategy)
139138
for epoch_id in range(args.retain_epoch):
140139
for batch_id, data in enumerate(train_loader()):
141140
fetches = [avg_cost.name]
142141
s_time = time.time()
143-
outs = exe.run(train_compiled_program,
144-
feed=data,
145-
fetch_list=fetches)[0]
142+
outs = exe.run(
143+
train_compiled_program, feed=data, fetch_list=fetches)[0]
146144
batch_time = time.time() - s_time
147145
if batch_id % 10 == 0:
148146
_logger.info(
@@ -154,9 +152,8 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
154152
test_fetches = [
155153
test_avg_cost.name, test_acc_top1.name, test_acc_top5.name
156154
]
157-
batch_reward = exe.run(test_program,
158-
feed=data,
159-
fetch_list=test_fetches)
155+
batch_reward = exe.run(
156+
test_program, feed=data, fetch_list=test_fetches)
160157
reward_avg = np.mean(np.array(batch_reward), axis=1)
161158
reward.append(reward_avg)
162159

@@ -223,15 +220,13 @@ def test_search_result(tokens, image_size, args, config):
223220

224221
build_strategy = static.BuildStrategy()
225222
train_compiled_program = static.CompiledProgram(
226-
train_program).with_data_parallel(
227-
loss_name=avg_cost.name, build_strategy=build_strategy)
223+
train_program, build_strategy=build_strategy)
228224
for epoch_id in range(args.retain_epoch):
229225
for batch_id, data in enumerate(train_loader()):
230226
fetches = [avg_cost.name]
231227
s_time = time.time()
232-
outs = exe.run(train_compiled_program,
233-
feed=data,
234-
fetch_list=fetches)[0]
228+
outs = exe.run(
229+
train_compiled_program, feed=data, fetch_list=fetches)[0]
235230
batch_time = time.time() - s_time
236231
if batch_id % 10 == 0:
237232
_logger.info(
@@ -243,9 +238,8 @@ def test_search_result(tokens, image_size, args, config):
243238
test_fetches = [
244239
test_avg_cost.name, test_acc_top1.name, test_acc_top5.name
245240
]
246-
batch_reward = exe.run(test_program,
247-
feed=data,
248-
fetch_list=test_fetches)
241+
batch_reward = exe.run(
242+
test_program, feed=data, fetch_list=test_fetches)
249243
reward_avg = np.mean(np.array(batch_reward), axis=1)
250244
reward.append(reward_avg)
251245

demo/nas/sanas_darts_space.py

+12-15
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ def train(main_prog, exe, epoch_id, train_loader, fetch_list, args):
119119
[[drop_path_probility * epoch_id / args.retain_epoch]
120120
for i in range(args.batch_size)]).astype(np.float32)
121121
drop_path_mask = 1 - np.random.binomial(
122-
1, drop_path_prob[0],
123-
size=[args.batch_size, 20, 4, 2]).astype(np.float32)
122+
1, drop_path_prob[0], size=[args.batch_size, 20, 4, 2
123+
]).astype(np.float32)
124124
feed.append({
125125
"image": image,
126126
"label": label,
@@ -195,8 +195,8 @@ def search(config, args, image_size, is_server=True):
195195

196196
current_params = count_parameters_in_MB(
197197
train_program.global_block().all_parameters(), 'cifar10')
198-
_logger.info('step: {}, current_params: {}M'.format(step,
199-
current_params))
198+
_logger.info(
199+
'step: {}, current_params: {}M'.format(step, current_params))
200200
if current_params > float(3.77):
201201
continue
202202

@@ -222,9 +222,7 @@ def search(config, args, image_size, is_server=True):
222222

223223
build_strategy = static.BuildStrategy()
224224
train_compiled_program = static.CompiledProgram(
225-
train_program).with_data_parallel(
226-
loss_name=train_fetch_list[0].name,
227-
build_strategy=build_strategy)
225+
train_program, build_strategy=build_strategy)
228226

229227
valid_top1_list = []
230228
for epoch_id in range(args.retain_epoch):
@@ -234,8 +232,8 @@ def search(config, args, image_size, is_server=True):
234232
step, epoch_id, train_top1))
235233
valid_top1 = valid(test_program, exe, epoch_id, test_loader,
236234
test_fetch_list, args)
237-
_logger.info("TEST: Epoch {}, valid_acc {:.6f}".format(epoch_id,
238-
valid_top1))
235+
_logger.info(
236+
"TEST: Epoch {}, valid_acc {:.6f}".format(epoch_id, valid_top1))
239237
valid_top1_list.append(valid_top1)
240238
sa_nas.reward(float(valid_top1_list[-1] + valid_top1_list[-2]) / 2)
241239

@@ -276,19 +274,18 @@ def final_test(config, args, image_size, token=None):
276274

277275
build_strategy = static.BuildStrategy()
278276
train_compiled_program = static.CompiledProgram(
279-
train_program).with_data_parallel(
280-
loss_name=train_fetch_list[0].name, build_strategy=build_strategy)
277+
train_program, build_strategy=build_strategy)
281278

282279
valid_top1_list = []
283280
for epoch_id in range(args.retain_epoch):
284281
train_top1 = train(train_compiled_program, exe, epoch_id, train_loader,
285282
train_fetch_list, args)
286-
_logger.info("TRAIN: Epoch {}, train_acc {:.6f}".format(epoch_id,
287-
train_top1))
283+
_logger.info(
284+
"TRAIN: Epoch {}, train_acc {:.6f}".format(epoch_id, train_top1))
288285
valid_top1 = valid(test_program, exe, epoch_id, test_loader,
289286
test_fetch_list, args)
290-
_logger.info("TEST: Epoch {}, valid_acc {:.6f}".format(epoch_id,
291-
valid_top1))
287+
_logger.info(
288+
"TEST: Epoch {}, valid_acc {:.6f}".format(epoch_id, valid_top1))
292289
valid_top1_list.append(valid_top1)
293290

294291
output_dir = os.path.join('darts_output', str(epoch_id))

demo/prune/README.md

+13-3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ tar -xf MobileNetV1_pretrained.tar
3434

3535
通过以下命令启动裁剪任务:
3636

37+
- 单卡启动:
3738
```
3839
export CUDA_VISIBLE_DEVICES=0
3940
python train.py \
@@ -43,9 +44,18 @@ python train.py \
4344
--criterion "l1_norm"
4445
```
4546

46-
其中,`model`用于指定待裁剪的模型。`pruned_ratio`用于指定各个卷积层通道数被裁剪的比例。`data`选项用于指定使用的数据集。
47-
`criterion` 选项用于指定所使用的剪裁算法策略,现在支持`l1_norm`, `bn_scale`, `geometry_median`。默认为`l1_norm`。可以
48-
设置该参数以改变剪裁算法策略。该目录下的四个shell脚本文件是在ResNet34, MobileNetV1, MobileNetV2等三个模型上进行的四组
47+
- 多卡启动:
48+
```
49+
export CUDA_VISIBLE_DEVICES=0, 1
50+
python -m paddle.distributed.launch train.py \
51+
--model "MobileNet" \
52+
--pruned_ratio 0.31 \
53+
--data "mnist" \
54+
--criterion "l1_norm" \
55+
--fleet
56+
```
57+
58+
其中,`model`用于指定待裁剪的模型。`pruned_ratio`用于指定各个卷积层通道数被裁剪的比例。`data`选项用于指定使用的数据集。`criterion` 选项用于指定所使用的剪裁算法策略,现在支持`l1_norm`, `bn_scale`, `geometry_median`,默认为`l1_norm``fleet` 用于开启多卡训练,在多卡启动时需要调用该参数。该目录下的四个shell脚本文件是在ResNet34, MobileNetV1, MobileNetV2等三个模型上进行的四组
4959
`criterion`设置为`geometry_median`的实验,可以直接运行脚本文件启动剪裁实验。
5060

5161
执行`python train.py --help`查看更多选项。

demo/prune/fpgm_mobilenetv1_f-50_train.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/bin/bash
22
export CUDA_VISIBLE_DEVICES=0,1
33
export FLAGS_fraction_of_gpu_memory_to_use=0.98
4-
python train.py \
4+
python -m paddle.distributed.launch train.py \
55
--model="MobileNet" \
66
--pretrained_model="/workspace/models/MobileNetV1_pretrained" \
77
--data="imagenet" \
@@ -14,4 +14,5 @@ python train.py \
1414
--lr_strategy="piecewise_decay" \
1515
--criterion="geometry_median" \
1616
--model_path="./fpgm_mobilenetv1_models" \
17+
--fleet \
1718
2>&1 | tee fpgm_mobilenetv1_train.log

demo/prune/fpgm_mobilenetv2_f-50_train.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/bin/bash
22
export CUDA_VISIBLE_DEVICES=0,1
33
export FLAGS_fraction_of_gpu_memory_to_use=0.98
4-
python train.py \
4+
python -m paddle.distributed.launch train.py \
55
--model="MobileNetV2" \
66
--pretrained_model="/workspace/models/MobileNetV2_pretrained" \
77
--data="imagenet" \
@@ -14,4 +14,5 @@ python train.py \
1414
--lr_strategy="piecewise_decay" \
1515
--criterion="geometry_median" \
1616
--model_path="./fpgm_mobilenetv2_models" \
17+
--fleet \
1718
2>&1 | tee fpgm_mobilenetv2_train.log
+2-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
#!/bin/bash
22
export CUDA_VISIBLE_DEVICES=0,1,2,3
33
export FLAGS_fraction_of_gpu_memory_to_use=0.98
4-
python train.py \
4+
python -m paddle.distributed.launch train.py \
55
--model="ResNet34" \
66
--pretrained_model="/workspace/models/ResNet34_pretrained" \
77
--data="imagenet" \
88
--pruned_ratio=0.25 \
99
--lr_strategy="cosine_decay" \
1010
--criterion="geometry_median" \
1111
--model_path="./fpgm_resnet34_025_120_models" \
12+
--fleet \
1213
2>&1 | tee fpgm_resnet025_120_train.log

demo/prune/fpgm_resnet34_f-50_train.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/bin/bash
22
export CUDA_VISIBLE_DEVICES=0,1
33
export FLAGS_fraction_of_gpu_memory_to_use=0.98
4-
python train.py \
4+
python -m paddle.distributed.launch train.py \
55
--model="ResNet34" \
66
--pretrained_model="/workspace/models/ResNet34_pretrained" \
77
--data="imagenet" \
@@ -14,4 +14,5 @@ python train.py \
1414
--lr_strategy="piecewise_decay" \
1515
--criterion="geometry_median" \
1616
--model_path="./fpgm_resnet34_models" \
17+
--fleet \
1718
2>&1 | tee fpgm_resnet03_train.log

demo/prune/train.py

+22-16
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import models
1616
from utility import add_arguments, print_arguments
1717
import paddle.vision.transforms as T
18+
from paddle.distributed import fleet
1819

1920
_logger = get_logger(__name__, level=logging.INFO)
2021

@@ -40,6 +41,7 @@
4041
add_arg('criterion', str, "l1_norm", "The prune criterion to be used, support l1_norm and batch_norm_scale.")
4142
add_arg('save_inference', bool, False, "Whether to save inference model.")
4243
add_arg('ce_test', bool, False, "Whether to CE test.")
44+
parser.add_argument('fleet', action='store_true', help="Whether to turn on distributed training.")
4345
# yapf: enable
4446

4547
model_list = models.__all__
@@ -96,6 +98,8 @@ def create_optimizer(args, step_per_epoch):
9698

9799

98100
def compress(args):
101+
if args.fleet:
102+
fleet.init(is_collective=True)
99103

100104
num_workers = 4
101105
shuffle = True
@@ -130,8 +134,8 @@ def compress(args):
130134
else:
131135
raise ValueError("{} is not supported.".format(args.data))
132136
image_shape = [int(m) for m in image_shape.split(",")]
133-
assert args.model in model_list, "{} is not in lists: {}".format(args.model,
134-
model_list)
137+
assert args.model in model_list, "{} is not in lists: {}".format(
138+
args.model, model_list)
135139
places = paddle.static.cuda_places(
136140
) if args.use_gpu else paddle.static.cpu_places()
137141
place = places[0]
@@ -140,13 +144,16 @@ def compress(args):
140144
name='image', shape=[None] + image_shape, dtype='float32')
141145
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
142146
batch_size_per_card = int(args.batch_size / len(places))
147+
sampler = paddle.io.DistributedBatchSampler(
148+
train_dataset,
149+
shuffle=shuffle,
150+
drop_last=True,
151+
batch_size=batch_size_per_card)
143152
train_loader = paddle.io.DataLoader(
144153
train_dataset,
145154
places=places,
146155
feed_list=[image, label],
147-
drop_last=True,
148-
batch_size=batch_size_per_card,
149-
shuffle=shuffle,
156+
batch_sampler=sampler,
150157
return_list=False,
151158
use_shared_memory=True,
152159
num_workers=num_workers)
@@ -171,6 +178,8 @@ def compress(args):
171178
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
172179
val_program = paddle.static.default_main_program().clone(for_test=True)
173180
opt, learning_rate = create_optimizer(args, step_per_epoch)
181+
if args.fleet:
182+
opt = fleet.distributed_optimizer(opt)
174183
opt.minimize(avg_cost)
175184

176185
exe.run(paddle.static.default_startup_program())
@@ -180,8 +189,8 @@ def compress(args):
180189
def if_exist(var):
181190
return os.path.exists(os.path.join(args.pretrained_model, var.name))
182191

183-
_logger.info("Load pretrained model from {}".format(
184-
args.pretrained_model))
192+
_logger.info(
193+
"Load pretrained model from {}".format(args.pretrained_model))
185194
paddle.static.load(paddle.static.default_main_program(),
186195
args.pretrained_model, exe)
187196

@@ -247,13 +256,10 @@ def train(epoch, program):
247256
place=place)
248257
_logger.info("FLOPs after pruning: {}".format(flops(pruned_program)))
249258

250-
build_strategy = paddle.static.BuildStrategy()
251-
exec_strategy = paddle.static.ExecutionStrategy()
252-
train_program = paddle.static.CompiledProgram(
253-
pruned_program).with_data_parallel(
254-
loss_name=avg_cost.name,
255-
build_strategy=build_strategy,
256-
exec_strategy=exec_strategy)
259+
if args.fleet:
260+
train_program = paddle.static.CompiledProgram(pruned_program)
261+
else:
262+
train_program = pruned_program
257263

258264
for i in range(args.num_epochs):
259265
train(i, train_program)
@@ -268,8 +274,8 @@ def train(epoch, program):
268274
infer_model_path, [image], [out],
269275
exe,
270276
program=pruned_val_program)
271-
_logger.info("Saved inference model into [{}]".format(
272-
infer_model_path))
277+
_logger.info(
278+
"Saved inference model into [{}]".format(infer_model_path))
273279

274280

275281
def main():

0 commit comments

Comments
 (0)