2-Save_And_Load_Model.ipynb 33.8 KB
Notebook
Newer Older
D
dyonghan 已提交
1 2 3 4 5 6 7 8 9 10
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1 style=\"text-align:center\">训练时模型的保存和加载</h1>\n",
    "\n",
    "## 实验介绍\n",
    "\n",
D
dyonghan 已提交
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
    "本实验主要介绍使用MindSpore实现训练时模型的保存和加载。建议先阅读MindSpore官网教程中关于模型参数保存和加载的内容。\n",
    "\n",
    "在模型训练过程中,可以添加检查点(CheckPoint)用于保存模型的参数,以便进行推理及中断后再训练使用。使用场景如下:\n",
    "\n",
    "- 训练后推理场景\n",
    "\n",
    "    - 模型训练完毕后保存模型的参数,用于推理或预测操作。\n",
    "\n",
    "    - 训练过程中,通过实时验证精度,把精度最高的模型参数保存下来,用于预测操作。\n",
    "\n",
    "- 再训练场景\n",
    "\n",
    "    - 进行长时间训练任务时,保存训练过程中的CheckPoint文件,防止任务异常退出后从初始状态开始训练。\n",
    "\n",
    "    - Fine-tuning(微调)场景,即训练一个模型并保存参数,基于该模型,面向第二个类似任务进行模型训练。\n",
D
dyonghan 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
    "\n",
    "## 实验目的\n",
    "\n",
    "- 了解如何使用MindSpore实现训练时模型的保存。\n",
    "- 了解如何使用MindSpore加载保存的模型文件并继续训练。\n",
    "- 了解如何MindSpore的Callback功能。\n",
    "\n",
    "## 预备知识\n",
    "\n",
    "- 熟练使用Python,了解Shell及Linux操作系统基本知识。\n",
    "- 具备一定的深度学习理论知识,如卷积神经网络、损失函数、优化器,训练策略、Checkpoint等。\n",
    "- 了解华为云的基本使用方法,包括[OBS(对象存储)](https://www.huaweicloud.com/product/obs.html)、[ModelArts(AI开发平台)](https://www.huaweicloud.com/product/modelarts.html)、[Notebook(开发工具)](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0033.html)、[训练作业](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0046.html)等功能。华为云官网:https://www.huaweicloud.com\n",
    "- 了解并熟悉MindSpore AI计算框架,MindSpore官网:https://www.mindspore.cn/\n",
    "\n",
    "## 实验环境\n",
    "\n",
    "- MindSpore 0.2.0(MindSpore版本会定期更新,本指导也会定期刷新,与版本配套);\n",
    "- 华为云ModelArts:ModelArts是华为云提供的面向开发者的一站式AI开发平台,集成了昇腾AI处理器资源池,用户可以在该平台下体验MindSpore。ModelArts官网:https://www.huaweicloud.com/product/modelarts.html\n",
    "\n",
    "## 实验准备\n",
    "\n",
    "### 创建OBS桶\n",
    "\n",
    "本实验需要使用华为云OBS存储实验脚本和数据集,可以参考[快速通过OBS控制台上传下载文件](https://support.huaweicloud.com/qs-obs/obs_qs_0001.html)了解使用OBS创建桶、上传文件、下载文件的使用方法。\n",
    "\n",
N
njzheng 已提交
51
    "> **提示:** 华为云新用户使用OBS时通常需要创建和配置“访问密钥”,可以在使用OBS时根据提示完成创建和配置。也可以参考[获取访问密钥并完成ModelArts全局配置](https://support.huaweicloud.com/prepare-modelarts/modelarts_08_0002.html)获取并配置访问密钥。\n",
D
dyonghan 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
    "\n",
    "创建OBS桶的参考配置如下:\n",
    "\n",
    "- 区域:华北-北京四\n",
    "- 数据冗余存储策略:单AZ存储\n",
    "- 桶名称:如ms-course\n",
    "- 存储类别:标准存储\n",
    "- 桶策略:公共读\n",
    "- 归档数据直读:关闭\n",
    "- 企业项目、标签等配置:免\n",
    "\n",
    "### 数据集准备\n",
    "\n",
    "MNIST是一个手写数字数据集,训练集包含60000张手写数字,测试集包含10000张手写数字,共10类。MNIST数据集的官网:[THE MNIST DATABASE](http://yann.lecun.com/exdb/mnist/)。\n",
    "\n",
    "从MNIST官网下载如下4个文件到本地并解压:\n",
    "\n",
    "```\n",
    "train-images-idx3-ubyte.gz:  training set images (9912422 bytes)\n",
    "train-labels-idx1-ubyte.gz:  training set labels (28881 bytes)\n",
    "t10k-images-idx3-ubyte.gz:   test set images (1648877 bytes)\n",
    "t10k-labels-idx1-ubyte.gz:   test set labels (4542 bytes)\n",
    "```\n",
    "\n",
    "### 脚本准备\n",
    "\n",
    "从[课程gitee仓库](https://gitee.com/mindspore/course)上下载本实验相关脚本。\n",
    "\n",
    "### 上传文件\n",
    "\n",
    "将脚本和数据集上传到OBS桶中,组织为如下形式:\n",
    "\n",
    "```\n",
N
njzheng 已提交
85
    "experiment_2\n",
D
dyonghan 已提交
86 87 88 89 90 91 92
    "├── MNIST\n",
    "│   ├── test\n",
    "│   │   ├── t10k-images-idx3-ubyte\n",
    "│   │   └── t10k-labels-idx1-ubyte\n",
    "│   └── train\n",
    "│       ├── train-images-idx3-ubyte\n",
    "│       └── train-labels-idx1-ubyte\n",
D
dyonghan 已提交
93 94
    "├── *.ipynb\n",
    "└── main.py\n",
D
dyonghan 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
    "```\n",
    "\n",
    "## 实验步骤(方案一)\n",
    "\n",
    "### 创建Notebook\n",
    "\n",
    "可以参考[创建并打开Notebook](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0034.html)来创建并打开本实验的Notebook脚本。\n",
    "\n",
    "创建Notebook的参考配置:\n",
    "\n",
    "- 计费模式:按需计费\n",
    "- 名称:experiment_2\n",
    "- 工作环境:Python3\n",
    "- 资源池:公共资源\n",
    "- 类型:Ascend\n",
    "- 规格:单卡1*Ascend 910\n",
    "- 存储位置:对象存储服务(OBS)->选择上述新建的OBS桶中的experiment_2文件夹\n",
    "- 自动停止等配置:默认\n",
    "\n",
    "> **注意:**\n",
    "> - 打开Notebook前,在Jupyter Notebook文件列表页面,勾选目录里的所有文件/文件夹(实验脚本和数据集),并点击列表上方的“Sync OBS”按钮,使OBS桶中的所有文件同时同步到Notebook工作环境中,这样Notebook中的代码才能访问数据集。参考[使用Sync OBS功能](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0038.html)。\n",
    "> - 打开Notebook后,选择MindSpore环境作为Kernel。\n",
    "\n",
D
dyonghan 已提交
118
    "> **提示:** 上述数据集和脚本的准备工作也可以在Notebook环境中完成,在Jupyter Notebook文件列表页面,点击右上角的\"New\"->\"Terminal\",进入Notebook环境所在终端,进入`work`目录,可以使用常用的linux shell命令,如`wget, gzip, tar, mkdir, mv`等,完成数据集和脚本的下载和准备。\n",
D
dyonghan 已提交
119
    "\n",
D
dyonghan 已提交
120
    "> **提示:** 请从上至下阅读提示并执行代码框进行体验。代码框执行过程中左侧呈现[\\*],代码框执行完毕后左侧呈现如[1],[2]等。请等上一个代码框执行完毕后再执行下一个代码框。\n",
D
dyonghan 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
    "\n",
    "导入MindSpore模块和辅助模块:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "# os.environ['DEVICE_ID'] = '0'\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "import mindspore as ms\n",
    "import mindspore.context as context\n",
    "import mindspore.dataset.transforms.c_transforms as C\n",
    "import mindspore.dataset.transforms.vision.c_transforms as CV\n",
    "\n",
    "from mindspore.dataset.transforms.vision import Inter\n",
    "from mindspore import nn, Tensor\n",
    "from mindspore.train import Model\n",
    "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor\n",
    "from mindspore.train.serialization import load_checkpoint, load_param_into_net\n",
    "\n",
    "import logging; logging.getLogger('matplotlib.font_manager').disabled = True\n",
    "\n",
    "context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据处理\n",
    "\n",
    "在使用数据集训练网络前,首先需要对数据进行预处理,如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_DIR_TRAIN = \"MNIST/train\" # 训练集信息\n",
    "DATA_DIR_TEST = \"MNIST/test\" # 测试集信息\n",
    "\n",
    "def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32),\n",
    "                   rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64):\n",
    "    ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST)\n",
    "    \n",
    "    # define map operations\n",
    "    resize_op = CV.Resize(resize)\n",
    "    rescale_op = CV.Rescale(rescale, shift)\n",
    "    hwc2chw_op = CV.HWC2CHW()\n",
    "    \n",
    "    # apply map operations on images\n",
    "    ds = ds.map(input_columns=\"image\", operations=[resize_op, rescale_op, hwc2chw_op])\n",
    "    ds = ds.map(input_columns=\"label\", operations=C.TypeCast(ms.int32))\n",
    "    \n",
    "    ds = ds.shuffle(buffer_size=buffer_size)\n",
    "    ds = ds.batch(batch_size, drop_remainder=True)\n",
    "    ds = ds.repeat(num_epoch)\n",
    "    \n",
    "    return ds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 定义模型\n",
    "\n",
D
dyonghan 已提交
197
    "定义LeNet5模型,模型结构如下图所示:\n",
D
dyonghan 已提交
198
    "\n",
D
dyonghan 已提交
199
    "<img src=\"https://www.mindspore.cn/tutorial/zh-CN/master/_images/LeNet_5.jpg\">\n",
N
njzheng 已提交
200
    "\n",
D
dyonghan 已提交
201
    "[1] 图片来源于http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf"
D
dyonghan 已提交
202 203 204 205 206 207 208 209
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
D
dyonghan 已提交
210
    "class LeNet5(nn.Cell):\n",
D
dyonghan 已提交
211
    "    def __init__(self):\n",
D
dyonghan 已提交
212
    "        super(LeNet5, self).__init__()\n",
D
dyonghan 已提交
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
    "        self.relu = nn.ReLU()\n",
    "        self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid')\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid')\n",
    "        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.fc1 = nn.Dense(400, 120)\n",
    "        self.fc2 = nn.Dense(120, 84)\n",
    "        self.fc3 = nn.Dense(84, 10)\n",
    "    \n",
    "    def construct(self, input_x):\n",
    "        output = self.conv1(input_x)\n",
    "        output = self.relu(output)\n",
    "        output = self.pool(output)\n",
    "        output = self.conv2(output)\n",
    "        output = self.relu(output)\n",
    "        output = self.pool(output)\n",
    "        output = self.flatten(output)\n",
    "        output = self.fc1(output)\n",
    "        output = self.fc2(output)\n",
    "        output = self.fc3(output)\n",
    "        \n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 保存模型Checkpoint\n",
    "\n",
D
dyonghan 已提交
243
    "MindSpore提供了Callback功能,可用于训练/测试过程中执行特定的任务。常用的Callback如下:\n",
D
dyonghan 已提交
244
    "\n",
D
dyonghan 已提交
245 246 247
    "- `ModelCheckpoint`:保存网络模型和参数,用于再训练或推理;\n",
    "- `LossMonitor`:监控loss值,当loss值为Nan或Inf时停止训练;\n",
    "- `SummaryStep`:把训练过程中的信息存储到文件中,用于后续查看或可视化展示。\n",
D
dyonghan 已提交
248
    "\n",
D
dyonghan 已提交
249
    "`ModelCheckpoint`会生成模型(.meta)和Chekpoint(.ckpt)文件,如每个epoch结束时,都保存一次checkpoint。\n",
D
dyonghan 已提交
250
    "\n",
D
dyonghan 已提交
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
    "```python\n",
    "class CheckpointConfig:\n",
    "    \"\"\"\n",
    "    The config for model checkpoint.\n",
    "\n",
    "    Args:\n",
    "        save_checkpoint_steps (int): Steps to save checkpoint. Default: 1.\n",
    "        save_checkpoint_seconds (int): Seconds to save checkpoint. Default: 0.\n",
    "            Can't be used with save_checkpoint_steps at the same time.\n",
    "        keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5.\n",
    "        keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0.\n",
    "            Can't be used with keep_checkpoint_max at the same time.\n",
    "        integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.\n",
    "            Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.\n",
    "\n",
    "    Raises:\n",
    "        ValueError: If the input_param is None or 0.\n",
    "    \"\"\"\n",
    "\n",
    "class ModelCheckpoint(Callback):\n",
    "    \"\"\"\n",
    "    The checkpoint callback class.\n",
    "\n",
    "    It is called to combine with train process and save the model and network parameters after traning.\n",
    "\n",
    "    Args:\n",
    "        prefix (str): Checkpoint files names prefix. Default: \"CKP\".\n",
    "        directory (str): Lolder path into which checkpoint files will be saved. Default: None.\n",
    "        config (CheckpointConfig): Checkpoint strategy config. Default: None.\n",
    "\n",
    "    Raises:\n",
    "        ValueError: If the prefix is invalid.\n",
    "        TypeError: If the config is not CheckpointConfig type.\n",
    "    \"\"\"\n",
    "```\n",
D
dyonghan 已提交
286
    "\n",
D
dyonghan 已提交
287
    "MindSpore提供了多种Metric评估指标,如`accuracy`、`loss`、`precision`、`recall`、`F1`。定义一个metrics字典/元组,里面包含多种指标,传递给`Model`,然后调用`model.eval`接口来计算这些指标。`model.eval`会返回一个字典,包含各个指标及其对应的值。"
D
dyonghan 已提交
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 step: 1875 ,loss is 2.3151364\n",
      "epoch: 2 step: 1875 ,loss is 0.3097728\n",
      "Metrics: {'acc': 0.9417067307692307, 'loss': 0.18866610953894755}\n",
      "b_lenet-1_1875.ckpt\n",
      "b_lenet-2_1875.ckpt\n"
     ]
    }
   ],
   "source": [
    "os.system('rm -f *.ckpt *.ir *.meta') # 清理旧的运行文件\n",
    "\n",
    "def test_train(lr=0.01, momentum=0.9, num_epoch=2, check_point_name=\"b_lenet\"):\n",
    "    ds_train = create_dataset(num_epoch=num_epoch)\n",
    "    ds_eval = create_dataset(training=False)\n",
    "    steps_per_epoch = ds_train.get_dataset_size()\n",
    "    \n",
D
dyonghan 已提交
315
    "    net = LeNet5()\n",
D
dyonghan 已提交
316 317 318 319 320
    "    loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')\n",
    "    opt = nn.Momentum(net.trainable_params(), lr, momentum)\n",
    "    \n",
    "    ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5)\n",
    "    ckpt_cb = ModelCheckpoint(prefix=check_point_name, config=ckpt_cfg)\n",
D
dyonghan 已提交
321
    "    loss_cb = LossMonitor(steps_per_epoch)\n",
D
dyonghan 已提交
322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
    "    \n",
    "    model = Model(net, loss, opt, metrics={'acc', 'loss'})\n",
    "    model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=True)\n",
    "    metrics = model.eval(ds_eval)\n",
    "    print('Metrics:', metrics)\n",
    "\n",
    "test_train()\n",
    "print('\\n'.join(sorted([x for x in os.listdir('.') if x.startswith('b_lenet')])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 加载Checkpoint继续训练\n",
    "\n",
D
dyonghan 已提交
338 339 340 341
    "```python\n",
    "def load_checkpoint(ckpoint_file_name, net=None):\n",
    "    \"\"\"\n",
    "    Loads checkpoint info from a specified file.\n",
D
dyonghan 已提交
342
    "\n",
D
dyonghan 已提交
343 344 345
    "    Args:\n",
    "        ckpoint_file_name (str): Checkpoint file name.\n",
    "        net (Cell): Cell network. Default: None\n",
D
dyonghan 已提交
346
    "\n",
D
dyonghan 已提交
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
    "    Returns:\n",
    "        Dict, key is parameter name, value is a Parameter.\n",
    "\n",
    "    Raises:\n",
    "        ValueError: Checkpoint file is incorrect.\n",
    "    \"\"\"\n",
    "\n",
    "def load_param_into_net(net, parameter_dict):\n",
    "    \"\"\"\n",
    "    Loads parameters into network.\n",
    "\n",
    "    Args:\n",
    "        net (Cell): Cell network.\n",
    "        parameter_dict (dict): Parameter dict.\n",
    "\n",
    "    Raises:\n",
    "        TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dict.\n",
    "    \"\"\"\n",
    "```\n",
D
dyonghan 已提交
366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394
    "\n",
    "> 使用load_checkpoint接口加载数据时,需要把数据传入给原始网络,而不能传递给带有优化器和损失函数的训练网络。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 step: 1875 ,loss is 0.1638589\n",
      "epoch: 2 step: 1875 ,loss is 0.060048036\n",
      "Metrics: {'acc': 0.9742588141025641, 'loss': 0.07910804035148034}\n",
      "b_lenet_1-1_1875.ckpt\n",
      "b_lenet_1-2_1875.ckpt\n"
     ]
    }
   ],
   "source": [
    "CKPT = 'b_lenet-2_1875.ckpt'\n",
    "\n",
    "def resume_train(lr=0.001, momentum=0.9, num_epoch=2, ckpt_name=\"b_lenet\"):\n",
    "    ds_train = create_dataset(num_epoch=num_epoch)\n",
    "    ds_eval = create_dataset(training=False)\n",
    "    steps_per_epoch = ds_train.get_dataset_size()\n",
    "    \n",
D
dyonghan 已提交
395
    "    net = LeNet5()\n",
D
dyonghan 已提交
396 397 398 399 400 401 402 403 404
    "    loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')\n",
    "    opt = nn.Momentum(net.trainable_params(), lr, momentum)\n",
    "    \n",
    "    param_dict = load_checkpoint(CKPT)\n",
    "    load_param_into_net(net, param_dict)\n",
    "    load_param_into_net(opt, param_dict)\n",
    "    \n",
    "    ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5)\n",
    "    ckpt_cb = ModelCheckpoint(prefix=ckpt_name, config=ckpt_cfg)\n",
D
dyonghan 已提交
405
    "    loss_cb = LossMonitor(steps_per_epoch)\n",
D
dyonghan 已提交
406 407 408 409 410 411 412 413 414 415 416 417 418 419 420
    "    \n",
    "    model = Model(net, loss, opt, metrics={'acc', 'loss'})\n",
    "    model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb])\n",
    "    \n",
    "    metrics = model.eval(ds_eval)\n",
    "    print('Metrics:', metrics)\n",
    "\n",
    "resume_train()\n",
    "print('\\n'.join(sorted([x for x in os.listdir('.') if x.startswith('b_lenet')])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
D
dyonghan 已提交
421
    "### 加载Checkpoint进行推理\n",
D
dyonghan 已提交
422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
    "  \n",
    "使用matplotlib定义一个将推理结果可视化的辅助函数,如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_images(pred_fn, ds, net):\n",
    "    for i in range(1, 5):\n",
    "        pred, image, label = pred_fn(ds, net)\n",
    "        plt.subplot(2, 2, i)\n",
    "        plt.imshow(np.squeeze(image))\n",
    "        color = 'blue' if pred == label else 'red'\n",
    "        plt.title(\"prediction: {}, truth: {}\".format(pred, label), color=color)\n",
    "        plt.xticks([])\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
D
dyonghan 已提交
447
    "使用训练后的LeNet5模型对手写数字进行识别,可以看到识别结果基本上是正确的。"
D
dyonghan 已提交
448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUoAAAD7CAYAAAAMyN1hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAcv0lEQVR4nO3de5RU5Znv8e9D03TLJUIraCMgXgDJMlEZguRoEuJl1IkecpxlJo7jQpfaIdE1OmO8xMl9NPHkmMvMMpmIE4TxbtBRYuIkSjQRZVCWgSSKCqMIKHJROgEEpJvn/FGbXbuaqn5316Wrqvv3WatXP7v27d3VTz+1330rc3dERKSwAdVugIhIrVOhFBEJUKEUEQlQoRQRCVChFBEJUKEUEQmoyUJpxjwzbozij5nxSpHL+bEZXylv6+qDGePNcDMGVrstkqXcLl01crsmC2WSO0+7Myk0nRkXmbG4y7yz3fnnyrUuXvexZvzSjC1m9OjCVDNmmLG+DG1YY8ZppS4nsbxPmvGkGX8yY025litZyu3Uy6l6ble8UPaTPZo9wAPAJZVYeJXewx3AXOCaKqy7Lii3S1c3ue3uPf4BXwP+JfCXwLeC3wHeHI2bAb4e/Drwt8HvjF4/G3w5eDv4s+AfTizvBPAXwLeB3w9+H/iNyeUlph0L/hD4ZvB3wG8Fnwy+C7wTfDt4ezTtvH3LiYYvA18N/i74QvDRiXEOPht8VbRNPwS3Hr4vR4N7D6YfAr4TfG/U7u3go8G/Dr4A/C7wP4Nfmmdb4vcF/M5oGTujZVwLPj7aplnga8G3gP9TEX/r08DXFJMn9fij3FZu5/spZY/yAuAM4ChgIvDlxLhDgRbgcKDNjClkKvjngIOA24CFZjSZMQh4GLgzmuenwF/nW6EZDcCjwBvAeOAw4D53VgKzgSXuDHVneJ55TwG+DXwGaI2WcV+Xyc4GPgIcF013RjTvODPazRiX9s1Jw50dwFnAW1G7h7rzVjR6JrAAGA7cHVjOhcBa4JxoGd9JjD4ZmAScCnzVjMnRNp1sRns5t6cPUW6XqK/ldimF8lZ31rnzLnATcH5i3F7ga+7sdmcncBlwmztL3el0Zz6wG5ge/TQCP3BnjzsLgOcLrHMaMBq4xp0d7uxyzz12040LgLnuvODObuBLwEfNGJ+Y5mZ32t1ZCzwJHA/gzlp3hkev95Yl7jzszt7oPSzWN9zZ6c4KYAWZfxTcWZzvn04A5Xal1V1ul1Io1yXiN8j8kffZ7M6uxPDhwNXRJ1d7VO3HRvOMBt50zzlQ/EaBdY4F3nCno4j2jk4u153twDtkPrn3eTsRvwcMLWI95bIuPEkqtbRN9UK5XVl1l9ulFMqxiXgcxLvVwH5nx9YBN0WfXPt+BrtzL7ABOMwM67K8fNYB4wocAA6dkXuLTFIDYMYQMl2lNwPzVVqhdnd9fQcwODF8aMrlSM8pt8ujz+R2KYXycjPGmNEC3ADc3820twOzzTjRDDNjiBmfMmMYsAToAP7ejIFmnEumG5LPc2SS7+ZoGc1mnBSN2wiMiY4L5XMPcLEZx5vRBHwLWOpe+qUv0TY1Q2bdUbuaEuPnmTGvwOwbgYPMODCwmuXAX5nRYsahwFV5lnNkURuQhxkDom1qhMz2dfPe9jXK7YhyO6OUQnkP8CvgtejnxkITurOMzLGcW4GtwGrgomjc+8C50fBW4G+AhwospxM4BziazAHe9dH0AL8GXgTeNmNLnnkXAV8BHiSTkEcBn02zodEB7+3dHPA+HNgZrZ8oTl5IPBZ4psA2vQzcC7wWdd1G55uOzAmBFcAaMu9713/ebwNfjpbxxcAm7bvYeXs3k3w82o5fkNkL2hmttz9QbmcptwHLnCbvmegizUvdeaLHM/cz0SfVCuDD7uypdnuke8rt9PpTbveHC2arKtqrmFztdoiUW3/K7Zq/hVFEpNqK6nqLiPQnJe1RmtmZZvaKma02s+vL1SiRalNuS1LRe5Rm1gC8CpxO5gzd88D57v5S+Zon0vuU29JVKSdzpgGr3f01ADO7j8w9nAWTaZA1eTNDSlillMs2tm5x95HVbkeNUm7XqV3s4H3fbeEpe6aUQnkYubcirQdO7DqRmbUBbQDNDOZEO7WEVUq5POELCt1KJ8rturXUF1VkuaUco8xXtffrx7v7HHef6u5TG7MX9IvUMuW25CilUK4n957YMeTeEytSr5TbkqOUQvk8MMHMjjCzQWRumVpYnmaJVJVyW3IUfYzS3TvM7Argl0ADMNfdXwzMJlLzlNvSVUm3MLr7L8jcWC7Spyi3JUm3MIqIBKhQiogEqFCKiASoUIqIBKhQiogE6MG9KQwYnP3eo7VXHp8zbm+KGzJGLs9+sd4BDz9XtnaJSO/QHqWISIAKpYhIgLreKdiw7PeqL/jcLTnjJg8a3HXy/Rzx2KVxPPHh8rVLJK2BR46P402faE01z8hHV8dx5+bN5W5SXdEepYhIgAqliEiACqWISICOUYr0A8njks/f9G+p5pnS9Pk4PvTB7Ov98Xil9ihFRAJUKEVEAtT1FpG8Xvhqtov+kd3ZbnjLHep6i4hIFyqUIiIBKpQiIgEqlCIiASqUIiIBOust0g8kH3CRvJA8eWZbCgvuUZrZXDPbZGZ/TLzWYmaPm9mq6PeIyjZTpPyU25JWmq73PODMLq9dDyxy9wnAomhYpN7MQ7ktKQS73u7+WzMb3+XlmcCMKJ4PPAVcV8Z2iVRcf8rt5P3ZB60c0+P5P37F0jhe3PnROB7+H0tKa1idKPZkziHuvgEg+j2q0IRm1mZmy8xs2R52F7k6kV6j3Jb9VPyst7vPcfep7j61kRTfxCVSJ5Tb/UexZ703mlmru28ws1ZgUzkbJVJFyu08vtv6QhxPPnJ6HA+vRmOqoNg9yoXArCieBTxSnuaIVJ1yW/aT5vKge4ElwCQzW29mlwA3A6eb2Srg9GhYpK4otyWtNGe9zy8w6tQyt0WkVym3JS3dwigiEqBCKSISoHu9C2gYOTKON5x3dBwPG7C3Gs0RkSrSHqWISIAKpYhIgLreBez5YPZ+2N/d8KPEmKG93xgRqSrtUYqIBKhQiogEqOudYE3ZBxu8/4HGkpa1qXNHdrm7GkpalohUl/YoRUQCVChFRALU9U5oP++EOL73plsSY3p+pvuku78Yx5Nuir+SBV2uLlJ/tEcpIhKgQikiEqBCKSISoGOUCZ2NFsdHNJZ2B87AXdll7d22raRliUh1aY9SRCRAhVJEJECFUkQkQIVSRCRAhVJEJECFUkQkIM33eo81syfNbKWZvWhmV0avt5jZ42a2Kvo9ovLNFSkf5baklWaPsgO42t0nA9OBy83sg8D1wCJ3nwAsioZF6olyW1IJXnDu7huADVG8zcxWAocBM4EZ0WTzgaeA6yrSygrqnDEljredtb2KLZHe1tdzW8qnR8cozWw8cAKwFDgkSrR9CTeqwDxtZrbMzJbtYXdprRWpEOW2dCd1oTSzocCDwFXu/ue087n7HHef6u5TG2kKzyDSy5TbEpLqXm8zaySTSHe7+0PRyxvNrNXdN5hZK7CpUo2spDdnNMfxyyfPrWJLpBr6cm5L+aQ5623AT4CV7v69xKiFwKwongU8Uv7miVSOclvSSrNHeRJwIfAHM1sevXYDcDPwgJldAqwFzqtME0UqRrktqaQ5670YsAKjTy1vc3pHw4Qj43hX654qtkSqqS/mtlSG7swREQlQoRQRCeiXTzhfeX1LHL9+1u1lW+76juwF6wN0WZ3UKOvwOH51z444PmrgATnTNVj+/aiO5uz8A4YNi+O+/CR/7VGKiASoUIqIBPTLrnelnPOda+N43Nzlcby3Go0RKWDA0j/G8VVnXRzHP3jsjpzpJjYOyTv/MxfcEsfTh/1jHE+4fGm5mlhztEcpIhKgQikiEqCud4mmfPPzcdz64Ko47nzvvWo0RyTIOzqyA++0x2GnF7r2PteohmyX3Js7y9auWqY9ShGRABVKEZEAdb1TSF5InjyzDV2625s391qbRKT3aI9SRCRAhVJEJECFUkQkoF8eoxz7s+znw+R1XwhOn3zARfKOG9BlQCL9gfYoRUQCVChFRAL6Zdf7gIefi+NxD/dsXj3gQvoSf29nHH/q5/+QOy7FXTcHL24se5tqkfYoRUQCVChFRAKCXW8zawZ+CzRF0y9w96+ZWQtwPzAeWAN8xt23Vq6pIuWl3M79+oa+/DzJUqXZo9wNnOLuxwHHA2ea2XTgemCRu08AFkXDIvVEuS2pBAulZ+y72bkx+nFgJjA/en0+8OmKtFCkQpTbklaqY5Rm1mBmy4FNwOPuvhQ4xN03AES/R1WumSKVodyWNFIVSnfvdPfjgTHANDM7Nu0KzKzNzJaZ2bI96DtcpbYotyWNHp31dvd24CngTGCjmbUCRL83FZhnjrtPdfepjTSV2FyRylBuS3eChdLMRprZ8Cg+ADgNeBlYCMyKJpsFPFKpRopUgnJb0kpzZ04rMN/MGsgU1gfc/VEzWwI8YGaXAGuB8yrYTpFKUG5LKsFC6e6/B07I8/o7wKmVaJRIb1BuS1rm7r23MrPNwBu9tkLpzuHuPrLajegrlNs1oyJ53auFUkSkHulebxGRABVKEZGAmiyUZswz48Yo/pgZrxS5nB+b8ZXytq4+mDHeDDfrn88crVXK7dJVI7drslAmufO0O5NC05lxkRmLu8w7251/rlzr4nUfa8YvzdhiRo8O+poxw4z1ZWjDGjNOK3U5ieV90ownzfiTGWvKtVzJqofcjtZ/pBmPmrEtyvHvpJyvJnM7WuYUM35rxnYzNppxZXfTV7xQ9pM9mj3AA8AllVh4ld7DHcBc4JoqrLsu9IfcNmMQ8Djwa+BQMrd63lXG5ff6e2jGwcB/AbcBBwFHA7/qdiZ37/EP+BrwL4G/BL4V/A7w5mjcDPD14NeBvw1+Z/T62eDLwdvBnwX/cGJ5J4C/AL4N/H7w+8BvTC4vMe1Y8IfAN4O/A34r+GTwXeCd4NvB26Np5+1bTjR8Gfhq8HfBF4KPToxz8Nngq6Jt+iG49fB9OTp6Jk3a6YeA7wTfG7V7O/ho8K+DLwC/C/zP4Jfm2Zb4fQG/M1rGzmgZ14KPj7ZpFvha8C3g/1TE3/o08DXF5Ek9/ii393s/2sCfLuJ9rNncBv/Wvr9d2p9S9igvAM4AjgImAl9OjDsUaAEOB9rMmEJm7+RzZCr4bcBCM5qiT6yHgTujeX4K/HW+FZrRADxK5nq18cBhwH3urARmA0vcGerO8DzzngJ8G/gMmTsy3gDu6zLZ2cBHgOOi6c6I5h1nRrsZ49K+OWm4swM4C3gravdQd96KRs8EFgDDgbsDy7mQzB0k50TLSHaNTgYmkbmA+qtmTI626WQz2su5PX2IcjtrOrDGjMeibvdTZnyowLSxGs/t6cC7ZjxrxiYzfhb63y6lUN7qzjp33gVuAs5PjNsLfM2d3e7sBC4DbnNnqTud7swn89DU6dFPI/ADd/a4swB4vsA6pwGjgWvc2eHOLvfcYzfduACY684L7uwGvgR81IzxiWludqfdnbXAk2Qe5oo7a90ZHr3eW5a487A7e6P3sFjfcGenOyuAFWT+UXBncb5/OgGU20ljgM8C/xq17+fAI9GHQLGqndtjyNzDfyUwDngduLe7FZVSKNcl4jfIvIn7bHZnV2L4cODq6JOrPar2Y6N5RgNvuuecBCl0h8NY4A13Oopo7+jkct3ZDrxD5pN7n7cT8XvA0CLWUy7rwpOkUkvbVC+U21k7gcXuPObO+8AtZPacJxfRzn2qnds7gf905/nob/kN4H+ZcWChGUoplGMT8TiId6uB/c78rgNuij659v0MdudeYANwmBnWZXn5rAPGFTgAHDrb/BaZpAbAjCFk/uBvBuartELt7vr6DmBwYvjQlMuRnlNuZ/0+xfoLqdXc7rpN+2LLMy1QWqG83IwxZrQAN5D5MqZCbgdmm3GiGWbGEDM+ZcYwYAnQAfy9GQPNOJdMNySf58gk383RMprNOCkatxEY002X4B7gYjOON6MJ+Baw1L30S1+ibWqGzLqjdjUlxs8zY16B2TcCB3X3aRZZDvyVGS1mHApclWc5Rxa1AXmYMSDapkbIbF+J3a16otzOuguYbsZp0XHUq4AtwEqoz9wG7gD+T/R+NQJfIbPXXPC4ZimF8h4yp9Rfi35uLDShO8vIHMu5FdgKrAYuisa9D5wbDW8F/gZ4qMByOoFzyJzOXwusj6aHzOULLwJvm7Elz7yLyLwhD5JJyKPIHHsJig54b+/mgO/hZHbnX4yGd0LOhcRjgWcKbNPLZI6PvBZ13Ubnm47MCYEVZL4V8Ffs/8/7beDL0TK+GNikfRc7b+9mko9H2/ELMntBOwldQtF3KLezy34F+Dvgx9E2zAT+d7RtUIe57c6vyXwA/pzMQ5mPBv6222VmTpf3jGUuQL7UnSd6PHM/E+0FrAA+7M6eardHuqfcTq8/5Xafv2C22qJP3lIOfIvUpP6U2zV/C6OISLUVVSjdGe/OE2Z2ppm9YmarzUxfEi91T7kt+RT94N7oe0ZeBU4nc+D5eeB8d3+pfM0T6X3KbemqlGOU04DV7v4agJndR+aMWMFkGmRN3syQElYp5bKNrVtcXwVRiHK7Tu1iB+/77oLXQxarlEJ5GLlX2K8HTuxuhmaGcKLpO5tqwRO+QN/vUphyu04t9UUVWW4phTJf1d6vH29mbUAbQHPOxfciNUu5LTlKOeu9ntxbvcaQe6sXAO4+x92nuvvUxuzNKiK1TLktOUoplM8DE8zsCDMbROZOgIXlaZZIVSm3JUfRXW937zCzK4BfAg3AXHd/MTCbSM1TbktXJd2Z4+6/IHMvsEifotyWJN2ZIyISoEIpIhKgQikiEqBCKSISoEIpIhKg51GK9FGdM6bE8ZszmvNOM2B3Nh73L8tzxu19772KtKseaY9SRCRAhVJEJEBd7xK9d272oTK7Dsz/uTPi1ex3vNszy/NOI9ITafJu21nZ79d6+eS5eadZ35Gd5pxt1+aMa71/VRx3bt5cVDv7Cu1RiogEqFCKiASoUIqIBOgYZQo2MPs27T3x2JxxF970szhuO3C/RxYCcMRjl8bxxLxfFS/SM2d8/Tdx/OWDXy56OWMGDo3j393wo5xxp//h4jge8BsdoxQRkW6oUIqIBKjrncKAg1ri+Pt353ZPJg/Sd6VIBVn263saDj44jpsGvFaN1vRb2qMUEQlQoRQRCVDXW6SGJbvbbc8uieOzBm9NTNXYiy3qn7RHKSISoEIpIhKgrrdIjbGp2ZsaPjkvf3e7ycLd7eOeOz+OR30///MouzPwhdVxvLfHc/ctwT1KM5trZpvM7I+J11rM7HEzWxX9HlHZZoqUn3Jb0krT9Z4HnNnlteuBRe4+AVgUDYvUm3kotyWFYNfb3X9rZuO7vDwTmBHF84GngOvK2K6qG3DsMXG891+3xfHYgTqs21fUam53Dsl2q69p+Z/EmHB3+5jFF8bxuH9piGN75nc9bkd/724nFftff4i7bwCIfo8qX5NEqkq5Lfup+MkcM2sD2gCa0e1+0ncot/uPYgvlRjNrdfcNZtYKbCo0obvPAeYAfMBavMj19bqOgw6I48ePuS8xpudnD6Wu1HVu20vDsvEzz1axJX1LsV3vhcCsKJ4FPFKe5ohUnXJb9pPm8qB7gSXAJDNbb2aXADcDp5vZKuD0aFikrii3Ja00Z73PLzDq1DK3paY0vvWnOD7q19knPS/7xA9zphvRkP/Y1MVrPxbHBy/Wvbi1qFZye8Bxk3OGX/3bnh0RO33lOXE8cnlHWdqUVueMKXH85oziD0uN/8+tOcN7V6wselmVoGtdREQCVChFRAJ0r3cBnauyT5CedHX2Urq3l+ZON6KBvP77vz4Ux+Pu0NlHKWzLlOE5w6+f8289mr/9rjFx3PLwkm6mLL9kd3tl24+6mbJ7R4xqyxk+5rbs4Yha6IZrj1JEJECFUkQkQF3vAgYMy164u2Pa+DhuNt0BK/1bw4Qj43hX656yLPP1T8/JGZ686QtxPG5FWVZREu1RiogEqFCKiASo611Ax5Sj4/g3tyW7BUN7vzEiVdYw/MA4fuVr2fj1U26vyPo6mrO3zicPg+3dti3f5BWnPUoRkQAVShGRAHW9RSTozXmj43jZXySfd1CZ53A+c8EtcTx92D/G8YTLl+abvOK0RykiEqBCKSISoEIpIhKgY5RldMzt2bsJjvrJ2jju3ScEipTHu49OjOOffujf43hEw5AeLef1Pdvj+KIrsscbL/jOo3HcduBbOfOMSqzDmzt7tL5K0B6liEiACqWISIC63mU0bE32boKOdeur2BKpJ6N+syFnOHkI5+XLws94/PgV2UtmFnd+NI6H/0dpz6acOe73cTyxsWfd7Qe2Z+/e+b/f/XwcN8zeEsczBq9KzNGz5fc27VGKiASoUIqIBKjrneAnHR/Hm/5hV6p5jvjZZXF8zAvtcaynVkpaHa+tyRk+6t+z10lMaMp2W5N3qyTPCn+39YU4vvgL2def+sTUOB720qA4bv1uuq8meXDOKXHcNDv73MlrWv4nOO+ru1qzbb0z+0DJtS3Z/7E1k7JfgTGxsTzPtayUNN/rPdbMnjSzlWb2opldGb3eYmaPm9mq6PeIyjdXpHyU25JWmq53B3C1u08GpgOXm9kHgeuBRe4+AVgUDYvUE+W2pBLserv7BmBDFG8zs5XAYcBMYEY02XzgKeC6irSyl2ydeEAcr5g2L9U8R9+T7SbVwrfFSXq1mtvJKyYmfP/9OH7nsxbHowp8++cd457ODiTi/3fiUXE8d9gZZWhl96YNznbP7772lG6mrA89OpljZuOBE4ClwCFRou1LuFGF5xSpbcpt6U7qQmlmQ4EHgavc/c89mK/NzJaZ2bI97C6mjSIVpdyWkFRnvc2skUwi3e3uD0UvbzSzVnffYGatwKZ887r7HGAOwAesxfNNI1ItNZ/bu7MF+LPLL4njn56Qvfc6zcXgyTPV17SFL2Iv1V8Ozp7FXtkL66u0NGe9DfgJsNLdv5cYtRCYFcWzgEfK3zyRylFuS1pp9ihPAi4E/mBmy6PXbgBuBh4ws0uAtcB5lWmiSMUotyWVNGe9FwNWYPSp5W2OSO+ph9zubP9THB/66Wx8waMXx3Hynuy/HPaHOJ7W1Fjh1pVP8t5wyL1gfdCG6m+HbmEUEQlQoRQRCdC93iJ1qOXsV+P4aZrjeO4Pr4jjn3/q+73aplJ875uzc4YPvOu/43g8pT0urhy0RykiEqBCKSISoK63SB8y6fqX4vjqb366ii3pmeHtv8sZrrU7U7RHKSISoEIpIhKgrrdIH7J327bsQDKWkmiPUkQkQIVSRCRAhVJEJECFUkQkQIVSRCRAhVJEJECFUkQkQIVSRCRAhVJEJEB35iSMeHVnHB/x2KWp5pn81rtx3Fn2FolILdAepYhIgAqliEiAut4J9szyOJ74TLp51N0W6fuCe5Rm1mxmz5nZCjN70cy+Eb3eYmaPm9mq6PeIyjdXpHyU25JWmq73buAUdz8OOB4408ymA9cDi9x9ArAoGhapJ8ptSSVYKD1jezTYGP04MBOYH70+H6if586LoNyW9FKdzDGzBjNbDmwCHnf3pcAh7r4BIPo9qnLNFKkM5bakkapQununux8PjAGmmdmxaVdgZm1mtszMlu1hd7HtFKkI5bak0aPLg9y9HXgKOBPYaGatANHvTQXmmePuU919aiNNJTZXpDKU29KdNGe9R5rZ8Cg+ADgNeBlYCMyKJpsFPFKpRopUgnJb0kpzHWUrMN/MGsgU1gfc/VEzWwI8YGaXAGuB8yrYTpFKUG5LKubee181bmabgTd6bYXSncPdfWS1G9FXKLdrRkXyulcLpYhIPdK93iIiASqUIiIBKpQiIgEqlCIiASqUIiIBKpQiIgEqlCIiASqUIiIBKpQiIgH/HyKwOLKinO/gAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 4 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "CKPT = 'b_lenet_1-2_1875.ckpt'\n",
    "\n",
    "def infer(ds, model):\n",
    "    data = ds.get_next()\n",
    "    images = data['image']\n",
    "    labels = data['label']\n",
    "    output = model.predict(Tensor(data['image']))\n",
    "    pred = np.argmax(output.asnumpy(), axis=1)\n",
    "    return pred[0], images[0], labels[0]\n",
    "\n",
    "ds = create_dataset(training=False, batch_size=1).create_dict_iterator()\n",
D
dyonghan 已提交
480 481
    "net = LeNet5()\n",
    "param_dict = load_checkpoint(CKPT, net)\n",
D
dyonghan 已提交
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
    "model = Model(net)\n",
    "plot_images(infer, ds, model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 实验步骤(方案二)\n",
    "\n",
    "### 代码梳理\n",
    "\n",
    "创建训练作业时,运行参数会通过脚本传参的方式输入给脚本代码,脚本必须解析传参才能在代码中使用相应参数。如data_url和train_url,分别对应数据存储路径(OBS路径)和训练输出路径(OBS路径)。脚本对传参进行解析后赋值到`args`变量里,在后续代码里可以使用。\n",
    "\n",
    "```python\n",
    "import argparse\n",
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument('--data_url', required=True, default=None, help='Location of data.')\n",
    "parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.')\n",
    "parser.add_argument('--num_epochs', type=int, default=1, help='Number of training epochs.')\n",
    "args, unknown = parser.parse_known_args()\n",
    "```\n",
    "\n",
    "MindSpore暂时没有提供直接访问OBS数据的接口,需要通过MoXing提供的API与OBS交互。将OBS中存储的数据拷贝至执行容器:\n",
    "\n",
    "```python\n",
    "import moxing as mox\n",
    "mox.file.copy_parallel(src_url=args.data_url, dst_url='MNIST/')\n",
    "```\n",
    "\n",
    "如需将训练输出(如模型Checkpoint)从执行容器拷贝至OBS,请参考:\n",
    "\n",
    "```python\n",
    "import moxing as mox\n",
    "mox.file.copy_parallel(src_url='output', dst_url='s3://OBS/PATH')\n",
    "```\n",
    "\n",
    "其他代码分析请参考方案一。\n",
    "\n",
    "### 创建训练作业\n",
    "\n",
    "可以参考[使用常用框架训练模型](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0238.html)来创建并启动训练作业。\n",
    "\n",
    "创建训练作业的参考配置:\n",
    "\n",
    "- 算法来源:常用框架->Ascend-Powered-Engine->MindSpore\n",
    "- 代码目录:选择上述新建的OBS桶中的experiment_2目录\n",
    "- 启动文件:选择上述新建的OBS桶中的experiment_2目录下的`main.py`\n",
N
njzheng 已提交
530 531
    "- 数据来源:数据存储位置->选择上述新建的OBS桶中的experiment_2文件夹下的MNIST目录\n",
    "- 训练输出位置:选择上述新建的OBS桶中的experiment_2目录并在其中创建output目录\n",
D
dyonghan 已提交
532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552
    "- 作业日志路径:同训练输出位置\n",
    "- 规格:Ascend:1*Ascend 910\n",
    "- 其他均为默认\n",
    "\n",
    "启动并查看训练过程:\n",
    "\n",
    "1. 点击提交以开始训练;\n",
    "2. 在训练作业列表里可以看到刚创建的训练作业,在训练作业页面可以看到版本管理;\n",
    "3. 点击运行中的训练作业,在展开的窗口中可以查看作业配置信息,以及训练过程中的日志,日志会不断刷新,等训练作业完成后也可以下载日志到本地进行查看;\n",
    "4. 在训练日志中可以看到`epoch: 3 step: 1875 ,loss is 0.025683485`等字段,即训练过程的loss值;\n",
    "5. 在训练日志中可以看到`Metrics: {'acc': 0.9742588141025641, 'loss': 0.08628832848253062}`等字段,即训练完成后的验证精度;\n",
    "6. 在训练日志里可以看到`b_lenet_1-2_1875.ckpt`等字段,即训练过程保存的Checkpoint。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 实验小结\n",
    "\n",
    "本实验展示了MindSpore的Checkpoint、断点继续训练等高级特性:\n",
D
dyonghan 已提交
553
    "\n",
D
dyonghan 已提交
554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575
    "1. 使用MindSpore的ModelCheckpoint接口每个epoch保存一次Checkpoint,训练2个epoch并终止。\n",
    "2. 使用MindSpore的load_checkpoint和load_param_into_net接口加载上一步保存的Checkpoint继续训练2个epoch。\n",
    "3. 观察训练过程中Loss的变化情况,加载Checkpoint继续训练后loss进一步下降。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
N
njzheng 已提交
576
   "version": "3.7.6"
D
dyonghan 已提交
577 578 579 580
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
D
dyonghan 已提交
581
}