提交 1e07ba05 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!16 support checkpoint path configuration for resnet

Merge pull request !16 from gengdongjie/master
......@@ -128,7 +128,8 @@ if __name__ == '__main__':
dataset = create_dataset()
batch_num = dataset.get_dataset_size()
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10", directory="./", config=config_ck)
checkpoint_path = args_opt.checkpoint_path if args_opt.checkpoint_path is not None else "./"
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10", directory=checkpoint_path, config=config_ck)
loss_cb = LossMonitor()
model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册