Skip to content

Commit e8d6f63

Browse files
authored
replace with_data_parallel with fleet (#1626)
1 parent 8bf2df5 commit e8d6f63

File tree

2 files changed

+47
-40
lines changed

2 files changed

+47
-40
lines changed

demo/distillation/README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,22 @@
1313
默认配置:
1414

1515
```yaml
16-
batch_size: 256
16+
batch_size: 64
1717
init_lr: 0.1
1818
lr_strategy: piecewise_decay
1919
l2_decay: 3e-5
2020
momentum_rate: 0.9
2121
num_epochs: 120
2222
data: imagenet
2323
```
24-
训练使用默认配置启动即可
24+
训练使用默认配置启动即可。这里的batch_size指每张卡上的batch_size。
2525
2626
### 2. 启动训练
2727
2828
在配置好ImageNet数据集后,用以下命令启动训练即可:
2929
3030
```shell
31-
CUDA_VISIBLE_DEVICES=0,1,2,3 python distill.py
31+
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch distill.py
3232
```
3333

3434
### 3. 训练结果

demo/distillation/distill.py

+44-37
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from utility import add_arguments, print_arguments, _download, _decompress
1616
from paddleslim.dist import merge, l2, soft_label
1717

18+
from paddle.distributed import fleet
19+
from paddle.distributed.fleet import DistributedStrategy
20+
1821
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
1922
_logger = logging.getLogger(__name__)
2023
_logger.setLevel(logging.INFO)
@@ -76,6 +79,9 @@ def create_optimizer(args):
7679

7780

7881
def compress(args):
82+
83+
fleet.init(is_collective=True)
84+
7985
if args.data == "cifar10":
8086
train_dataset = paddle.vision.datasets.Cifar10(mode='train')
8187
val_dataset = paddle.vision.datasets.Cifar10(mode='test')
@@ -103,38 +109,38 @@ def compress(args):
103109
else:
104110
devices_num = int(os.environ.get('CPU_NUM', 1))
105111
with paddle.static.program_guard(student_program, s_startup):
106-
with paddle.utils.unique_name.guard():
107-
image = paddle.static.data(
108-
name='image', shape=[None] + image_shape, dtype='float32')
109-
label = paddle.static.data(
110-
name='label', shape=[None, 1], dtype='int64')
111-
train_loader = paddle.io.DataLoader(
112-
train_dataset,
113-
places=places,
114-
feed_list=[image, label],
115-
drop_last=True,
116-
batch_size=int(args.batch_size / devices_num),
117-
return_list=False,
118-
shuffle=True,
119-
use_shared_memory=True,
120-
num_workers=4)
121-
valid_loader = paddle.io.DataLoader(
122-
val_dataset,
123-
places=place,
124-
feed_list=[image, label],
125-
drop_last=False,
126-
return_list=False,
127-
use_shared_memory=True,
128-
batch_size=args.batch_size,
129-
shuffle=False)
130-
# model definition
131-
model = models.__dict__[args.model]()
132-
out = model.net(input=image, class_dim=class_dim)
133-
cost = paddle.nn.functional.loss.cross_entropy(
134-
input=out, label=label)
135-
avg_cost = paddle.mean(x=cost)
136-
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
137-
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
112+
image = paddle.static.data(
113+
name='image', shape=[None] + image_shape, dtype='float32')
114+
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
115+
sampler = paddle.io.DistributedBatchSampler(
116+
train_dataset,
117+
shuffle=False,
118+
drop_last=True,
119+
batch_size=args.batch_size)
120+
train_loader = paddle.io.DataLoader(
121+
train_dataset,
122+
places=places,
123+
feed_list=[image, label],
124+
batch_sampler=sampler,
125+
return_list=False,
126+
use_shared_memory=False,
127+
num_workers=4)
128+
valid_loader = paddle.io.DataLoader(
129+
val_dataset,
130+
places=place,
131+
feed_list=[image, label],
132+
drop_last=False,
133+
return_list=False,
134+
use_shared_memory=False,
135+
batch_size=args.batch_size,
136+
shuffle=False)
137+
# model definition
138+
model = models.__dict__[args.model]()
139+
out = model.net(input=image, class_dim=class_dim)
140+
cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
141+
avg_cost = paddle.mean(x=cost)
142+
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
143+
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
138144

139145
val_program = student_program.clone(for_test=True)
140146
exe = paddle.static.Executor(place)
@@ -172,18 +178,19 @@ def if_exist(var):
172178
data_name_map = {'image': 'image'}
173179
merge(teacher_program, student_program, data_name_map, place)
174180

181+
build_strategy = paddle.static.BuildStrategy()
182+
dist_strategy = DistributedStrategy()
183+
dist_strategy.build_strategy = build_strategy
184+
175185
with paddle.static.program_guard(student_program, s_startup):
176186
distill_loss = soft_label("teacher_fc_0.tmp_0", "fc_0.tmp_0",
177187
student_program)
178188
loss = avg_cost + distill_loss
179189
lr, opt = create_optimizer(args)
190+
opt = fleet.distributed_optimizer(opt, strategy=dist_strategy)
180191
opt.minimize(loss)
181192
exe.run(s_startup)
182-
build_strategy = paddle.static.BuildStrategy()
183-
build_strategy.fuse_all_reduce_ops = False
184-
parallel_main = paddle.static.CompiledProgram(
185-
student_program).with_data_parallel(
186-
loss_name=loss.name, build_strategy=build_strategy)
193+
parallel_main = student_program
187194

188195
for epoch_id in range(args.num_epochs):
189196
for step_id, data in enumerate(train_loader):

0 commit comments

Comments
 (0)