提交 da41c9d6 编写于 作者: P Payne

modify wrong args

上级 b7425d3e
......@@ -26,8 +26,9 @@ def set_config(args):
"batch_size": 150,
"epoch_size": 15,
"warmup_epochs": 0,
"lr_max": 0.03,
"lr_init": .0,
"lr_end": 0.03,
"lr_max": 0.03,
"momentum": 0.9,
"weight_decay": 4e-5,
"label_smooth": 0.1,
......@@ -45,9 +46,9 @@ def set_config(args):
"batch_size": 150,
"epoch_size": 200,
"warmup_epochs": 0,
"lr": 0.8,
"lr_max": 0.03,
"lr_end": 0.03,
"lr_init": .0,
"lr_end": .0,
"lr_max": 0.8,
"momentum": 0.9,
"weight_decay": 4e-5,
"label_smooth": 0.1,
......@@ -66,9 +67,9 @@ def set_config(args):
"batch_size": 256,
"epoch_size": 200,
"warmup_epochs": 4,
"lr": 0.4,
"lr_max": 0.03,
"lr_end": 0.03,
"lr_init": 0.00,
"lr_end": 0.00,
"lr_max": 0.4,
"momentum": 0.9,
"weight_decay": 4e-5,
"label_smooth": 0.1,
......
......@@ -17,7 +17,6 @@ from mindspore import context
from mindspore import nn
from mindspore.common import dtype as mstype
from mindspore.train.model import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.communication.management import get_rank, init, get_group_size
......@@ -63,7 +62,7 @@ def set_context(config):
device_id=config.device_id, save_graphs=False)
elif config.platform == "GPU":
context.set_context(mode=context.GRAPH_MODE,
device_target=args_opt.platform, save_graphs=False)
device_target=config.platform, save_graphs=False)
def config_ckpoint(config, lr, step_size):
cb = None
......
......@@ -77,7 +77,7 @@ if __name__ == '__main__':
# get learning rate
lr = Tensor(get_lr(global_step=0,
lr_init=0,
lr_init=config.lr_init,
lr_end=config.lr_end,
lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册