提交 da41c9d6 编写于 作者: P Payne

modify wrong args

上级 b7425d3e
...@@ -26,8 +26,9 @@ def set_config(args): ...@@ -26,8 +26,9 @@ def set_config(args):
"batch_size": 150, "batch_size": 150,
"epoch_size": 15, "epoch_size": 15,
"warmup_epochs": 0, "warmup_epochs": 0,
"lr_max": 0.03, "lr_init": .0,
"lr_end": 0.03, "lr_end": 0.03,
"lr_max": 0.03,
"momentum": 0.9, "momentum": 0.9,
"weight_decay": 4e-5, "weight_decay": 4e-5,
"label_smooth": 0.1, "label_smooth": 0.1,
...@@ -45,9 +46,9 @@ def set_config(args): ...@@ -45,9 +46,9 @@ def set_config(args):
"batch_size": 150, "batch_size": 150,
"epoch_size": 200, "epoch_size": 200,
"warmup_epochs": 0, "warmup_epochs": 0,
"lr": 0.8, "lr_init": .0,
"lr_max": 0.03, "lr_end": .0,
"lr_end": 0.03, "lr_max": 0.8,
"momentum": 0.9, "momentum": 0.9,
"weight_decay": 4e-5, "weight_decay": 4e-5,
"label_smooth": 0.1, "label_smooth": 0.1,
...@@ -66,9 +67,9 @@ def set_config(args): ...@@ -66,9 +67,9 @@ def set_config(args):
"batch_size": 256, "batch_size": 256,
"epoch_size": 200, "epoch_size": 200,
"warmup_epochs": 4, "warmup_epochs": 4,
"lr": 0.4, "lr_init": 0.00,
"lr_max": 0.03, "lr_end": 0.00,
"lr_end": 0.03, "lr_max": 0.4,
"momentum": 0.9, "momentum": 0.9,
"weight_decay": 4e-5, "weight_decay": 4e-5,
"label_smooth": 0.1, "label_smooth": 0.1,
......
...@@ -17,7 +17,6 @@ from mindspore import context ...@@ -17,7 +17,6 @@ from mindspore import context
from mindspore import nn from mindspore import nn
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.train.model import ParallelMode from mindspore.train.model import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.communication.management import get_rank, init, get_group_size from mindspore.communication.management import get_rank, init, get_group_size
...@@ -63,7 +62,7 @@ def set_context(config): ...@@ -63,7 +62,7 @@ def set_context(config):
device_id=config.device_id, save_graphs=False) device_id=config.device_id, save_graphs=False)
elif config.platform == "GPU": elif config.platform == "GPU":
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target=args_opt.platform, save_graphs=False) device_target=config.platform, save_graphs=False)
def config_ckpoint(config, lr, step_size): def config_ckpoint(config, lr, step_size):
cb = None cb = None
......
...@@ -77,7 +77,7 @@ if __name__ == '__main__': ...@@ -77,7 +77,7 @@ if __name__ == '__main__':
# get learning rate # get learning rate
lr = Tensor(get_lr(global_step=0, lr = Tensor(get_lr(global_step=0,
lr_init=0, lr_init=config.lr_init,
lr_end=config.lr_end, lr_end=config.lr_end,
lr_max=config.lr_max, lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs, warmup_epochs=config.warmup_epochs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册