未验证 提交 23c32aa8 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

add args check for learning rate scheduler API (#34394)

上级 81fe3ac9
......@@ -570,7 +570,7 @@ class PolynomialDecay(LRScheduler):
Args:
learning_rate (float): The initial learning rate. It is a python float number.
decay_steps(int): The decay step size. It determines the decay cycle.
decay_steps(int): The decay step size. It determines the decay cycle. It must be a positive integer.
end_lr(float, optional): The minimum final learning rate. Default: 0.0001.
power(float, optional): Power of polynomial. Default: 1.0.
cycle(bool, optional): Whether the learning rate rises again. If True, then the learning rate will rise when it decrease
......@@ -639,6 +639,8 @@ class PolynomialDecay(LRScheduler):
cycle=False,
last_epoch=-1,
verbose=False):
assert decay_steps > 0 and isinstance(
decay_steps, int), " 'decay_steps' must be a positive integer."
self.decay_steps = decay_steps
self.end_lr = end_lr
self.power = power
......@@ -688,7 +690,7 @@ class LinearWarmup(LRScheduler):
Args:
learning_rate (float|LRScheduler): The learning rate after warm-up. It is a python float number or any subclass of ``LRScheduler`` .
warmup_steps (int): total steps of warm up.
warmup_steps (int): total steps of warm up. It must be a positive integer.
start_lr (float): Initial learning rate of warm up.
end_lr (float): Final learning rate of warm up.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
......@@ -763,6 +765,8 @@ class LinearWarmup(LRScheduler):
"the type of learning_rate should be [int, float or LRScheduler], the current type is {}".
format(learning_rate))
self.learning_rate = learning_rate
assert warmup_steps > 0 and isinstance(
warmup_steps, int), " 'warmup_steps' must be a positive integer."
self.warmup_steps = warmup_steps
self.start_lr = start_lr
self.end_lr = end_lr
......@@ -1010,7 +1014,7 @@ class StepDecay(LRScheduler):
Args:
learning_rate (float): The initial learning rate. It is a python float number.
step_size (int): the interval to update.
step_size (int): the interval to update. It must be a positive integer.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
......@@ -1083,6 +1087,8 @@ class StepDecay(LRScheduler):
if gamma >= 1.0:
raise ValueError('gamma should be < 1.0.')
assert step_size > 0 and isinstance(
step_size, int), " 'step_size' must be a positive integer."
self.step_size = step_size
self.gamma = gamma
super(StepDecay, self).__init__(learning_rate, last_epoch, verbose)
......@@ -1415,7 +1421,7 @@ class CosineAnnealingDecay(LRScheduler):
Args:
learning_rate (float): The initial learning rate, that is :math:`\eta_{max}` . It can be set to python float or int number.
T_max (int): Maximum number of iterations. It is half of the decay cycle of learning rate.
T_max (int): Maximum number of iterations. It is half of the decay cycle of learning rate. It must be a positive integer.
eta_min (float|int, optional): Minimum learning rate, that is :math:`\eta_{min}` . Default: 0.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
......@@ -1487,6 +1493,8 @@ class CosineAnnealingDecay(LRScheduler):
raise TypeError(
"The type of 'eta_min' in 'CosineAnnealingDecay' must be 'float, int', but received %s."
% type(eta_min))
assert T_max > 0 and isinstance(
T_max, int), " 'T_max' must be a positive integer."
self.T_max = T_max
self.eta_min = float(eta_min)
super(CosineAnnealingDecay, self).__init__(learning_rate, last_epoch,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册