Skip to content

Commit 0b3f626

Browse files
authored
add new learing rate strategy to reduce lr when loss reach on plateau (PaddlePaddle#24322) (PaddlePaddle#24979)
添加loss自适应的学习率衰减策略。
1 parent 1185a96 commit 0b3f626

File tree

3 files changed

+304
-5
lines changed

3 files changed

+304
-5
lines changed

python/paddle/fluid/dygraph/learning_rate_scheduler.py

+195-2
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717
import math
1818

1919
from .. import unique_name
20+
from ..framework import Variable
21+
from ..data_feeder import check_type
2022

2123
__all__ = [
2224
'NoamDecay', 'PiecewiseDecay', 'NaturalExpDecay', 'ExponentialDecay',
23-
'InverseTimeDecay', 'PolynomialDecay', 'CosineDecay'
25+
'InverseTimeDecay', 'PolynomialDecay', 'CosineDecay', 'LinearLrWarmup',
26+
'ReduceLROnPlateau'
2427
]
2528

2629

@@ -619,7 +622,7 @@ class LinearLrWarmup(LearningRateDecay):
619622
620623
learning_rate = 0.1
621624
warmup_steps = 50
622-
start_lr = 1. / 3.
625+
start_lr = 0
623626
end_lr = 0.1
624627
625628
with fluid.dygraph.guard():
@@ -660,3 +663,193 @@ def step(self):
660663
return self.lr_ratio_before_warmup * self.step_num
661664
else:
662665
return base_lr
666+
667+
668+
class ReduceLROnPlateau(LearningRateDecay):
669+
"""
670+
Reduce learning rate when ``loss`` has stopped descending. Models often benefit from reducing the learning rate
671+
by 2 to 10 times once model performance has no longer improvement.
672+
673+
The ``loss`` is the one which has been pass into ``step`` , it must be 1-D Tensor with shape [1]. When ``loss``
674+
stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * decay_rate`` .
675+
(Specially, ``mode`` can also be set to ``'max`` , in this case, when ``loss`` stop ascending for a ``patience`` number
676+
of epochs, the learning rate will be reduced.)
677+
678+
In addition, After each reduction, it will wait a ``cooldown`` number of epochs before resuming normal operation.
679+
680+
Args:
681+
learning_rate (Variable|float|int): The initial learning rate. It can be set to python float or int number.
682+
If the type is Variable, it should be 1-D Tensor with shape [1], the data type can be 'float32' or 'float64'.
683+
mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the
684+
learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` , the learning
685+
rate will reduce when ``loss`` stops ascending. Default: ``'min'`` .
686+
decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` .
687+
It should be less than 1.0. Default: 0.1.
688+
patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced.
689+
Default: 10.
690+
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False``.
691+
threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` .
692+
This make tiny changes of ``loss`` will be ignored. Default: 1e-4.
693+
threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss``
694+
is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum
695+
change of ``loss`` is ``threshold`` . Default: ``'rel'`` .
696+
cooldown (int, optional): The number of epochs to wait before resuming normal operation. Default: 0.
697+
min_lr (float, optional): The lower bound of the learning rate after reduction. Default: 0.
698+
eps (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is
699+
ignored. Default: 1e-8.
700+
dtype (str, optional): The data type used to create the learning rate variable. The data type can be set as
701+
'float32', 'float64'. Default: 'float32'.
702+
703+
Returns:
704+
Reduced learning rate.
705+
706+
Examples:
707+
708+
.. code-block:: python
709+
710+
import paddle.fluid as fluid
711+
import numpy as np
712+
713+
with fluid.dygraph.guard():
714+
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
715+
linear = fluid.dygraph.Linear(10, 10)
716+
input = fluid.dygraph.to_variable(x)
717+
718+
reduce_lr = fluid.dygraph.ReduceLROnPlateau(
719+
learning_rate = 1.0,
720+
decay_rate = 0.5,
721+
patience = 5,
722+
verbose = True,
723+
cooldown = 3)
724+
adam = fluid.optimizer.Adam(
725+
learning_rate = reduce_lr,
726+
parameter_list = linear.parameters())
727+
728+
for epoch in range(10):
729+
total_loss = 0
730+
for bath_id in range(5):
731+
out = linear(input)
732+
loss = fluid.layers.reduce_mean(out)
733+
total_loss += loss
734+
adam.minimize(loss)
735+
736+
avg_loss = total_loss/5
737+
738+
# adjust learning rate according to avg_loss
739+
reduce_lr.step(avg_loss)
740+
lr = adam.current_step_lr()
741+
print("current avg_loss is %s, current lr is %s" % (avg_loss.numpy()[0], lr))
742+
743+
"""
744+
745+
def __init__(self,
746+
learning_rate,
747+
mode='min',
748+
decay_rate=0.1,
749+
patience=10,
750+
verbose=False,
751+
threshold=1e-4,
752+
threshold_mode='rel',
753+
cooldown=0,
754+
min_lr=0,
755+
eps=1e-8,
756+
dtype='float32'):
757+
super(ReduceLROnPlateau, self).__init__(dtype=dtype)
758+
mode = mode.lower()
759+
if mode not in ['min', 'max']:
760+
raise ValueError('mode ' + mode + ' is unknown!')
761+
self.mode = mode
762+
763+
if decay_rate >= 1.0:
764+
raise ValueError(
765+
'new_lr = origin_lr * decay_rate and decay_rate should be < 1.0.'
766+
)
767+
self.decay_rate = decay_rate
768+
769+
threshold_mode = threshold_mode.lower()
770+
if threshold_mode not in ['rel', 'abs']:
771+
raise ValueError('threshold mode ' + threshold_mode +
772+
' is unknown!')
773+
self.threshold_mode = threshold_mode
774+
775+
check_type(learning_rate, 'learning_rate', (float, int, Variable),
776+
'ReduceLROnPlateau')
777+
if isinstance(learning_rate, (float, int)):
778+
learning_rate = self.create_lr_var(learning_rate)
779+
780+
self.learning_rate = learning_rate
781+
self.verbose = verbose
782+
self.patience = patience
783+
self.threshold = threshold
784+
self.threshold_mode = threshold_mode
785+
self.cooldown = cooldown
786+
self.min_lr = self.create_lr_var(min_lr)
787+
self.eps = eps
788+
789+
self.cooldown_counter = 0
790+
self.best_loss = None
791+
self.num_bad_epochs = 0
792+
self.epoch = 0
793+
794+
def __call__(self):
795+
return self.learning_rate
796+
797+
def step(self, loss):
798+
"""
799+
It should be invoked on each epoch. Update the learning rate in optimizer according to ``loss`` .
800+
The new learning rate will take effect on next call to ``optimizer.minimize`` .
801+
802+
Args:
803+
loss (Variable): A ``Variable`` that will be monitored to determine whether the learning rate will reduce.
804+
If it stop descending for a ``patience`` number of epochs, the learning rate will reduce. It should
805+
be 1-D Tensor with shape [1].
806+
Specially, if ``mode`` has been set to ``'max'`` , the learning rate will reduce when it stops ascending.
807+
Returns:
808+
None
809+
810+
Examples:
811+
Please refer to the example of current LearningRateDecay.
812+
"""
813+
814+
# loss must be 1-D Tensor with shape [1]
815+
check_type(loss, 'loss', Variable, 'ReduceLROnPlateau.step')
816+
assert len(loss.shape) == 1 and loss.shape[0] == 1, "the loss.shape " \
817+
"should be (1L,), but the current loss.shape is {}. Maybe that " \
818+
"you should call fluid.layers.mean to process it first.".format(loss.shape)
819+
820+
self.epoch += 1
821+
if self.cooldown_counter > 0:
822+
self.cooldown_counter -= 1
823+
else:
824+
if self.best_loss is None or self._is_better(loss, self.best_loss):
825+
self.best_loss = loss
826+
self.num_bad_epochs = 0
827+
else:
828+
self.num_bad_epochs += 1
829+
830+
if self.num_bad_epochs > self.patience:
831+
from .. import layers
832+
self.cooldown_counter = self.cooldown
833+
self.num_bad_epochs = 0
834+
new_lr = layers.elementwise_max(self.learning_rate *
835+
self.decay_rate, self.min_lr)
836+
if self.learning_rate - new_lr > self.eps:
837+
if self.verbose:
838+
print('Epoch {}: reducing learning rate from {} to {}.'.
839+
format(self.epoch,
840+
self.learning_rate.numpy()[0],
841+
new_lr.numpy()[0]))
842+
self.learning_rate = new_lr
843+
844+
def _is_better(self, current, best):
845+
if self.mode == 'min' and self.threshold_mode == 'rel':
846+
return current < best - best * self.threshold
847+
848+
elif self.mode == 'min' and self.threshold_mode == 'abs':
849+
return current < best - self.threshold
850+
851+
elif self.mode == 'max' and self.threshold_mode == 'rel':
852+
return current > best + best * self.threshold
853+
854+
else:
855+
return current > best + self.threshold

python/paddle/fluid/optimizer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ def apply_gradients(self, params_grads):
708708
params_grads, table_param_and_grad, table_optimize_op = \
709709
self._process_distribute_lookuptable(params_grads)
710710

711-
# 'minimize(grad_clip)' or 'set_gradient_clip'
711+
# 'optimizer(grad_clip)' or 'set_gradient_clip'
712712
if self._grad_clip is not None:
713713
params_grads = self._grad_clip(params_grads)
714714
else:
@@ -1460,7 +1460,7 @@ def apply_gradients(self, params_grads):
14601460
else:
14611461
dgc_params_grads.append((param, grad))
14621462

1463-
# 'minimize(grad_clip)' or 'set_gradient_clip'
1463+
# 'optimizer(grad_clip)' or 'set_gradient_clip'
14641464
if self._grad_clip is not None:
14651465
not_dgc_params_grads = self._grad_clip(not_dgc_params_grads)
14661466
else:

python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py

+107-1
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def test_decay(self):
199199
]
200200

201201
for py_decay_fn, fluid_decay_fn, kwargs in decay_fns:
202-
print("class=" + self.__class__.__name__ + "decay_fn=" +
202+
print("class=" + self.__class__.__name__ + " decay_fn=" +
203203
py_decay_fn.__name__ + " kwargs=" + str(kwargs))
204204
main_program = framework.Program()
205205
startup_program = framework.Program()
@@ -335,5 +335,111 @@ def test_dygraph_mode(self):
335335
end_lr=1.0)
336336

337337

338+
def reduce_lr_on_plateau(decay_rate, threshold, cooldown, patience, m, n, loss,
339+
var_list):
340+
def is_better(current, best, m, n):
341+
if m == 'min' and n == 'rel':
342+
return current < best - best * threshold
343+
elif m == 'min' and n == 'abs':
344+
return current < best - threshold
345+
elif m == 'max' and n == 'rel':
346+
return current > best + best * threshold
347+
else: # mode == 'max' and epsilon_mode == 'abs':
348+
return current > best + threshold
349+
350+
if var_list[2] > 0:
351+
var_list[2] -= 1
352+
return var_list[1]
353+
354+
if is_better(loss, var_list[0], m, n):
355+
var_list[0] = loss
356+
var_list[3] = 0
357+
else:
358+
var_list[3] += 1
359+
if var_list[3] > patience:
360+
var_list[2] = cooldown
361+
var_list[3] = 0
362+
new_lr = var_list[1] * decay_rate
363+
var_list[1] = new_lr if var_list[1] - new_lr > 1e-8 else var_list[1]
364+
365+
return var_list[1]
366+
367+
368+
class TestReduceLROnPlateauDecay(unittest.TestCase):
369+
def test_dygraph_mode(self):
370+
with fluid.dygraph.guard():
371+
# the decay rate must be less than 1.0
372+
with self.assertRaises(ValueError):
373+
fluid.dygraph.ReduceLROnPlateau(
374+
learning_rate=1.0, decay_rate=2.0)
375+
# the mode must be "min" or "max"
376+
with self.assertRaises(ValueError):
377+
fluid.dygraph.ReduceLROnPlateau(learning_rate=1.0, mode="test")
378+
# the threshold_mode must be "rel" or "abs"
379+
with self.assertRaises(ValueError):
380+
fluid.dygraph.ReduceLROnPlateau(
381+
learning_rate=1.0, threshold_mode="test")
382+
383+
base_lr = 1.0
384+
patience = 3
385+
cooldown = 1
386+
decay_rate = 0.5
387+
threshold = 1e-4
388+
linear = fluid.dygraph.Linear(10, 10)
389+
390+
for m, n in zip(['min', 'max', 'min', 'max'],
391+
['rel', 'rel', 'abs', 'abs']):
392+
kwargs = {
393+
'learning_rate': base_lr,
394+
'decay_rate': decay_rate,
395+
'threshold': threshold,
396+
'verbose': True,
397+
'patience': patience,
398+
'cooldown': cooldown,
399+
'mode': m,
400+
'threshold_mode': n,
401+
'eps': 1e-6
402+
}
403+
print("class=" + fluid.dygraph.ReduceLROnPlateau.__name__ +
404+
" kwargs=" + str(kwargs))
405+
lr = fluid.dygraph.ReduceLROnPlateau(**kwargs)
406+
sgd = fluid.optimizer.SGD(learning_rate=lr,
407+
parameter_list=linear.parameters())
408+
409+
best = float("-10000") if m == "max" else float("10000")
410+
expected_lr = 1.0
411+
cooldown_counter = 0
412+
num_bad_epochs = 0
413+
var_list = [best, expected_lr, cooldown_counter, num_bad_epochs]
414+
step_num = 0
415+
epoch_num = 0
416+
for epoch in range(30):
417+
total_loss = 0
418+
419+
for batch_id in range(2):
420+
step_num += 1
421+
x = fluid.dygraph.to_variable(
422+
np.array([step_num]).astype('float32'))
423+
loss = layers.sin(x)
424+
sgd.minimize(loss)
425+
total_loss += loss
426+
427+
epoch_num += 1
428+
# get expected lr from fluid
429+
avg_loss = total_loss / 1
430+
lr.step(avg_loss)
431+
actual_lr = lr().numpy()[0]
432+
433+
# get expected lr form python
434+
expected_lr = reduce_lr_on_plateau(decay_rate, threshold,
435+
cooldown, patience, m, n,
436+
avg_loss, var_list)
437+
self.assertEqual(
438+
expected_lr,
439+
actual_lr,
440+
msg='Failed reduce lr scheduler in epoch {0}, Python result is {1}, Fluid result is {2}'.
441+
format(epoch_num, expected_lr, actual_lr))
442+
443+
338444
if __name__ == '__main__':
339445
unittest.main()

0 commit comments

Comments
 (0)