|
15 | 15 | from utility import add_arguments, print_arguments, _download, _decompress
|
16 | 16 | from paddleslim.dist import merge, l2, soft_label
|
17 | 17 |
|
| 18 | +from paddle.distributed import fleet |
| 19 | +from paddle.distributed.fleet import DistributedStrategy |
| 20 | + |
18 | 21 | logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
|
19 | 22 | _logger = logging.getLogger(__name__)
|
20 | 23 | _logger.setLevel(logging.INFO)
|
@@ -76,6 +79,9 @@ def create_optimizer(args):
|
76 | 79 |
|
77 | 80 |
|
78 | 81 | def compress(args):
|
| 82 | + |
| 83 | + fleet.init(is_collective=True) |
| 84 | + |
79 | 85 | if args.data == "cifar10":
|
80 | 86 | train_dataset = paddle.vision.datasets.Cifar10(mode='train')
|
81 | 87 | val_dataset = paddle.vision.datasets.Cifar10(mode='test')
|
@@ -103,38 +109,38 @@ def compress(args):
|
103 | 109 | else:
|
104 | 110 | devices_num = int(os.environ.get('CPU_NUM', 1))
|
105 | 111 | with paddle.static.program_guard(student_program, s_startup):
|
106 |
| - with paddle.utils.unique_name.guard(): |
107 |
| - image = paddle.static.data( |
108 |
| - name='image', shape=[None] + image_shape, dtype='float32') |
109 |
| - label = paddle.static.data( |
110 |
| - name='label', shape=[None, 1], dtype='int64') |
111 |
| - train_loader = paddle.io.DataLoader( |
112 |
| - train_dataset, |
113 |
| - places=places, |
114 |
| - feed_list=[image, label], |
115 |
| - drop_last=True, |
116 |
| - batch_size=int(args.batch_size / devices_num), |
117 |
| - return_list=False, |
118 |
| - shuffle=True, |
119 |
| - use_shared_memory=True, |
120 |
| - num_workers=4) |
121 |
| - valid_loader = paddle.io.DataLoader( |
122 |
| - val_dataset, |
123 |
| - places=place, |
124 |
| - feed_list=[image, label], |
125 |
| - drop_last=False, |
126 |
| - return_list=False, |
127 |
| - use_shared_memory=True, |
128 |
| - batch_size=args.batch_size, |
129 |
| - shuffle=False) |
130 |
| - # model definition |
131 |
| - model = models.__dict__[args.model]() |
132 |
| - out = model.net(input=image, class_dim=class_dim) |
133 |
| - cost = paddle.nn.functional.loss.cross_entropy( |
134 |
| - input=out, label=label) |
135 |
| - avg_cost = paddle.mean(x=cost) |
136 |
| - acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) |
137 |
| - acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5) |
| 112 | + image = paddle.static.data( |
| 113 | + name='image', shape=[None] + image_shape, dtype='float32') |
| 114 | + label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') |
| 115 | + sampler = paddle.io.DistributedBatchSampler( |
| 116 | + train_dataset, |
| 117 | + shuffle=False, |
| 118 | + drop_last=True, |
| 119 | + batch_size=args.batch_size) |
| 120 | + train_loader = paddle.io.DataLoader( |
| 121 | + train_dataset, |
| 122 | + places=places, |
| 123 | + feed_list=[image, label], |
| 124 | + batch_sampler=sampler, |
| 125 | + return_list=False, |
| 126 | + use_shared_memory=False, |
| 127 | + num_workers=4) |
| 128 | + valid_loader = paddle.io.DataLoader( |
| 129 | + val_dataset, |
| 130 | + places=place, |
| 131 | + feed_list=[image, label], |
| 132 | + drop_last=False, |
| 133 | + return_list=False, |
| 134 | + use_shared_memory=False, |
| 135 | + batch_size=args.batch_size, |
| 136 | + shuffle=False) |
| 137 | + # model definition |
| 138 | + model = models.__dict__[args.model]() |
| 139 | + out = model.net(input=image, class_dim=class_dim) |
| 140 | + cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label) |
| 141 | + avg_cost = paddle.mean(x=cost) |
| 142 | + acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) |
| 143 | + acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5) |
138 | 144 |
|
139 | 145 | val_program = student_program.clone(for_test=True)
|
140 | 146 | exe = paddle.static.Executor(place)
|
@@ -172,18 +178,19 @@ def if_exist(var):
|
172 | 178 | data_name_map = {'image': 'image'}
|
173 | 179 | merge(teacher_program, student_program, data_name_map, place)
|
174 | 180 |
|
| 181 | + build_strategy = paddle.static.BuildStrategy() |
| 182 | + dist_strategy = DistributedStrategy() |
| 183 | + dist_strategy.build_strategy = build_strategy |
| 184 | + |
175 | 185 | with paddle.static.program_guard(student_program, s_startup):
|
176 | 186 | distill_loss = soft_label("teacher_fc_0.tmp_0", "fc_0.tmp_0",
|
177 | 187 | student_program)
|
178 | 188 | loss = avg_cost + distill_loss
|
179 | 189 | lr, opt = create_optimizer(args)
|
| 190 | + opt = fleet.distributed_optimizer(opt, strategy=dist_strategy) |
180 | 191 | opt.minimize(loss)
|
181 | 192 | exe.run(s_startup)
|
182 |
| - build_strategy = paddle.static.BuildStrategy() |
183 |
| - build_strategy.fuse_all_reduce_ops = False |
184 |
| - parallel_main = paddle.static.CompiledProgram( |
185 |
| - student_program).with_data_parallel( |
186 |
| - loss_name=loss.name, build_strategy=build_strategy) |
| 193 | + parallel_main = student_program |
187 | 194 |
|
188 | 195 | for epoch_id in range(args.num_epochs):
|
189 | 196 | for step_id, data in enumerate(train_loader):
|
|
0 commit comments