提交 16d64255 编写于 作者: D dyonghan

bugfix for ModelCheckpoint

上级 90d94759
......@@ -8,7 +8,21 @@
"\n",
"## 实验介绍\n",
"\n",
"本实验主要介绍使用MindSpore实现训练时模型的保存和加载。训练过程中保存模型以及训练中断后基于断点继续训练是一项非常常用的功能。建议先阅读MindSpore官网教程中关于模型参数保存和加载的内容。\n",
"本实验主要介绍使用MindSpore实现训练时模型的保存和加载。建议先阅读MindSpore官网教程中关于模型参数保存和加载的内容。\n",
"\n",
"在模型训练过程中,可以添加检查点(CheckPoint)用于保存模型的参数,以便进行推理及中断后再训练使用。使用场景如下:\n",
"\n",
"- 训练后推理场景\n",
"\n",
" - 模型训练完毕后保存模型的参数,用于推理或预测操作。\n",
"\n",
" - 训练过程中,通过实时验证精度,把精度最高的模型参数保存下来,用于预测操作。\n",
"\n",
"- 再训练场景\n",
"\n",
" - 进行长时间训练任务时,保存训练过程中的CheckPoint文件,防止任务异常退出后从初始状态开始训练。\n",
"\n",
" - Fine-tuning(微调)场景,即训练一个模型并保存参数,基于该模型,面向第二个类似任务进行模型训练。\n",
"\n",
"## 实验目的\n",
"\n",
......@@ -76,7 +90,8 @@
"│   └── train\n",
"│   ├── train-images-idx3-ubyte\n",
"│   └── train-labels-idx1-ubyte\n",
"└── 脚本等文件\n",
"├── *.ipynb\n",
"└── main.py\n",
"```\n",
"\n",
"## 实验步骤(方案一)\n",
......@@ -100,9 +115,9 @@
"> - 打开Notebook前,在Jupyter Notebook文件列表页面,勾选目录里的所有文件/文件夹(实验脚本和数据集),并点击列表上方的“Sync OBS”按钮,使OBS桶中的所有文件同时同步到Notebook工作环境中,这样Notebook中的代码才能访问数据集。参考[使用Sync OBS功能](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0038.html)。\n",
"> - 打开Notebook后,选择MindSpore环境作为Kernel。\n",
"\n",
"> **提示:**上述数据集和脚本的准备工作也可以在Notebook环境中完成,在Jupyter Notebook文件列表页面,点击右上角的\"New\"->\"Terminal\",进入Notebook环境所在终端,进入`work`目录,可以使用常用的linux shell命令,如`wget, gzip, tar, mkdir, mv`等,完成数据集和脚本的下载和准备。\n",
"> **提示:** 上述数据集和脚本的准备工作也可以在Notebook环境中完成,在Jupyter Notebook文件列表页面,点击右上角的\"New\"->\"Terminal\",进入Notebook环境所在终端,进入`work`目录,可以使用常用的linux shell命令,如`wget, gzip, tar, mkdir, mv`等,完成数据集和脚本的下载和准备。\n",
"\n",
"> **提示:**请从上至下阅读提示并执行代码框进行体验。代码框执行过程中左侧呈现[\\*],代码框执行完毕后左侧呈现如[1],[2]等。请等上一个代码框执行完毕后再执行下一个代码框。\n",
"> **提示:** 请从上至下阅读提示并执行代码框进行体验。代码框执行过程中左侧呈现[\\*],代码框执行完毕后左侧呈现如[1],[2]等。请等上一个代码框执行完毕后再执行下一个代码框。\n",
"\n",
"导入MindSpore模块和辅助模块:"
]
......@@ -179,11 +194,11 @@
"source": [
"### 定义模型\n",
"\n",
"定义LeNet5模型,模型结构如下图所示\n",
"定义LeNet5模型,模型结构如下图所示\n",
"\n",
"<img src=\"http://deeplearning.net/tutorial/_images/mylenet.png\">\n",
"<img src=\"https://www.mindspore.cn/tutorial/zh-CN/master/_images/LeNet_5.jpg\">\n",
"\n",
"[1] 图片来源于http://deeplearning.net"
"[1] 图片来源于http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf"
]
},
{
......@@ -192,9 +207,9 @@
"metadata": {},
"outputs": [],
"source": [
"class LeNet(nn.Cell):\n",
"class LeNet5(nn.Cell):\n",
" def __init__(self):\n",
" super(LeNet, self).__init__()\n",
" super(LeNet5, self).__init__()\n",
" self.relu = nn.ReLU()\n",
" self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid')\n",
" self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid')\n",
......@@ -225,43 +240,51 @@
"source": [
"### 保存模型Checkpoint\n",
"\n",
"使用MNIST数据集对上述定义的LeNet5模型进行单机单卡训练,包含:\n",
"\n",
"- 在MNIST数据集上训练模型。\n",
"- 通过`ModelCheckpoint`保存Checkpoint。\n",
"- 通过`LossMonitor`输出训练过程中的Loss。\n",
"\n",
"Callback是模型训练/测试过程中的一种调试工具,可用在训练/测试过程中执行特定的任务。MindSpore框架提供的Callback:\n",
"MindSpore提供了Callback功能,可用于训练/测试过程中执行特定的任务。常用的Callback如下:\n",
"\n",
"- `ModelCheckpoint`:保存网络模型和参数,默认会保存最后一次训练的参数。\n",
"- `SummaryStep`:对Tensor值进行监控。此功能会在MindData平台训练脚本中使用。\n",
"- `LossMonitor`:监控loss值,当loss值为Nan或Inf时停止训练。此功能会在MindData平台训练脚本中使用。\n",
"- `ModelCheckpoint`:保存网络模型和参数,用于再训练或推理;\n",
"- `LossMonitor`:监控loss值,当loss值为Nan或Inf时停止训练;\n",
"- `SummaryStep`:把训练过程中的信息存储到文件中,用于后续查看或可视化展示。\n",
"\n",
"`ModelCheckpoint`用于保存模型和参数,如每个epoch结束时,都保存一次checkpoint。\n",
"`ModelCheckpoint`会生成模型(.meta)和Chekpoint(.ckpt)文件,如每个epoch结束时,都保存一次checkpoint。\n",
"\n",
"1. 首先需要初始化一个`CheckpointConfig`类对象,用以声明保存策略。调用方法如:\n",
" \n",
" ```py\n",
" CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=5)\n",
" ```\n",
" \n",
" 参数说明:\n",
" \n",
" - `save_checkpoint_steps`:每多少step保存一个checkpoint文件,单位为step;\n",
" - `keep_checkpoint_max`:最多保留checkpoint文件的数量(按最新的文件)。\n",
"\n",
"2. 创建`ModelCheckpoint`对象。调用方法如:\n",
" \n",
" ```py\n",
" ModelCheckpoint(prefix=DEFAULT_CHECKPOINT_PREFIX_NAME, config=None)\n",
" ```\n",
" \n",
" 参数说明:\n",
" \n",
" - `prefix`:保存的文件前缀名,如'ck_lenet'。\n",
" - `config`:配置策略信息,传入上文创建的CheckpointConfig对象。\n",
"```python\n",
"class CheckpointConfig:\n",
" \"\"\"\n",
" The config for model checkpoint.\n",
"\n",
" Args:\n",
" save_checkpoint_steps (int): Steps to save checkpoint. Default: 1.\n",
" save_checkpoint_seconds (int): Seconds to save checkpoint. Default: 0.\n",
" Can't be used with save_checkpoint_steps at the same time.\n",
" keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5.\n",
" keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0.\n",
" Can't be used with keep_checkpoint_max at the same time.\n",
" integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.\n",
" Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.\n",
"\n",
" Raises:\n",
" ValueError: If the input_param is None or 0.\n",
" \"\"\"\n",
"\n",
"class ModelCheckpoint(Callback):\n",
" \"\"\"\n",
" The checkpoint callback class.\n",
"\n",
" It is called to combine with train process and save the model and network parameters after traning.\n",
"\n",
" Args:\n",
" prefix (str): Checkpoint files names prefix. Default: \"CKP\".\n",
" directory (str): Lolder path into which checkpoint files will be saved. Default: None.\n",
" config (CheckpointConfig): Checkpoint strategy config. Default: None.\n",
"\n",
" Raises:\n",
" ValueError: If the prefix is invalid.\n",
" TypeError: If the config is not CheckpointConfig type.\n",
" \"\"\"\n",
"```\n",
"\n",
"> `ModelCheckpoint`会生成和保存模型(.meta)和Chekpoint(.ckpt)文件。"
"MindSpore提供了多种Metric评估指标,如`accuracy`、`loss`、`precision`、`recall`、`F1`。定义一个metrics字典/元组,里面包含多种指标,传递给`Model`,然后调用`model.eval`接口来计算这些指标。`model.eval`会返回一个字典,包含各个指标及其对应的值。"
]
},
{
......@@ -283,20 +306,19 @@
],
"source": [
"os.system('rm -f *.ckpt *.ir *.meta') # 清理旧的运行文件\n",
"LOOP_SINK = context.get_context('enable_loop_sink')\n",
"\n",
"def test_train(lr=0.01, momentum=0.9, num_epoch=2, check_point_name=\"b_lenet\"):\n",
" ds_train = create_dataset(num_epoch=num_epoch)\n",
" ds_eval = create_dataset(training=False)\n",
" steps_per_epoch = ds_train.get_dataset_size()\n",
" \n",
" net = LeNet()\n",
" net = LeNet5()\n",
" loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')\n",
" opt = nn.Momentum(net.trainable_params(), lr, momentum)\n",
" \n",
" ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5)\n",
" ckpt_cb = ModelCheckpoint(prefix=check_point_name, config=ckpt_cfg)\n",
" loss_cb = LossMonitor(per_print_times=1 if LOOP_SINK else steps_per_epoch)\n",
" loss_cb = LossMonitor(steps_per_epoch)\n",
" \n",
" model = Model(net, loss, opt, metrics={'acc', 'loss'})\n",
" model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=True)\n",
......@@ -313,29 +335,34 @@
"source": [
"### 加载Checkpoint继续训练\n",
"\n",
"模型训练过程偶尔会中断,可以通过加载Checkpoint文件继续训练。\n",
"```python\n",
"def load_checkpoint(ckpoint_file_name, net=None):\n",
" \"\"\"\n",
" Loads checkpoint info from a specified file.\n",
"\n",
"1. 读取Checkpoint文件,调用方法如:\n",
" \n",
" ```py\n",
" load_checkpoint(ckpoint_file_name)\n",
" ```\n",
" \n",
" 参数说明:\n",
" \n",
" - `ckpoint_file_name`:checkpoint文件名,如'ck_lenet-7_1875.ckpt'。\n",
" - 返回值:返回一个字典。key为参数name,value为parameter类型的实例。\n",
" Args:\n",
" ckpoint_file_name (str): Checkpoint file name.\n",
" net (Cell): Cell network. Default: None\n",
"\n",
"2. 加载参数后继续训练,调用方法如:\n",
" \n",
" ```py\n",
" load_param_into_net(net, param_dict)\n",
" ```\n",
" \n",
" 参数说明:\n",
" \n",
" - `net`:初始不带优化器和损失函数的网络,如:`Resnet()`。\n",
" - `param_dict`:加载checkpoint文件后生成的字典。\n",
" Returns:\n",
" Dict, key is parameter name, value is a Parameter.\n",
"\n",
" Raises:\n",
" ValueError: Checkpoint file is incorrect.\n",
" \"\"\"\n",
"\n",
"def load_param_into_net(net, parameter_dict):\n",
" \"\"\"\n",
" Loads parameters into network.\n",
"\n",
" Args:\n",
" net (Cell): Cell network.\n",
" parameter_dict (dict): Parameter dict.\n",
"\n",
" Raises:\n",
" TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dict.\n",
" \"\"\"\n",
"```\n",
"\n",
"> 使用load_checkpoint接口加载数据时,需要把数据传入给原始网络,而不能传递给带有优化器和损失函数的训练网络。"
]
......@@ -365,7 +392,7 @@
" ds_eval = create_dataset(training=False)\n",
" steps_per_epoch = ds_train.get_dataset_size()\n",
" \n",
" net = LeNet()\n",
" net = LeNet5()\n",
" loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')\n",
" opt = nn.Momentum(net.trainable_params(), lr, momentum)\n",
" \n",
......@@ -375,7 +402,7 @@
" \n",
" ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5)\n",
" ckpt_cb = ModelCheckpoint(prefix=ckpt_name, config=ckpt_cfg)\n",
" loss_cb = LossMonitor(per_print_times=1 if LOOP_SINK else steps_per_epoch)\n",
" loss_cb = LossMonitor(steps_per_epoch)\n",
" \n",
" model = Model(net, loss, opt, metrics={'acc', 'loss'})\n",
" model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb])\n",
......@@ -391,18 +418,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### 推理\n",
"\n",
"加载Checkpoint,并执行验证。读取模型和Checkpoint文件,调用方法如:\n",
" \n",
" ```py\n",
" load(model_file_name, ckpoint_file_name)\n",
" ```\n",
" \n",
" 参数说明:\n",
" \n",
" - `model_file_name`:模型文件名,如'ck_lenet-model.pkl'。\n",
" - `ckpoint_file_name`:checkpoint文件名,如'ck_lenet-7_1875.ckpt'。\n",
"### 加载Checkpoint进行推理\n",
" \n",
"使用matplotlib定义一个将推理结果可视化的辅助函数,如下:"
]
......@@ -428,7 +444,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"使用训练后的LeNet模型对手写数字进行识别,可以看到识别结果基本上是正确的。"
"使用训练后的LeNet5模型对手写数字进行识别,可以看到识别结果基本上是正确的。"
]
},
{
......@@ -461,9 +477,8 @@
" return pred[0], images[0], labels[0]\n",
"\n",
"ds = create_dataset(training=False, batch_size=1).create_dict_iterator()\n",
"net = LeNet()\n",
"param_dict = load_checkpoint(CKPT)\n",
"load_param_into_net(net, param_dict)\n",
"net = LeNet5()\n",
"param_dict = load_checkpoint(CKPT, net)\n",
"model = Model(net)\n",
"plot_images(infer, ds, model)"
]
......@@ -535,6 +550,7 @@
"## 实验小结\n",
"\n",
"本实验展示了MindSpore的Checkpoint、断点继续训练等高级特性:\n",
"\n",
"1. 使用MindSpore的ModelCheckpoint接口每个epoch保存一次Checkpoint,训练2个epoch并终止。\n",
"2. 使用MindSpore的load_checkpoint和load_param_into_net接口加载上一步保存的Checkpoint继续训练2个epoch。\n",
"3. 观察训练过程中Loss的变化情况,加载Checkpoint继续训练后loss进一步下降。"
......@@ -562,4 +578,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
\ No newline at end of file
......@@ -47,9 +47,9 @@ def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32),
return ds
class LeNet(nn.Cell):
class LeNet5(nn.Cell):
def __init__(self):
super(LeNet, self).__init__()
super(LeNet5, self).__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid')
......@@ -73,21 +73,19 @@ class LeNet(nn.Cell):
return output
LOOP_SINK = context.get_context('enable_loop_sink')
def test_train(lr=0.01, momentum=0.9, num_epoch=2, check_point_name="b_lenet"):
ds_train = create_dataset(num_epoch=num_epoch)
ds_eval = create_dataset(training=False)
steps_per_epoch = ds_train.get_dataset_size()
net = LeNet()
net = LeNet5()
loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
opt = nn.Momentum(net.trainable_params(), lr, momentum)
ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5)
ckpt_cb = ModelCheckpoint(prefix=check_point_name, config=ckpt_cfg)
loss_cb = LossMonitor(per_print_times=1 if LOOP_SINK else steps_per_epoch)
loss_cb = LossMonitor(steps_per_epoch)
model = Model(net, loss, opt, metrics={'acc', 'loss'})
model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=True)
......@@ -102,7 +100,7 @@ def resume_train(lr=0.001, momentum=0.9, num_epoch=2, ckpt_name="b_lenet"):
ds_eval = create_dataset(training=False)
steps_per_epoch = ds_train.get_dataset_size()
net = LeNet()
net = LeNet5()
loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
opt = nn.Momentum(net.trainable_params(), lr, momentum)
......@@ -112,7 +110,7 @@ def resume_train(lr=0.001, momentum=0.9, num_epoch=2, ckpt_name="b_lenet"):
ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5)
ckpt_cb = ModelCheckpoint(prefix=ckpt_name, config=ckpt_cfg)
loss_cb = LossMonitor(per_print_times=1 if LOOP_SINK else steps_per_epoch)
loss_cb = LossMonitor(steps_per_epoch)
model = Model(net, loss, opt, metrics={'acc', 'loss'})
model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册