Skip to content

Commit f68ec4b

Browse files
authored
[cherry-pick] add loss info and skd distillation (#1612)
* add skd distillation. (#1587) * add skd distillation. * update skd's test. * [ACT] add loss info (#1597) * add loss info on ACT training. * Add flops info.
1 parent d521460 commit f68ec4b

File tree

7 files changed

+189
-22
lines changed

7 files changed

+189
-22
lines changed

paddleslim/auto_compression/analysis.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,19 @@ def analysis_prune(eval_function,
2828
params_filename,
2929
analysis_file,
3030
pruned_ratios,
31-
target_loss=None):
31+
target_loss=None,
32+
criterion='l1_norm'):
33+
'''
34+
Args:
35+
eval_func(function): The callback function used to evaluate the model. It should accept a instance of `paddle.static.Program` as argument and return a score on test dataset.
36+
model_dir(str): Directory path to load model. If you want to load onnx model, only set ``model_dir=model.onnx``.
37+
model_filename(str): Specify model_filename. If you want to load onnx model, model filename should be None.
38+
params_filename(str): Specify params_filename. If you want to load onnx model, params filename should be None.
39+
analysis_file(str): The file to save the sensitivities. It will append the latest computed sensitivities into the file. And the sensitivities in the file would not be computed again. This file can be loaded by `pickle` library.
40+
pruned_ratios(list): The ratios to be pruned.
41+
criterion(str|function): The criterion used to sort channels for pruning. Currently supports l1_ norm, bn_scale, geometry_median. Default: l1_norm.
42+
'''
43+
3244
devices = paddle.device.get_device().split(':')[0]
3345
places = paddle.device._convert_to_place(devices)
3446
exe = paddle.static.Executor(places)
@@ -47,7 +59,8 @@ def analysis_prune(eval_function,
4759
eval_function,
4860
sensitivities_file=analysis_file,
4961
eval_args=[exe, feed_target_names, fetch_targets],
50-
pruned_ratios=pruned_ratios)
62+
pruned_ratios=pruned_ratios,
63+
criterion=criterion)
5164

5265
with open(analysis_file, 'rb') as f:
5366
if sys.version_info < (3, 0):

paddleslim/auto_compression/compressor.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -783,13 +783,17 @@ def _start_train(self, train_program_info, test_program_info, strategy,
783783
total_epochs = train_config.epochs if train_config.epochs else 100
784784
total_train_iter = 0
785785
stop_training = False
786+
787+
loss_vars = [var for var in train_program_info.loss_dict.values()]
788+
loss_names = [name for name in train_program_info.loss_dict.keys()]
789+
786790
for epoch_id in range(total_epochs):
787791
if stop_training:
788792
break
789793
for batch_id, data in enumerate(self.train_dataloader()):
790-
np_probs_float, = self._exe.run(train_program_info.program, \
794+
loss = self._exe.run(train_program_info.program, \
791795
feed=data, \
792-
fetch_list=train_program_info.fetch_targets)
796+
fetch_list=train_program_info.fetch_targets+loss_vars)
793797
if not isinstance(train_program_info.learning_rate, float):
794798
train_program_info.learning_rate.step()
795799
if 'unstructure' in strategy:
@@ -800,10 +804,12 @@ def _start_train(self, train_program_info, test_program_info, strategy,
800804
else:
801805
logging_iter = train_config.logging_iter
802806
if batch_id % int(logging_iter) == 0:
803-
_logger.info(
804-
"Total iter: {}, epoch: {}, batch: {}, loss: {}".format(
805-
total_train_iter, epoch_id, batch_id,
806-
np_probs_float))
807+
print_info = "Total iter: {}, epoch: {}, batch: {}, loss: {}".format(
808+
total_train_iter, epoch_id, batch_id, loss[0])
809+
for idx, loss_value in enumerate(loss[1:]):
810+
print_info += '{}: {} '.format(loss_names[idx],
811+
loss_value)
812+
_logger.info(print_info)
807813
total_train_iter += 1
808814
if total_train_iter % int(
809815
train_config.eval_iter) == 0 and total_train_iter != 0:

paddleslim/auto_compression/create_compressed_program.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ..common import get_logger
2525
from .strategy_config import ProgramInfo
2626
from ..common.load_model import load_inference_model
27+
from ..analysis import flops
2728

2829
_logger = get_logger(__name__, level=logging.INFO)
2930
__all__ = [
@@ -118,7 +119,7 @@ def _parse_distill_loss(distill_node_pair,
118119
distill_lambda=1.0):
119120
"""parse distill loss config"""
120121
loss_dist = 0.0
121-
losses = []
122+
losses = {}
122123
if isinstance(distill_node_pair[0], str):
123124
assert isinstance(distill_loss, str)
124125
assert isinstance(distill_lambda, float)
@@ -128,16 +129,17 @@ def _parse_distill_loss(distill_node_pair,
128129

129130
assert len(distill_node_pair) == len(distill_loss)
130131
assert len(distill_node_pair) == len(distill_lambda)
131-
for node, loss, lam in zip(distill_node_pair, distill_loss, distill_lambda):
132-
tmp_loss = 0.0
133-
_logger.info("train config.distill_node_pair: {}".format(node, loss,
134-
lam))
132+
for node, loss_clas, lam in zip(distill_node_pair, distill_loss,
133+
distill_lambda):
134+
tmp_loss = losses.get(loss_clas, 0.0)
135+
_logger.info("train config.distill_node_pair: {}".format(
136+
node, loss_clas, lam))
135137
assert len(node) % 2 == 0, \
136138
"distill_node_pair config wrong, the length needs to be an even number"
137139
for i in range(len(node) // 2):
138-
tmp_loss += eval(loss)(node[i * 2], node[i * 2 + 1])
139-
loss_dist += lam * tmp_loss
140-
losses.append(tmp_loss)
140+
tmp_loss += eval(loss_clas)(node[i * 2], node[i * 2 + 1]) * lam
141+
loss_dist += tmp_loss
142+
losses[loss_clas] = tmp_loss
141143

142144
return loss_dist, losses
143145

@@ -313,7 +315,7 @@ def build_distill_program(executor,
313315
use_dynamic_loss_scaling=True,
314316
**train_config['amp_config'])
315317

316-
distill_loss, losses = _parse_distill_loss(
318+
distill_loss, loss_dict = _parse_distill_loss(
317319
distill_node_pair,
318320
config.get('loss') or 'l2', ### default loss is l2
319321
config.get('alpha') or 1.0) ### default alpha is 1.0
@@ -334,7 +336,7 @@ def build_distill_program(executor,
334336

335337
train_program_info = ProgramInfo(startup_program, train_program,
336338
feed_target_names, train_fetch_list,
337-
optimizer, learning_rate)
339+
optimizer, learning_rate, loss_dict)
338340
test_program_info = ProgramInfo(startup_program, test_program,
339341
feed_target_names, fetch_targets)
340342
return train_program_info, test_program_info
@@ -469,6 +471,8 @@ def build_prune_program(executor,
469471
params.append(param.name)
470472
original_shapes[param.name] = param.shape
471473

474+
origin_flops = flops(train_program_info.program)
475+
472476
pruned_program, _, _ = pruner.prune(
473477
train_program_info.program,
474478
paddle.static.global_scope(),
@@ -485,6 +489,12 @@ def build_prune_program(executor,
485489
param.name, original_shapes[param.name], param.shape))
486490
_logger.info(
487491
"####################channel pruning end##########################")
492+
493+
final_flops = flops(pruned_program)
494+
pruned_flops = abs(origin_flops - final_flops) / origin_flops
495+
_logger.info("FLOPs before pruning: {}".format(origin_flops))
496+
_logger.info("FLOPs after pruning: {}. Pruned FLOPs: {}%.".format(
497+
final_flops, round(pruned_flops * 100, 2)))
488498
train_program_info.program = pruned_program
489499

490500
elif strategy.startswith('asp'):

paddleslim/auto_compression/strategy_config.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,8 @@ def __init__(self,
431431
feed_target_names,
432432
fetch_targets,
433433
optimizer=None,
434-
learning_rate=None):
434+
learning_rate=None,
435+
loss_dict=None):
435436
"""
436437
ProgramInfo Config.
437438
Args:
@@ -441,10 +442,12 @@ def __init__(self,
441442
fetch_targets(list(Variable)): The fetch variable in the program.
442443
optimizer(Optimizer, optional): Optimizer in training. Default: None.
443444
learning_rate(float|paddle.optimizer.lr, optional): learning_rate in training. Default: None.
445+
loss_dict(dict): The components of losses.
444446
"""
445447
self.startup_program = startup_program
446448
self.program = program
447449
self.feed_target_names = feed_target_names
448450
self.fetch_targets = fetch_targets
449451
self.optimizer = optimizer
450452
self.learning_rate = learning_rate
453+
self.loss_dict = loss_dict

paddleslim/dist/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .single_distiller import merge, fsp, l2, soft_label, loss, dkd
15+
from .single_distiller import merge, fsp, l2, soft_label, loss, dkd, skd
1616
from .dml import DML

paddleslim/dist/single_distiller.py

+56-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numpy as np
1616
import paddle
1717
from paddleslim.core import GraphWrapper
18+
import paddle.nn.functional as F
1819

1920

2021
def merge(teacher_program,
@@ -203,8 +204,11 @@ def soft_label(teacher_var_name,
203204
teacher_var = paddle.nn.functional.softmax(teacher_var /
204205
teacher_temperature)
205206
soft_label_loss = paddle.mean(
206-
paddle.fluid.layers.cross_entropy(
207-
student_var, teacher_var, soft_label=True))
207+
paddle.nn.functional.cross_entropy(
208+
input=student_var,
209+
label=teacher_var,
210+
soft_label=True,
211+
use_softmax=False))
208212
return soft_label_loss
209213

210214

@@ -305,3 +309,53 @@ def dkd(teacher_var_name,
305309
temperature=temperature,
306310
alpha=alpha,
307311
beta=beta)
312+
313+
314+
def skd(teacher_var_name, student_var_name, program=None, multiplier=None):
315+
"""Combine variables from student model and teacher model
316+
by Spherical Knowledge Distillation loss (aka. skd-loss).
317+
Reference: https://github.com/forjiuzhou/Spherical-Knowledge-Distillation
318+
Args:
319+
teacher_var_name(str): The name of teacher_var.
320+
student_var_name(str): The name of student_var.
321+
program(Program): The input distiller program. If not specified,
322+
the default program will be used. Default: None
323+
multiplier(float): The multiplier to recover its norm to the original
324+
level. When it's None, the appropriate multiplier can be computed by
325+
teacher's logits with paddle.std(output_t, axis=1). Default: None.
326+
327+
Returns:
328+
Variable: skd distiller loss.
329+
"""
330+
if program == None:
331+
program = paddle.static.default_main_program()
332+
333+
student_var = program.global_block().var(student_var_name)
334+
teacher_var = program.global_block().var(teacher_var_name)
335+
teacher_var.stop_gradient = True
336+
337+
if multiplier is None:
338+
multiplier = paddle.std(teacher_var, axis=1, keepdim=True)
339+
340+
logits_student = F.layer_norm(
341+
student_var,
342+
student_var.shape[1:],
343+
weight=None,
344+
bias=None,
345+
epsilon=1e-7) * multiplier
346+
logits_teacher = F.layer_norm(
347+
teacher_var,
348+
teacher_var.shape[1:],
349+
weight=None,
350+
bias=None,
351+
epsilon=1e-7) * multiplier
352+
353+
student_out = F.softmax(logits_student, axis=1)
354+
teacher_out = F.softmax(logits_teacher, axis=1)
355+
skd_loss = paddle.mean(
356+
F.cross_entropy(
357+
input=student_out,
358+
label=teacher_out,
359+
soft_label=True,
360+
use_softmax=False))
361+
return skd_loss

tests/test_skd_loss.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import sys
15+
sys.path.append("../")
16+
import unittest
17+
import paddle
18+
from paddleslim.dist import merge, skd
19+
from layers import conv_bn_layer
20+
from static_case import StaticCase
21+
22+
23+
class TestSKDLoss(StaticCase):
24+
def test_skd_loss(self):
25+
place = paddle.CPUPlace()
26+
exe = paddle.static.Executor(place)
27+
28+
student_program = paddle.static.Program()
29+
student_startup = paddle.static.Program()
30+
with paddle.static.program_guard(student_program, student_startup):
31+
with paddle.utils.unique_name.guard():
32+
input = paddle.static.data(
33+
name="image", shape=[None, 3, 224, 224])
34+
conv1 = conv_bn_layer(input, 8, 3, "conv1")
35+
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
36+
student_predict = conv1 + conv2
37+
38+
teacher_program = paddle.static.Program()
39+
teacher_startup = paddle.static.Program()
40+
with paddle.static.program_guard(teacher_program, teacher_startup):
41+
with paddle.utils.unique_name.guard():
42+
input = paddle.static.data(
43+
name="image", shape=[None, 3, 224, 224])
44+
conv1 = conv_bn_layer(input, 8, 3, "conv1")
45+
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
46+
sum1 = conv1 + conv2
47+
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
48+
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
49+
sum2 = conv4 + sum1
50+
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
51+
teacher_predict = conv_bn_layer(conv5, 8, 3, "conv6")
52+
53+
exe.run(teacher_startup)
54+
exe.run(student_startup)
55+
56+
data_name_map = {'image': 'image'}
57+
merge(teacher_program, student_program, data_name_map, place)
58+
merged_ops = []
59+
for block in student_program.blocks:
60+
for op in block.ops:
61+
merged_ops.append(op.type)
62+
with paddle.static.program_guard(student_program, student_startup):
63+
distill_loss = skd('teacher_' + teacher_predict.name,
64+
student_predict.name,
65+
program=None,
66+
multiplier=None)
67+
68+
loss_ops = []
69+
for block in student_program.blocks:
70+
for op in block.ops:
71+
loss_ops.append(op.type)
72+
print(f"ret: {set(loss_ops).difference(set(merged_ops))}")
73+
self.assertTrue(set(merged_ops).difference(set(loss_ops)) == set())
74+
75+
self.assertTrue({
76+
'softmax_with_cross_entropy', 'softmax', 'reduce_mean', 'layer_norm'
77+
}.issubset(set(loss_ops).difference(set(merged_ops))))
78+
79+
80+
if __name__ == '__main__':
81+
unittest.main()

0 commit comments

Comments
 (0)