未验证 提交 0b3f6265 编写于 作者: Z Zhou Wei 提交者: GitHub

add new learing rate strategy to reduce lr when loss reach on plateau (#24322) (#24979)

添加loss自适应的学习率衰减策略。
上级 1185a96f
......@@ -17,10 +17,13 @@ from __future__ import print_function
import math
from .. import unique_name
from ..framework import Variable
from ..data_feeder import check_type
__all__ = [
'NoamDecay', 'PiecewiseDecay', 'NaturalExpDecay', 'ExponentialDecay',
'InverseTimeDecay', 'PolynomialDecay', 'CosineDecay'
'InverseTimeDecay', 'PolynomialDecay', 'CosineDecay', 'LinearLrWarmup',
'ReduceLROnPlateau'
]
......@@ -619,7 +622,7 @@ class LinearLrWarmup(LearningRateDecay):
learning_rate = 0.1
warmup_steps = 50
start_lr = 1. / 3.
start_lr = 0
end_lr = 0.1
with fluid.dygraph.guard():
......@@ -660,3 +663,193 @@ class LinearLrWarmup(LearningRateDecay):
return self.lr_ratio_before_warmup * self.step_num
else:
return base_lr
class ReduceLROnPlateau(LearningRateDecay):
"""
Reduce learning rate when ``loss`` has stopped descending. Models often benefit from reducing the learning rate
by 2 to 10 times once model performance has no longer improvement.
The ``loss`` is the one which has been pass into ``step`` , it must be 1-D Tensor with shape [1]. When ``loss``
stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * decay_rate`` .
(Specially, ``mode`` can also be set to ``'max`` , in this case, when ``loss`` stop ascending for a ``patience`` number
of epochs, the learning rate will be reduced.)
In addition, After each reduction, it will wait a ``cooldown`` number of epochs before resuming normal operation.
Args:
learning_rate (Variable|float|int): The initial learning rate. It can be set to python float or int number.
If the type is Variable, it should be 1-D Tensor with shape [1], the data type can be 'float32' or 'float64'.
mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the
learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` , the learning
rate will reduce when ``loss`` stops ascending. Default: ``'min'`` .
decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` .
It should be less than 1.0. Default: 0.1.
patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced.
Default: 10.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False``.
threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` .
This make tiny changes of ``loss`` will be ignored. Default: 1e-4.
threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss``
is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum
change of ``loss`` is ``threshold`` . Default: ``'rel'`` .
cooldown (int, optional): The number of epochs to wait before resuming normal operation. Default: 0.
min_lr (float, optional): The lower bound of the learning rate after reduction. Default: 0.
eps (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is
ignored. Default: 1e-8.
dtype (str, optional): The data type used to create the learning rate variable. The data type can be set as
'float32', 'float64'. Default: 'float32'.
Returns:
Reduced learning rate.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
with fluid.dygraph.guard():
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = fluid.dygraph.Linear(10, 10)
input = fluid.dygraph.to_variable(x)
reduce_lr = fluid.dygraph.ReduceLROnPlateau(
learning_rate = 1.0,
decay_rate = 0.5,
patience = 5,
verbose = True,
cooldown = 3)
adam = fluid.optimizer.Adam(
learning_rate = reduce_lr,
parameter_list = linear.parameters())
for epoch in range(10):
total_loss = 0
for bath_id in range(5):
out = linear(input)
loss = fluid.layers.reduce_mean(out)
total_loss += loss
adam.minimize(loss)
avg_loss = total_loss/5
# adjust learning rate according to avg_loss
reduce_lr.step(avg_loss)
lr = adam.current_step_lr()
print("current avg_loss is %s, current lr is %s" % (avg_loss.numpy()[0], lr))
"""
def __init__(self,
learning_rate,
mode='min',
decay_rate=0.1,
patience=10,
verbose=False,
threshold=1e-4,
threshold_mode='rel',
cooldown=0,
min_lr=0,
eps=1e-8,
dtype='float32'):
super(ReduceLROnPlateau, self).__init__(dtype=dtype)
mode = mode.lower()
if mode not in ['min', 'max']:
raise ValueError('mode ' + mode + ' is unknown!')
self.mode = mode
if decay_rate >= 1.0:
raise ValueError(
'new_lr = origin_lr * decay_rate and decay_rate should be < 1.0.'
)
self.decay_rate = decay_rate
threshold_mode = threshold_mode.lower()
if threshold_mode not in ['rel', 'abs']:
raise ValueError('threshold mode ' + threshold_mode +
' is unknown!')
self.threshold_mode = threshold_mode
check_type(learning_rate, 'learning_rate', (float, int, Variable),
'ReduceLROnPlateau')
if isinstance(learning_rate, (float, int)):
learning_rate = self.create_lr_var(learning_rate)
self.learning_rate = learning_rate
self.verbose = verbose
self.patience = patience
self.threshold = threshold
self.threshold_mode = threshold_mode
self.cooldown = cooldown
self.min_lr = self.create_lr_var(min_lr)
self.eps = eps
self.cooldown_counter = 0
self.best_loss = None
self.num_bad_epochs = 0
self.epoch = 0
def __call__(self):
return self.learning_rate
def step(self, loss):
"""
It should be invoked on each epoch. Update the learning rate in optimizer according to ``loss`` .
The new learning rate will take effect on next call to ``optimizer.minimize`` .
Args:
loss (Variable): A ``Variable`` that will be monitored to determine whether the learning rate will reduce.
If it stop descending for a ``patience`` number of epochs, the learning rate will reduce. It should
be 1-D Tensor with shape [1].
Specially, if ``mode`` has been set to ``'max'`` , the learning rate will reduce when it stops ascending.
Returns:
None
Examples:
Please refer to the example of current LearningRateDecay.
"""
# loss must be 1-D Tensor with shape [1]
check_type(loss, 'loss', Variable, 'ReduceLROnPlateau.step')
assert len(loss.shape) == 1 and loss.shape[0] == 1, "the loss.shape " \
"should be (1L,), but the current loss.shape is {}. Maybe that " \
"you should call fluid.layers.mean to process it first.".format(loss.shape)
self.epoch += 1
if self.cooldown_counter > 0:
self.cooldown_counter -= 1
else:
if self.best_loss is None or self._is_better(loss, self.best_loss):
self.best_loss = loss
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1
if self.num_bad_epochs > self.patience:
from .. import layers
self.cooldown_counter = self.cooldown
self.num_bad_epochs = 0
new_lr = layers.elementwise_max(self.learning_rate *
self.decay_rate, self.min_lr)
if self.learning_rate - new_lr > self.eps:
if self.verbose:
print('Epoch {}: reducing learning rate from {} to {}.'.
format(self.epoch,
self.learning_rate.numpy()[0],
new_lr.numpy()[0]))
self.learning_rate = new_lr
def _is_better(self, current, best):
if self.mode == 'min' and self.threshold_mode == 'rel':
return current < best - best * self.threshold
elif self.mode == 'min' and self.threshold_mode == 'abs':
return current < best - self.threshold
elif self.mode == 'max' and self.threshold_mode == 'rel':
return current > best + best * self.threshold
else:
return current > best + self.threshold
......@@ -708,7 +708,7 @@ class Optimizer(object):
params_grads, table_param_and_grad, table_optimize_op = \
self._process_distribute_lookuptable(params_grads)
# 'minimize(grad_clip)' or 'set_gradient_clip'
# 'optimizer(grad_clip)' or 'set_gradient_clip'
if self._grad_clip is not None:
params_grads = self._grad_clip(params_grads)
else:
......@@ -1460,7 +1460,7 @@ class DGCMomentumOptimizer(Optimizer):
else:
dgc_params_grads.append((param, grad))
# 'minimize(grad_clip)' or 'set_gradient_clip'
# 'optimizer(grad_clip)' or 'set_gradient_clip'
if self._grad_clip is not None:
not_dgc_params_grads = self._grad_clip(not_dgc_params_grads)
else:
......
......@@ -199,7 +199,7 @@ class TestLearningRateDecay(unittest.TestCase):
]
for py_decay_fn, fluid_decay_fn, kwargs in decay_fns:
print("class=" + self.__class__.__name__ + "decay_fn=" +
print("class=" + self.__class__.__name__ + " decay_fn=" +
py_decay_fn.__name__ + " kwargs=" + str(kwargs))
main_program = framework.Program()
startup_program = framework.Program()
......@@ -335,5 +335,111 @@ class TestLinearWamrupLearningRateDecayDygraphModeTypeCheck(unittest.TestCase):
end_lr=1.0)
def reduce_lr_on_plateau(decay_rate, threshold, cooldown, patience, m, n, loss,
var_list):
def is_better(current, best, m, n):
if m == 'min' and n == 'rel':
return current < best - best * threshold
elif m == 'min' and n == 'abs':
return current < best - threshold
elif m == 'max' and n == 'rel':
return current > best + best * threshold
else: # mode == 'max' and epsilon_mode == 'abs':
return current > best + threshold
if var_list[2] > 0:
var_list[2] -= 1
return var_list[1]
if is_better(loss, var_list[0], m, n):
var_list[0] = loss
var_list[3] = 0
else:
var_list[3] += 1
if var_list[3] > patience:
var_list[2] = cooldown
var_list[3] = 0
new_lr = var_list[1] * decay_rate
var_list[1] = new_lr if var_list[1] - new_lr > 1e-8 else var_list[1]
return var_list[1]
class TestReduceLROnPlateauDecay(unittest.TestCase):
def test_dygraph_mode(self):
with fluid.dygraph.guard():
# the decay rate must be less than 1.0
with self.assertRaises(ValueError):
fluid.dygraph.ReduceLROnPlateau(
learning_rate=1.0, decay_rate=2.0)
# the mode must be "min" or "max"
with self.assertRaises(ValueError):
fluid.dygraph.ReduceLROnPlateau(learning_rate=1.0, mode="test")
# the threshold_mode must be "rel" or "abs"
with self.assertRaises(ValueError):
fluid.dygraph.ReduceLROnPlateau(
learning_rate=1.0, threshold_mode="test")
base_lr = 1.0
patience = 3
cooldown = 1
decay_rate = 0.5
threshold = 1e-4
linear = fluid.dygraph.Linear(10, 10)
for m, n in zip(['min', 'max', 'min', 'max'],
['rel', 'rel', 'abs', 'abs']):
kwargs = {
'learning_rate': base_lr,
'decay_rate': decay_rate,
'threshold': threshold,
'verbose': True,
'patience': patience,
'cooldown': cooldown,
'mode': m,
'threshold_mode': n,
'eps': 1e-6
}
print("class=" + fluid.dygraph.ReduceLROnPlateau.__name__ +
" kwargs=" + str(kwargs))
lr = fluid.dygraph.ReduceLROnPlateau(**kwargs)
sgd = fluid.optimizer.SGD(learning_rate=lr,
parameter_list=linear.parameters())
best = float("-10000") if m == "max" else float("10000")
expected_lr = 1.0
cooldown_counter = 0
num_bad_epochs = 0
var_list = [best, expected_lr, cooldown_counter, num_bad_epochs]
step_num = 0
epoch_num = 0
for epoch in range(30):
total_loss = 0
for batch_id in range(2):
step_num += 1
x = fluid.dygraph.to_variable(
np.array([step_num]).astype('float32'))
loss = layers.sin(x)
sgd.minimize(loss)
total_loss += loss
epoch_num += 1
# get expected lr from fluid
avg_loss = total_loss / 1
lr.step(avg_loss)
actual_lr = lr().numpy()[0]
# get expected lr form python
expected_lr = reduce_lr_on_plateau(decay_rate, threshold,
cooldown, patience, m, n,
avg_loss, var_list)
self.assertEqual(
expected_lr,
actual_lr,
msg='Failed reduce lr scheduler in epoch {0}, Python result is {1}, Fluid result is {2}'.
format(epoch_num, expected_lr, actual_lr))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册