未验证 提交 c87574b5 编写于 作者: L lilong12 提交者: GitHub

parameterize lr_decay_factor, step_boundaries and log_period (#34)

上级 c24ce4a4
......@@ -136,6 +136,9 @@ class Entry(object):
self.model_save_dir = os.path.abspath(self.model_save_dir)
if self.dataset_dir:
self.dataset_dir = os.path.abspath(self.dataset_dir)
self.lr_decay_factor = 0.1
self.log_period = 200
logger.info('=' * 30)
logger.info("Default configuration:")
......@@ -143,6 +146,8 @@ class Entry(object):
logger.info('\t' + str(key) + ": " + str(self.config[key]))
logger.info('trainer_id: {}, num_trainers: {}'.format(
trainer_id, num_trainers))
logger.info('default lr_decay_factor: {}'.format(self.lr_decay_factor))
logger.info('default log period: {}'.format(self.log_period))
logger.info('=' * 30)
def set_val_targets(self, targets):
......@@ -157,6 +162,20 @@ class Entry(object):
self.global_train_batch_size = batch_size * self.num_trainers
logger.info("Set train batch size to {}.".format(batch_size))
def set_log_period(self, period):
self.log_period = period
logger.info("Set log period to {}.".format(period))
def set_lr_decay_factor(self, factor):
self.lr_decay_factor = factor
logger.info("Set lr decay factor to {}.".format(factor))
def set_step_boundaries(self, boundaries):
if not isinstance(boundaries, list):
raise ValueError("The parameter must be of type list.")
self.lr_steps = boundaries
logger.info("Set step boundaries to {}.".format(boundaries))
def set_mixed_precision(self,
use_fp16,
init_loss_scaling=1.0,
......@@ -332,7 +351,8 @@ class Entry(object):
warmup_steps = steps_per_pass * self.warmup_epochs
batch_denom = 1024
base_lr = start_lr * global_batch_size / batch_denom
lr = [base_lr * (0.1 ** i) for i in range(len(bd) + 1)]
lr_decay_factor = self.lr_decay_factor
lr = [base_lr * (lr_decay_factor ** i) for i in range(len(bd) + 1)]
logger.info("LR boundaries: {}".format(bd))
logger.info("lr_step: {}".format(lr))
if self.warmup_epochs:
......@@ -938,7 +958,7 @@ class Entry(object):
local_time = 0.0
nsamples = 0
inspect_steps = 200
inspect_steps = self.log_period
global_batch_size = self.global_train_batch_size
for pass_id in range(self.train_epochs):
self.train_pass_id = pass_id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册