From 3f5da42196dfc1244997dca39c5e767e3f9aa537 Mon Sep 17 00:00:00 2001 From: dongyonghan Date: Thu, 6 Aug 2020 16:44:16 +0800 Subject: [PATCH] upgrade lenet experiment to r0.5, unify codes for different platform --- .gitignore | 11 +- checkpoint/README.md | 465 ++++++++++++++++ checkpoint/images/prediction.png | Bin 0 -> 7521 bytes checkpoint/main.py | 159 ++++++ experiment_1/1-LeNet5_MNIST.ipynb | 357 ------------ experiment_1/main.py | 60 -- experiment_2/2-Save_And_Load_Model.ipynb | 581 -------------------- experiment_2/main.py | 140 ----- experiment_5/LeNet_MNIST_Windows.md | 171 ------ experiment_5/main.py | 62 --- experiment_6/Save_And_Load_Model_Windows.md | 346 ------------ experiment_6/main.py | 146 ----- lenet5/README.md | 309 +++++++++++ lenet5/images/mnist.png | Bin 0 -> 7485 bytes lenet5/main.py | 85 +++ 15 files changed, 1024 insertions(+), 1868 deletions(-) create mode 100644 checkpoint/README.md create mode 100644 checkpoint/images/prediction.png create mode 100644 checkpoint/main.py delete mode 100644 experiment_1/1-LeNet5_MNIST.ipynb delete mode 100644 experiment_1/main.py delete mode 100644 experiment_2/2-Save_And_Load_Model.ipynb delete mode 100644 experiment_2/main.py delete mode 100644 experiment_5/LeNet_MNIST_Windows.md delete mode 100644 experiment_5/main.py delete mode 100644 experiment_6/Save_And_Load_Model_Windows.md delete mode 100644 experiment_6/main.py create mode 100644 lenet5/README.md create mode 100644 lenet5/images/mnist.png create mode 100644 lenet5/main.py diff --git a/.gitignore b/.gitignore index 90a4621..118548d 100644 --- a/.gitignore +++ b/.gitignore @@ -129,14 +129,15 @@ dmypy.json .pyre/ # MindSpore files -.dat -.ir -.meta -.ckpt +*.dat +*.ir +*.meta +*.ckpt +*.pb # system files .DS_Store -.swap +*.swap # IDE .idea/ diff --git a/checkpoint/README.md b/checkpoint/README.md new file mode 100644 index 0000000..408c53a --- /dev/null +++ b/checkpoint/README.md @@ -0,0 +1,465 @@ +# 模型的保存和加载 + +## 实验介绍 + +本实验主要介绍使用MindSpore实现模型的保存和加载。建议先阅读MindSpore官网教程中关于模型参数保存和加载的内容。 + +在模型训练过程中,可以添加检查点(Checkpoint)用于保存模型的参数,以便进行推理及中断后再训练使用。使用场景如下: + +- 训练后推理场景 + - 模型训练完毕后保存模型的参数,用于推理或预测操作。 + - 训练过程中,通过实时验证精度,把精度最高的模型参数保存下来,用于预测操作。 +- 再训练场景 + - 进行长时间训练任务时,保存训练过程中的Checkpoint文件,防止任务异常退出后从初始状态开始训练。 + - Fine-tuning(微调)场景,即训练一个模型并保存参数,基于该模型,面向第二个类似任务进行模型训练。 + +## 实验目的 + +- 了解如何使用MindSpore实现训练时模型的保存。 +- 了解如何使用MindSpore加载保存的模型文件并继续训练。 +- 了解如何MindSpore的Callback功能。 + +## 预备知识 + +- 熟练使用Python,了解Shell及Linux操作系统基本知识。 +- 具备一定的深度学习理论知识,如卷积神经网络、损失函数、优化器,训练策略、Checkpoint等。 +- 了解华为云的基本使用方法,包括[OBS(对象存储)](https://www.huaweicloud.com/product/obs.html)、[ModelArts(AI开发平台)](https://www.huaweicloud.com/product/modelarts.html)、[Notebook(开发工具)](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0033.html)、[训练作业](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0046.html)等功能。华为云官网:https://www.huaweicloud.com +- 了解并熟悉MindSpore AI计算框架,MindSpore官网:https://www.mindspore.cn/ + +## 实验环境 + +- MindSpore 0.5.0(MindSpore版本会定期更新,本指导也会定期刷新,与版本配套); +- 华为云ModelArts:ModelArts是华为云提供的面向开发者的一站式AI开发平台,集成了昇腾AI处理器资源池,用户可以在该平台下体验MindSpore。ModelArts官网:https://www.huaweicloud.com/product/modelarts.html +- Windows/Ubuntu x64笔记本,NVIDIA GPU服务器,或Atlas Ascend服务器等。 + +## 实验准备 + +### 创建OBS桶 + +本实验需要使用华为云OBS存储实验脚本和数据集,可以参考[快速通过OBS控制台上传下载文件](https://support.huaweicloud.com/qs-obs/obs_qs_0001.html)了解使用OBS创建桶、上传文件、下载文件的使用方法。 + +> **提示:** 华为云新用户使用OBS时通常需要创建和配置“访问密钥”,可以在使用OBS时根据提示完成创建和配置。也可以参考[获取访问密钥并完成ModelArts全局配置](https://support.huaweicloud.com/prepare-modelarts/modelarts_08_0002.html)获取并配置访问密钥。 + +创建OBS桶的参考配置如下: + +- 区域:华北-北京四 +- 数据冗余存储策略:单AZ存储 +- 桶名称:全局唯一的字符串 +- 存储类别:标准存储 +- 桶策略:公共读 +- 归档数据直读:关闭 +- 企业项目、标签等配置:免 + +### 数据集准备 + +MNIST是一个手写数字数据集,训练集包含60000张手写数字,测试集包含10000张手写数字,共10类。MNIST数据集的官网:[THE MNIST DATABASE](http://yann.lecun.com/exdb/mnist/)。 + +从MNIST官网下载如下4个文件到本地并解压: + +``` +train-images-idx3-ubyte.gz: training set images (9912422 bytes) +train-labels-idx1-ubyte.gz: training set labels (28881 bytes) +t10k-images-idx3-ubyte.gz: test set images (1648877 bytes) +t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes) +``` + +### 脚本准备 + +从[课程gitee仓库](https://gitee.com/mindspore/course)上下载本实验相关脚本。 + +### 上传文件 + +将脚本和数据集上传到OBS桶中,组织为如下形式: + +``` +checkpoint +├── MNIST +│   ├── test +│   │   ├── t10k-images-idx3-ubyte +│   │   └── t10k-labels-idx1-ubyte +│   └── train +│   ├── train-images-idx3-ubyte +│   └── train-labels-idx1-ubyte +└── main.py +``` + +## 实验步骤(ModelArts Notebook) + +### 创建Notebook + +可以参考[创建并打开Notebook](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0034.html)来创建并打开本实验的Notebook脚本。 + +创建Notebook的参考配置: + +- 计费模式:按需计费 +- 名称:checkpoint +- 工作环境:Python3 +- 资源池:公共资源 +- 类型:Ascend +- 规格:单卡1*Ascend 910 +- 存储位置:对象存储服务(OBS)->选择上述新建的OBS桶中的checkpoint文件夹 +- 自动停止等配置:默认 + +> **注意:** +> - 打开Notebook前,在Jupyter Notebook文件列表页面,勾选目录里的所有文件/文件夹(实验脚本和数据集),并点击列表上方的“Sync OBS”按钮,使OBS桶中的所有文件同时同步到Notebook工作环境中,这样Notebook中的代码才能访问数据集。参考[使用Sync OBS功能](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0038.html)。 +> - 打开Notebook后,选择MindSpore环境作为Kernel。 + +> **提示:** 上述数据集和脚本的准备工作也可以在Notebook环境中完成,在Jupyter Notebook文件列表页面,点击右上角的"New"->"Terminal",进入Notebook环境所在终端,进入`work`目录,可以使用常用的linux shell命令,如`wget, gzip, tar, mkdir, mv`等,完成数据集和脚本的下载和准备。 + +> **提示:** 请从上至下阅读提示并执行代码框进行体验。代码框执行过程中左侧呈现[\*],代码框执行完毕后左侧呈现如[1],[2]等。请等上一个代码框执行完毕后再执行下一个代码框。 + +导入MindSpore模块和辅助模块: + +```python +import os +# os.environ['DEVICE_ID'] = '0' +# Log level includes 3(ERROR), 2(WARNING), 1(INFO), 0(DEBUG). +os.environ['GLOG_v'] = '2' + +import matplotlib.pyplot as plt +import numpy as np + +import mindspore as ms +import mindspore.context as context +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset.transforms.vision.c_transforms as CV + +from mindspore import nn, Tensor +from mindspore.train import Model +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +import logging; logging.getLogger('matplotlib.font_manager').disabled = True + +context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') # Ascend, CPU, GPU +``` + +### 数据处理 + +在使用数据集训练网络前,首先需要对数据进行预处理,如下: + +```python +def create_dataset(data_dir, training=True, batch_size=32, resize=(32, 32), repeat=1, + rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64): + data_train = os.path.join(data_dir, 'train') # 训练集信息 + data_test = os.path.join(data_dir, 'test') # 测试集信息 + ds = ms.dataset.MnistDataset(data_train if training else data_test) + + ds = ds.map(input_columns=["image"], operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()]) + ds = ds.map(input_columns=["label"], operations=C.TypeCast(ms.int32)) + ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True).repeat(repeat) + + return ds +``` + +### 定义模型 + +定义LeNet5模型,模型结构如下图所示: + + + +[1] 图片来源于http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf + +```python +class LeNet5(nn.Cell): + def __init__(self): + super(LeNet5, self).__init__() + self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid') + self.relu = nn.ReLU() + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + self.fc1 = nn.Dense(400, 120) + self.fc2 = nn.Dense(120, 84) + self.fc3 = nn.Dense(84, 10) + + def construct(self, x): + x = self.relu(self.conv1(x)) + x = self.pool(x) + x = self.relu(self.conv2(x)) + x = self.pool(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + + return x +``` + +### 保存模型Checkpoint + +MindSpore提供了Callback功能,可用于训练/测试过程中执行特定的任务。常用的Callback如下: + +- `ModelCheckpoint`:保存网络模型和参数,用于再训练或推理; +- `LossMonitor`:监控loss值,当loss值为Nan或Inf时停止训练; +- `SummaryStep`:把训练过程中的信息存储到文件中,用于后续查看或可视化展示。 + +`ModelCheckpoint`会生成模型(.meta)和Chekpoint(.ckpt)文件,如每个epoch结束时,都保存一次checkpoint。 + +```python +class CheckpointConfig: + """ + The config for model checkpoint. + + Args: + save_checkpoint_steps (int): Steps to save checkpoint. Default: 1. + save_checkpoint_seconds (int): Seconds to save checkpoint. Default: 0. + Can't be used with save_checkpoint_steps at the same time. + keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5. + keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0. + Can't be used with keep_checkpoint_max at the same time. + integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True. + Integrated save function is only supported in automatic parallel scene, not supported in manual parallel. + + Raises: + ValueError: If the input_param is None or 0. + """ + +class ModelCheckpoint(Callback): + """ + The checkpoint callback class. + + It is called to combine with train process and save the model and network parameters after traning. + + Args: + prefix (str): Checkpoint files names prefix. Default: "CKP". + directory (str): Lolder path into which checkpoint files will be saved. Default: None. + config (CheckpointConfig): Checkpoint strategy config. Default: None. + + Raises: + ValueError: If the prefix is invalid. + TypeError: If the config is not CheckpointConfig type. + """ +``` + +MindSpore提供了多种Metric评估指标,如`accuracy`、`loss`、`precision`、`recall`、`F1`。定义一个metrics字典/元组,里面包含多种指标,传递给`Model`,然后调用`model.eval`接口来计算这些指标。`model.eval`会返回一个字典,包含各个指标及其对应的值。 + +```python +# 请先删除旧的checkpoint目录`ckpt` +def train(data_dir, lr=0.01, momentum=0.9, num_epochs=2, ckpt_name="lenet"): + dataset_sink = context.get_context('device_target') == 'Ascend' + repeat = num_epochs if dataset_sink else 1 + ds_train = create_dataset(data_dir, repeat=repeat) + ds_eval = create_dataset(data_dir, training=False) + steps_per_epoch = ds_train.get_dataset_size() + + net = LeNet5() + loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') + opt = nn.Momentum(net.trainable_params(), lr, momentum) + + ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5) + ckpt_cb = ModelCheckpoint(prefix=ckpt_name, directory='ckpt', config=ckpt_cfg) + loss_cb = LossMonitor(steps_per_epoch) + + model = Model(net, loss, opt, metrics={'acc', 'loss'}) + model.train(num_epochs, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=dataset_sink) + metrics = model.eval(ds_eval, dataset_sink_mode=dataset_sink) + print('Metrics:', metrics) + +train('MNIST') +print('Checkpoints after first training:') +print('\n'.join(sorted([x for x in os.listdir('ckpt') if x.startswith('lenet')]))) +``` + + epoch: 1 step 1875, loss is 0.23394052684307098 + Epoch time: 23049.360, per step time: 12.293, avg loss: 2.049 + ************************************************************ + epoch: 2 step 1875, loss is 0.4737345278263092 + Epoch time: 26768.848, per step time: 14.277, avg loss: 0.155 + ************************************************************ + Metrics: {'loss': 0.10531254443608654, 'acc': 0.9701522435897436} + Checkpoints after first training: + lenet-1_1875.ckpt + lenet-2_1875.ckpt + lenet-graph.meta + + +### 加载Checkpoint继续训练 + +```python +def load_checkpoint(ckpoint_file_name, net=None): + """ + Loads checkpoint info from a specified file. + + Args: + ckpoint_file_name (str): Checkpoint file name. + net (Cell): Cell network. Default: None + + Returns: + Dict, key is parameter name, value is a Parameter. + + Raises: + ValueError: Checkpoint file is incorrect. + """ + +def load_param_into_net(net, parameter_dict): + """ + Loads parameters into network. + + Args: + net (Cell): Cell network. + parameter_dict (dict): Parameter dict. + + Raises: + TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dict. + """ +``` + +> 使用load_checkpoint接口加载数据时,需要把数据传入给原始网络,而不能传递给带有优化器和损失函数的训练网络。 + +```python +CKPT_1 = 'ckpt/lenet-2_1875.ckpt' + +def resume_train(data_dir, lr=0.001, momentum=0.9, num_epochs=2, ckpt_name="lenet"): + dataset_sink = context.get_context('device_target') == 'Ascend' + repeat = num_epochs if dataset_sink else 1 + ds_train = create_dataset(data_dir, repeat=repeat) + ds_eval = create_dataset(data_dir, training=False) + steps_per_epoch = ds_train.get_dataset_size() + + net = LeNet5() + loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') + opt = nn.Momentum(net.trainable_params(), lr, momentum) + + param_dict = load_checkpoint(CKPT_1) + load_param_into_net(net, param_dict) + load_param_into_net(opt, param_dict) + + ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5) + ckpt_cb = ModelCheckpoint(prefix=ckpt_name, directory='ckpt', config=ckpt_cfg) + loss_cb = LossMonitor(steps_per_epoch) + + model = Model(net, loss, opt, metrics={'acc', 'loss'}) + model.train(num_epochs, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=dataset_sink) + + metrics = model.eval(ds_eval, dataset_sink_mode=dataset_sink) + print('Metrics:', metrics) + +resume_train('MNIST') +print('Checkpoints after resuming training:') +print('\n'.join(sorted([x for x in os.listdir('ckpt') if x.startswith('lenet')]))) +``` + + epoch: 1 step 1875, loss is 0.07734094560146332 + Epoch time: 25687.625, per step time: 13.700, avg loss: 0.094 + ************************************************************ + epoch: 2 step 1875, loss is 0.007969829253852367 + Epoch time: 22888.613, per step time: 12.207, avg loss: 0.078 + ************************************************************ + Metrics: {'loss': 0.07375562800846708, 'acc': 0.975761217948718} + Checkpoints after resuming training: + lenet-1_1875.ckpt + lenet-2_1875.ckpt + lenet-graph.meta + lenet_1-1_1875.ckpt + lenet_1-2_1875.ckpt + lenet_1-graph.meta + + +### 加载Checkpoint进行推理 + +使用训练后的LeNet5模型对手写数字进行识别,使用matplotlib将推理结果可视化,可以看到识别结果基本上是正确的。 + +```python +CKPT_2 = 'ckpt/lenet_1-2_1875.ckpt' + +def infer(data_dir): + ds = create_dataset(data_dir, training=False).create_dict_iterator() + data = ds.get_next() + images = data['image'] + labels = data['label'] + net = LeNet5() + load_checkpoint(CKPT_2, net=net) + model = Model(net) + output = model.predict(Tensor(data['image'])) + preds = np.argmax(output.asnumpy(), axis=1) + + for i in range(1, 5): + plt.subplot(2, 2, i) + plt.imshow(np.squeeze(images[i])) + color = 'blue' if preds[i] == labels[i] else 'red' + plt.title("prediction: {}, truth: {}".format(preds[i], labels[i]), color=color) + plt.xticks([]) + plt.show() + +infer('MNIST') +``` + +![png](images/prediction.png) + +## 实验步骤(ModelArts训练作业) + +### 适配训练作业 + +创建训练作业时,运行参数会通过脚本传参的方式输入给脚本代码,脚本必须解析传参才能在代码中使用相应参数。如data_url和train_url,分别对应数据存储路径(OBS路径)和训练输出路径(OBS路径)。脚本对传参进行解析后赋值到`args`变量里,在后续代码里可以使用。 + +```python +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--data_url', required=True, default=None, help='Location of data.') +parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.') +args, unknown = parser.parse_known_args() +``` + +MindSpore暂时没有提供直接访问OBS数据的接口,需要通过MoXing提供的API与OBS交互。将OBS中存储的数据拷贝至执行容器: + +```python +import moxing +moxing.file.copy_parallel(src_url=args.data_url, dst_url='MNIST/') +``` + +如需将训练输出(如模型Checkpoint)从执行容器拷贝至OBS,请参考: + +```python +import moxing +# dst_url形如's3://OBS/PATH',将ckpt目录拷贝至OBS后,可在OBS的`args.train_url`目录下看到ckpt目录 +moxing.file.copy_parallel(src_url='ckpt', dst_url=os.path.join(args.train_url, 'ckpt')) +``` + +### 创建训练作业 + +可以参考[使用常用框架训练模型](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0238.html)来创建并启动训练作业。 + +创建训练作业的参考配置: + +- 算法来源:常用框架->Ascend-Powered-Engine->MindSpore +- 代码目录:选择上述新建的OBS桶中的checkpoint目录 +- 启动文件:选择上述新建的OBS桶中的checkpoint目录下的`main.py` +- 数据来源:数据存储位置->选择上述新建的OBS桶中的checkpoint文件夹下的MNIST目录 +- 训练输出位置:选择上述新建的OBS桶中的checkpoint目录并在其中创建output目录 +- 作业日志路径:同训练输出位置 +- 规格:Ascend:1*Ascend 910 +- 其他均为默认 + +启动并查看训练过程: + +1. 点击提交以开始训练; +2. 在训练作业列表里可以看到刚创建的训练作业,在训练作业页面可以看到版本管理; +3. 点击运行中的训练作业,在展开的窗口中可以查看作业配置信息,以及训练过程中的日志,日志会不断刷新,等训练作业完成后也可以下载日志到本地进行查看; +4. 参考实验步骤(Notebook),在日志中找到对应的打印信息,检查实验是否成功。 + +## 实验步骤(本地CPU/GPU/Ascend) + +MindSpore还支持在本地CPU/GPU/Ascend环境上运行,如Windows/Ubuntu x64笔记本,NVIDIA GPU服务器,以及Atlas Ascend服务器等。在本地环境运行实验前,需要先参考[安装教程](https://www.mindspore.cn/install/)配置环境。 + +在Windows/Ubuntu x64笔记本上运行实验: + +```shell script +vim main.py # 将第23行的context设置为`device_target='CPU'` +python main.py --data_url=D:\dataset\MNIST +``` + +在Ascend服务器上运行实验: + +```shell script +vim main.py # 将第23行的context设置为`device_target='Ascend'` +python main.py --data_url=/PATH/TO/MNIST +``` + +## 实验小结 + +本实验展示了使用MindSpore实现训练时保存Checkpoint、断点继续训练、加载Checkpoint进行推理等高级特性: + +1. 使用MindSpore的ModelCheckpoint接口每个epoch保存一次Checkpoint,训练2个epoch并终止。 +2. 使用MindSpore的load_checkpoint和load_param_into_net接口加载上一步保存的Checkpoint继续训练2个epoch。 +3. 观察训练过程中Loss的变化情况,加载Checkpoint继续训练后loss进一步下降。 diff --git a/checkpoint/images/prediction.png b/checkpoint/images/prediction.png new file mode 100644 index 0000000000000000000000000000000000000000..5cd8cfc79833ae4d2aacf1cd2de10618c686fec7 GIT binary patch literal 7521 zcmb7}2Ut_tx`q>~Qb#EoKpYvWN)19F6b~XGC?rV#^ecP86=HPh`XTlw=8#7cd;M^rnL`Kwz<>1H7dS$S+huQr zgkH1{y%qoqMf@7<3Gn|lG|(>~)X(SE&*7fIAwB_cu!_2h#@U~}Lqh{~RaL+5R|yFA zQk8q~n0-~L?n}o1xE@hVBSmLlpDkdlFfM!u#mO0+@&4s!7eJqeHR#S~@1vhQN0X(4 zi#_GjsYZEIr-skzx_7+ki`88DoR7n66eNlKwgP`|ee08~x#khitGP(o6uyc>iow>p zK(5N7`CaQGXH4PKIK)I%-P0YGL!F%=W`jl=vodU%i9Hj7FgC7y1{{MYB$+i_dn|h7 z@YVk&Jn1v?>CpqjQ$J-Cte9uwwwgh%N&!oYcR|#FKroLjM)^?{QN{UOFh-u=^AXev z16I;30Uax7eu?Wqk+$ahTU1g)5by zcn$X}ZMkP0`e~!M+9qze6Mj8WF8Y>xjiA-&f^y72H?ZG-3`X2MHNS#@u*u#gGdv z-KA=JH(I#*BuAG>Ou#6eMyYt!Ez||rRbd+5buILA`o!Yi9rOl4D?D%V8yHMgJ5YRI zp%V_S9f@E1fSnM{Qdr-tAlky`ca%IczH(3mD2i?}%-J_DFCfJ!fxi!=DtVmVoKLu$ zFr9D@rdgnVDB}1_aN&Fn)^*}2wM&mQjRVV<`q#IX$9;OjN$Q4 z%MEhP>l@@T;E+0Y1!1DefPkk4j^}L*TzXw{0_%Wvvc@E~ui9!x$v*=e7b2dpAH0Wk zq!fFDW-g@~fbzMD`C`UW`-Zhj{7j;C&^5x=@VqBYv*AkjzS<(j+g4QqMFQh)5nU+0 z^yP3bg+NeX5!8aT5b9;{Ofp}3Pm{wi&?^JVpD@o~A0nY9@v^7n1bsPpNI(2T8h=6? zy@N6cr&wT32acEJ&%qE=BnN)I!&1VOK}W=W(h)e;dwTS6LO!l-|c3y2+M-5V=PGT^ z)f6wyRW! zwa#6Dk*qO=N2&7#>P^^XN#tETMD1DOEY`JP_*Tzj?2X9W5D#m~hlwE0a{t@E;Wxro z$s6bTX-UD*ifp4uVdfc1%a{2cBawDgD!#ec3E6CcMd|p%NHMI01|xbRo*cmGnx4Dz7AsG~Oh!g)XT|7ErMo^P@K1pn z(1usxUP@>nu2`N?@I_|8f)rWa3ub|b6~_1|uj$LR;n9`R{j|01Yj9Q~KdkQbR9*f} zT6j{nQ+z+Iu=}31aphFXV$_4AH3d1Ebs?buwvkfMpoLn(jVWf{z6Q6j8II0|jRQLP ziY?~^NUPW_`o7Ipp~A>c_VmaNC54b`U>#}i8~}{T!0Q)tXZbul7YV*#ZeNv+(`tnW zu`xT@CdQ4|sv2eJL#YdQPxYuxM=R2nmW`=A-1D=Tw^@xeDOrZC1ZM zz$LU=`;O{6qP@yUjUrl(y_%B5_2lY*4PRb$DhV7f-XmAE*h!ax;Q=X<8YU`3am{3) zu%TURc(NJey>wNQSpwg;&`wsGe74Mbs=7&3)}hJ^OAy1p?n$joRt4#;eEiyCA=`j7 zP0yfSDf7N^Zs?tDLXCL*^(9Bvu?Oegg+7`rM?bY(q z-bzH(hT=tIj_cpT=(>1jhs);e)(88UBmz7Fcj8+{&5oVUmXQJeThLun9uby38JOAe zP%jeZ#K%u(XZxf?>m7dXJ%&$u&Iir!g|~{?RvtKuis=>*5^F}Z-dka>Xu>L+YyKh zacdYNy?@fe+}FNDal>Qq6uk@h^S1ue_rJl$VK5xtk02d>tW7u5rGl8dcYU`WjTs7J z)zS@GGDTP=Bz0qoTjX?AEc3;@&-+> z&LLH`#Ze9P#Y#IRw45^O4@H@yzZ)sBLv$y<_i-JKbnRk9sFbkeGql9Ok<8O0-;>l% z=3d@%+lZT-Au%i7?YzQtm7xpZ_%OJ|C5CV6}!#Hqi=v{!M!PcM+;nJ7fo z{6*|x?Wb5he^@IXpG*4{;nN^i`_pV|s6yCdP1IxTUF;3|{``y8w!g=_58u$wBe?A8 zqmLF{0>HkqS!M{#&V^x|hRu|Jo5Jw4cI#2jTBc{K18lNanH-5Jgq2xcY}ICIo<57u24xp>;7 zO}Uocyz}@ojtVQ~fRHHD*T4Njjj&rPJ>6>jv{rSl_4<7}{>WeY?n|B(#5FS3LaIA% zoaX67-%elVUg8XOC>d2;VnX7=hSDJhw<$oK59jRn_nxHBE$?#P7kohN+Lev}UMOq0 zMnqNdYKbp3f=HYV>Of1R9Z_<_ju`f=-0*Ui>NSj5PA^`N*AnW39)OP#kM@2@+;f+* zdKu%>R%jpcLwPhJlB;7@SA&S#l&27S45Rb5rm5^ZsP1!4#*5xxYmT*V7)(65@0Uj` zg5t&-9jo=tk7NSGzJ;*PYi2HHrFtkq=l1C$ZI)w^aV5+9U(4xtSLq6=9n|vt>(61X zhb+rtuou6kOTPC!FSCdIUeDW}6FUisWe=!9!Yn@*eTvcc(f9x_srbeDgNIO+x6azU zs_?VXMYl7*H@V}?PA*Rzzvjgcs)4Mg0tfy;Lp5o8&X^t&+P1M)X10nrxVL@3Ygez& zrixHO6fFKk?*;a9dKD5vDN$~j7%PDI#n&?H#_ZH_@mF~*V+d{P+d*#fYKBp<2F!(& z7gsBdzRhP(-Px}*Hy93i1xuxOi9+6xv;oH5rZ*Z3WO5!8b2V+m%n^)G78R+E)dWMe zXngqk0eo%&;ko`7=*_BD_KcAxU9uj|3a2PhF1u9}v~U=AYm4Q3rDLqZShZ=IIg>Xt z`>(<+5O?m4;7tk%Ohp3#L@)wme+ZuWR>;^mbuz6+T zPDVznkNBn|Q<-vb{msm_CiC3e&zJP~7169{!zpZ;%BR?R0+HunysO)Js?26XQ57%< zNav8}aJ5W@ih2Wuc&G6BMlijy1&jsVOAc7VLnR=_G95Irc0tHYFGJw9MK_FmF6R712gGqBIyi714>PVN4&r1|1 zgg@fUZPrx;ImKj%OC0bTiQ&YxJ;a}ooj1mrg{N8a$Y-`CY}dgqr?)P6fTwV6`mt>7 z>#O^$C$+a6&bkJRd4he>EHOk|z+zx}zCH5bIPvWnm8z@c#X(&JnH$bAXq7t6A2U!w zqAzo)aKb~pYM~``yu#}H(S>2FUznVTjh84uL>E@K23UdgB@#Xks<4Re+N76&{G~yD z`Z7-!SIlVYO0o|;yoOZLg3a=fo=buHl0ed(WT)JhtIwd(6|86WGl5<;&@%~Q9{bjT zUO`>)y1?w|OYR9n90w9DjCt%Y!-?mcpO#F!zazemms1mpIPnsV5zG0Dd^(bM{KCEL zZfyoq7QbZG=|A0YuZT3^3}~M(=28kxd~Ht=MChh?GroKdCJ}@>F3e>`m!U;gqOtE~ z#;>U?{gG8^V=;YQHGA3tihN7gn;YE0Dtl}cD;yghjAdUGiaZFN-}(};B}j=c2y%N2 zGFQrW7g>QMA?lEOgS(BK@Dzi#jXiJAf-FKD_&5^gH1i< zu}=x$nX7wYM5AwS3%Lc}+I*jl%PhSH_u4afyUkxaMx?i2*|~)SD_?bd+(ArLhrIx3 zKitx&HiGlQih{~Pu%4g52A(}?Zin9kcRN9x1CYGS+OlQy?__5NhBXQYAW9PGBga2n z$)iY9Zh_OaD-LgpKO5~_vv6gj&wYG*Mff?no*}Ti3zDGC!5hZ%5qA*U%>gOHr1);; zck^c)=1+!*JaoXm)mq#+4%btt ziJQ!qFgDalYt)RoK8S;#)qD)ay_3SViB0m=^QffsbP5+s3I=fR^F%zs2J0e2-W3nJ z^4DinC*o)MRXGBRn&pN%9=z)~CH&~pyNY*G(T7unA_5q535<^(;q$Jk_Yp>e_ujZ8 zz4-h-;5^bZegr9xE-|Lv^1;xKKLt0e(RF;XGWHlWfB}1yVLH9SuMzzZ=wWB0VT2+Y zg^T0fgi?BkweVc~h!_4>6WeHmO?@$d{bKxuf()Mh)@e_eWgaWlQ{{ZuXo*6UUWlOx zjk_LTusm3L8*SQNL$qu%oDDDQc59s~w<6mJ#XK0T+-J|t@%$A@W-Vej6iFb)pa?9y zx}ftGtZv}4>rvG(t2X5=(LD!e_eCG8e3V~2ic;(~bHpOj@c*~F$S z_yBmDewdIlz_fX}FG&n@$G6{Lf3C4;g&Rk~wG5rmSPm zRzIagr_P>-(RIfFbc0?}P5aJh&+@=&nU>4n_A6x%af;xV4}m3(=PVQM&HCh;-C&Dn z(_!ZT1VasVCc%|M5PK65udY^PHo(S6J2_GTmEQd`+Yc}+D<|g4k_vwL7e)JB+s?C@ z$GQYu{br>IT)1ipc0*`Mvs7^*J+xRxXUdB@=^rGnGvjC{0>1HB10#e#Xaj#5TSjzD zgK?{tzl@0Ze1%Ut5Ev6tTc6t^YaC`WOXiCGC+1RpJL7LiA>ibj*P%On-xb)O)c57M zC!dzujuuUSWwJQv7`|A~=HC4VH`b^NrsgP0Guk{2F9RGD!*Y%E1f2pK@q zi+DP0)5h1Frq~VsG7Bf1D(}l-ru~OhlBp=3fGd~BCLbs-mi?&k&}M1PwS*!G10B_P z1g7r~5{+64HB2{+H9S~V#DJOi#@PC?H-x49O5Yq@d1@d9Nl((2L+p>hbNH+l5iwZs%EsA6-oO^B))1-J`7!}i@Cz5~U?VsYlf3oqZ zWZ<%TPzTEu{dQ#%sarp7H+mb65b6uBYx3DDdy^lC*vWr;tY_tAyibyQ^c+GTZW_;k zD~Rp5#wTF{07-Efr(Ls*pPIKR@6>MgNr8PirvQb+w73!W$lJ-d#XoBT_6=hR!p8ah zPHR5K3<2i=F<8frckAoz>34;asJ$6&s_NDAUcyz4qhauAi`UOrz(cmpae!Xed?!w(N8N8kDyiudxWj(5Ind14zJjSb2Qx$n9+ zU3%gOG94ldwhF!BB(F>u9_Bvr!H?zxz5H|HbE_O$cB~ z-E*slrJ-RPHtp!VIZN$wWXJwaw}EIcmHwV!s^o{u{;2c1scbbszp#6+Q9X$GmH7x> zeYV>J5$!ph(O&R8+BEex4rDOby8JQ8@ymVR@50=f!NDVdb9sqzpW1}LwMf7-drp4? z%k#6V>sCX3pwf;7{vbX0MF(5!qd(OAlJoiA>%&;iBx9fyfGK1i@9d3~{w!KF}cKm2!|?>?9jo ze#7O>6?J_2p>Y2qSIx@Y!cq_pwmqrG6B zCUg8SYL=z-V-J;36i%G%E!zLp?rCKDCc!f(2G36ms;{v%cRbP5&<8SDoIy=oy6oK_ z?v#b7lO5uVN3Y+5bZznzCd#(oj6$S0Mt+CD_u%bK@{xgqNZ*KuC~!l=VI zqWn9-JLmgalnDpT=dzx=BGsFl^@dvNvo&Xy^vCg~ny4JuMtOlVsFVENoYA(cjQOWI zgWwHp!C3KBbLJHhc#%I0WDz#WA6<^Z?C`i<5DfsMa@dmbvYUw}A%+bu?I|yw=lPXH zoR@0K*HRtRGo8tz$&dYketRjZrEt3MjJ$lz!5$=T9wpve!56bLe}Y;%C?xyo0P)E> zNDnay^7Qs(RSOP=4>7tnvf%DS9p@zVOzE)_`SGZ2OLs zn-G?_{_X0OUu=}ml}%$%+?yzzYBIH~^kFAUci=6h4?H&U;-2a))iGD0!P2)b1pcME z7i>#MaP#o95mI;F@s^+R#onS4O#5f_GW)IB^yIUlXvLwzwa;e2iU~LX&LgdcvPfvb zM&|gGo4sYvemU&s$2e`>)@o<{)6E#QiA1{J!8NIq+Y5d%uPTo927N7E?v~m}gp`ht z+J79+vHD65eAnp~`~Lo7KqCCtzi0CIon)D?C$r0T)J>m4Y6M-lDT!JZzqr#mRc#Z) zC_;tY7w%Z8mtK{V+@#(Pk$#7Pj*~Z7v4HwWlaA6nvz0<3Noq&};v8)N{mi z>+*Sx+(-XL>Hnks^UocKIoZ^RFfXZ-Y(tW5D(;SBf`|?q;>l*p|1z!)Z{WW|N;faJ zM~;%j7k2J^^MU&Sm!kSsVw6k#X?k%oN0S94BqJ|M<&JX4equT|*x zza(d`sP7)FIeB~!q9&4mveV$?ld}222fq_@*KA7g>L?ztlaL>WG=jx0qrWV|-~rKE z0d_5CKE2u=U4uC(oC>TiAT;~fxxxO^Lu#gp)*fOXGhBhYOYA%m!3PnU^oBJ6h8b75 z=h(928Y%YnFIVKW>?$?brf&&5zo!2yn4j?czrcKs9PN4z4l6=xNBQ3yDS1tb^=OeC zwo^p=nq2(aZpY6yf%Qk&1@K?&*MDTGPT)Jp!+-T%>=RW`+ C{msY# literal 0 HcmV?d00001 diff --git a/checkpoint/main.py b/checkpoint/main.py new file mode 100644 index 0000000..486bb69 --- /dev/null +++ b/checkpoint/main.py @@ -0,0 +1,159 @@ +# Save and load model + +import os +# os.environ['DEVICE_ID'] = '0' +# Log level includes 3(ERROR), 2(WARNING), 1(INFO), 0(DEBUG). +os.environ['GLOG_v'] = '2' + +import matplotlib.pyplot as plt +import numpy as np + +import mindspore as ms +import mindspore.context as context +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset.transforms.vision.c_transforms as CV + +from mindspore import nn, Tensor +from mindspore.train import Model +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +import logging; logging.getLogger('matplotlib.font_manager').disabled = True + +context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') # Ascend, CPU, GPU + + +def create_dataset(data_dir, training=True, batch_size=32, resize=(32, 32), repeat=1, + rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64): + data_train = os.path.join(data_dir, 'train') # 训练集信息 + data_test = os.path.join(data_dir, 'test') # 测试集信息 + ds = ms.dataset.MnistDataset(data_train if training else data_test) + + ds = ds.map(input_columns=["image"], operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()]) + ds = ds.map(input_columns=["label"], operations=C.TypeCast(ms.int32)) + ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True).repeat(repeat) + + return ds + + +class LeNet5(nn.Cell): + def __init__(self): + super(LeNet5, self).__init__() + self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid') + self.relu = nn.ReLU() + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + self.fc1 = nn.Dense(400, 120) + self.fc2 = nn.Dense(120, 84) + self.fc3 = nn.Dense(84, 10) + + def construct(self, x): + x = self.relu(self.conv1(x)) + x = self.pool(x) + x = self.relu(self.conv2(x)) + x = self.pool(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + + return x + + +def train(data_dir, lr=0.01, momentum=0.9, num_epochs=2, ckpt_name="lenet"): + dataset_sink = context.get_context('device_target') == 'Ascend' + repeat = num_epochs if dataset_sink else 1 + ds_train = create_dataset(data_dir, repeat=repeat) + ds_eval = create_dataset(data_dir, training=False) + steps_per_epoch = ds_train.get_dataset_size() + + net = LeNet5() + loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') + opt = nn.Momentum(net.trainable_params(), lr, momentum) + + ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5) + ckpt_cb = ModelCheckpoint(prefix=ckpt_name, directory='ckpt', config=ckpt_cfg) + loss_cb = LossMonitor(steps_per_epoch) + + model = Model(net, loss, opt, metrics={'acc', 'loss'}) + model.train(num_epochs, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=dataset_sink) + metrics = model.eval(ds_eval, dataset_sink_mode=dataset_sink) + print('Metrics:', metrics) + + +CKPT_1 = 'ckpt/lenet-2_1875.ckpt' + +def resume_train(data_dir, lr=0.001, momentum=0.9, num_epochs=2, ckpt_name="lenet"): + dataset_sink = context.get_context('device_target') == 'Ascend' + repeat = num_epochs if dataset_sink else 1 + ds_train = create_dataset(data_dir, repeat=repeat) + ds_eval = create_dataset(data_dir, training=False) + steps_per_epoch = ds_train.get_dataset_size() + + net = LeNet5() + loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') + opt = nn.Momentum(net.trainable_params(), lr, momentum) + + param_dict = load_checkpoint(CKPT_1) + load_param_into_net(net, param_dict) + load_param_into_net(opt, param_dict) + + ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5) + ckpt_cb = ModelCheckpoint(prefix=ckpt_name, directory='ckpt', config=ckpt_cfg) + loss_cb = LossMonitor(steps_per_epoch) + + model = Model(net, loss, opt, metrics={'acc', 'loss'}) + model.train(num_epochs, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=dataset_sink) + + metrics = model.eval(ds_eval, dataset_sink_mode=dataset_sink) + print('Metrics:', metrics) + + +CKPT_2 = 'ckpt/lenet_1-2_1875.ckpt' + +def infer(data_dir): + ds = create_dataset(data_dir, training=False).create_dict_iterator() + data = ds.get_next() + images = data['image'] + labels = data['label'] + net = LeNet5() + load_checkpoint(CKPT_2, net=net) + model = Model(net) + output = model.predict(Tensor(data['image'])) + preds = np.argmax(output.asnumpy(), axis=1) + + for i in range(1, 5): + plt.subplot(2, 2, i) + plt.imshow(np.squeeze(images[i])) + color = 'blue' if preds[i] == labels[i] else 'red' + plt.title("prediction: {}, truth: {}".format(preds[i], labels[i]), color=color) + plt.xticks([]) + plt.show() + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--data_url', required=False, default='MNIST', help='Location of data.') + parser.add_argument('--train_url', required=False, default=None, help='Location of training outputs.') + args, unknown = parser.parse_known_args() + + if args.data_url.startswith('s3'): + import moxing + moxing.file.copy_parallel(src_url=args.data_url, dst_url='MNIST') + args.data_url = 'MNIST' + + # 请先删除旧的checkpoint目录`ckpt` + train(args.data_url) + print('Checkpoints after first training:') + print('\n'.join(sorted([x for x in os.listdir('ckpt') if x.startswith('lenet')]))) + + resume_train(args.data_url) + print('Checkpoints after resuming training:') + print('\n'.join(sorted([x for x in os.listdir('ckpt') if x.startswith('lenet')]))) + + infer(args.data_url) + if args.data_url.startswith('s3'): + import moxing + # 将ckpt目录拷贝至OBS后,可在OBS的`args.train_url`目录下看到ckpt目录 + moxing.file.copy_parallel(src_url='ckpt', dst_url=os.path.join(args.data_url, 'ckpt')) diff --git a/experiment_1/1-LeNet5_MNIST.ipynb b/experiment_1/1-LeNet5_MNIST.ipynb deleted file mode 100644 index 1b385da..0000000 --- a/experiment_1/1-LeNet5_MNIST.ipynb +++ /dev/null @@ -1,357 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "

基于LeNet5的手写数字识别

\n", - "\n", - "## 实验介绍\n", - "\n", - "LeNet5 + MINST被誉为深度学习领域的“Hello world”。本实验主要介绍使用MindSpore在MNIST数据集上开发和训练一个LeNet5模型,并验证模型精度。\n", - "\n", - "## 实验目的\n", - "\n", - "- 了解如何使用MindSpore进行简单卷积神经网络的开发。\n", - "- 了解如何使用MindSpore进行简单图片分类任务的训练。\n", - "- 了解如何使用MindSpore进行简单图片分类任务的验证。\n", - "\n", - "## 预备知识\n", - "\n", - "- 熟练使用Python,了解Shell及Linux操作系统基本知识。\n", - "- 具备一定的深度学习理论知识,如卷积神经网络、损失函数、优化器,训练策略等。\n", - "- 了解华为云的基本使用方法,包括[OBS(对象存储)](https://www.huaweicloud.com/product/obs.html)、[ModelArts(AI开发平台)](https://www.huaweicloud.com/product/modelarts.html)、[Notebook(开发工具)](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0033.html)、[训练作业](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0046.html)等服务。华为云官网:https://www.huaweicloud.com\n", - "- 了解并熟悉MindSpore AI计算框架,MindSpore官网:https://www.mindspore.cn\n", - "\n", - "## 实验环境\n", - "\n", - "- MindSpore 0.2.0(MindSpore版本会定期更新,本指导也会定期刷新,与版本配套);\n", - "- 华为云ModelArts:ModelArts是华为云提供的面向开发者的一站式AI开发平台,集成了昇腾AI处理器资源池,用户可以在该平台下体验MindSpore。ModelArts官网:https://www.huaweicloud.com/product/modelarts.html\n", - "\n", - "## 实验准备\n", - "\n", - "### 创建OBS桶\n", - "\n", - "本实验需要使用华为云OBS存储实验脚本和数据集,可以参考[快速通过OBS控制台上传下载文件](https://support.huaweicloud.com/qs-obs/obs_qs_0001.html)了解使用OBS创建桶、上传文件、下载文件的使用方法。\n", - "\n", - "> **提示:** 华为云新用户使用OBS时通常需要创建和配置“访问密钥”,可以在使用OBS时根据提示完成创建和配置。也可以参考[获取访问密钥并完成ModelArts全局配置](https://support.huaweicloud.com/prepare-modelarts/modelarts_08_0002.html)获取并配置访问密钥。\n", - "\n", - "创建OBS桶的参考配置如下:\n", - "\n", - "- 区域:华北-北京四\n", - "- 数据冗余存储策略:单AZ存储\n", - "- 桶名称:如ms-course\n", - "- 存储类别:标准存储\n", - "- 桶策略:公共读\n", - "- 归档数据直读:关闭\n", - "- 企业项目、标签等配置:免\n", - "\n", - "### 数据集准备\n", - "\n", - "MNIST是一个手写数字数据集,训练集包含60000张手写数字,测试集包含10000张手写数字,共10类。MNIST数据集的官网:[THE MNIST DATABASE](http://yann.lecun.com/exdb/mnist/)。\n", - "\n", - "从MNIST官网下载如下4个文件到本地并解压:\n", - "\n", - "```\n", - "train-images-idx3-ubyte.gz: training set images (9912422 bytes)\n", - "train-labels-idx1-ubyte.gz: training set labels (28881 bytes)\n", - "t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)\n", - "t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)\n", - "```\n", - "\n", - "### 脚本准备\n", - "\n", - "从[课程gitee仓库](https://gitee.com/mindspore/course)上下载本实验相关脚本。\n", - "\n", - "### 上传文件\n", - "\n", - "将脚本和数据集上传到OBS桶中,组织为如下形式:\n", - "\n", - "```\n", - "experiment_1\n", - "├── MNIST\n", - "│   ├── test\n", - "│   │   ├── t10k-images-idx3-ubyte\n", - "│   │   └── t10k-labels-idx1-ubyte\n", - "│   └── train\n", - "│   ├── train-images-idx3-ubyte\n", - "│   └── train-labels-idx1-ubyte\n", - "├── *.ipynb\n", - "└── main.py\n", - "```\n", - "\n", - "## 实验步骤(方案一)\n", - "\n", - "### 创建Notebook\n", - "\n", - "可以参考[创建并打开Notebook](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0034.html)来创建并打开本实验的Notebook脚本。\n", - "\n", - "创建Notebook的参考配置:\n", - "\n", - "- 计费模式:按需计费\n", - "- 名称:experiment_1\n", - "- 工作环境:Python3\n", - "- 资源池:公共资源\n", - "- 类型:Ascend\n", - "- 规格:单卡1*Ascend 910\n", - "- 存储位置:对象存储服务(OBS)->选择上述新建的OBS桶中的experiment_1文件夹\n", - "- 自动停止等配置:默认\n", - "\n", - "> **注意:**\n", - "> - 打开Notebook前,在Jupyter Notebook文件列表页面,勾选目录里的所有文件/文件夹(实验脚本和数据集),并点击列表上方的“Sync OBS”按钮,使OBS桶中的所有文件同时同步到Notebook工作环境中,这样Notebook中的代码才能访问数据集。参考[使用Sync OBS功能](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0038.html)。\n", - "> - 打开Notebook后,选择MindSpore环境作为Kernel。\n", - "\n", - "> **提示:** 上述数据集和脚本的准备工作也可以在Notebook环境中完成,在Jupyter Notebook文件列表页面,点击右上角的\"New\"->\"Terminal\",进入Notebook环境所在终端,进入`work`目录,可以使用常用的linux shell命令,如`wget, gzip, tar, mkdir, mv`等,完成数据集和脚本的下载和准备。\n", - "\n", - "> **提示:** 请从上至下阅读提示并执行代码框进行体验。代码框执行过程中左侧呈现[\\*],代码框执行完毕后左侧呈现如[1],[2]等。请等上一个代码框执行完毕后再执行下一个代码框。\n", - "\n", - "导入MindSpore模块和辅助模块:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "# os.environ['DEVICE_ID'] = '0'\n", - "import matplotlib.pyplot as plt\n", - "import mindspore as ms\n", - "import mindspore.context as context\n", - "import mindspore.dataset.transforms.c_transforms as C\n", - "import mindspore.dataset.transforms.vision.c_transforms as CV\n", - "\n", - "from mindspore import nn\n", - "from mindspore.model_zoo.lenet import LeNet5\n", - "from mindspore.train import Model\n", - "from mindspore.train.callback import LossMonitor\n", - "\n", - "context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 数据处理\n", - "\n", - "在使用数据集训练网络前,首先需要对数据进行预处理,如下:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "DATA_DIR_TRAIN = \"MNIST/train\" # 训练集信息\n", - "DATA_DIR_TEST = \"MNIST/test\" # 测试集信息\n", - "\n", - "def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32),\n", - " rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64):\n", - " ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST)\n", - " \n", - " ds = ds.map(input_columns=\"image\", operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()])\n", - " ds = ds.map(input_columns=\"label\", operations=C.TypeCast(ms.int32))\n", - " ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True).repeat(num_epoch)\n", - " \n", - " return ds" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "对其中几张图片进行可视化,可以看到图片中的手写数字,图片的大小为32x32。" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAATsAAAD7CAYAAAAVQzPHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAcm0lEQVR4nO3deZRV1Zk28OepQWaBYrIQAkZBIKyICjjE1U3aEDHdaU268QuiTRxCVqKt+aJGErOiMdqxTaL9pfOZDh0ZooKxo+0QtQnNEhLRBis4oSggDhArTIIWU0FVvf3HPexzCupW3enc4ezntxar3nvGXfCy795n2JtmBhGRpKsqdQFERIpBlZ2IeEGVnYh4QZWdiHhBlZ2IeEGVnYh4QZVdBkguIHlbqcshUmg+5XZFVnYk3yG5lWSvyLIrSS4vYbEKiuRnSK4huZfkZpIXlbpMEr+k5zbJO4N8/ojkuyRvKta5K7KyC9QAuLbUhcgWyeoMthkHYBGAmwD0BTABwB9jLpqUj8TmNoB7AYwxs2MBnA3gYpJfjLdkKZVc2f0IwPUk+x25guRIkkayJrJsOckrg/jLJFeSvJvkbpKbSJ4dLN9MchvJWUccdiDJpSSbSK4gOSJy7DHBug9IvhlthQXdhJ+TfIrkXgCfzuB3+y6AX5jZ02bWYmY7zeytLP9+pHIlNrfN7E0z2xtZ1AbgpIz/ZvJQyZVdA4DlAK7Pcf8zALwCYABSragHAUxC6i/+EgA/I9k7sv1MAD8AMBDASwAeAICgu7E0OMZgADMA3EPyE5F9LwZwO4A+AJ4leTHJVzop25nBsV8l2UjyfpJ1Of6eUnmSnNsgOYfkHgBbAPQKjh+7Sq7sAOB7AP6R5KAc9n3bzOabWSuAXwMYDuBWM2s2s98BOIj23zhPmtnvzawZqe7lWSSHA/gbAO8Ex2oxszUAHgbw95F9HzOzlWbWZmYHzGyRmX2yk7INA3ApgL8DMApADwD/msPvKJUrqbkNM7sDqcrxNAD3Afgwh98xaxVd2ZnZWgC/BTAnh923RuL9wfGOXBb99tscOe8eAB8AGApgBIAzgi7DbpK7kfqmPK6jfTO0H8B8M1sfnOufAHwuy2NIBUtwbh8+j5nZi0FZvp/LMbJV0/UmZe9mAGsA/CSy7PA1gZ4APgri6D9QLoYfDoIuQB2A95H6x15hZlM72TfboWVeyWEfSZ4k5vaRagCcmOcxMlLRLTsAMLONSDXVr4ks2w7gTwAuIVlN8nLk/xf6OZLnkDwGqesbq8xsM1LfvqNJXkqyNvgzieTYPM41H8BlJD9OsieAG4PziEeSltskq0h+lWR/pkwGcBWAZXmWPyMVX9kFbkXqQmfUVwDcAGAngE8AeC7PcyxC6pv2AwCnI9Wch5k1AfgsgC8h9W34ZwD/DKBbugORnEnytXTrzWwegF8BWAXgXQDNiCS8eCVRuQ3gCwDeAtAE4H6krkUX5Xo0NXiniPggKS07EZFOqbITES+oshMRL+RV2ZGcFrxCspFkLs8DiZQl5Xby5HyDgqmXftcDmIrUax8vAJhhZq8XrngixafcTqZ8HiqeDGCjmW0CAJIPArgAQNqEOIbdrPtRd9GlFJqwa4eZ5fIqkg+U2xXqAPbioDWzo3X5VHbHo/2rIluQegG5HZKzAcwGgO7oiTN4bh6nlEL5b/vNu6UuQxlTbleoVZb++eR8rtl1VHse1Sc2s7lmNtHMJtamfxZRpJwotxMon8puCyLv1CE1Usf7+RVHpCwotxMon8ruBQCjSJ4QvFP3JQCPF6ZYIiWl3E6gnK/ZmVkLyasBLAFQDWCemXX2TpxIRVBuJ1NeQzyZ2VMAnipQWUTKhnI7efQGhYh4QZWdiHghCSMVl43G6852cdO4g1ntywPhLHQnzwmfXW1rasq/YCKilp2I+EGVnYh4Qd3YAqo7L3zu9JXxj2a177qD+1x83a0XhivUjfVaVc+eLn7v2gkubiuTFzYGvdTi4h6Pri5hSbqmlp2IeEGVnYh4Qd3YPO37YjgYxuQBL5SwJJIU1YPCkbca/88oFy/52p0uHlbTG+Vg6rrPu/jg/okuPmZJQymK0ym17ETEC6rsRMQLquxExAu6ZpdGVZ8+Lm457aS02116+xMunt03uyHPtrXudfFdWz8brmhp6WBr8cWhccNc/OJ37omsKY/rdFFLx4b5f8nNU1y89cCpLmZLOO5p1aq1LrYi57ladiLiBVV2IuIFdWMj2C18LH3flLEuXvGLubGcb+6u01383hl7I2v2Hr2xSJm7f+Ty8MPiMF5/KMznb5x/WbjNzt0ubNv9YbtjWXNzoYunlp2I+EGVnYh4Qd3YiN3TwztIi2//cWRN+d0FE6kUJ9b0cPG/PD3fxa0Wzlh52Xe/2W6fvvf/T8HLoZadiHhBlZ2IeMH7buy2q8Oh1L99zQMuPqG2cF3XU1bPcPHgu7u7uHrvochWayECADVrNrr4L78628ULfnaXizPJz3R5l6np/7bExdk+MB9VzbBNNbq2V4fbtNayw+WF1GXLjuQ8kttIro0sqyO5lOSG4Gf/eIspUnjKbb9k0o1dAGDaEcvmAFhmZqMALAs+i1SaBVBue6PLbqyZ/Z7kyCMWXwBgShAvBLAcwI0FLFfRHBgYxhf1/jD9hnn4aGfYdD9uRTjOl3W0sRRNueZ2dEa5Hr972cUzbrrexZl0+4as3+9irnyxw22q+/V18Z8WDG23bkrPDZFPHXc/03loT3jcH/50pouf+FbpxuTL9QbFEDNrBIDg5+B0G5KcTbKBZMMhFP6paJECU24nVOx3Y81srplNNLOJtSiTWUJECkC5XVlyvRu7lWS9mTWSrAewrZCFitsHl53l4jOnvRrLOaasDWcIG/6EnvCpIGWV29F3RPN50NY+Fc5MtuHy8L99VbdWFzec/v/b7dO/Oruua9T6A/UuPm7eSy4+r8+3XBydIW3kml3t9m/L+czp5fq/8HEAs4J4FoDHClMckZJTbidUJo+eLAbwPICTSW4heQWAOwBMJbkBwNTgs0hFUW77JZO7sTPSrDq3wGWJ1e5/CLuu478aPsA7/2N/iP3c2yeEf83HHhuWo9+vno/93JJeUnI7ndYpp7l4y9fDB9jfPmdBmj16plmemQeaBrj4vsc/7eKR+8I8H/bD5zrcN45u65F0MUlEvKDKTkS84M27sT0uaXRxMbquy8c/Gn4YH4bXNYZdi9V7wgm2ez6yKvYySTIdPC+cnHrP0FoXN52/x8VvnHNf7OVYs2eEiwevKUbHNDtq2YmIF1TZiYgXvOnGlouf1K9x8W237HPxHx7JfggeEQCw63a4+IXo5ZMii+b23Nv/7OLfROZE1ryxIiIxU2UnIl5IdDe2ekCdi7vXHOpky9LoVhWWqXrQcBe37gi7JTANBCWVJzqy8ZRF4Tu30Xlj29ZvcnExurRq2YmIF1TZiYgXEt2NPf6pcHicu49/KrKmPO58XtP/DRePem6ri+eeHb4/27p9e1HLJFJo6eaNvfriq1zMlS8hbmrZiYgXVNmJiBdU2YmIFxJ9zW5Ej50u7l3V9XW6S96Z4uLXfzXWxWu+9/OCluuwbgxf2j6/Zzgs9YZnwlvyy74cXr+zBk2kLUfr8c0wt0+5LRyi7+XJi0tRnKOkmyTbasIZ0uKfIlstOxHxhCo7EfFCIrqx6Sb6vajvLyNbdT1T0pY9/Vx83MMbXTyp+Wtp97nh24vC8+UxyXa0S3tD3Vsu/l2vv3CxvpmkI21rw0eYhvwonEVs0uj0eZutv7g6HG8x+sJ/JdH/HxHxgio7EfFCIrqx6BbOtvvghHtdHL3zk63omwt189O/xfDDXjNdfPPAcHl08u1iDAMvArR/E6FuZeGO+8cZI8MPSe3GkhxO8hmS60i+RvLaYHkdyaUkNwQ/+8dfXJHCUW77JZNubAuA68xsLIAzAVxFchyAOQCWmdkoAMuCzyKVRLntkUwmyW4E0BjETSTXATgewAUApgSbLQSwHMCNsZSySP56aNj1nHfLeSUsiRRDueR21Slj231+5wvl15C8fOiSUhchb1ndoCA5EsCpAFYBGBIky+GkGZxmn9kkG0g2HEJzR5uIlJxyO/kyruxI9gbwMIBvmNlHme5nZnPNbKKZTaxFt653ECky5bYfMrobS7IWqWR4wMweCRZvJVlvZo0k6wFsi6uQxRJ9mPeG2feUsCRSLOWQ2ztO69fu8zrlXiwyuRtLAPcCWGdmd0VWPQ5gVhDPAvBY4YsnEh/ltl8yadl9CsClAF4lefghnu8AuAPAQySvAPAegOnxFFEkNsptj2RyN/ZZpB+B5dzCFkekeEqZ2zXDh7m4aWQxBjgSvS4mIl5QZSciXkjGu7Ft4UTSbx6KPhIV3kQbXhPW65mMWlwMzRZOkr3pUMeTeLNFk2Qn0aYrPubiN76SvLuv5ZjbatmJiBdU2YmIFxLRjW3dscPF0QmmURXe5WpbHI4E/F9jnixKubry011jXPzMuSd2uE3VznCSHXVopVKUY26rZSciXlBlJyJeSEQ3FhY2gqMjDEc1t4wsUmGOdsrqcC7PwXeHd4Kr94Z3qWyr5oT1ycfvfc/FY/j1dusq9e5sdN7lHdeED02XS26rZSciXlBlJyJeSEY3NgP8STgbzqShhZtPMxND1u8Py7HyRRfr7qq/WjZvcfGJC9v/NxyDsFtb7l3aqes+7+KWO4e4+JiGhlIUp1Nq2YmIF1TZiYgXVNmJiBe8uWZ3zJLwGkJdCcshcqSWTe+0+3ziL1tcPNa+jnI26KWwrD2WrC5hSbqmlp2IeEGVnYh4wZturEiliD6W8rFbtnSypWRDLTsR8YIqOxHxQibzxnYnuZrkyyRfI/n9YHkdyaUkNwQ/+8dfXJHCUW77JZOWXTOAvzKzUwBMADCN5JkA5gBYZmajACwLPotUEuW2R7qs7CxlT/CxNvhjAC4AsDBYvhDAhbGUUCQmym2/ZHTNjmR1MGP6NgBLzWwVgCFm1ggAwc/BnR1DpBwpt/2RUWVnZq1mNgHAMACTSY7P9AQkZ5NsINlwCM25llMkFsptf2R1N9bMdgNYDmAagK0k6wEg+LktzT5zzWyimU2sRbc8iysSD+V28mVyN3YQyX5B3APAZwC8AeBxALOCzWYBeCyuQorEQbntl0zeoKgHsJBkNVKV40Nm9luSzwN4iOQVAN4DMD3GcorEQbntkS4rOzN7BcCpHSzfCeDcOAolUgzKbb/QrHiDg5PcDuDdop1QOjPCzAaVuhBJodwuG2nzuqiVnYhIqejdWBHxgio7EfGCKrsMkFxA8rZSl0Ok0HzK7Yqs7Ei+Q3IryV6RZVeSXF7CYhUMyTtJbib5Ecl3Sd5U6jJJcSQ9twGA5GdIriG5N8jzi4px3oqs7AI1AK4tdSGyFTzT1ZV7AYwxs2MBnA3gYpJfjLdkUkYSm9skxwFYBOAmAH2RGm3mjzEXDUBlV3Y/AnD94Sfgo0iOJGkkayLLlpO8Moi/THIlybtJ7ia5ieTZwfLNJLeRnHXEYQcGY5s1kVxBckTk2GOCdR+QfDP6TRV0E35O8imSewF8uqtfzMzeNLO9kUVtAE7K+G9GKl1icxvAdwH8wsyeNrMWM9tpZm9l+feTk0qu7BqQepfx+hz3PwPAKwAGIPVN8yCASUhVKpcA+BnJ3pHtZwL4AYCBAF4C8AAABN2NpcExBgOYAeAekp+I7HsxgNsB9AHwLMmLSb7SWeFIziG5B8AWAL2C44sfkpzbZwbHfpVkI8n7SRZldtNKruwA4HsA/pFkLg/Hvm1m882sFcCvAQwHcKuZNZvZ7wAcRPvW1JNm9nsza0aqCX4WyeEA/gbAO8GxWsxsDYCHAfx9ZN/HzGylmbWZ2QEzW2Rmn+yscGZ2B1IJdBqA+wB8mMPvKJUrqbk9DMClAP4OwCgAPQD8aw6/Y9YqurIzs7UAfovcRpLdGon3B8c7cln0229z5Lx7AHwAYCiAEQDOCLoMu0nuRuqb8riO9s1GMLjki0FZvp/LMaQyJTi39wOYb2brg3P9E4DPZXmMnCRhKsWbAawB8JPIssPXu3oC+CiIo/9AuRh+OAi6AHUA3kfqH3uFmU3tZN98X1OpAXBinseQypPE3H4lh30KoqJbdgBgZhuRaqpfE1m2HcCfAFzC1Ei0lyP/yuJzJM8heQxS1zdWmdlmpL59R5O8lGRt8GcSybG5nIRkFcmvkuzPlMkArkJqLgTxSNJyOzAfwGUkP06yJ4Abg/PEruIru8CtSF3Ej/oKgBsA7ATwCQDP5XmORUh9034A4HSkmvMwsyYAnwXwJaS+Df8M4J+B9KM5kpxJ8rVOzvUFAG8BaAJwP1LXNIpyXUPKTqJy28zmAfgVgFVIDZzQjEhlHicNBCAiXkhKy05EpFOq7ETEC3lVdiSnBU9VbySpiYQlMZTbyZPzNTum3oNbD2AqUk/5vwBghpm9XrjiiRSfcjuZ8nnObjKAjWa2CQBIPojUTOppE+IYdrPuR91YklJowq4dGpY9LeV2hTqAvThozexoXT6V3fFo//T0FqTeyUurO3rhDGoek3Lw3/YbzZeQnnK7Qq2y9I+j5lPZdVR7HtUnJjkbwGwA6I6eeZxOpGiU2wmUzw2KLYi8ZoLUC77vH7mRZk2XCqTcTqB8KrsXAIwieULwmsmXkJpJXaTSKbcTKOdurJm1kLwawBIA1QDmmVlnr0CJVATldjLlNeqJmT0F4KkClUWkbCQttw+eN9HFdt2OjPbp8c3uLm5b+0bBy1RseoNCRLygyk5EvJCEwTtFpAt7hta6+IXxj2a0z9QBl7k4Ca2iJPwOIiJdUmUnIl5IdDd229Vnu/jAwPjPN/I/d7m47eV18Z9QRDKmlp2IeEGVnYh4IRHdWHYL30vcPf1UF3/7mgdcfFHv+OeYPmHwbBcPfOEsF/dfv9/FXPlS7OUQkaOpZSciXlBlJyJeUGUnIl5IxDW7qn59XTz/trtcPPaY4g6o+PaFc8MPF4bhKatnuHjoh2NcnISXq6V81Qwf5uKmkR2OVO4VtexExAuq7ETEC4noxpa7lycvdvGUu8L+bbfPlqI0kmRVffq4eP1V4cjyG/7hnlIUp6yoZSciXlBlJyJeUDdWJEHevGOci//nb38cWaMJvNWyExEvqLITES8kohvbtvMDF1878+sutpqOH6Tc9n8PuDh6p1Sk0ln3VhcPrlbXNarLlh3JeSS3kVwbWVZHcinJDcHP/vEWU6TwlNt+yaQbuwDAtCOWzQGwzMxGAVgWfBapNAug3PZGl91YM/s9yZFHLL4AwJQgXghgOYAbC1iurFhLi4uj48WlextwSMsEF08a/bUuj9/SKzzSE9+6s926YTW9MyyllJtKyO24bWnZ4+LP3/mtduvqX9/g4lZUvlxvUAwxs0YACH4OLlyRREpKuZ1Qsd+gIDkbwGwA6I7ijkIiEifldmXJtbLbSrLezBpJ1gPYlm5DM5sLYC4AHMs6y/F8BRXt6tat7Hr76iHhl3vT9XpaJ+EqOrez1dQW5nP9f2xst651+/ZiFydWuf7PfRzArCCeBeCxwhRHpOSU2wmVyaMniwE8D+BkkltIXgHgDgBTSW4AMDX4LFJRlNt+yeRu7Iw0q84tcFnKSnSU1+hQOQOqK7K3Ih3wNbd9pQtQIuIFVXYi4oVEvBubLftU+FDxrtE9OtwmOkFJ+1Fe9b6h+OfgeRNdvGdobVb7Vh8KL/30+48XXWzNzfkXLAtq2YmIF1TZiYgXvOzGbrg8/LXfPv/nRT33sN67Xbxj4ngXW8PajjYX6VLV+HAu4mMH7O1y+22t4TZ3bY3M+hR5xxwAGMnP428OHzi+f+TyrMr39qHw/dsvf/hNF/dcvs7FbU1NWR0zF2rZiYgXVNmJiBe87MaWUrQLcNu8sPvxh092L0FpJAn23xUZeXv8o11uP3fX6S5+78x9Lq4e2H4wg3MXPO/iG+reyrl8J9SGw6Ct+MVcF0+dcZmLq1a8iLipZSciXlBlJyJeUDdWxGPVAwe6ePZzz7dbd37PXZFP2T1IXI7UshMRL6iyExEvqLITES/omp2Iz6rCAS9Orm0/An03dj2vximrwyEBm18Op9h94yv3dLR5O9P/bYmL77vp8+3W9XxkVZf7Z0stOxHxgio7EfGCl93YUfPCF54nPdv1JNmdueHbi1x8Ue8P8zqWSDFc1PePLn528YkuHl6TWdtnzLOXuvhj/6/axbtGZzdlwey+77v43/u2P3ccE1OqZSciXlBlJyJe8LIbm+0k2VV9+rj4zTvGtVs3snZH5FPlP2UuyTe6Npxa4L/GPBlZk34wimjXddg9YZ5z5ZrIgc8qSPniksm8scNJPkNyHcnXSF4bLK8juZTkhuBn/66OJVJOlNt+yaQb2wLgOjMbC+BMAFeRHAdgDoBlZjYKwLLgs0glUW57JJNJshsBNAZxE8l1AI4HcAGAKcFmCwEsB3BjLKUsMfYMZyB78q/vbrdu7DFx3DeSYlBuZ67P0+GYdAf7tbp4+y1nu9jGZTe0evSB5CHr9+dRusxkdYOC5EgApwJYBWBIkCyHk2ZwoQsnUizK7eTLuLIj2RvAwwC+YWYfZbHfbJINJBsOobjzRIpkQrnth4zuxpKsRSoZHjCzR4LFW0nWm1kjyXoA2zra18zmApgLAMeyLrunDjPEbt1cvHv6qS5urWVHm2etpVd4nD5VbXkda/n+8Pvl3tXnuHg0GvI6ruSm3HM7E++vqXfxAyMGuHhmn50FO8eOSWHX9aSTG128buwTOR9z8N3h3V+uLINh2UkSwL0A1pnZXZFVjwOYFcSzADxW+OKJxEe57ZdMWnafAnApgFdJHn5A7TsA7gDwEMkrALwHYHo8RRSJjXLbI5ncjX0WQLr+4LmFLU7mog/67psy1sWLb/+xi6OzGhVOfse85a2/dfHoK9V1LaVyze1snTAnHE79u4O+4OKZ5/+yYOd4+8K5XW+URrMdcvFPd4Uz6lXvDZcX4xqAXhcTES+oshMRL1Tsu7Etp53k4ujEu/l2M+OwqzWciHjXvvAB5eNKURhJNB4Ih1xadzDMu+hTBMNq4v8/Eu26Pr0vfNvumU9/3MW2fW3s5YhSy05EvKDKTkS8ULHd2EoyccVVLj756k0ubu1oY5E8nDzndRdfd+uFLm6cHl72efE7XU+Gk6/oXddo17V1x46ONi8KtexExAuq7ETEC+rGFkFbc3iHrHW3JuWR+LQ1RYZZisT1vw4f25366mWxl6PdA8NFvuuajlp2IuIFVXYi4oWK7cbWvr7FxZNu6nju11LO6RqdoCQ6T61IKbRu3+7iqhXbO9myMEo23lUn1LITES+oshMRL6iyExEvVOw1u+g1iLr5HV+D+GGvmS6+eWDsRWpn2PIDLm43kbCIlIRadiLiBVV2IuKFiu3GZmLwz54rdRFEpEyoZSciXlBlJyJeUGUnIl7IZJLs7iRXk3yZ5Gskvx8sryO5lOSG4Gf/ro4lUk6U237JpGXXDOCvzOwUABMATCN5JoA5AJaZ2SgAy4LPIpVEue2RLis7S9kTfKwN/hiACwAsDJYvBHBhB7uLlC3ltl8yumZHsprkSwC2AVhqZqsADDGzRgAIfg6Or5gi8VBu+yOjys7MWs1sAoBhACaTHJ/pCUjOJtlAsuEQmnMtp0gslNv+yOpurJntBrAcwDQAW0nWA0Dwc1uafeaa2UQzm1iLbnkWVyQeyu3ky+Ru7CCS/YK4B4DPAHgDwOMAZgWbzQLwWFyFFImDctsvmbwuVg9gIclqpCrHh8zstySfB/AQySsAvAdgeozlFImDctsjNCveAMoktwN4t2gnlM6MMLNBpS5EUii3y0bavC5qZSciUip6XUxEvKDKTkS8oMpORLygyk5EvKDKTkS8oMpORLygyk5EvKDKTkS8oMpORLzwv9NPrlrn6D7QAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "ds = create_dataset(training=False)\n", - "data = ds.create_dict_iterator().get_next()\n", - "images = data['image']\n", - "labels = data['label']\n", - "\n", - "for i in range(1, 5):\n", - " plt.subplot(2, 2, i)\n", - " plt.imshow(images[i][0])\n", - " plt.title('Number: %s' % labels[i])\n", - " plt.xticks([])\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 定义模型\n", - "\n", - "MindSpore model_zoo中提供了多种常见的模型,可以直接使用。这里使用其中的LeNet5模型,模型结构如下图所示:\n", - "\n", - "\n", - "\n", - "[1] 图片来源于http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 训练\n", - "\n", - "使用MNIST数据集对上述定义的LeNet5模型进行训练。训练策略如下表所示,可以调整训练策略并查看训练效果,要求验证精度大于95%。\n", - "\n", - "| batch size | number of epochs | learning rate | optimizer |\n", - "| -- | -- | -- | -- |\n", - "| 32 | 3 | 0.01 | Momentum 0.9 |" - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch: 1 step: 1875 ,loss is 2.3086565\n", - "epoch: 2 step: 1875 ,loss is 0.22017351\n", - "epoch: 3 step: 1875 ,loss is 0.025683485\n", - "Metrics: {'acc': 0.9742588141025641, 'loss': 0.08628832848253062}\n" - ] - } - ], - "source": [ - "ds_train = create_dataset(num_epoch=3)\n", - "ds_eval = create_dataset(training=False)\n", - "\n", - "net = LeNet5()\n", - "loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')\n", - "opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)\n", - "\n", - "loss_cb = LossMonitor(per_print_times=1)\n", - "\n", - "model = Model(net, loss, opt, metrics={'acc', 'loss'})\n", - "model.train(3, ds_train, callbacks=[loss_cb])\n", - "metrics = model.eval(ds_eval)\n", - "print('Metrics:', metrics)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 实验步骤(方案二)\n", - "\n", - "除了Notebook,ModelArts还提供了训练作业服务。相比Notebook,训练作业资源池更大,且具有作业排队等功能,适合大规模并发使用。使用训练作业时,也会有修改代码和调试的需求,有如下三个方案:\n", - "\n", - "1. 在本地修改代码后重新上传;\n", - "\n", - "2. 使用[PyCharm ToolKit](https://support.huaweicloud.com/tg-modelarts/modelarts_15_0001.html)配置一个本地Pycharm+ModelArts的开发环境,便于上传代码、提交训练作业和获取训练日志。\n", - "\n", - "3. 在ModelArts上创建Notebook,然后设置[Sync OBS功能](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0038.html),可以在线修改代码并自动同步到OBS中。因为只用Notebook来编辑代码,所以创建CPU类型最低规格的Notebook就行。\n", - "\n", - "### 代码梳理\n", - "\n", - "创建训练作业时,运行参数会通过脚本传参的方式输入给脚本代码,脚本必须解析传参才能在代码中使用相应参数。如data_url和train_url,分别对应数据存储路径(OBS路径)和训练输出路径(OBS路径)。脚本对传参进行解析后赋值到`args`变量里,在后续代码里可以使用。\n", - "\n", - "```python\n", - "import argparse\n", - "parser = argparse.ArgumentParser()\n", - "parser.add_argument('--data_url', required=True, default=None, help='Location of data.')\n", - "parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.')\n", - "parser.add_argument('--num_epochs', type=int, default=1, help='Number of training epochs.')\n", - "args, unknown = parser.parse_known_args()\n", - "```\n", - "\n", - "MindSpore暂时没有提供直接访问OBS数据的接口,需要通过MoXing提供的API与OBS交互。将OBS中存储的数据拷贝至执行容器:\n", - "\n", - "```python\n", - "import moxing as mox\n", - "mox.file.copy_parallel(src_url=args.data_url, dst_url='MNIST/')\n", - "```\n", - "\n", - "如需将训练输出(如模型Checkpoint)从执行容器拷贝至OBS,请参考:\n", - "\n", - "```python\n", - "import moxing as mox\n", - "mox.file.copy_parallel(src_url='output', dst_url='s3://OBS/PATH')\n", - "```\n", - "\n", - "其他代码分析请参考方案一。\n", - "\n", - "### 创建训练作业\n", - "\n", - "可以参考[使用常用框架训练模型](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0238.html)来创建并启动训练作业。\n", - "\n", - "创建训练作业的参考配置:\n", - "\n", - "- 算法来源:常用框架->Ascend-Powered-Engine->MindSpore\n", - "- 代码目录:选择上述新建的OBS桶中的experiment_1目录\n", - "- 启动文件:选择上述新建的OBS桶中的experiment_1目录下的`main.py`\n", - "- 数据来源:数据存储位置->选择上述新建的OBS桶中的experiment_1目录下的MNIST目录\n", - "- 训练输出位置:选择上述新建的OBS桶中的experiment_1目录并在其中创建output目录\n", - "- 作业日志路径:同训练输出位置\n", - "- 规格:Ascend:1*Ascend 910\n", - "- 其他均为默认\n", - "\n", - "启动并查看训练过程:\n", - "\n", - "1. 点击提交以开始训练;\n", - "2. 在训练作业列表里可以看到刚创建的训练作业,在训练作业页面可以看到版本管理;\n", - "3. 点击运行中的训练作业,在展开的窗口中可以查看作业配置信息,以及训练过程中的日志,日志会不断刷新,等训练作业完成后也可以下载日志到本地进行查看;\n", - "4. 在训练日志中可以看到`epoch: 3 step: 1875 ,loss is 0.025683485`等字段,即训练过程的loss值;\n", - "5. 在训练日志中可以看到`Metrics: {'acc': 0.9742588141025641, 'loss': 0.08628832848253062}`字段,即训练完成后的验证精度。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 实验小结\n", - "\n", - "本实验展示了如何使用MindSpore进行手写数字识别,以及开发和训练LeNet5模型。通过对LeNet5模型做几代的训练,然后使用训练后的LeNet5模型对手写数字进行识别,识别准确率大于95%。即LeNet5学习到了如何进行手写数字识别。" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.5" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} \ No newline at end of file diff --git a/experiment_1/main.py b/experiment_1/main.py deleted file mode 100644 index bf23697..0000000 --- a/experiment_1/main.py +++ /dev/null @@ -1,60 +0,0 @@ -# LeNet5 mnist - -import os -# os.environ['DEVICE_ID'] = '0' - -import mindspore as ms -import mindspore.context as context -import mindspore.dataset.transforms.c_transforms as C -import mindspore.dataset.transforms.vision.c_transforms as CV - -from mindspore import nn -from mindspore.model_zoo.lenet import LeNet5 -from mindspore.train import Model -from mindspore.train.callback import LossMonitor - -context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') - -DATA_DIR_TRAIN = "MNIST/train" # 训练集信息 -DATA_DIR_TEST = "MNIST/test" # 测试集信息 - - -def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32), - rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64): - ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST) - - ds = ds.map(input_columns="image", operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()]) - ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32)) - ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True).repeat(num_epoch) - - return ds - - -def test_train(lr=0.01, momentum=0.9, num_epoch=3, ckpt_name="a_lenet"): - ds_train = create_dataset(num_epoch=num_epoch) - ds_eval = create_dataset(training=False) - - net = LeNet5() - loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') - opt = nn.Momentum(net.trainable_params(), lr, momentum) - - loss_cb = LossMonitor(per_print_times=1) - - model = Model(net, loss, opt, metrics={'acc', 'loss'}) - model.train(num_epoch, ds_train, callbacks=[loss_cb]) - metrics = model.eval(ds_eval) - print('Metrics:', metrics) - - -if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser() - parser.add_argument('--data_url', required=True, default=None, help='Location of data.') - parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.') - parser.add_argument('--num_epochs', type=int, default=1, help='Number of training epochs.') - args, unknown = parser.parse_known_args() - - import moxing as mox - mox.file.copy_parallel(src_url=args.data_url, dst_url='MNIST/') - - test_train() diff --git a/experiment_2/2-Save_And_Load_Model.ipynb b/experiment_2/2-Save_And_Load_Model.ipynb deleted file mode 100644 index 7976daf..0000000 --- a/experiment_2/2-Save_And_Load_Model.ipynb +++ /dev/null @@ -1,581 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "

训练时模型的保存和加载

\n", - "\n", - "## 实验介绍\n", - "\n", - "本实验主要介绍使用MindSpore实现训练时模型的保存和加载。建议先阅读MindSpore官网教程中关于模型参数保存和加载的内容。\n", - "\n", - "在模型训练过程中,可以添加检查点(CheckPoint)用于保存模型的参数,以便进行推理及中断后再训练使用。使用场景如下:\n", - "\n", - "- 训练后推理场景\n", - "\n", - " - 模型训练完毕后保存模型的参数,用于推理或预测操作。\n", - "\n", - " - 训练过程中,通过实时验证精度,把精度最高的模型参数保存下来,用于预测操作。\n", - "\n", - "- 再训练场景\n", - "\n", - " - 进行长时间训练任务时,保存训练过程中的CheckPoint文件,防止任务异常退出后从初始状态开始训练。\n", - "\n", - " - Fine-tuning(微调)场景,即训练一个模型并保存参数,基于该模型,面向第二个类似任务进行模型训练。\n", - "\n", - "## 实验目的\n", - "\n", - "- 了解如何使用MindSpore实现训练时模型的保存。\n", - "- 了解如何使用MindSpore加载保存的模型文件并继续训练。\n", - "- 了解如何MindSpore的Callback功能。\n", - "\n", - "## 预备知识\n", - "\n", - "- 熟练使用Python,了解Shell及Linux操作系统基本知识。\n", - "- 具备一定的深度学习理论知识,如卷积神经网络、损失函数、优化器,训练策略、Checkpoint等。\n", - "- 了解华为云的基本使用方法,包括[OBS(对象存储)](https://www.huaweicloud.com/product/obs.html)、[ModelArts(AI开发平台)](https://www.huaweicloud.com/product/modelarts.html)、[Notebook(开发工具)](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0033.html)、[训练作业](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0046.html)等功能。华为云官网:https://www.huaweicloud.com\n", - "- 了解并熟悉MindSpore AI计算框架,MindSpore官网:https://www.mindspore.cn/\n", - "\n", - "## 实验环境\n", - "\n", - "- MindSpore 0.2.0(MindSpore版本会定期更新,本指导也会定期刷新,与版本配套);\n", - "- 华为云ModelArts:ModelArts是华为云提供的面向开发者的一站式AI开发平台,集成了昇腾AI处理器资源池,用户可以在该平台下体验MindSpore。ModelArts官网:https://www.huaweicloud.com/product/modelarts.html\n", - "\n", - "## 实验准备\n", - "\n", - "### 创建OBS桶\n", - "\n", - "本实验需要使用华为云OBS存储实验脚本和数据集,可以参考[快速通过OBS控制台上传下载文件](https://support.huaweicloud.com/qs-obs/obs_qs_0001.html)了解使用OBS创建桶、上传文件、下载文件的使用方法。\n", - "\n", - "> **提示:** 华为云新用户使用OBS时通常需要创建和配置“访问密钥”,可以在使用OBS时根据提示完成创建和配置。也可以参考[获取访问密钥并完成ModelArts全局配置](https://support.huaweicloud.com/prepare-modelarts/modelarts_08_0002.html)获取并配置访问密钥。\n", - "\n", - "创建OBS桶的参考配置如下:\n", - "\n", - "- 区域:华北-北京四\n", - "- 数据冗余存储策略:单AZ存储\n", - "- 桶名称:如ms-course\n", - "- 存储类别:标准存储\n", - "- 桶策略:公共读\n", - "- 归档数据直读:关闭\n", - "- 企业项目、标签等配置:免\n", - "\n", - "### 数据集准备\n", - "\n", - "MNIST是一个手写数字数据集,训练集包含60000张手写数字,测试集包含10000张手写数字,共10类。MNIST数据集的官网:[THE MNIST DATABASE](http://yann.lecun.com/exdb/mnist/)。\n", - "\n", - "从MNIST官网下载如下4个文件到本地并解压:\n", - "\n", - "```\n", - "train-images-idx3-ubyte.gz: training set images (9912422 bytes)\n", - "train-labels-idx1-ubyte.gz: training set labels (28881 bytes)\n", - "t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)\n", - "t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)\n", - "```\n", - "\n", - "### 脚本准备\n", - "\n", - "从[课程gitee仓库](https://gitee.com/mindspore/course)上下载本实验相关脚本。\n", - "\n", - "### 上传文件\n", - "\n", - "将脚本和数据集上传到OBS桶中,组织为如下形式:\n", - "\n", - "```\n", - "experiment_2\n", - "├── MNIST\n", - "│   ├── test\n", - "│   │   ├── t10k-images-idx3-ubyte\n", - "│   │   └── t10k-labels-idx1-ubyte\n", - "│   └── train\n", - "│   ├── train-images-idx3-ubyte\n", - "│   └── train-labels-idx1-ubyte\n", - "├── *.ipynb\n", - "└── main.py\n", - "```\n", - "\n", - "## 实验步骤(方案一)\n", - "\n", - "### 创建Notebook\n", - "\n", - "可以参考[创建并打开Notebook](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0034.html)来创建并打开本实验的Notebook脚本。\n", - "\n", - "创建Notebook的参考配置:\n", - "\n", - "- 计费模式:按需计费\n", - "- 名称:experiment_2\n", - "- 工作环境:Python3\n", - "- 资源池:公共资源\n", - "- 类型:Ascend\n", - "- 规格:单卡1*Ascend 910\n", - "- 存储位置:对象存储服务(OBS)->选择上述新建的OBS桶中的experiment_2文件夹\n", - "- 自动停止等配置:默认\n", - "\n", - "> **注意:**\n", - "> - 打开Notebook前,在Jupyter Notebook文件列表页面,勾选目录里的所有文件/文件夹(实验脚本和数据集),并点击列表上方的“Sync OBS”按钮,使OBS桶中的所有文件同时同步到Notebook工作环境中,这样Notebook中的代码才能访问数据集。参考[使用Sync OBS功能](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0038.html)。\n", - "> - 打开Notebook后,选择MindSpore环境作为Kernel。\n", - "\n", - "> **提示:** 上述数据集和脚本的准备工作也可以在Notebook环境中完成,在Jupyter Notebook文件列表页面,点击右上角的\"New\"->\"Terminal\",进入Notebook环境所在终端,进入`work`目录,可以使用常用的linux shell命令,如`wget, gzip, tar, mkdir, mv`等,完成数据集和脚本的下载和准备。\n", - "\n", - "> **提示:** 请从上至下阅读提示并执行代码框进行体验。代码框执行过程中左侧呈现[\\*],代码框执行完毕后左侧呈现如[1],[2]等。请等上一个代码框执行完毕后再执行下一个代码框。\n", - "\n", - "导入MindSpore模块和辅助模块:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "# os.environ['DEVICE_ID'] = '0'\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "\n", - "import mindspore as ms\n", - "import mindspore.context as context\n", - "import mindspore.dataset.transforms.c_transforms as C\n", - "import mindspore.dataset.transforms.vision.c_transforms as CV\n", - "\n", - "from mindspore.dataset.transforms.vision import Inter\n", - "from mindspore import nn, Tensor\n", - "from mindspore.train import Model\n", - "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor\n", - "from mindspore.train.serialization import load_checkpoint, load_param_into_net\n", - "\n", - "import logging; logging.getLogger('matplotlib.font_manager').disabled = True\n", - "\n", - "context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 数据处理\n", - "\n", - "在使用数据集训练网络前,首先需要对数据进行预处理,如下:" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "DATA_DIR_TRAIN = \"MNIST/train\" # 训练集信息\n", - "DATA_DIR_TEST = \"MNIST/test\" # 测试集信息\n", - "\n", - "def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32),\n", - " rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64):\n", - " ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST)\n", - " \n", - " # define map operations\n", - " resize_op = CV.Resize(resize)\n", - " rescale_op = CV.Rescale(rescale, shift)\n", - " hwc2chw_op = CV.HWC2CHW()\n", - " \n", - " # apply map operations on images\n", - " ds = ds.map(input_columns=\"image\", operations=[resize_op, rescale_op, hwc2chw_op])\n", - " ds = ds.map(input_columns=\"label\", operations=C.TypeCast(ms.int32))\n", - " \n", - " ds = ds.shuffle(buffer_size=buffer_size)\n", - " ds = ds.batch(batch_size, drop_remainder=True)\n", - " ds = ds.repeat(num_epoch)\n", - " \n", - " return ds" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 定义模型\n", - "\n", - "定义LeNet5模型,模型结构如下图所示:\n", - "\n", - "\n", - "\n", - "[1] 图片来源于http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "class LeNet5(nn.Cell):\n", - " def __init__(self):\n", - " super(LeNet5, self).__init__()\n", - " self.relu = nn.ReLU()\n", - " self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid')\n", - " self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid')\n", - " self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\n", - " self.flatten = nn.Flatten()\n", - " self.fc1 = nn.Dense(400, 120)\n", - " self.fc2 = nn.Dense(120, 84)\n", - " self.fc3 = nn.Dense(84, 10)\n", - " \n", - " def construct(self, input_x):\n", - " output = self.conv1(input_x)\n", - " output = self.relu(output)\n", - " output = self.pool(output)\n", - " output = self.conv2(output)\n", - " output = self.relu(output)\n", - " output = self.pool(output)\n", - " output = self.flatten(output)\n", - " output = self.fc1(output)\n", - " output = self.fc2(output)\n", - " output = self.fc3(output)\n", - " \n", - " return output" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 保存模型Checkpoint\n", - "\n", - "MindSpore提供了Callback功能,可用于训练/测试过程中执行特定的任务。常用的Callback如下:\n", - "\n", - "- `ModelCheckpoint`:保存网络模型和参数,用于再训练或推理;\n", - "- `LossMonitor`:监控loss值,当loss值为Nan或Inf时停止训练;\n", - "- `SummaryStep`:把训练过程中的信息存储到文件中,用于后续查看或可视化展示。\n", - "\n", - "`ModelCheckpoint`会生成模型(.meta)和Chekpoint(.ckpt)文件,如每个epoch结束时,都保存一次checkpoint。\n", - "\n", - "```python\n", - "class CheckpointConfig:\n", - " \"\"\"\n", - " The config for model checkpoint.\n", - "\n", - " Args:\n", - " save_checkpoint_steps (int): Steps to save checkpoint. Default: 1.\n", - " save_checkpoint_seconds (int): Seconds to save checkpoint. Default: 0.\n", - " Can't be used with save_checkpoint_steps at the same time.\n", - " keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5.\n", - " keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0.\n", - " Can't be used with keep_checkpoint_max at the same time.\n", - " integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.\n", - " Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.\n", - "\n", - " Raises:\n", - " ValueError: If the input_param is None or 0.\n", - " \"\"\"\n", - "\n", - "class ModelCheckpoint(Callback):\n", - " \"\"\"\n", - " The checkpoint callback class.\n", - "\n", - " It is called to combine with train process and save the model and network parameters after traning.\n", - "\n", - " Args:\n", - " prefix (str): Checkpoint files names prefix. Default: \"CKP\".\n", - " directory (str): Lolder path into which checkpoint files will be saved. Default: None.\n", - " config (CheckpointConfig): Checkpoint strategy config. Default: None.\n", - "\n", - " Raises:\n", - " ValueError: If the prefix is invalid.\n", - " TypeError: If the config is not CheckpointConfig type.\n", - " \"\"\"\n", - "```\n", - "\n", - "MindSpore提供了多种Metric评估指标,如`accuracy`、`loss`、`precision`、`recall`、`F1`。定义一个metrics字典/元组,里面包含多种指标,传递给`Model`,然后调用`model.eval`接口来计算这些指标。`model.eval`会返回一个字典,包含各个指标及其对应的值。" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch: 1 step: 1875 ,loss is 2.3151364\n", - "epoch: 2 step: 1875 ,loss is 0.3097728\n", - "Metrics: {'acc': 0.9417067307692307, 'loss': 0.18866610953894755}\n", - "b_lenet-1_1875.ckpt\n", - "b_lenet-2_1875.ckpt\n" - ] - } - ], - "source": [ - "os.system('rm -f *.ckpt *.ir *.meta') # 清理旧的运行文件\n", - "\n", - "def test_train(lr=0.01, momentum=0.9, num_epoch=2, check_point_name=\"b_lenet\"):\n", - " ds_train = create_dataset(num_epoch=num_epoch)\n", - " ds_eval = create_dataset(training=False)\n", - " steps_per_epoch = ds_train.get_dataset_size()\n", - " \n", - " net = LeNet5()\n", - " loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')\n", - " opt = nn.Momentum(net.trainable_params(), lr, momentum)\n", - " \n", - " ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5)\n", - " ckpt_cb = ModelCheckpoint(prefix=check_point_name, config=ckpt_cfg)\n", - " loss_cb = LossMonitor(steps_per_epoch)\n", - " \n", - " model = Model(net, loss, opt, metrics={'acc', 'loss'})\n", - " model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=True)\n", - " metrics = model.eval(ds_eval)\n", - " print('Metrics:', metrics)\n", - "\n", - "test_train()\n", - "print('\\n'.join(sorted([x for x in os.listdir('.') if x.startswith('b_lenet')])))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 加载Checkpoint继续训练\n", - "\n", - "```python\n", - "def load_checkpoint(ckpoint_file_name, net=None):\n", - " \"\"\"\n", - " Loads checkpoint info from a specified file.\n", - "\n", - " Args:\n", - " ckpoint_file_name (str): Checkpoint file name.\n", - " net (Cell): Cell network. Default: None\n", - "\n", - " Returns:\n", - " Dict, key is parameter name, value is a Parameter.\n", - "\n", - " Raises:\n", - " ValueError: Checkpoint file is incorrect.\n", - " \"\"\"\n", - "\n", - "def load_param_into_net(net, parameter_dict):\n", - " \"\"\"\n", - " Loads parameters into network.\n", - "\n", - " Args:\n", - " net (Cell): Cell network.\n", - " parameter_dict (dict): Parameter dict.\n", - "\n", - " Raises:\n", - " TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dict.\n", - " \"\"\"\n", - "```\n", - "\n", - "> 使用load_checkpoint接口加载数据时,需要把数据传入给原始网络,而不能传递给带有优化器和损失函数的训练网络。" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch: 1 step: 1875 ,loss is 0.1638589\n", - "epoch: 2 step: 1875 ,loss is 0.060048036\n", - "Metrics: {'acc': 0.9742588141025641, 'loss': 0.07910804035148034}\n", - "b_lenet_1-1_1875.ckpt\n", - "b_lenet_1-2_1875.ckpt\n" - ] - } - ], - "source": [ - "CKPT = 'b_lenet-2_1875.ckpt'\n", - "\n", - "def resume_train(lr=0.001, momentum=0.9, num_epoch=2, ckpt_name=\"b_lenet\"):\n", - " ds_train = create_dataset(num_epoch=num_epoch)\n", - " ds_eval = create_dataset(training=False)\n", - " steps_per_epoch = ds_train.get_dataset_size()\n", - " \n", - " net = LeNet5()\n", - " loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')\n", - " opt = nn.Momentum(net.trainable_params(), lr, momentum)\n", - " \n", - " param_dict = load_checkpoint(CKPT)\n", - " load_param_into_net(net, param_dict)\n", - " load_param_into_net(opt, param_dict)\n", - " \n", - " ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5)\n", - " ckpt_cb = ModelCheckpoint(prefix=ckpt_name, config=ckpt_cfg)\n", - " loss_cb = LossMonitor(steps_per_epoch)\n", - " \n", - " model = Model(net, loss, opt, metrics={'acc', 'loss'})\n", - " model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb])\n", - " \n", - " metrics = model.eval(ds_eval)\n", - " print('Metrics:', metrics)\n", - "\n", - "resume_train()\n", - "print('\\n'.join(sorted([x for x in os.listdir('.') if x.startswith('b_lenet')])))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 加载Checkpoint进行推理\n", - " \n", - "使用matplotlib定义一个将推理结果可视化的辅助函数,如下:" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "def plot_images(pred_fn, ds, net):\n", - " for i in range(1, 5):\n", - " pred, image, label = pred_fn(ds, net)\n", - " plt.subplot(2, 2, i)\n", - " plt.imshow(np.squeeze(image))\n", - " color = 'blue' if pred == label else 'red'\n", - " plt.title(\"prediction: {}, truth: {}\".format(pred, label), color=color)\n", - " plt.xticks([])\n", - " plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "使用训练后的LeNet5模型对手写数字进行识别,可以看到识别结果基本上是正确的。" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUoAAAD7CAYAAAAMyN1hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAcv0lEQVR4nO3de5RU5Znv8e9D03TLJUIraCMgXgDJMlEZguRoEuJl1IkecpxlJo7jQpfaIdE1OmO8xMl9NPHkmMvMMpmIE4TxbtBRYuIkSjQRZVCWgSSKCqMIKHJROgEEpJvn/FGbXbuaqn5316Wrqvv3WatXP7v27d3VTz+1330rc3dERKSwAdVugIhIrVOhFBEJUKEUEQlQoRQRCVChFBEJUKEUEQmoyUJpxjwzbozij5nxSpHL+bEZXylv6+qDGePNcDMGVrstkqXcLl01crsmC2WSO0+7Myk0nRkXmbG4y7yz3fnnyrUuXvexZvzSjC1m9OjCVDNmmLG+DG1YY8ZppS4nsbxPmvGkGX8yY025litZyu3Uy6l6ble8UPaTPZo9wAPAJZVYeJXewx3AXOCaKqy7Lii3S1c3ue3uPf4BXwP+JfCXwLeC3wHeHI2bAb4e/Drwt8HvjF4/G3w5eDv4s+AfTizvBPAXwLeB3w9+H/iNyeUlph0L/hD4ZvB3wG8Fnwy+C7wTfDt4ezTtvH3LiYYvA18N/i74QvDRiXEOPht8VbRNPwS3Hr4vR4N7D6YfAr4TfG/U7u3go8G/Dr4A/C7wP4Nfmmdb4vcF/M5oGTujZVwLPj7aplnga8G3gP9TEX/r08DXFJMn9fij3FZu5/spZY/yAuAM4ChgIvDlxLhDgRbgcKDNjClkKvjngIOA24CFZjSZMQh4GLgzmuenwF/nW6EZDcCjwBvAeOAw4D53VgKzgSXuDHVneJ55TwG+DXwGaI2WcV+Xyc4GPgIcF013RjTvODPazRiX9s1Jw50dwFnAW1G7h7rzVjR6JrAAGA7cHVjOhcBa4JxoGd9JjD4ZmAScCnzVjMnRNp1sRns5t6cPUW6XqK/ldimF8lZ31rnzLnATcH5i3F7ga+7sdmcncBlwmztL3el0Zz6wG5ge/TQCP3BnjzsLgOcLrHMaMBq4xp0d7uxyzz12040LgLnuvODObuBLwEfNGJ+Y5mZ32t1ZCzwJHA/gzlp3hkev95Yl7jzszt7oPSzWN9zZ6c4KYAWZfxTcWZzvn04A5Xal1V1ul1Io1yXiN8j8kffZ7M6uxPDhwNXRJ1d7VO3HRvOMBt50zzlQ/EaBdY4F3nCno4j2jk4u153twDtkPrn3eTsRvwcMLWI95bIuPEkqtbRN9UK5XVl1l9ulFMqxiXgcxLvVwH5nx9YBN0WfXPt+BrtzL7ABOMwM67K8fNYB4wocAA6dkXuLTFIDYMYQMl2lNwPzVVqhdnd9fQcwODF8aMrlSM8pt8ujz+R2KYXycjPGmNEC3ADc3820twOzzTjRDDNjiBmfMmMYsAToAP7ejIFmnEumG5LPc2SS7+ZoGc1mnBSN2wiMiY4L5XMPcLEZx5vRBHwLWOpe+qUv0TY1Q2bdUbuaEuPnmTGvwOwbgYPMODCwmuXAX5nRYsahwFV5lnNkURuQhxkDom1qhMz2dfPe9jXK7YhyO6OUQnkP8CvgtejnxkITurOMzLGcW4GtwGrgomjc+8C50fBW4G+AhwospxM4BziazAHe9dH0AL8GXgTeNmNLnnkXAV8BHiSTkEcBn02zodEB7+3dHPA+HNgZrZ8oTl5IPBZ4psA2vQzcC7wWdd1G55uOzAmBFcAaMu9713/ebwNfjpbxxcAm7bvYeXs3k3w82o5fkNkL2hmttz9QbmcptwHLnCbvmegizUvdeaLHM/cz0SfVCuDD7uypdnuke8rt9PpTbveHC2arKtqrmFztdoiUW3/K7Zq/hVFEpNqK6nqLiPQnJe1RmtmZZvaKma02s+vL1SiRalNuS1LRe5Rm1gC8CpxO5gzd88D57v5S+Zon0vuU29JVKSdzpgGr3f01ADO7j8w9nAWTaZA1eTNDSlillMs2tm5x95HVbkeNUm7XqV3s4H3fbeEpe6aUQnkYubcirQdO7DqRmbUBbQDNDOZEO7WEVUq5POELCt1KJ8rturXUF1VkuaUco8xXtffrx7v7HHef6u5TG7MX9IvUMuW25CilUK4n957YMeTeEytSr5TbkqOUQvk8MMHMjjCzQWRumVpYnmaJVJVyW3IUfYzS3TvM7Argl0ADMNfdXwzMJlLzlNvSVUm3MLr7L8jcWC7Spyi3JUm3MIqIBKhQiogEqFCKiASoUIqIBKhQiogE6MG9KQwYnP3eo7VXHp8zbm+KGzJGLs9+sd4BDz9XtnaJSO/QHqWISIAKpYhIgLreKdiw7PeqL/jcLTnjJg8a3HXy/Rzx2KVxPPHh8rVLJK2BR46P402faE01z8hHV8dx5+bN5W5SXdEepYhIgAqliEiACqWISICOUYr0A8njks/f9G+p5pnS9Pk4PvTB7Ov98Xil9ihFRAJUKEVEAtT1FpG8Xvhqtov+kd3ZbnjLHep6i4hIFyqUIiIBKpQiIgEqlCIiASqUIiIBOust0g8kH3CRvJA8eWZbCgvuUZrZXDPbZGZ/TLzWYmaPm9mq6PeIyjZTpPyU25JWmq73PODMLq9dDyxy9wnAomhYpN7MQ7ktKQS73u7+WzMb3+XlmcCMKJ4PPAVcV8Z2iVRcf8rt5P3ZB60c0+P5P37F0jhe3PnROB7+H0tKa1idKPZkziHuvgEg+j2q0IRm1mZmy8xs2R52F7k6kV6j3Jb9VPyst7vPcfep7j61kRTfxCVSJ5Tb/UexZ703mlmru28ws1ZgUzkbJVJFyu08vtv6QhxPPnJ6HA+vRmOqoNg9yoXArCieBTxSnuaIVJ1yW/aT5vKge4ElwCQzW29mlwA3A6eb2Srg9GhYpK4otyWtNGe9zy8w6tQyt0WkVym3JS3dwigiEqBCKSISoHu9C2gYOTKON5x3dBwPG7C3Gs0RkSrSHqWISIAKpYhIgLreBez5YPZ+2N/d8KPEmKG93xgRqSrtUYqIBKhQiogEqOudYE3ZBxu8/4HGkpa1qXNHdrm7GkpalohUl/YoRUQCVChFRALU9U5oP++EOL73plsSY3p+pvuku78Yx5Nuir+SBV2uLlJ/tEcpIhKgQikiEqBCKSISoGOUCZ2NFsdHNJZ2B87AXdll7d22raRliUh1aY9SRCRAhVJEJECFUkQkQIVSRCRAhVJEJECFUkQkIM33eo81syfNbKWZvWhmV0avt5jZ42a2Kvo9ovLNFSkf5baklWaPsgO42t0nA9OBy83sg8D1wCJ3nwAsioZF6olyW1IJXnDu7huADVG8zcxWAocBM4EZ0WTzgaeA6yrSygrqnDEljredtb2KLZHe1tdzW8qnR8cozWw8cAKwFDgkSrR9CTeqwDxtZrbMzJbtYXdprRWpEOW2dCd1oTSzocCDwFXu/ue087n7HHef6u5TG2kKzyDSy5TbEpLqXm8zaySTSHe7+0PRyxvNrNXdN5hZK7CpUo2spDdnNMfxyyfPrWJLpBr6cm5L+aQ5623AT4CV7v69xKiFwKwongU8Uv7miVSOclvSSrNHeRJwIfAHM1sevXYDcDPwgJldAqwFzqtME0UqRrktqaQ5670YsAKjTy1vc3pHw4Qj43hX654qtkSqqS/mtlSG7swREQlQoRQRCeiXTzhfeX1LHL9+1u1lW+76juwF6wN0WZ3UKOvwOH51z444PmrgATnTNVj+/aiO5uz8A4YNi+O+/CR/7VGKiASoUIqIBPTLrnelnPOda+N43Nzlcby3Go0RKWDA0j/G8VVnXRzHP3jsjpzpJjYOyTv/MxfcEsfTh/1jHE+4fGm5mlhztEcpIhKgQikiEqCud4mmfPPzcdz64Ko47nzvvWo0RyTIOzqyA++0x2GnF7r2PteohmyX3Js7y9auWqY9ShGRABVKEZEAdb1TSF5InjyzDV2625s391qbRKT3aI9SRCRAhVJEJECFUkQkoF8eoxz7s+znw+R1XwhOn3zARfKOG9BlQCL9gfYoRUQCVChFRAL6Zdf7gIefi+NxD/dsXj3gQvoSf29nHH/q5/+QOy7FXTcHL24se5tqkfYoRUQCVChFRAKCXW8zawZ+CzRF0y9w96+ZWQtwPzAeWAN8xt23Vq6pIuWl3M79+oa+/DzJUqXZo9wNnOLuxwHHA2ea2XTgemCRu08AFkXDIvVEuS2pBAulZ+y72bkx+nFgJjA/en0+8OmKtFCkQpTbklaqY5Rm1mBmy4FNwOPuvhQ4xN03AES/R1WumSKVodyWNFIVSnfvdPfjgTHANDM7Nu0KzKzNzJaZ2bI96DtcpbYotyWNHp31dvd24CngTGCjmbUCRL83FZhnjrtPdfepjTSV2FyRylBuS3eChdLMRprZ8Cg+ADgNeBlYCMyKJpsFPFKpRopUgnJb0kpzZ04rMN/MGsgU1gfc/VEzWwI8YGaXAGuB8yrYTpFKUG5LKsFC6e6/B07I8/o7wKmVaJRIb1BuS1rm7r23MrPNwBu9tkLpzuHuPrLajegrlNs1oyJ53auFUkSkHulebxGRABVKEZGAmiyUZswz48Yo/pgZrxS5nB+b8ZXytq4+mDHeDDfrn88crVXK7dJVI7drslAmufO0O5NC05lxkRmLu8w7251/rlzr4nUfa8YvzdhiRo8O+poxw4z1ZWjDGjNOK3U5ieV90ownzfiTGWvKtVzJqofcjtZ/pBmPmrEtyvHvpJyvJnM7WuYUM35rxnYzNppxZXfTV7xQ9pM9mj3AA8AllVh4ld7DHcBc4JoqrLsu9IfcNmMQ8Djwa+BQMrd63lXG5ff6e2jGwcB/AbcBBwFHA7/qdiZ37/EP+BrwL4G/BL4V/A7w5mjcDPD14NeBvw1+Z/T62eDLwdvBnwX/cGJ5J4C/AL4N/H7w+8BvTC4vMe1Y8IfAN4O/A34r+GTwXeCd4NvB26Np5+1bTjR8Gfhq8HfBF4KPToxz8Nngq6Jt+iG49fB9OTp6Jk3a6YeA7wTfG7V7O/ho8K+DLwC/C/zP4Jfm2Zb4fQG/M1rGzmgZ14KPj7ZpFvha8C3g/1TE3/o08DXF5Ek9/ii393s/2sCfLuJ9rNncBv/Wvr9d2p9S9igvAM4AjgImAl9OjDsUaAEOB9rMmEJm7+RzZCr4bcBCM5qiT6yHgTujeX4K/HW+FZrRADxK5nq18cBhwH3urARmA0vcGerO8DzzngJ8G/gMmTsy3gDu6zLZ2cBHgOOi6c6I5h1nRrsZ49K+OWm4swM4C3gravdQd96KRs8EFgDDgbsDy7mQzB0k50TLSHaNTgYmkbmA+qtmTI626WQz2su5PX2IcjtrOrDGjMeibvdTZnyowLSxGs/t6cC7ZjxrxiYzfhb63y6lUN7qzjp33gVuAs5PjNsLfM2d3e7sBC4DbnNnqTud7swn89DU6dFPI/ADd/a4swB4vsA6pwGjgWvc2eHOLvfcYzfduACY684L7uwGvgR81IzxiWludqfdnbXAk2Qe5oo7a90ZHr3eW5a487A7e6P3sFjfcGenOyuAFWT+UXBncb5/OgGU20ljgM8C/xq17+fAI9GHQLGqndtjyNzDfyUwDngduLe7FZVSKNcl4jfIvIn7bHZnV2L4cODq6JOrPar2Y6N5RgNvuuecBCl0h8NY4A13Oopo7+jkct3ZDrxD5pN7n7cT8XvA0CLWUy7rwpOkUkvbVC+U21k7gcXuPObO+8AtZPacJxfRzn2qnds7gf905/nob/kN4H+ZcWChGUoplGMT8TiId6uB/c78rgNuij659v0MdudeYANwmBnWZXn5rAPGFTgAHDrb/BaZpAbAjCFk/uBvBuartELt7vr6DmBwYvjQlMuRnlNuZ/0+xfoLqdXc7rpN+2LLMy1QWqG83IwxZrQAN5D5MqZCbgdmm3GiGWbGEDM+ZcYwYAnQAfy9GQPNOJdMNySf58gk383RMprNOCkatxEY002X4B7gYjOON6MJ+Baw1L30S1+ibWqGzLqjdjUlxs8zY16B2TcCB3X3aRZZDvyVGS1mHApclWc5Rxa1AXmYMSDapkbIbF+J3a16otzOuguYbsZp0XHUq4AtwEqoz9wG7gD+T/R+NQJfIbPXXPC4ZimF8h4yp9Rfi35uLDShO8vIHMu5FdgKrAYuisa9D5wbDW8F/gZ4qMByOoFzyJzOXwusj6aHzOULLwJvm7Elz7yLyLwhD5JJyKPIHHsJig54b+/mgO/hZHbnX4yGd0LOhcRjgWcKbNPLZI6PvBZ13Ubnm47MCYEVZL4V8Ffs/8/7beDL0TK+GNikfRc7b+9mko9H2/ELMntBOwldQtF3KLezy34F+Dvgx9E2zAT+d7RtUIe57c6vyXwA/pzMQ5mPBv6222VmTpf3jGUuQL7UnSd6PHM/E+0FrAA+7M6eardHuqfcTq8/5Xafv2C22qJP3lIOfIvUpP6U2zV/C6OISLUVVSjdGe/OE2Z2ppm9YmarzUxfEi91T7kt+RT94N7oe0ZeBU4nc+D5eeB8d3+pfM0T6X3KbemqlGOU04DV7v4agJndR+aMWMFkGmRN3syQElYp5bKNrVtcXwVRiHK7Tu1iB+/77oLXQxarlEJ5GLlX2K8HTuxuhmaGcKLpO5tqwRO+QN/vUphyu04t9UUVWW4phTJf1d6vH29mbUAbQHPOxfciNUu5LTlKOeu9ntxbvcaQe6sXAO4+x92nuvvUxuzNKiK1TLktOUoplM8DE8zsCDMbROZOgIXlaZZIVSm3JUfRXW937zCzK4BfAg3AXHd/MTCbSM1TbktXJd2Z4+6/IHMvsEifotyWJN2ZIyISoEIpIhKgQikiEqBCKSISoEIpIhKg51GK9FGdM6bE8ZszmvNOM2B3Nh73L8tzxu19772KtKseaY9SRCRAhVJEJEBd7xK9d272oTK7Dsz/uTPi1ex3vNszy/NOI9ITafJu21nZ79d6+eS5eadZ35Gd5pxt1+aMa71/VRx3bt5cVDv7Cu1RiogEqFCKiASoUIqIBOgYZQo2MPs27T3x2JxxF970szhuO3C/RxYCcMRjl8bxxLxfFS/SM2d8/Tdx/OWDXy56OWMGDo3j393wo5xxp//h4jge8BsdoxQRkW6oUIqIBKjrncKAg1ri+Pt353ZPJg/Sd6VIBVn263saDj44jpsGvFaN1vRb2qMUEQlQoRQRCVDXW6SGJbvbbc8uieOzBm9NTNXYiy3qn7RHKSISoEIpIhKgrrdIjbGp2ZsaPjkvf3e7ycLd7eOeOz+OR30///MouzPwhdVxvLfHc/ctwT1KM5trZpvM7I+J11rM7HEzWxX9HlHZZoqUn3Jb0krT9Z4HnNnlteuBRe4+AVgUDYvUm3kotyWFYNfb3X9rZuO7vDwTmBHF84GngOvK2K6qG3DsMXG891+3xfHYgTqs21fUam53Dsl2q69p+Z/EmHB3+5jFF8bxuH9piGN75nc9bkd/724nFftff4i7bwCIfo8qX5NEqkq5Lfup+MkcM2sD2gCa0e1+0ncot/uPYgvlRjNrdfcNZtYKbCo0obvPAeYAfMBavMj19bqOgw6I48ePuS8xpudnD6Wu1HVu20vDsvEzz1axJX1LsV3vhcCsKJ4FPFKe5ohUnXJb9pPm8qB7gSXAJDNbb2aXADcDp5vZKuD0aFikrii3Ja00Z73PLzDq1DK3paY0vvWnOD7q19knPS/7xA9zphvRkP/Y1MVrPxbHBy/Wvbi1qFZye8Bxk3OGX/3bnh0RO33lOXE8cnlHWdqUVueMKXH85oziD0uN/8+tOcN7V6wselmVoGtdREQCVChFRAJ0r3cBnauyT5CedHX2Urq3l+ZON6KBvP77vz4Ux+Pu0NlHKWzLlOE5w6+f8289mr/9rjFx3PLwkm6mLL9kd3tl24+6mbJ7R4xqyxk+5rbs4Yha6IZrj1JEJECFUkQkQF3vAgYMy164u2Pa+DhuNt0BK/1bw4Qj43hX656yLPP1T8/JGZ686QtxPG5FWVZREu1RiogEqFCKiASo611Ax5Sj4/g3tyW7BUN7vzEiVdYw/MA4fuVr2fj1U26vyPo6mrO3zicPg+3dti3f5BWnPUoRkQAVShGRAHW9RSTozXmj43jZXySfd1CZ53A+c8EtcTx92D/G8YTLl+abvOK0RykiEqBCKSISoEIpIhKgY5RldMzt2bsJjvrJ2jju3ScEipTHu49OjOOffujf43hEw5AeLef1Pdvj+KIrsscbL/jOo3HcduBbOfOMSqzDmzt7tL5K0B6liEiACqWISIC63mU0bE32boKOdeur2BKpJ6N+syFnOHkI5+XLws94/PgV2UtmFnd+NI6H/0dpz6acOe73cTyxsWfd7Qe2Z+/e+b/f/XwcN8zeEsczBq9KzNGz5fc27VGKiASoUIqIBKjrneAnHR/Hm/5hV6p5jvjZZXF8zAvtcaynVkpaHa+tyRk+6t+z10lMaMp2W5N3qyTPCn+39YU4vvgL2def+sTUOB720qA4bv1uuq8meXDOKXHcNDv73MlrWv4nOO+ru1qzbb0z+0DJtS3Z/7E1k7JfgTGxsTzPtayUNN/rPdbMnjSzlWb2opldGb3eYmaPm9mq6PeIyjdXpHyU25JWmq53B3C1u08GpgOXm9kHgeuBRe4+AVgUDYvUE+W2pBLserv7BmBDFG8zs5XAYcBMYEY02XzgKeC6irSyl2ydeEAcr5g2L9U8R9+T7SbVwrfFSXq1mtvJKyYmfP/9OH7nsxbHowp8++cd457ODiTi/3fiUXE8d9gZZWhl96YNznbP7772lG6mrA89OpljZuOBE4ClwCFRou1LuFGF5xSpbcpt6U7qQmlmQ4EHgavc/c89mK/NzJaZ2bI97C6mjSIVpdyWkFRnvc2skUwi3e3uD0UvbzSzVnffYGatwKZ887r7HGAOwAesxfNNI1ItNZ/bu7MF+LPLL4njn56Qvfc6zcXgyTPV17SFL2Iv1V8Ozp7FXtkL66u0NGe9DfgJsNLdv5cYtRCYFcWzgEfK3zyRylFuS1pp9ihPAi4E/mBmy6PXbgBuBh4ws0uAtcB5lWmiSMUotyWVNGe9FwNWYPSp5W2OSO+ph9zubP9THB/66Wx8waMXx3Hynuy/HPaHOJ7W1Fjh1pVP8t5wyL1gfdCG6m+HbmEUEQlQoRQRCdC93iJ1qOXsV+P4aZrjeO4Pr4jjn3/q+73aplJ875uzc4YPvOu/43g8pT0urhy0RykiEqBCKSISoK63SB8y6fqX4vjqb366ii3pmeHtv8sZrrU7U7RHKSISoEIpIhKgrrdIH7J327bsQDKWkmiPUkQkQIVSRCRAhVJEJECFUkQkQIVSRCRAhVJEJECFUkQkQIVSRCRAhVJEJEB35iSMeHVnHB/x2KWp5pn81rtx3Fn2FolILdAepYhIgAqliEiAut4J9szyOJ74TLp51N0W6fuCe5Rm1mxmz5nZCjN70cy+Eb3eYmaPm9mq6PeIyjdXpHyU25JWmq73buAUdz8OOB4408ymA9cDi9x9ArAoGhapJ8ptSSVYKD1jezTYGP04MBOYH70+H6if586LoNyW9FKdzDGzBjNbDmwCHnf3pcAh7r4BIPo9qnLNFKkM5bakkapQununux8PjAGmmdmxaVdgZm1mtszMlu1hd7HtFKkI5bak0aPLg9y9HXgKOBPYaGatANHvTQXmmePuU919aiNNJTZXpDKU29KdNGe9R5rZ8Cg+ADgNeBlYCMyKJpsFPFKpRopUgnJb0kpzHWUrMN/MGsgU1gfc/VEzWwI8YGaXAGuB8yrYTpFKUG5LKubee181bmabgTd6bYXSncPdfWS1G9FXKLdrRkXyulcLpYhIPdK93iIiASqUIiIBKpQiIgEqlCIiASqUIiIBKpQiIgEqlCIiASqUIiIBKpQiIgH/HyKwOLKinO/gAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "CKPT = 'b_lenet_1-2_1875.ckpt'\n", - "\n", - "def infer(ds, model):\n", - " data = ds.get_next()\n", - " images = data['image']\n", - " labels = data['label']\n", - " output = model.predict(Tensor(data['image']))\n", - " pred = np.argmax(output.asnumpy(), axis=1)\n", - " return pred[0], images[0], labels[0]\n", - "\n", - "ds = create_dataset(training=False, batch_size=1).create_dict_iterator()\n", - "net = LeNet5()\n", - "param_dict = load_checkpoint(CKPT, net)\n", - "model = Model(net)\n", - "plot_images(infer, ds, model)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 实验步骤(方案二)\n", - "\n", - "### 代码梳理\n", - "\n", - "创建训练作业时,运行参数会通过脚本传参的方式输入给脚本代码,脚本必须解析传参才能在代码中使用相应参数。如data_url和train_url,分别对应数据存储路径(OBS路径)和训练输出路径(OBS路径)。脚本对传参进行解析后赋值到`args`变量里,在后续代码里可以使用。\n", - "\n", - "```python\n", - "import argparse\n", - "parser = argparse.ArgumentParser()\n", - "parser.add_argument('--data_url', required=True, default=None, help='Location of data.')\n", - "parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.')\n", - "parser.add_argument('--num_epochs', type=int, default=1, help='Number of training epochs.')\n", - "args, unknown = parser.parse_known_args()\n", - "```\n", - "\n", - "MindSpore暂时没有提供直接访问OBS数据的接口,需要通过MoXing提供的API与OBS交互。将OBS中存储的数据拷贝至执行容器:\n", - "\n", - "```python\n", - "import moxing as mox\n", - "mox.file.copy_parallel(src_url=args.data_url, dst_url='MNIST/')\n", - "```\n", - "\n", - "如需将训练输出(如模型Checkpoint)从执行容器拷贝至OBS,请参考:\n", - "\n", - "```python\n", - "import moxing as mox\n", - "mox.file.copy_parallel(src_url='output', dst_url='s3://OBS/PATH')\n", - "```\n", - "\n", - "其他代码分析请参考方案一。\n", - "\n", - "### 创建训练作业\n", - "\n", - "可以参考[使用常用框架训练模型](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0238.html)来创建并启动训练作业。\n", - "\n", - "创建训练作业的参考配置:\n", - "\n", - "- 算法来源:常用框架->Ascend-Powered-Engine->MindSpore\n", - "- 代码目录:选择上述新建的OBS桶中的experiment_2目录\n", - "- 启动文件:选择上述新建的OBS桶中的experiment_2目录下的`main.py`\n", - "- 数据来源:数据存储位置->选择上述新建的OBS桶中的experiment_2文件夹下的MNIST目录\n", - "- 训练输出位置:选择上述新建的OBS桶中的experiment_2目录并在其中创建output目录\n", - "- 作业日志路径:同训练输出位置\n", - "- 规格:Ascend:1*Ascend 910\n", - "- 其他均为默认\n", - "\n", - "启动并查看训练过程:\n", - "\n", - "1. 点击提交以开始训练;\n", - "2. 在训练作业列表里可以看到刚创建的训练作业,在训练作业页面可以看到版本管理;\n", - "3. 点击运行中的训练作业,在展开的窗口中可以查看作业配置信息,以及训练过程中的日志,日志会不断刷新,等训练作业完成后也可以下载日志到本地进行查看;\n", - "4. 在训练日志中可以看到`epoch: 3 step: 1875 ,loss is 0.025683485`等字段,即训练过程的loss值;\n", - "5. 在训练日志中可以看到`Metrics: {'acc': 0.9742588141025641, 'loss': 0.08628832848253062}`等字段,即训练完成后的验证精度;\n", - "6. 在训练日志里可以看到`b_lenet_1-2_1875.ckpt`等字段,即训练过程保存的Checkpoint。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 实验小结\n", - "\n", - "本实验展示了MindSpore的Checkpoint、断点继续训练等高级特性:\n", - "\n", - "1. 使用MindSpore的ModelCheckpoint接口每个epoch保存一次Checkpoint,训练2个epoch并终止。\n", - "2. 使用MindSpore的load_checkpoint和load_param_into_net接口加载上一步保存的Checkpoint继续训练2个epoch。\n", - "3. 观察训练过程中Loss的变化情况,加载Checkpoint继续训练后loss进一步下降。" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} \ No newline at end of file diff --git a/experiment_2/main.py b/experiment_2/main.py deleted file mode 100644 index 972f524..0000000 --- a/experiment_2/main.py +++ /dev/null @@ -1,140 +0,0 @@ -# Save and load model - -import os -# os.environ['DEVICE_ID'] = '0' -# Log level includes 3(ERROR), 2(WARNING), 1(INFO), 0(DEBUG). -os.environ['GLOG_v'] = '2' - -import matplotlib.pyplot as plt -import numpy as np - -import mindspore as ms -import mindspore.context as context -import mindspore.dataset.transforms.c_transforms as C -import mindspore.dataset.transforms.vision.c_transforms as CV - -from mindspore.dataset.transforms.vision import Inter -from mindspore import nn, Tensor -from mindspore.train import Model -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor -from mindspore.train.serialization import load_checkpoint, load_param_into_net - -import logging; logging.getLogger('matplotlib.font_manager').disabled = True - -context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') - -DATA_DIR_TRAIN = "MNIST/train" # 训练集信息 -DATA_DIR_TEST = "MNIST/test" # 测试集信息 - - -def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32), - rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64): - ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST) - - # define map operations - resize_op = CV.Resize(resize) - rescale_op = CV.Rescale(rescale, shift) - hwc2chw_op = CV.HWC2CHW() - - # apply map operations on images - ds = ds.map(input_columns="image", operations=[resize_op, rescale_op, hwc2chw_op]) - ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32)) - - ds = ds.shuffle(buffer_size=buffer_size) - ds = ds.batch(batch_size, drop_remainder=True) - ds = ds.repeat(num_epoch) - - return ds - - -class LeNet5(nn.Cell): - def __init__(self): - super(LeNet5, self).__init__() - self.relu = nn.ReLU() - self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid') - self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid') - self.pool = nn.MaxPool2d(kernel_size=2, stride=2) - self.flatten = nn.Flatten() - self.fc1 = nn.Dense(400, 120) - self.fc2 = nn.Dense(120, 84) - self.fc3 = nn.Dense(84, 10) - - def construct(self, input_x): - output = self.conv1(input_x) - output = self.relu(output) - output = self.pool(output) - output = self.conv2(output) - output = self.relu(output) - output = self.pool(output) - output = self.flatten(output) - output = self.fc1(output) - output = self.fc2(output) - output = self.fc3(output) - - return output - - -def test_train(lr=0.01, momentum=0.9, num_epoch=2, check_point_name="b_lenet"): - ds_train = create_dataset(num_epoch=num_epoch) - ds_eval = create_dataset(training=False) - steps_per_epoch = ds_train.get_dataset_size() - - net = LeNet5() - loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') - opt = nn.Momentum(net.trainable_params(), lr, momentum) - - ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5) - ckpt_cb = ModelCheckpoint(prefix=check_point_name, config=ckpt_cfg) - loss_cb = LossMonitor(steps_per_epoch) - - model = Model(net, loss, opt, metrics={'acc', 'loss'}) - model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=True) - metrics = model.eval(ds_eval) - print('Metrics:', metrics) - - -CKPT = 'b_lenet-2_1875.ckpt' - -def resume_train(lr=0.001, momentum=0.9, num_epoch=2, ckpt_name="b_lenet"): - ds_train = create_dataset(num_epoch=num_epoch) - ds_eval = create_dataset(training=False) - steps_per_epoch = ds_train.get_dataset_size() - - net = LeNet5() - loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') - opt = nn.Momentum(net.trainable_params(), lr, momentum) - - param_dict = load_checkpoint(CKPT) - load_param_into_net(net, param_dict) - load_param_into_net(opt, param_dict) - - ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5) - ckpt_cb = ModelCheckpoint(prefix=ckpt_name, config=ckpt_cfg) - loss_cb = LossMonitor(steps_per_epoch) - - model = Model(net, loss, opt, metrics={'acc', 'loss'}) - model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb]) - - metrics = model.eval(ds_eval) - print('Metrics:', metrics) - - -if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser() - parser.add_argument('--data_url', required=True, default=None, help='Location of data.') - parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.') - parser.add_argument('--num_epochs', type=int, default=1, help='Number of training epochs.') - args, unknown = parser.parse_known_args() - - import moxing as mox - mox.file.copy_parallel(src_url=args.data_url, dst_url='MNIST/') - - os.system('rm -f *.ckpt *.ir *.meta') # 清理旧的运行文件 - - test_train() - print('\n'.join(sorted([x for x in os.listdir('.') if x.startswith('b_lenet')]))) - - resume_train() - print('\n'.join(sorted([x for x in os.listdir('.') if x.startswith('b_lenet')]))) - \ No newline at end of file diff --git a/experiment_5/LeNet_MNIST_Windows.md b/experiment_5/LeNet_MNIST_Windows.md deleted file mode 100644 index 002be9e..0000000 --- a/experiment_5/LeNet_MNIST_Windows.md +++ /dev/null @@ -1,171 +0,0 @@ -# 在Windows上运行LeNet_MNIST - -## 实验介绍 - -LeNet5 + MINST被誉为深度学习领域的“Hello world”。本实验主要介绍使用MindSpore在Windows环境下MNIST数据集上开发和训练一个LeNet5模型,并验证模型精度。 - -## 实验目的 - -- 了解如何使用MindSpore进行简单卷积神经网络的开发。 -- 了解如何使用MindSpore进行简单图片分类任务的训练。 -- 了解如何使用MindSpore进行简单图片分类任务的验证。 - -## 预备知识 - -- 熟练使用Python,了解Shell及Linux操作系统基本知识。 -- 具备一定的深度学习理论知识,如卷积神经网络、损失函数、优化器,训练策略等。 -- 了解并熟悉MindSpore AI计算框架,MindSpore官网:[https://www.mindspore.cn](https://www.mindspore.cn/) - -## 实验环境 - -- Windows-x64版本MindSpore 0.3.0;安装命令可见官网: - - [https://www.mindspore.cn/install](https://www.mindspore.cn/install)(MindSpore版本会定期更新,本指导也会定期刷新,与版本配套)。 - -## 实验准备 - -### 创建目录 - -创建一个experiment文件夹,用于存放实验所需的文件代码等。 - -### 数据集准备 - -MNIST是一个手写数字数据集,训练集包含60000张手写数字,测试集包含10000张手写数字,共10类。MNIST数据集的官网:[THE MNIST DATABASE](http://yann.lecun.com/exdb/mnist/)。 - -从MNIST官网下载如下4个文件到本地并解压: - -``` -train-images-idx3-ubyte.gz: training set images (9912422 bytes) -train-labels-idx1-ubyte.gz: training set labels (28881 bytes) -t10k-images-idx3-ubyte.gz: test set images (1648877 bytes) -t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes) -``` - -### 脚本准备 - -从[课程gitee仓库](https://gitee.com/mindspore/course)上下载本实验相关脚本。 - -### 准备文件 - -将脚本和数据集放到到experiment文件夹中,组织为如下形式: - -``` -experiment -├── MNIST -│ ├── test -│ │ ├── t10k-images-idx3-ubyte -│ │ └── t10k-labels-idx1-ubyte -│ └── train -│ ├── train-images-idx3-ubyte -│ └── train-labels-idx1-ubyte -└── main.py -``` - -## 实验步骤 - -### 导入MindSpore模块和辅助模块 - -```python -import matplotlib.pyplot as plt - -import mindspore as ms -import mindspore.context as context -import mindspore.dataset.transforms.c_transforms as C -import mindspore.dataset.transforms.vision.c_transforms as CV - -from mindspore import nn -from mindspore.model_zoo.lenet import LeNet5 -from mindspore.train import Model -from mindspore.train.callback import LossMonitor - -context.set_context(mode=context.GRAPH_MODE, device_target='CPU') -``` - -### 数据处理 - -在使用数据集训练网络前,首先需要对数据进行预处理,如下: - -```python -DATA_DIR_TRAIN = "MNIST/train" # 训练集信息 -DATA_DIR_TEST = "MNIST/test" # 测试集信息 - -def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32), rescale=1 / (255 * 0.3081), shift=-0.1307 / 0.3081, buffer_size=64): - ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST) - - ds = ds.map(input_columns="image", operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()]) - ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32)) - ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True).repeat(num_epoch) - - return ds -``` - -对其中几张图片进行可视化,可以看到图片中的手写数字,图片的大小为32x32。 - -```python -def show_dataset(): - ds = create_dataset(training=False) - data = ds.create_dict_iterator().get_next() - images = data['image'] - labels = data['label'] - - for i in range(1, 5): - plt.subplot(2, 2, i) - plt.imshow(images[i][0]) - plt.title('Number: %s' % labels[i]) - plt.xticks([]) - plt.show() -``` - -![img](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAATsAAAD7CAYAAAAVQzPHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAcm0lEQVR4nO3deZRV1Zk28OepQWaBYrIQAkZBIKyICjjE1U3aEDHdaU268QuiTRxCVqKt+aJGErOiMdqxTaL9pfOZDh0ZooKxo+0QtQnNEhLRBis4oSggDhArTIIWU0FVvf3HPexzCupW3enc4ezntxar3nvGXfCy795n2JtmBhGRpKsqdQFERIpBlZ2IeEGVnYh4QZWdiHhBlZ2IeEGVnYh4QZVdBkguIHlbqcshUmg+5XZFVnYk3yG5lWSvyLIrSS4vYbEKiuRnSK4huZfkZpIXlbpMEr+k5zbJO4N8/ojkuyRvKta5K7KyC9QAuLbUhcgWyeoMthkHYBGAmwD0BTABwB9jLpqUj8TmNoB7AYwxs2MBnA3gYpJfjLdkKZVc2f0IwPUk+x25guRIkkayJrJsOckrg/jLJFeSvJvkbpKbSJ4dLN9MchvJWUccdiDJpSSbSK4gOSJy7DHBug9IvhlthQXdhJ+TfIrkXgCfzuB3+y6AX5jZ02bWYmY7zeytLP9+pHIlNrfN7E0z2xtZ1AbgpIz/ZvJQyZVdA4DlAK7Pcf8zALwCYABSragHAUxC6i/+EgA/I9k7sv1MAD8AMBDASwAeAICgu7E0OMZgADMA3EPyE5F9LwZwO4A+AJ4leTHJVzop25nBsV8l2UjyfpJ1Of6eUnmSnNsgOYfkHgBbAPQKjh+7Sq7sAOB7AP6R5KAc9n3bzOabWSuAXwMYDuBWM2s2s98BOIj23zhPmtnvzawZqe7lWSSHA/gbAO8Ex2oxszUAHgbw95F9HzOzlWbWZmYHzGyRmX2yk7INA3ApgL8DMApADwD/msPvKJUrqbkNM7sDqcrxNAD3Afgwh98xaxVd2ZnZWgC/BTAnh923RuL9wfGOXBb99tscOe8eAB8AGApgBIAzgi7DbpK7kfqmPK6jfTO0H8B8M1sfnOufAHwuy2NIBUtwbh8+j5nZi0FZvp/LMbJV0/UmZe9mAGsA/CSy7PA1gZ4APgri6D9QLoYfDoIuQB2A95H6x15hZlM72TfboWVeyWEfSZ4k5vaRagCcmOcxMlLRLTsAMLONSDXVr4ks2w7gTwAuIVlN8nLk/xf6OZLnkDwGqesbq8xsM1LfvqNJXkqyNvgzieTYPM41H8BlJD9OsieAG4PziEeSltskq0h+lWR/pkwGcBWAZXmWPyMVX9kFbkXqQmfUVwDcAGAngE8AeC7PcyxC6pv2AwCnI9Wch5k1AfgsgC8h9W34ZwD/DKBbugORnEnytXTrzWwegF8BWAXgXQDNiCS8eCVRuQ3gCwDeAtAE4H6krkUX5Xo0NXiniPggKS07EZFOqbITES+oshMRL+RV2ZGcFrxCspFkLs8DiZQl5Xby5HyDgqmXftcDmIrUax8vAJhhZq8XrngixafcTqZ8HiqeDGCjmW0CAJIPArgAQNqEOIbdrPtRd9GlFJqwa4eZ5fIqkg+U2xXqAPbioDWzo3X5VHbHo/2rIluQegG5HZKzAcwGgO7oiTN4bh6nlEL5b/vNu6UuQxlTbleoVZb++eR8rtl1VHse1Sc2s7lmNtHMJtamfxZRpJwotxMon8puCyLv1CE1Usf7+RVHpCwotxMon8ruBQCjSJ4QvFP3JQCPF6ZYIiWl3E6gnK/ZmVkLyasBLAFQDWCemXX2TpxIRVBuJ1NeQzyZ2VMAnipQWUTKhnI7efQGhYh4QZWdiHghCSMVl43G6852cdO4g1ntywPhLHQnzwmfXW1rasq/YCKilp2I+EGVnYh4Qd3YAqo7L3zu9JXxj2a177qD+1x83a0XhivUjfVaVc+eLn7v2gkubiuTFzYGvdTi4h6Pri5hSbqmlp2IeEGVnYh4Qd3YPO37YjgYxuQBL5SwJJIU1YPCkbca/88oFy/52p0uHlbTG+Vg6rrPu/jg/okuPmZJQymK0ym17ETEC6rsRMQLquxExAu6ZpdGVZ8+Lm457aS02116+xMunt03uyHPtrXudfFdWz8brmhp6WBr8cWhccNc/OJ37omsKY/rdFFLx4b5f8nNU1y89cCpLmZLOO5p1aq1LrYi57ladiLiBVV2IuIFdWMj2C18LH3flLEuXvGLubGcb+6u01383hl7I2v2Hr2xSJm7f+Ty8MPiMF5/KMznb5x/WbjNzt0ubNv9YbtjWXNzoYunlp2I+EGVnYh4Qd3YiN3TwztIi2//cWRN+d0FE6kUJ9b0cPG/PD3fxa0Wzlh52Xe/2W6fvvf/T8HLoZadiHhBlZ2IeMH7buy2q8Oh1L99zQMuPqG2cF3XU1bPcPHgu7u7uHrvochWayECADVrNrr4L78628ULfnaXizPJz3R5l6np/7bExdk+MB9VzbBNNbq2V4fbtNayw+WF1GXLjuQ8kttIro0sqyO5lOSG4Gf/eIspUnjKbb9k0o1dAGDaEcvmAFhmZqMALAs+i1SaBVBue6PLbqyZ/Z7kyCMWXwBgShAvBLAcwI0FLFfRHBgYxhf1/jD9hnn4aGfYdD9uRTjOl3W0sRRNueZ2dEa5Hr972cUzbrrexZl0+4as3+9irnyxw22q+/V18Z8WDG23bkrPDZFPHXc/03loT3jcH/50pouf+FbpxuTL9QbFEDNrBIDg5+B0G5KcTbKBZMMhFP6paJECU24nVOx3Y81srplNNLOJtSiTWUJECkC5XVlyvRu7lWS9mTWSrAewrZCFitsHl53l4jOnvRrLOaasDWcIG/6EnvCpIGWV29F3RPN50NY+Fc5MtuHy8L99VbdWFzec/v/b7dO/Oruua9T6A/UuPm7eSy4+r8+3XBydIW3kml3t9m/L+czp5fq/8HEAs4J4FoDHClMckZJTbidUJo+eLAbwPICTSW4heQWAOwBMJbkBwNTgs0hFUW77JZO7sTPSrDq3wGWJ1e5/CLuu478aPsA7/2N/iP3c2yeEf83HHhuWo9+vno/93JJeUnI7ndYpp7l4y9fDB9jfPmdBmj16plmemQeaBrj4vsc/7eKR+8I8H/bD5zrcN45u65F0MUlEvKDKTkS84M27sT0uaXRxMbquy8c/Gn4YH4bXNYZdi9V7wgm2ez6yKvYySTIdPC+cnHrP0FoXN52/x8VvnHNf7OVYs2eEiwevKUbHNDtq2YmIF1TZiYgXvOnGlouf1K9x8W237HPxHx7JfggeEQCw63a4+IXo5ZMii+b23Nv/7OLfROZE1ryxIiIxU2UnIl5IdDe2ekCdi7vXHOpky9LoVhWWqXrQcBe37gi7JTANBCWVJzqy8ZRF4Tu30Xlj29ZvcnExurRq2YmIF1TZiYgXEt2NPf6pcHicu49/KrKmPO58XtP/DRePem6ri+eeHb4/27p9e1HLJFJo6eaNvfriq1zMlS8hbmrZiYgXVNmJiBdU2YmIFxJ9zW5Ej50u7l3V9XW6S96Z4uLXfzXWxWu+9/OCluuwbgxf2j6/Zzgs9YZnwlvyy74cXr+zBk2kLUfr8c0wt0+5LRyi7+XJi0tRnKOkmyTbasIZ0uKfIlstOxHxhCo7EfFCIrqx6Sb6vajvLyNbdT1T0pY9/Vx83MMbXTyp+Wtp97nh24vC8+UxyXa0S3tD3Vsu/l2vv3CxvpmkI21rw0eYhvwonEVs0uj0eZutv7g6HG8x+sJ/JdH/HxHxgio7EfFCIrqx6BbOtvvghHtdHL3zk63omwt189O/xfDDXjNdfPPAcHl08u1iDAMvArR/E6FuZeGO+8cZI8MPSe3GkhxO8hmS60i+RvLaYHkdyaUkNwQ/+8dfXJHCUW77JZNubAuA68xsLIAzAVxFchyAOQCWmdkoAMuCzyKVRLntkUwmyW4E0BjETSTXATgewAUApgSbLQSwHMCNsZSySP56aNj1nHfLeSUsiRRDueR21Slj231+5wvl15C8fOiSUhchb1ndoCA5EsCpAFYBGBIky+GkGZxmn9kkG0g2HEJzR5uIlJxyO/kyruxI9gbwMIBvmNlHme5nZnPNbKKZTaxFt653ECky5bYfMrobS7IWqWR4wMweCRZvJVlvZo0k6wFsi6uQxRJ9mPeG2feUsCRSLOWQ2ztO69fu8zrlXiwyuRtLAPcCWGdmd0VWPQ5gVhDPAvBY4YsnEh/ltl8yadl9CsClAF4lefghnu8AuAPAQySvAPAegOnxFFEkNsptj2RyN/ZZpB+B5dzCFkekeEqZ2zXDh7m4aWQxBjgSvS4mIl5QZSciXkjGu7Ft4UTSbx6KPhIV3kQbXhPW65mMWlwMzRZOkr3pUMeTeLNFk2Qn0aYrPubiN76SvLuv5ZjbatmJiBdU2YmIFxLRjW3dscPF0QmmURXe5WpbHI4E/F9jnixKubry011jXPzMuSd2uE3VznCSHXVopVKUY26rZSciXlBlJyJeSEQ3FhY2gqMjDEc1t4wsUmGOdsrqcC7PwXeHd4Kr94Z3qWyr5oT1ycfvfc/FY/j1dusq9e5sdN7lHdeED02XS26rZSciXlBlJyJeSEY3NgP8STgbzqShhZtPMxND1u8Py7HyRRfr7qq/WjZvcfGJC9v/NxyDsFtb7l3aqes+7+KWO4e4+JiGhlIUp1Nq2YmIF1TZiYgXVNmJiBe8uWZ3zJLwGkJdCcshcqSWTe+0+3ziL1tcPNa+jnI26KWwrD2WrC5hSbqmlp2IeEGVnYh4wZturEiliD6W8rFbtnSypWRDLTsR8YIqOxHxQibzxnYnuZrkyyRfI/n9YHkdyaUkNwQ/+8dfXJHCUW77JZOWXTOAvzKzUwBMADCN5JkA5gBYZmajACwLPotUEuW2R7qs7CxlT/CxNvhjAC4AsDBYvhDAhbGUUCQmym2/ZHTNjmR1MGP6NgBLzWwVgCFm1ggAwc/BnR1DpBwpt/2RUWVnZq1mNgHAMACTSY7P9AQkZ5NsINlwCM25llMkFsptf2R1N9bMdgNYDmAagK0k6wEg+LktzT5zzWyimU2sRbc8iysSD+V28mVyN3YQyX5B3APAZwC8AeBxALOCzWYBeCyuQorEQbntl0zeoKgHsJBkNVKV40Nm9luSzwN4iOQVAN4DMD3GcorEQbntkS4rOzN7BcCpHSzfCeDcOAolUgzKbb/QrHiDg5PcDuDdop1QOjPCzAaVuhBJodwuG2nzuqiVnYhIqejdWBHxgio7EfGCKrsMkFxA8rZSl0Ok0HzK7Yqs7Ei+Q3IryV6RZVeSXF7CYhUMyTtJbib5Ecl3Sd5U6jJJcSQ9twGA5GdIriG5N8jzi4px3oqs7AI1AK4tdSGyFTzT1ZV7AYwxs2MBnA3gYpJfjLdkUkYSm9skxwFYBOAmAH2RGm3mjzEXDUBlV3Y/AnD94Sfgo0iOJGkkayLLlpO8Moi/THIlybtJ7ia5ieTZwfLNJLeRnHXEYQcGY5s1kVxBckTk2GOCdR+QfDP6TRV0E35O8imSewF8uqtfzMzeNLO9kUVtAE7K+G9GKl1icxvAdwH8wsyeNrMWM9tpZm9l+feTk0qu7BqQepfx+hz3PwPAKwAGIPVN8yCASUhVKpcA+BnJ3pHtZwL4AYCBAF4C8AAABN2NpcExBgOYAeAekp+I7HsxgNsB9AHwLMmLSb7SWeFIziG5B8AWAL2C44sfkpzbZwbHfpVkI8n7SRZldtNKruwA4HsA/pFkLg/Hvm1m882sFcCvAQwHcKuZNZvZ7wAcRPvW1JNm9nsza0aqCX4WyeEA/gbAO8GxWsxsDYCHAfx9ZN/HzGylmbWZ2QEzW2Rmn+yscGZ2B1IJdBqA+wB8mMPvKJUrqbk9DMClAP4OwCgAPQD8aw6/Y9YqurIzs7UAfovcRpLdGon3B8c7cln0229z5Lx7AHwAYCiAEQDOCLoMu0nuRuqb8riO9s1GMLjki0FZvp/LMaQyJTi39wOYb2brg3P9E4DPZXmMnCRhKsWbAawB8JPIssPXu3oC+CiIo/9AuRh+OAi6AHUA3kfqH3uFmU3tZN98X1OpAXBinseQypPE3H4lh30KoqJbdgBgZhuRaqpfE1m2HcCfAFzC1Ei0lyP/yuJzJM8heQxS1zdWmdlmpL59R5O8lGRt8GcSybG5nIRkFcmvkuzPlMkArkJqLgTxSNJyOzAfwGUkP06yJ4Abg/PEruIru8CtSF3Ej/oKgBsA7ATwCQDP5XmORUh9034A4HSkmvMwsyYAnwXwJaS+Df8M4J+B9KM5kpxJ8rVOzvUFAG8BaAJwP1LXNIpyXUPKTqJy28zmAfgVgFVIDZzQjEhlHicNBCAiXkhKy05EpFOq7ETEC3lVdiSnBU9VbySpiYQlMZTbyZPzNTum3oNbD2AqUk/5vwBghpm9XrjiiRSfcjuZ8nnObjKAjWa2CQBIPojUTOppE+IYdrPuR91YklJowq4dGpY9LeV2hTqAvThozexoXT6V3fFo//T0FqTeyUurO3rhDGoek3Lw3/YbzZeQnnK7Qq2y9I+j5lPZdVR7HtUnJjkbwGwA6I6eeZxOpGiU2wmUzw2KLYi8ZoLUC77vH7mRZk2XCqTcTqB8KrsXAIwieULwmsmXkJpJXaTSKbcTKOdurJm1kLwawBIA1QDmmVlnr0CJVATldjLlNeqJmT0F4KkClUWkbCQttw+eN9HFdt2OjPbp8c3uLm5b+0bBy1RseoNCRLygyk5EvJCEwTtFpAt7hta6+IXxj2a0z9QBl7k4Ca2iJPwOIiJdUmUnIl5IdDd229Vnu/jAwPjPN/I/d7m47eV18Z9QRDKmlp2IeEGVnYh4IRHdWHYL30vcPf1UF3/7mgdcfFHv+OeYPmHwbBcPfOEsF/dfv9/FXPlS7OUQkaOpZSciXlBlJyJeUGUnIl5IxDW7qn59XTz/trtcPPaY4g6o+PaFc8MPF4bhKatnuHjoh2NcnISXq6V81Qwf5uKmkR2OVO4VtexExAuq7ETEC4noxpa7lycvdvGUu8L+bbfPlqI0kmRVffq4eP1V4cjyG/7hnlIUp6yoZSciXlBlJyJeUDdWJEHevGOci//nb38cWaMJvNWyExEvqLITES8kohvbtvMDF1878+sutpqOH6Tc9n8PuDh6p1Sk0ln3VhcPrlbXNarLlh3JeSS3kVwbWVZHcinJDcHP/vEWU6TwlNt+yaQbuwDAtCOWzQGwzMxGAVgWfBapNAug3PZGl91YM/s9yZFHLL4AwJQgXghgOYAbC1iurFhLi4uj48WlextwSMsEF08a/bUuj9/SKzzSE9+6s926YTW9MyyllJtKyO24bWnZ4+LP3/mtduvqX9/g4lZUvlxvUAwxs0YACH4OLlyRREpKuZ1Qsd+gIDkbwGwA6I7ijkIiEifldmXJtbLbSrLezBpJ1gPYlm5DM5sLYC4AHMs6y/F8BRXt6tat7Hr76iHhl3vT9XpaJ+EqOrez1dQW5nP9f2xst651+/ZiFydWuf7PfRzArCCeBeCxwhRHpOSU2wmVyaMniwE8D+BkkltIXgHgDgBTSW4AMDX4LFJRlNt+yeRu7Iw0q84tcFnKSnSU1+hQOQOqK7K3Ih3wNbd9pQtQIuIFVXYi4oVEvBubLftU+FDxrtE9OtwmOkFJ+1Fe9b6h+OfgeRNdvGdobVb7Vh8KL/30+48XXWzNzfkXLAtq2YmIF1TZiYgXvOzGbrg8/LXfPv/nRT33sN67Xbxj4ngXW8PajjYX6VLV+HAu4mMH7O1y+22t4TZ3bY3M+hR5xxwAGMnP428OHzi+f+TyrMr39qHw/dsvf/hNF/dcvs7FbU1NWR0zF2rZiYgXVNmJiBe87MaWUrQLcNu8sPvxh092L0FpJAn23xUZeXv8o11uP3fX6S5+78x9Lq4e2H4wg3MXPO/iG+reyrl8J9SGw6Ct+MVcF0+dcZmLq1a8iLipZSciXlBlJyJeUDdWxGPVAwe6ePZzz7dbd37PXZFP2T1IXI7UshMRL6iyExEvqLITES/omp2Iz6rCAS9Orm0/An03dj2vximrwyEBm18Op9h94yv3dLR5O9P/bYmL77vp8+3W9XxkVZf7Z0stOxHxgio7EfGCl93YUfPCF54nPdv1JNmdueHbi1x8Ue8P8zqWSDFc1PePLn528YkuHl6TWdtnzLOXuvhj/6/axbtGZzdlwey+77v43/u2P3ccE1OqZSciXlBlJyJe8LIbm+0k2VV9+rj4zTvGtVs3snZH5FPlP2UuyTe6Npxa4L/GPBlZk34wimjXddg9YZ5z5ZrIgc8qSPniksm8scNJPkNyHcnXSF4bLK8juZTkhuBn/66OJVJOlNt+yaQb2wLgOjMbC+BMAFeRHAdgDoBlZjYKwLLgs0glUW57JJNJshsBNAZxE8l1AI4HcAGAKcFmCwEsB3BjLKUsMfYMZyB78q/vbrdu7DFx3DeSYlBuZ67P0+GYdAf7tbp4+y1nu9jGZTe0evSB5CHr9+dRusxkdYOC5EgApwJYBWBIkCyHk2ZwoQsnUizK7eTLuLIj2RvAwwC+YWYfZbHfbJINJBsOobjzRIpkQrnth4zuxpKsRSoZHjCzR4LFW0nWm1kjyXoA2zra18zmApgLAMeyLrunDjPEbt1cvHv6qS5urWVHm2etpVd4nD5VbXkda/n+8Pvl3tXnuHg0GvI6ruSm3HM7E++vqXfxAyMGuHhmn50FO8eOSWHX9aSTG128buwTOR9z8N3h3V+uLINh2UkSwL0A1pnZXZFVjwOYFcSzADxW+OKJxEe57ZdMWnafAnApgFdJHn5A7TsA7gDwEMkrALwHYHo8RRSJjXLbI5ncjX0WQLr+4LmFLU7mog/67psy1sWLb/+xi6OzGhVOfse85a2/dfHoK9V1LaVyze1snTAnHE79u4O+4OKZ5/+yYOd4+8K5XW+URrMdcvFPd4Uz6lXvDZcX4xqAXhcTES+oshMRL1Tsu7Etp53k4ujEu/l2M+OwqzWciHjXvvAB5eNKURhJNB4Ih1xadzDMu+hTBMNq4v8/Eu26Pr0vfNvumU9/3MW2fW3s5YhSy05EvKDKTkS8ULHd2EoyccVVLj756k0ubu1oY5E8nDzndRdfd+uFLm6cHl72efE7XU+Gk6/oXddo17V1x46ONi8KtexExAuq7ETEC+rGFkFbc3iHrHW3JuWR+LQ1RYZZisT1vw4f25366mWxl6PdA8NFvuuajlp2IuIFVXYi4oWK7cbWvr7FxZNu6nju11LO6RqdoCQ6T61IKbRu3+7iqhXbO9myMEo23lUn1LITES+oshMRL6iyExEvVOw1u+g1iLr5HV+D+GGvmS6+eWDsRWpn2PIDLm43kbCIlIRadiLiBVV2IuKFiu3GZmLwz54rdRFEpEyoZSciXlBlJyJeUGUnIl7IZJLs7iRXk3yZ5Gskvx8sryO5lOSG4Gf/ro4lUk6U237JpGXXDOCvzOwUABMATCN5JoA5AJaZ2SgAy4LPIpVEue2RLis7S9kTfKwN/hiACwAsDJYvBHBhB7uLlC3ltl8yumZHsprkSwC2AVhqZqsADDGzRgAIfg6Or5gi8VBu+yOjys7MWs1sAoBhACaTHJ/pCUjOJtlAsuEQmnMtp0gslNv+yOpurJntBrAcwDQAW0nWA0Dwc1uafeaa2UQzm1iLbnkWVyQeyu3ky+Ru7CCS/YK4B4DPAHgDwOMAZgWbzQLwWFyFFImDctsvmbwuVg9gIclqpCrHh8zstySfB/AQySsAvAdgeozlFImDctsjNCveAMoktwN4t2gnlM6MMLNBpS5EUii3y0bavC5qZSciUip6XUxEvKDKTkS8oMpORLygyk5EvKDKTkS8oMpORLygyk5EvKDKTkS8oMpORLzwv9NPrlrn6D7QAAAAAElFTkSuQmCC) - -### 定义模型 - -MindSpore model_zoo中提供了多种常见的模型,可以直接使用。这里使用其中的LeNet5模型,模型结构如下图所示: - -![img](https://www.mindspore.cn/tutorial/zh-CN/master/_images/LeNet_5.jpg) - -图片来源于http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf - -### 训练 - -使用MNIST数据集对上述定义的LeNet5模型进行训练。训练策略如下表所示,可以调整训练策略并查看训练效果,要求验证精度大于95%。 - -| batch size | number of epochs | learning rate | optimizer | -| ---------: | ---------------: | ------------: | -----------: | -| 32 | 3 | 0.01 | Momentum 0.9 | - -```python -def test_train(lr=0.01, momentum=0.9, num_epoch=3, ckpt_name="a_lenet"): - ds_train = create_dataset(num_epoch=num_epoch) - ds_eval = create_dataset(training=False) - - net = LeNet5() - loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') - opt = nn.Momentum(net.trainable_params(), lr, momentum) - - loss_cb = LossMonitor(per_print_times=1) - - model = Model(net, loss, opt, metrics={'acc', 'loss'}) - model.train(num_epoch, ds_train, callbacks=[loss_cb], dataset_sink_mode=False) - metrics = model.eval(ds_eval, dataset_sink_mode=False) - print('Metrics:', metrics) -``` - -### 实验结果 - -1. 在训练日志中可以看到`epoch: 1 step: 1875, loss is 0.29772663`等字段,即训练过程的loss值; -2. 在训练日志中可以看到`Metrics: {'loss': 0.06830393138807267, 'acc': 0.9785657051282052}`字段,即训练完成后的验证精度。 - -```python - ... ->>> epoch: 1 step: 1875, loss is 0.29772663 - ... ->>> epoch: 2 step: 1875, loss is 0.049111396 - ... ->>> epoch: 3 step: 1875, loss is 0.08183163 ->>> Metrics: {'loss': 0.06830393138807267, 'acc': 0.9785657051282052} -``` - -## 实验小结 - -本实验展示了如何使用MindSpore进行手写数字识别,以及开发和训练LeNet5模型。通过对LeNet5模型做几代的训练,然后使用训练后的LeNet5模型对手写数字进行识别,识别准确率大于95%。即LeNet5学习到了如何进行手写数字识别。 \ No newline at end of file diff --git a/experiment_5/main.py b/experiment_5/main.py deleted file mode 100644 index 2d4c691..0000000 --- a/experiment_5/main.py +++ /dev/null @@ -1,62 +0,0 @@ -# LeNet5 mnist -import matplotlib.pyplot as plt - -import mindspore as ms -import mindspore.context as context -import mindspore.dataset.transforms.c_transforms as C -import mindspore.dataset.transforms.vision.c_transforms as CV - -from mindspore import nn -from mindspore.model_zoo.lenet import LeNet5 -from mindspore.train import Model -from mindspore.train.callback import LossMonitor - -context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - -DATA_DIR_TRAIN = "MNIST/train" # 训练集信息 -DATA_DIR_TEST = "MNIST/test" # 测试集信息 - - -def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32), - rescale=1 / (255 * 0.3081), shift=-0.1307 / 0.3081, buffer_size=64): - ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST) - - ds = ds.map(input_columns="image", operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()]) - ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32)) - ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True).repeat(num_epoch) - - return ds - - -def test_train(lr=0.01, momentum=0.9, num_epoch=3, ckpt_name="a_lenet"): - ds_train = create_dataset(num_epoch=num_epoch) - ds_eval = create_dataset(training=False) - - net = LeNet5() - loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') - opt = nn.Momentum(net.trainable_params(), lr, momentum) - - loss_cb = LossMonitor(per_print_times=1) - - model = Model(net, loss, opt, metrics={'acc', 'loss'}) - model.train(num_epoch, ds_train, callbacks=[loss_cb], dataset_sink_mode=False) - metrics = model.eval(ds_eval, dataset_sink_mode=False) - print('Metrics:', metrics) - -def show_dataset(): - ds = create_dataset(training=False) - data = ds.create_dict_iterator().get_next() - images = data['image'] - labels = data['label'] - - for i in range(1, 5): - plt.subplot(2, 2, i) - plt.imshow(images[i][0]) - plt.title('Number: %s' % labels[i]) - plt.xticks([]) - plt.show() - -if __name__ == "__main__": - show_dataset() - - test_train() \ No newline at end of file diff --git a/experiment_6/Save_And_Load_Model_Windows.md b/experiment_6/Save_And_Load_Model_Windows.md deleted file mode 100644 index 1c1c805..0000000 --- a/experiment_6/Save_And_Load_Model_Windows.md +++ /dev/null @@ -1,346 +0,0 @@ -# 在Windows上运行训练时模型的保存和加载 - -## 实验介绍 - -本实验主要介绍在Windows环境下使用MindSpore实现训练时模型的保存和加载。建议先阅读MindSpore官网教程中关于模型参数保存和加载的内容。 - -在模型训练过程中,可以添加检查点(CheckPoint)用于保存模型的参数,以便进行推理及中断后再训练使用。使用场景如下: - -- 训练后推理场景 - - 模型训练完毕后保存模型的参数,用于推理或预测操作。 - - 训练过程中,通过实时验证精度,把精度最高的模型参数保存下来,用于预测操作。 -- 再训练场景 - - 进行长时间训练任务时,保存训练过程中的CheckPoint文件,防止任务异常退出后从初始状态开始训练。 - - Fine-tuning(微调)场景,即训练一个模型并保存参数,基于该模型,面向第二个类似任务进行模型训练。 - -## 实验目的 - -- 了解如何使用MindSpore实现训练时模型的保存。 -- 了解如何使用MindSpore加载保存的模型文件并继续训练。 -- 了解如何MindSpore的Callback功能。 - -## 预备知识 - -- 熟练使用Python,了解Shell及Linux操作系统基本知识。 -- 具备一定的深度学习理论知识,如卷积神经网络、损失函数、优化器,训练策略、Checkpoint等。 -- 了解并熟悉MindSpore AI计算框架,MindSpore官网:https://www.mindspore.cn/ - -## 实验环境 - -- Windows-x64版本MindSpore 0.3.0;安装命令可见官网: - - [https://www.mindspore.cn/install](https://www.mindspore.cn/install)(MindSpore版本会定期更新,本指导也会定期刷新,与版本配套)。 - -## 实验准备 - -### 创建目录 - -创建一个experiment文件夹,用于存放实验所需的文件代码等。 - -### 数据集准备 - -MNIST是一个手写数字数据集,训练集包含60000张手写数字,测试集包含10000张手写数字,共10类。MNIST数据集的官网:[THE MNIST DATABASE](http://yann.lecun.com/exdb/mnist/)。 - -从MNIST官网下载如下4个文件到本地并解压: - -``` -train-images-idx3-ubyte.gz: training set images (9912422 bytes) -train-labels-idx1-ubyte.gz: training set labels (28881 bytes) -t10k-images-idx3-ubyte.gz: test set images (1648877 bytes) -t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes) -``` - -### 脚本准备 - -从[课程gitee仓库](https://gitee.com/mindspore/course)上下载本实验相关脚本。 - -### 准备文件 - -将脚本和数据集放到到experiment文件夹中,组织为如下形式: - -``` -experiment -├── MNIST -│ ├── test -│ │ ├── t10k-images-idx3-ubyte -│ │ └── t10k-labels-idx1-ubyte -│ └── train -│ ├── train-images-idx3-ubyte -│ └── train-labels-idx1-ubyte -└── main.py -``` - -## 实验步骤 - -### 导入MindSpore模块和辅助模块 - -```python -import matplotlib.pyplot as plt -import numpy as np - -import mindspore as ms -import mindspore.context as context -import mindspore.dataset.transforms.c_transforms as C -import mindspore.dataset.transforms.vision.c_transforms as CV - -from mindspore import nn, Tensor -from mindspore.train import Model -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor -from mindspore.train.serialization import load_checkpoint, load_param_into_net - -context.set_context(mode=context.GRAPH_MODE, device_target='CPU') -``` - -### 数据处理 - -在使用数据集训练网络前,首先需要对数据进行预处理,如下: - -```python -DATA_DIR_TRAIN = "MNIST/train" # 训练集信息 -DATA_DIR_TEST = "MNIST/test" # 测试集信息 - - -def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32), - rescale=1 / (255 * 0.3081), shift=-0.1307 / 0.3081, buffer_size=64): - ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST) - - # define map operations - resize_op = CV.Resize(resize) - rescale_op = CV.Rescale(rescale, shift) - hwc2chw_op = CV.HWC2CHW() - - # apply map operations on images - ds = ds.map(input_columns="image", operations=[resize_op, rescale_op, hwc2chw_op]) - ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32)) - - ds = ds.shuffle(buffer_size=buffer_size) - ds = ds.batch(batch_size, drop_remainder=True) - ds = ds.repeat(num_epoch) - - return ds -``` - -### 定义模型 - -定义LeNet5模型,模型结构如下图所示: - -![img](https://www.mindspore.cn/tutorial/zh-CN/master/_images/LeNet_5.jpg) - -图片来源于http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf - -```python -class LeNet5(nn.Cell): - def __init__(self): - super(LeNet5, self).__init__() - self.relu = nn.ReLU() - self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid') - self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid') - self.pool = nn.MaxPool2d(kernel_size=2, stride=2) - self.flatten = nn.Flatten() - self.fc1 = nn.Dense(400, 120) - self.fc2 = nn.Dense(120, 84) - self.fc3 = nn.Dense(84, 10) - - def construct(self, input_x): - output = self.conv1(input_x) - output = self.relu(output) - output = self.pool(output) - output = self.conv2(output) - output = self.relu(output) - output = self.pool(output) - output = self.flatten(output) - output = self.fc1(output) - output = self.fc2(output) - output = self.fc3(output) - - return output -``` - -### 保存模型Checkpoint - -MindSpore提供了Callback功能,可用于训练/测试过程中执行特定的任务。常用的Callback如下: - -- `ModelCheckpoint`:保存网络模型和参数,用于再训练或推理; -- `LossMonitor`:监控loss值,当loss值为Nan或Inf时停止训练; -- `SummaryStep`:把训练过程中的信息存储到文件中,用于后续查看或可视化展示。 - -`ModelCheckpoint`会生成模型(.meta)和Chekpoint(.ckpt)文件,如每个epoch结束时,都保存一次checkpoint。 - -```python -class CheckpointConfig: - """ - The config for model checkpoint. - - Args: - save_checkpoint_steps (int): Steps to save checkpoint. Default: 1. - save_checkpoint_seconds (int): Seconds to save checkpoint. Default: 0. - Can't be used with save_checkpoint_steps at the same time. - keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5. - keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0. - Can't be used with keep_checkpoint_max at the same time. - integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True. - Integrated save function is only supported in automatic parallel scene, not supported in manual parallel. - - Raises: - ValueError: If the input_param is None or 0. - """ - -class ModelCheckpoint(Callback): - """ - The checkpoint callback class. - - It is called to combine with train process and save the model and network parameters after traning. - - Args: - prefix (str): Checkpoint files names prefix. Default: "CKP". - directory (str): Lolder path into which checkpoint files will be saved. Default: None. - config (CheckpointConfig): Checkpoint strategy config. Default: None. - - Raises: - ValueError: If the prefix is invalid. - TypeError: If the config is not CheckpointConfig type. - """ -``` - -MindSpore提供了多种Metric评估指标,如`accuracy`、`loss`、`precision`、`recall`、`F1`。定义一个metrics字典/元组,里面包含多种指标,传递给`Model`,然后调用`model.eval`接口来计算这些指标。`model.eval`会返回一个字典,包含各个指标及其对应的值。 - -```python -def test_train(lr=0.01, momentum=0.9, num_epoch=2, check_point_name="b_lenet"): - ds_train = create_dataset(num_epoch=num_epoch) - ds_eval = create_dataset(training=False) - steps_per_epoch = ds_train.get_dataset_size() - - net = LeNet5() - loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') - opt = nn.Momentum(net.trainable_params(), lr, momentum) - - ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5) - ckpt_cb = ModelCheckpoint(prefix=check_point_name, config=ckpt_cfg) - loss_cb = LossMonitor(steps_per_epoch) - - model = Model(net, loss, opt, metrics={'acc', 'loss'}) - model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=False) - metrics = model.eval(ds_eval, dataset_sink_mode=False) - print('Metrics:', metrics) -``` - -### 加载Checkpoint继续训练 - -```python -def load_checkpoint(ckpoint_file_name, net=None): - """ - Loads checkpoint info from a specified file. - - Args: - ckpoint_file_name (str): Checkpoint file name. - net (Cell): Cell network. Default: None - - Returns: - Dict, key is parameter name, value is a Parameter. - - Raises: - ValueError: Checkpoint file is incorrect. - """ - -def load_param_into_net(net, parameter_dict): - """ - Loads parameters into network. - - Args: - net (Cell): Cell network. - parameter_dict (dict): Parameter dict. - - Raises: - TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dict. - """ -``` - -使用load_checkpoint接口加载数据时,需要把数据传入给原始网络,而不能传递给带有优化器和损失函数的训练网络。 - -```python -CKPT = 'b_lenet-2_1875.ckpt' - -def resume_train(lr=0.001, momentum=0.9, num_epoch=2, ckpt_name="b_lenet"): - ds_train = create_dataset(num_epoch=num_epoch) - ds_eval = create_dataset(training=False) - steps_per_epoch = ds_train.get_dataset_size() - - net = LeNet5() - loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') - opt = nn.Momentum(net.trainable_params(), lr, momentum) - - param_dict = load_checkpoint(CKPT) - load_param_into_net(net, param_dict) - load_param_into_net(opt, param_dict) - - ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5) - ckpt_cb = ModelCheckpoint(prefix=ckpt_name, config=ckpt_cfg) - loss_cb = LossMonitor(steps_per_epoch) - - model = Model(net, loss, opt, metrics={'acc', 'loss'}) - model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=False) - metrics = model.eval(ds_eval, dataset_sink_mode=False) - print('Metrics:', metrics) -``` - -### 加载Checkpoint进行推理 - -使用matplotlib定义一个将推理结果可视化的辅助函数,如下: - -```python -def plot_images(pred_fn, ds, net): - for i in range(1, 5): - pred, image, label = pred_fn(ds, net) - plt.subplot(2, 2, i) - plt.imshow(np.squeeze(image)) - color = 'blue' if pred == label else 'red' - plt.title("prediction: {}, truth: {}".format(pred, label), color=color) - plt.xticks([]) - plt.show() -``` - -使用训练后的LeNet5模型对手写数字进行识别,可以看到识别结果基本上是正确的。 - -```python -CKPT = 'b_lenet_1-2_1875.ckpt' - -def infer(ds, model): - data = ds.get_next() - images = data['image'] - labels = data['label'] - output = model.predict(Tensor(data['image'])) - pred = np.argmax(output.asnumpy(), axis=1) - return pred[0], images[0], labels[0] - -def test_infer(): - ds = create_dataset(training=False, batch_size=1).create_dict_iterator() - net = LeNet5() - param_dict = load_checkpoint(CKPT, net) - model = Model(net) - plot_images(infer, ds, model) -``` - -![img](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAUoAAAD7CAYAAAAMyN1hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAcv0lEQVR4nO3de5RU5Znv8e9D03TLJUIraCMgXgDJMlEZguRoEuJl1IkecpxlJo7jQpfaIdE1OmO8xMl9NPHkmMvMMpmIE4TxbtBRYuIkSjQRZVCWgSSKCqMIKHJROgEEpJvn/FGbXbuaqn5316Wrqvv3WatXP7v27d3VTz+1330rc3dERKSwAdVugIhIrVOhFBEJUKEUEQlQoRQRCVChFBEJUKEUEQmoyUJpxjwzbozij5nxSpHL+bEZXylv6+qDGePNcDMGVrstkqXcLl01crsmC2WSO0+7Myk0nRkXmbG4y7yz3fnnyrUuXvexZvzSjC1m9OjCVDNmmLG+DG1YY8ZppS4nsbxPmvGkGX8yY025litZyu3Uy6l6ble8UPaTPZo9wAPAJZVYeJXewx3AXOCaKqy7Lii3S1c3ue3uPf4BXwP+JfCXwLeC3wHeHI2bAb4e/Drwt8HvjF4/G3w5eDv4s+AfTizvBPAXwLeB3w9+H/iNyeUlph0L/hD4ZvB3wG8Fnwy+C7wTfDt4ezTtvH3LiYYvA18N/i74QvDRiXEOPht8VbRNPwS3Hr4vR4N7D6YfAr4TfG/U7u3go8G/Dr4A/C7wP4Nfmmdb4vcF/M5oGTujZVwLPj7aplnga8G3gP9TEX/r08DXFJMn9fij3FZu5/spZY/yAuAM4ChgIvDlxLhDgRbgcKDNjClkKvjngIOA24CFZjSZMQh4GLgzmuenwF/nW6EZDcCjwBvAeOAw4D53VgKzgSXuDHVneJ55TwG+DXwGaI2WcV+Xyc4GPgIcF013RjTvODPazRiX9s1Jw50dwFnAW1G7h7rzVjR6JrAAGA7cHVjOhcBa4JxoGd9JjD4ZmAScCnzVjMnRNp1sRns5t6cPUW6XqK/ldimF8lZ31rnzLnATcH5i3F7ga+7sdmcncBlwmztL3el0Zz6wG5ge/TQCP3BnjzsLgOcLrHMaMBq4xp0d7uxyzz12040LgLnuvODObuBLwEfNGJ+Y5mZ32t1ZCzwJHA/gzlp3hkev95Yl7jzszt7oPSzWN9zZ6c4KYAWZfxTcWZzvn04A5Xal1V1ul1Io1yXiN8j8kffZ7M6uxPDhwNXRJ1d7VO3HRvOMBt50zzlQ/EaBdY4F3nCno4j2jk4u153twDtkPrn3eTsRvwcMLWI95bIuPEkqtbRN9UK5XVl1l9ulFMqxiXgcxLvVwH5nx9YBN0WfXPt+BrtzL7ABOMwM67K8fNYB4wocAA6dkXuLTFIDYMYQMl2lNwPzVVqhdnd9fQcwODF8aMrlSM8pt8ujz+R2KYXycjPGmNEC3ADc3820twOzzTjRDDNjiBmfMmMYsAToAP7ejIFmnEumG5LPc2SS7+ZoGc1mnBSN2wiMiY4L5XMPcLEZx5vRBHwLWOpe+qUv0TY1Q2bdUbuaEuPnmTGvwOwbgYPMODCwmuXAX5nRYsahwFV5lnNkURuQhxkDom1qhMz2dfPe9jXK7YhyO6OUQnkP8CvgtejnxkITurOMzLGcW4GtwGrgomjc+8C50fBW4G+AhwospxM4BziazAHe9dH0AL8GXgTeNmNLnnkXAV8BHiSTkEcBn02zodEB7+3dHPA+HNgZrZ8oTl5IPBZ4psA2vQzcC7wWdd1G55uOzAmBFcAaMu9713/ebwNfjpbxxcAm7bvYeXs3k3w82o5fkNkL2hmttz9QbmcptwHLnCbvmegizUvdeaLHM/cz0SfVCuDD7uypdnuke8rt9PpTbveHC2arKtqrmFztdoiUW3/K7Zq/hVFEpNqK6nqLiPQnJe1RmtmZZvaKma02s+vL1SiRalNuS1LRe5Rm1gC8CpxO5gzd88D57v5S+Zon0vuU29JVKSdzpgGr3f01ADO7j8w9nAWTaZA1eTNDSlillMs2tm5x95HVbkeNUm7XqV3s4H3fbeEpe6aUQnkYubcirQdO7DqRmbUBbQDNDOZEO7WEVUq5POELCt1KJ8rturXUF1VkuaUco8xXtffrx7v7HHef6u5TG7MX9IvUMuW25CilUK4n957YMeTeEytSr5TbkqOUQvk8MMHMjjCzQWRumVpYnmaJVJVyW3IUfYzS3TvM7Argl0ADMNfdXwzMJlLzlNvSVUm3MLr7L8jcWC7Spyi3JUm3MIqIBKhQiogEqFCKiASoUIqIBKhQiogE6MG9KQwYnP3eo7VXHp8zbm+KGzJGLs9+sd4BDz9XtnaJSO/QHqWISIAKpYhIgLreKdiw7PeqL/jcLTnjJg8a3HXy/Rzx2KVxPPHh8rVLJK2BR46P402faE01z8hHV8dx5+bN5W5SXdEepYhIgAqliEiACqWISICOUYr0A8njks/f9G+p5pnS9Pk4PvTB7Ov98Xil9ihFRAJUKEVEAtT1FpG8Xvhqtov+kd3ZbnjLHep6i4hIFyqUIiIBKpQiIgEqlCIiASqUIiIBOust0g8kH3CRvJA8eWZbCgvuUZrZXDPbZGZ/TLzWYmaPm9mq6PeIyjZTpPyU25JWmq73PODMLq9dDyxy9wnAomhYpN7MQ7ktKQS73u7+WzMb3+XlmcCMKJ4PPAVcV8Z2iVRcf8rt5P3ZB60c0+P5P37F0jhe3PnROB7+H0tKa1idKPZkziHuvgEg+j2q0IRm1mZmy8xs2R52F7k6kV6j3Jb9VPyst7vPcfep7j61kRTfxCVSJ5Tb/UexZ703mlmru28ws1ZgUzkbJVJFyu08vtv6QhxPPnJ6HA+vRmOqoNg9yoXArCieBTxSnuaIVJ1yW/aT5vKge4ElwCQzW29mlwA3A6eb2Srg9GhYpK4otyWtNGe9zy8w6tQyt0WkVym3JS3dwigiEqBCKSISoHu9C2gYOTKON5x3dBwPG7C3Gs0RkSrSHqWISIAKpYhIgLreBez5YPZ+2N/d8KPEmKG93xgRqSrtUYqIBKhQiogEqOudYE3ZBxu8/4HGkpa1qXNHdrm7GkpalohUl/YoRUQCVChFRALU9U5oP++EOL73plsSY3p+pvuku78Yx5Nuir+SBV2uLlJ/tEcpIhKgQikiEqBCKSISoGOUCZ2NFsdHNJZ2B87AXdll7d22raRliUh1aY9SRCRAhVJEJECFUkQkQIVSRCRAhVJEJECFUkQkIM33eo81syfNbKWZvWhmV0avt5jZ42a2Kvo9ovLNFSkf5baklWaPsgO42t0nA9OBy83sg8D1wCJ3nwAsioZF6olyW1IJXnDu7huADVG8zcxWAocBM4EZ0WTzgaeA6yrSygrqnDEljredtb2KLZHe1tdzW8qnR8cozWw8cAKwFDgkSrR9CTeqwDxtZrbMzJbtYXdprRWpEOW2dCd1oTSzocCDwFXu/ue087n7HHef6u5TG2kKzyDSy5TbEpLqXm8zaySTSHe7+0PRyxvNrNXdN5hZK7CpUo2spDdnNMfxyyfPrWJLpBr6cm5L+aQ5623AT4CV7v69xKiFwKwongU8Uv7miVSOclvSSrNHeRJwIfAHM1sevXYDcDPwgJldAqwFzqtME0UqRrktqaQ5670YsAKjTy1vc3pHw4Qj43hX654qtkSqqS/mtlSG7swREQlQoRQRCeiXTzhfeX1LHL9+1u1lW+76juwF6wN0WZ3UKOvwOH51z444PmrgATnTNVj+/aiO5uz8A4YNi+O+/CR/7VGKiASoUIqIBPTLrnelnPOda+N43Nzlcby3Go0RKWDA0j/G8VVnXRzHP3jsjpzpJjYOyTv/MxfcEsfTh/1jHE+4fGm5mlhztEcpIhKgQikiEqCud4mmfPPzcdz64Ko47nzvvWo0RyTIOzqyA++0x2GnF7r2PteohmyX3Js7y9auWqY9ShGRABVKEZEAdb1TSF5InjyzDV2625s391qbRKT3aI9SRCRAhVJEJECFUkQkoF8eoxz7s+znw+R1XwhOn3zARfKOG9BlQCL9gfYoRUQCVChFRAL6Zdf7gIefi+NxD/dsXj3gQvoSf29nHH/q5/+QOy7FXTcHL24se5tqkfYoRUQCVChFRAKCXW8zawZ+CzRF0y9w96+ZWQtwPzAeWAN8xt23Vq6pIuWl3M79+oa+/DzJUqXZo9wNnOLuxwHHA2ea2XTgemCRu08AFkXDIvVEuS2pBAulZ+y72bkx+nFgJjA/en0+8OmKtFCkQpTbklaqY5Rm1mBmy4FNwOPuvhQ4xN03AES/R1WumSKVodyWNFIVSnfvdPfjgTHANDM7Nu0KzKzNzJaZ2bI96DtcpbYotyWNHp31dvd24CngTGCjmbUCRL83FZhnjrtPdfepjTSV2FyRylBuS3eChdLMRprZ8Cg+ADgNeBlYCMyKJpsFPFKpRopUgnJb0kpzZ04rMN/MGsgU1gfc/VEzWwI8YGaXAGuB8yrYTpFKUG5LKsFC6e6/B07I8/o7wKmVaJRIb1BuS1rm7r23MrPNwBu9tkLpzuHuPrLajegrlNs1oyJ53auFUkSkHulebxGRABVKEZGAmiyUZswz48Yo/pgZrxS5nB+b8ZXytq4+mDHeDDfrn88crVXK7dJVI7drslAmufO0O5NC05lxkRmLu8w7251/rlzr4nUfa8YvzdhiRo8O+poxw4z1ZWjDGjNOK3U5ieV90ownzfiTGWvKtVzJqofcjtZ/pBmPmrEtyvHvpJyvJnM7WuYUM35rxnYzNppxZXfTV7xQ9pM9mj3AA8AllVh4ld7DHcBc4JoqrLsu9IfcNmMQ8Djwa+BQMrd63lXG5ff6e2jGwcB/AbcBBwFHA7/qdiZ37/EP+BrwL4G/BL4V/A7w5mjcDPD14NeBvw1+Z/T62eDLwdvBnwX/cGJ5J4C/AL4N/H7w+8BvTC4vMe1Y8IfAN4O/A34r+GTwXeCd4NvB26Np5+1bTjR8Gfhq8HfBF4KPToxz8Nngq6Jt+iG49fB9OTp6Jk3a6YeA7wTfG7V7O/ho8K+DLwC/C/zP4Jfm2Zb4fQG/M1rGzmgZ14KPj7ZpFvha8C3g/1TE3/o08DXF5Ek9/ii393s/2sCfLuJ9rNncBv/Wvr9d2p9S9igvAM4AjgImAl9OjDsUaAEOB9rMmEJm7+RzZCr4bcBCM5qiT6yHgTujeX4K/HW+FZrRADxK5nq18cBhwH3urARmA0vcGerO8DzzngJ8G/gMmTsy3gDu6zLZ2cBHgOOi6c6I5h1nRrsZ49K+OWm4swM4C3gravdQd96KRs8EFgDDgbsDy7mQzB0k50TLSHaNTgYmkbmA+qtmTI626WQz2su5PX2IcjtrOrDGjMeibvdTZnyowLSxGs/t6cC7ZjxrxiYzfhb63y6lUN7qzjp33gVuAs5PjNsLfM2d3e7sBC4DbnNnqTud7swn89DU6dFPI/ADd/a4swB4vsA6pwGjgWvc2eHOLvfcYzfduACY684L7uwGvgR81IzxiWludqfdnbXAk2Qe5oo7a90ZHr3eW5a487A7e6P3sFjfcGenOyuAFWT+UXBncb5/OgGU20ljgM8C/xq17+fAI9GHQLGqndtjyNzDfyUwDngduLe7FZVSKNcl4jfIvIn7bHZnV2L4cODq6JOrPar2Y6N5RgNvuuecBCl0h8NY4A13Oopo7+jkct3ZDrxD5pN7n7cT8XvA0CLWUy7rwpOkUkvbVC+U21k7gcXuPObO+8AtZPacJxfRzn2qnds7gf905/nob/kN4H+ZcWChGUoplGMT8TiId6uB/c78rgNuij659v0MdudeYANwmBnWZXn5rAPGFTgAHDrb/BaZpAbAjCFk/uBvBuartELt7vr6DmBwYvjQlMuRnlNuZ/0+xfoLqdXc7rpN+2LLMy1QWqG83IwxZrQAN5D5MqZCbgdmm3GiGWbGEDM+ZcYwYAnQAfy9GQPNOJdMNySf58gk383RMprNOCkatxEY002X4B7gYjOON6MJ+Baw1L30S1+ibWqGzLqjdjUlxs8zY16B2TcCB3X3aRZZDvyVGS1mHApclWc5Rxa1AXmYMSDapkbIbF+J3a16otzOuguYbsZp0XHUq4AtwEqoz9wG7gD+T/R+NQJfIbPXXPC4ZimF8h4yp9Rfi35uLDShO8vIHMu5FdgKrAYuisa9D5wbDW8F/gZ4qMByOoFzyJzOXwusj6aHzOULLwJvm7Elz7yLyLwhD5JJyKPIHHsJig54b+/mgO/hZHbnX4yGd0LOhcRjgWcKbNPLZI6PvBZ13Ubnm47MCYEVZL4V8Ffs/8/7beDL0TK+GNikfRc7b+9mko9H2/ELMntBOwldQtF3KLezy34F+Dvgx9E2zAT+d7RtUIe57c6vyXwA/pzMQ5mPBv6222VmTpf3jGUuQL7UnSd6PHM/E+0FrAA+7M6eardHuqfcTq8/5Xafv2C22qJP3lIOfIvUpP6U2zV/C6OISLUVVSjdGe/OE2Z2ppm9YmarzUxfEi91T7kt+RT94N7oe0ZeBU4nc+D5eeB8d3+pfM0T6X3KbemqlGOU04DV7v4agJndR+aMWMFkGmRN3syQElYp5bKNrVtcXwVRiHK7Tu1iB+/77oLXQxarlEJ5GLlX2K8HTuxuhmaGcKLpO5tqwRO+QN/vUphyu04t9UUVWW4phTJf1d6vH29mbUAbQHPOxfciNUu5LTlKOeu9ntxbvcaQe6sXAO4+x92nuvvUxuzNKiK1TLktOUoplM8DE8zsCDMbROZOgIXlaZZIVSm3JUfRXW937zCzK4BfAg3AXHd/MTCbSM1TbktXJd2Z4+6/IHMvsEifotyWJN2ZIyISoEIpIhKgQikiEqBCKSISoEIpIhKg51GK9FGdM6bE8ZszmvNOM2B3Nh73L8tzxu19772KtKseaY9SRCRAhVJEJEBd7xK9d272oTK7Dsz/uTPi1ex3vNszy/NOI9ITafJu21nZ79d6+eS5eadZ35Gd5pxt1+aMa71/VRx3bt5cVDv7Cu1RiogEqFCKiASoUIqIBOgYZQo2MPs27T3x2JxxF970szhuO3C/RxYCcMRjl8bxxLxfFS/SM2d8/Tdx/OWDXy56OWMGDo3j393wo5xxp//h4jge8BsdoxQRkW6oUIqIBKjrncKAg1ri+Pt353ZPJg/Sd6VIBVn263saDj44jpsGvFaN1vRb2qMUEQlQoRQRCVDXW6SGJbvbbc8uieOzBm9NTNXYiy3qn7RHKSISoEIpIhKgrrdIjbGp2ZsaPjkvf3e7ycLd7eOeOz+OR30///MouzPwhdVxvLfHc/ctwT1KM5trZpvM7I+J11rM7HEzWxX9HlHZZoqUn3Jb0krT9Z4HnNnlteuBRe4+AVgUDYvUm3kotyWFYNfb3X9rZuO7vDwTmBHF84GngOvK2K6qG3DsMXG891+3xfHYgTqs21fUam53Dsl2q69p+Z/EmHB3+5jFF8bxuH9piGN75nc9bkd/724nFftff4i7bwCIfo8qX5NEqkq5Lfup+MkcM2sD2gCa0e1+0ncot/uPYgvlRjNrdfcNZtYKbCo0obvPAeYAfMBavMj19bqOgw6I48ePuS8xpudnD6Wu1HVu20vDsvEzz1axJX1LsV3vhcCsKJ4FPFKe5ohUnXJb9pPm8qB7gSXAJDNbb2aXADcDp5vZKuD0aFikrii3Ja00Z73PLzDq1DK3paY0vvWnOD7q19knPS/7xA9zphvRkP/Y1MVrPxbHBy/Wvbi1qFZye8Bxk3OGX/3bnh0RO33lOXE8cnlHWdqUVueMKXH85oziD0uN/8+tOcN7V6wselmVoGtdREQCVChFRAJ0r3cBnauyT5CedHX2Urq3l+ZON6KBvP77vz4Ux+Pu0NlHKWzLlOE5w6+f8289mr/9rjFx3PLwkm6mLL9kd3tl24+6mbJ7R4xqyxk+5rbs4Yha6IZrj1JEJECFUkQkQF3vAgYMy164u2Pa+DhuNt0BK/1bw4Qj43hX656yLPP1T8/JGZ686QtxPG5FWVZREu1RiogEqFCKiASo611Ax5Sj4/g3tyW7BUN7vzEiVdYw/MA4fuVr2fj1U26vyPo6mrO3zicPg+3dti3f5BWnPUoRkQAVShGRAHW9RSTozXmj43jZXySfd1CZ53A+c8EtcTx92D/G8YTLl+abvOK0RykiEqBCKSISoEIpIhKgY5RldMzt2bsJjvrJ2jju3ScEipTHu49OjOOffujf43hEw5AeLef1Pdvj+KIrsscbL/jOo3HcduBbOfOMSqzDmzt7tL5K0B6liEiACqWISIC63mU0bE32boKOdeur2BKpJ6N+syFnOHkI5+XLws94/PgV2UtmFnd+NI6H/0dpz6acOe73cTyxsWfd7Qe2Z+/e+b/f/XwcN8zeEsczBq9KzNGz5fc27VGKiASoUIqIBKjrneAnHR/Hm/5hV6p5jvjZZXF8zAvtcaynVkpaHa+tyRk+6t+z10lMaMp2W5N3qyTPCn+39YU4vvgL2def+sTUOB720qA4bv1uuq8meXDOKXHcNDv73MlrWv4nOO+ru1qzbb0z+0DJtS3Z/7E1k7JfgTGxsTzPtayUNN/rPdbMnjSzlWb2opldGb3eYmaPm9mq6PeIyjdXpHyU25JWmq53B3C1u08GpgOXm9kHgeuBRe4+AVgUDYvUE+W2pBLserv7BmBDFG8zs5XAYcBMYEY02XzgKeC6irSyl2ydeEAcr5g2L9U8R9+T7SbVwrfFSXq1mtvJKyYmfP/9OH7nsxbHowp8++cd457ODiTi/3fiUXE8d9gZZWhl96YNznbP7772lG6mrA89OpljZuOBE4ClwCFRou1LuFGF5xSpbcpt6U7qQmlmQ4EHgavc/c89mK/NzJaZ2bI97C6mjSIVpdyWkFRnvc2skUwi3e3uD0UvbzSzVnffYGatwKZ887r7HGAOwAesxfNNI1ItNZ/bu7MF+LPLL4njn56Qvfc6zcXgyTPV17SFL2Iv1V8Ozp7FXtkL66u0NGe9DfgJsNLdv5cYtRCYFcWzgEfK3zyRylFuS1pp9ihPAi4E/mBmy6PXbgBuBh4ws0uAtcB5lWmiSMUotyWVNGe9FwNWYPSp5W2OSO+ph9zubP9THB/66Wx8waMXx3Hynuy/HPaHOJ7W1Fjh1pVP8t5wyL1gfdCG6m+HbmEUEQlQoRQRCdC93iJ1qOXsV+P4aZrjeO4Pr4jjn3/q+73aplJ875uzc4YPvOu/43g8pT0urhy0RykiEqBCKSISoK63SB8y6fqX4vjqb366ii3pmeHtv8sZrrU7U7RHKSISoEIpIhKgrrdIH7J327bsQDKWkmiPUkQkQIVSRCRAhVJEJECFUkQkQIVSRCRAhVJEJECFUkQkQIVSRCRAhVJEJEB35iSMeHVnHB/x2KWp5pn81rtx3Fn2FolILdAepYhIgAqliEiAut4J9szyOJ74TLp51N0W6fuCe5Rm1mxmz5nZCjN70cy+Eb3eYmaPm9mq6PeIyjdXpHyU25JWmq73buAUdz8OOB4408ymA9cDi9x9ArAoGhapJ8ptSSVYKD1jezTYGP04MBOYH70+H6if586LoNyW9FKdzDGzBjNbDmwCHnf3pcAh7r4BIPo9qnLNFKkM5bakkapQununux8PjAGmmdmxaVdgZm1mtszMlu1hd7HtFKkI5bak0aPLg9y9HXgKOBPYaGatANHvTQXmmePuU919aiNNJTZXpDKU29KdNGe9R5rZ8Cg+ADgNeBlYCMyKJpsFPFKpRopUgnJb0kpzHWUrMN/MGsgU1gfc/VEzWwI8YGaXAGuB8yrYTpFKUG5LKubee181bmabgTd6bYXSncPdfWS1G9FXKLdrRkXyulcLpYhIPdK93iIiASqUIiIBKpQiIgEqlCIiASqUIiIBKpQiIgEqlCIiASqUIiIBKpQiIgH/HyKwOLKinO/gAAAAAElFTkSuQmCC) - -### 实验结果 - -1. 在训练日志中可以看到两阶段的loss值和验证精度打印,第一阶段为初始训练,第二阶段为加载Checkpoint继续训练; -2. 在训练目录里可以看到`b_lenet-graph.meta`、`b_lenet-2_1875.ckpt`等文件,即训练过程保存的Checkpoint。 - -```python ->>> epoch: 1 step: 1875, loss is 2.2984316 ->>> epoch: 2 step: 1875, loss is 0.06388051 ->>> Metrics: {'loss': 0.11160586341821517, 'acc': 0.9637419871794872} - ->>> epoch: 1 step: 1875, loss is 0.008898618 ->>> epoch: 2 step: 1875, loss is 0.05747048 ->>> Metrics: {'loss': 0.07453688951276351, 'acc': 0.9767628205128205} -``` - -## 实验小结 - -本实验展示了MindSpore的Checkpoint、断点继续训练等高级特性: - -1. 使用MindSpore的ModelCheckpoint接口每个epoch保存一次Checkpoint,训练2个epoch并终止。 -2. 使用MindSpore的load_checkpoint和load_param_into_net接口加载上一步保存的Checkpoint继续训练2个epoch。 -3. 观察训练过程中Loss的变化情况,加载Checkpoint继续训练后loss进一步下降。 \ No newline at end of file diff --git a/experiment_6/main.py b/experiment_6/main.py deleted file mode 100644 index 0c74a75..0000000 --- a/experiment_6/main.py +++ /dev/null @@ -1,146 +0,0 @@ -# Save and load model - -import matplotlib.pyplot as plt -import numpy as np - -import mindspore as ms -import mindspore.context as context -import mindspore.dataset.transforms.c_transforms as C -import mindspore.dataset.transforms.vision.c_transforms as CV - -from mindspore import nn, Tensor -from mindspore.train import Model -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor -from mindspore.train.serialization import load_checkpoint, load_param_into_net - -context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - -DATA_DIR_TRAIN = "MNIST/train" # 训练集信息 -DATA_DIR_TEST = "MNIST/test" # 测试集信息 - - -def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32), - rescale=1 / (255 * 0.3081), shift=-0.1307 / 0.3081, buffer_size=64): - ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST) - - # define map operations - resize_op = CV.Resize(resize) - rescale_op = CV.Rescale(rescale, shift) - hwc2chw_op = CV.HWC2CHW() - - # apply map operations on images - ds = ds.map(input_columns="image", operations=[resize_op, rescale_op, hwc2chw_op]) - ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32)) - - ds = ds.shuffle(buffer_size=buffer_size) - ds = ds.batch(batch_size, drop_remainder=True) - ds = ds.repeat(num_epoch) - - return ds - - -class LeNet5(nn.Cell): - def __init__(self): - super(LeNet5, self).__init__() - self.relu = nn.ReLU() - self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid') - self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid') - self.pool = nn.MaxPool2d(kernel_size=2, stride=2) - self.flatten = nn.Flatten() - self.fc1 = nn.Dense(400, 120) - self.fc2 = nn.Dense(120, 84) - self.fc3 = nn.Dense(84, 10) - - def construct(self, input_x): - output = self.conv1(input_x) - output = self.relu(output) - output = self.pool(output) - output = self.conv2(output) - output = self.relu(output) - output = self.pool(output) - output = self.flatten(output) - output = self.fc1(output) - output = self.fc2(output) - output = self.fc3(output) - - return output - - -def test_train(lr=0.01, momentum=0.9, num_epoch=2, check_point_name="b_lenet"): - ds_train = create_dataset(num_epoch=num_epoch) - ds_eval = create_dataset(training=False) - steps_per_epoch = ds_train.get_dataset_size() - - net = LeNet5() - loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') - opt = nn.Momentum(net.trainable_params(), lr, momentum) - - ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5) - ckpt_cb = ModelCheckpoint(prefix=check_point_name, config=ckpt_cfg) - loss_cb = LossMonitor(steps_per_epoch) - - model = Model(net, loss, opt, metrics={'acc', 'loss'}) - model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=False) - metrics = model.eval(ds_eval, dataset_sink_mode=False) - print('Metrics:', metrics) - - -CKPT = 'b_lenet-2_1875.ckpt' - - -def resume_train(lr=0.001, momentum=0.9, num_epoch=2, ckpt_name="b_lenet"): - ds_train = create_dataset(num_epoch=num_epoch) - ds_eval = create_dataset(training=False) - steps_per_epoch = ds_train.get_dataset_size() - - net = LeNet5() - loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') - opt = nn.Momentum(net.trainable_params(), lr, momentum) - - param_dict = load_checkpoint(CKPT) - load_param_into_net(net, param_dict) - load_param_into_net(opt, param_dict) - - ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5) - ckpt_cb = ModelCheckpoint(prefix=ckpt_name, config=ckpt_cfg) - loss_cb = LossMonitor(steps_per_epoch) - - model = Model(net, loss, opt, metrics={'acc', 'loss'}) - model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=False) - metrics = model.eval(ds_eval, dataset_sink_mode=False) - print('Metrics:', metrics) - -def plot_images(pred_fn, ds, net): - for i in range(1, 5): - pred, image, label = pred_fn(ds, net) - plt.subplot(2, 2, i) - plt.imshow(np.squeeze(image)) - color = 'blue' if pred == label else 'red' - plt.title("prediction: {}, truth: {}".format(pred, label), color=color) - plt.xticks([]) - plt.show() - -CKPT = 'b_lenet_1-2_1875.ckpt' - -def infer(ds, model): - data = ds.get_next() - images = data['image'] - labels = data['label'] - output = model.predict(Tensor(data['image'])) - pred = np.argmax(output.asnumpy(), axis=1) - return pred[0], images[0], labels[0] - -def test_infer(): - ds = create_dataset(training=False, batch_size=1).create_dict_iterator() - net = LeNet5() - param_dict = load_checkpoint(CKPT, net) - model = Model(net) - plot_images(infer, ds, model) - -if __name__ == "__main__": - - test_train() - - resume_train() - - test_infer() \ No newline at end of file diff --git a/lenet5/README.md b/lenet5/README.md new file mode 100644 index 0000000..6b5cdcf --- /dev/null +++ b/lenet5/README.md @@ -0,0 +1,309 @@ +# 基于LeNet5的手写数字识别 + +## 实验介绍 + +LeNet5 + MINST被誉为深度学习领域的“Hello world”。本实验主要介绍使用MindSpore在MNIST数据集上开发和训练一个LeNet5模型,并验证模型精度。 + +## 实验目的 + +- 了解如何使用MindSpore进行简单卷积神经网络的开发。 +- 了解如何使用MindSpore进行简单图片分类任务的训练。 +- 了解如何使用MindSpore进行简单图片分类任务的验证。 + +## 预备知识 + +- 熟练使用Python,了解Shell及Linux操作系统基本知识。 +- 具备一定的深度学习理论知识,如卷积神经网络、损失函数、优化器,训练策略等。 +- 了解华为云的基本使用方法,包括[OBS(对象存储)](https://www.huaweicloud.com/product/obs.html)、[ModelArts(AI开发平台)](https://www.huaweicloud.com/product/modelarts.html)、[Notebook(开发工具)](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0033.html)、[训练作业](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0046.html)等服务。华为云官网:https://www.huaweicloud.com +- 了解并熟悉MindSpore AI计算框架,MindSpore官网:https://www.mindspore.cn + +## 实验环境 + +- MindSpore 0.5.0(MindSpore版本会定期更新,本指导也会定期刷新,与版本配套); +- 华为云ModelArts:ModelArts是华为云提供的面向开发者的一站式AI开发平台,集成了昇腾AI处理器资源池,用户可以在该平台下体验MindSpore。ModelArts官网:https://www.huaweicloud.com/product/modelarts.html +- Windows/Ubuntu x64笔记本,NVIDIA GPU服务器,或Atlas Ascend服务器等。 + +## 实验准备 + +### 创建OBS桶 + +本实验需要使用华为云OBS存储实验脚本和数据集,可以参考[快速通过OBS控制台上传下载文件](https://support.huaweicloud.com/qs-obs/obs_qs_0001.html)了解使用OBS创建桶、上传文件、下载文件的使用方法。 + +> **提示:** 华为云新用户使用OBS时通常需要创建和配置“访问密钥”,可以在使用OBS时根据提示完成创建和配置。也可以参考[获取访问密钥并完成ModelArts全局配置](https://support.huaweicloud.com/prepare-modelarts/modelarts_08_0002.html)获取并配置访问密钥。 + +创建OBS桶的参考配置如下: + +- 区域:华北-北京四 +- 数据冗余存储策略:单AZ存储 +- 桶名称:全局唯一的字符串 +- 存储类别:标准存储 +- 桶策略:公共读 +- 归档数据直读:关闭 +- 企业项目、标签等配置:免 + +### 数据集准备 + +MNIST是一个手写数字数据集,训练集包含60000张手写数字,测试集包含10000张手写数字,共10类。MNIST数据集的官网:[THE MNIST DATABASE](http://yann.lecun.com/exdb/mnist/)。 + +从MNIST官网下载如下4个文件到本地并解压: + +``` +train-images-idx3-ubyte.gz: training set images (9912422 bytes) +train-labels-idx1-ubyte.gz: training set labels (28881 bytes) +t10k-images-idx3-ubyte.gz: test set images (1648877 bytes) +t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes) +``` + +### 脚本准备 + +从[课程gitee仓库](https://gitee.com/mindspore/course)上下载本实验相关脚本。 + +### 上传文件 + +将脚本和数据集上传到OBS桶中,组织为如下形式: + +``` +lenet5 +├── MNIST +│   ├── test +│   │   ├── t10k-images-idx3-ubyte +│   │   └── t10k-labels-idx1-ubyte +│   └── train +│   ├── train-images-idx3-ubyte +│   └── train-labels-idx1-ubyte +└── main.py +``` + +## 实验步骤(ModelArts Notebook) + +### 创建Notebook + +可以参考[创建并打开Notebook](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0034.html)来创建并打开本实验的Notebook脚本。 + +创建Notebook的参考配置: + +- 计费模式:按需计费 +- 名称:lenet5 +- 工作环境:Python3 +- 资源池:公共资源 +- 类型:Ascend +- 规格:单卡1*Ascend 910 +- 存储位置:对象存储服务(OBS)->选择上述新建的OBS桶中的lenet5文件夹 +- 自动停止等配置:默认 + +> **注意:** +> - 打开Notebook前,在Jupyter Notebook文件列表页面,勾选目录里的所有文件/文件夹(实验脚本和数据集),并点击列表上方的“Sync OBS”按钮,使OBS桶中的所有文件同时同步到Notebook工作环境中,这样Notebook中的代码才能访问数据集。参考[使用Sync OBS功能](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0038.html)。 +> - 打开Notebook后,选择MindSpore环境作为Kernel。 + +> **提示:** 上述数据集和脚本的准备工作也可以在Notebook环境中完成,在Jupyter Notebook文件列表页面,点击右上角的"New"->"Terminal",进入Notebook环境所在终端,进入`work`目录,可以使用常用的linux shell命令,如`wget, gzip, tar, mkdir, mv`等,完成数据集和脚本的下载和准备。 + +> **提示:** 请从上至下阅读提示并执行代码框进行体验。代码框执行过程中左侧呈现[\*],代码框执行完毕后左侧呈现如[1],[2]等。请等上一个代码框执行完毕后再执行下一个代码框。 + +导入MindSpore模块和辅助模块: + +```python +import os +# os.environ['DEVICE_ID'] = '0' + +import mindspore as ms +import mindspore.context as context +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset.transforms.vision.c_transforms as CV + +from mindspore import nn +from mindspore.train import Model +from mindspore.train.callback import LossMonitor + +context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') # Ascend, CPU, GPU +``` + +### 数据处理 + +在使用数据集训练网络前,首先需要对数据进行预处理,如下: + +```python +def create_dataset(data_dir, training=True, batch_size=32, resize=(32, 32), + rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64): + data_train = os.path.join(data_dir, 'train') # 训练集信息 + data_test = os.path.join(data_dir, 'test') # 测试集信息 + ds = ms.dataset.MnistDataset(data_train if training else data_test) + + ds = ds.map(input_columns=["image"], operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()]) + ds = ds.map(input_columns=["label"], operations=C.TypeCast(ms.int32)) + # When `dataset_sink_mode=True` on Ascend, append `ds = ds.repeat(num_epochs) to the end + ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True) + + return ds +``` + +对其中几张图片进行可视化,可以看到图片中的手写数字,图片的大小为32x32。 + +```python +ds = create_dataset('MNIST', training=False) +data = ds.create_dict_iterator().get_next() +images = data['image'] +labels = data['label'] + +for i in range(1, 5): + plt.subplot(2, 2, i) + plt.imshow(images[i][0]) + plt.title('Number: %s' % labels[i]) + plt.xticks([]) +plt.show() +``` + +![png](images/mnist.png) + +### 定义模型 + +MindSpore model_zoo中提供了多种常见的模型,可以直接使用。LeNet5模型结构如下图所示: + +![LeNet5](https://www.mindspore.cn/tutorial/zh-CN/master/_images/LeNet_5.jpg) + +[1] 图片来源于http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf + +```python +class LeNet5(nn.Cell): + def __init__(self): + super(LeNet5, self).__init__() + self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid') + self.relu = nn.ReLU() + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + self.fc1 = nn.Dense(400, 120) + self.fc2 = nn.Dense(120, 84) + self.fc3 = nn.Dense(84, 10) + + def construct(self, x): + x = self.relu(self.conv1(x)) + x = self.pool(x) + x = self.relu(self.conv2(x)) + x = self.pool(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + + return x +``` + +### 训练 + +使用MNIST数据集对上述定义的LeNet5模型进行训练。训练策略如下表所示,可以调整训练策略并查看训练效果,要求验证精度大于95%。 + +| batch size | number of epochs | learning rate | optimizer | +| -- | -- | -- | -- | +| 32 | 3 | 0.01 | Momentum 0.9 | + +```python +def train(data_dir, lr=0.01, momentum=0.9, num_epochs=3): + ds_train = create_dataset(data_dir) + ds_eval = create_dataset(data_dir, training=False) + + net = LeNet5() + loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') + opt = nn.Momentum(net.trainable_params(), lr, momentum) + loss_cb = LossMonitor(per_print_times=ds_train.get_dataset_size()) + + model = Model(net, loss, opt, metrics={'acc', 'loss'}) + # dataset_sink_mode can be True when using Ascend + model.train(num_epochs, ds_train, callbacks=[loss_cb], dataset_sink_mode=False) + metrics = model.eval(ds_eval, dataset_sink_mode=False) + print('Metrics:', metrics) + +train('MNIST') +``` + + epoch: 1 step 1875, loss is 0.23394052684307098 + Epoch time: 23049.360, per step time: 12.293, avg loss: 2.049 + ************************************************************ + epoch: 2 step 1875, loss is 0.4737345278263092 + Epoch time: 26768.848, per step time: 14.277, avg loss: 0.155 + ************************************************************ + epoch: 3 step 1875, loss is 0.07734094560146332 + Epoch time: 25687.625, per step time: 13.700, avg loss: 0.094 + ************************************************************ + Metrics: {'loss': 0.10531254443608654, 'acc': 0.9701522435897436} + + +## 实验步骤(ModelArts训练作业) + +除了Notebook,ModelArts还提供了训练作业服务。相比Notebook,训练作业资源池更大,且具有作业排队等功能,适合大规模并发使用。使用训练作业时,也会有修改代码和调试的需求,有如下三个方案: + +1. 在本地修改代码后重新上传; + +2. 使用[PyCharm ToolKit](https://support.huaweicloud.com/tg-modelarts/modelarts_15_0001.html)配置一个本地Pycharm+ModelArts的开发环境,便于上传代码、提交训练作业和获取训练日志。 + +3. 在ModelArts上创建Notebook,然后设置[Sync OBS功能](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0038.html),可以在线修改代码并自动同步到OBS中。因为只用Notebook来编辑代码,所以创建CPU类型最低规格的Notebook就行。 + +### 适配训练作业 + +创建训练作业时,运行参数会通过脚本传参的方式输入给脚本代码,脚本必须解析传参才能在代码中使用相应参数。如data_url和train_url,分别对应数据存储路径(OBS路径)和训练输出路径(OBS路径)。脚本对传参进行解析后赋值到`args`变量里,在后续代码里可以使用。 + +```python +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--data_url', required=True, default=None, help='Location of data.') +parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.') +args, unknown = parser.parse_known_args() +``` + +MindSpore暂时没有提供直接访问OBS数据的接口,需要通过MoXing提供的API与OBS交互。将OBS中存储的数据拷贝至执行容器: + +```python +import moxing +moxing.file.copy_parallel(src_url=args.data_url, dst_url='MNIST/') +``` + +如需将训练输出(如模型Checkpoint)从执行容器拷贝至OBS,请参考: + +```python +import moxing +# dst_url形如's3://OBS/PATH',将ckpt目录拷贝至OBS后,可在OBS的`args.train_url`目录下看到ckpt目录 +moxing.file.copy_parallel(src_url='ckpt', dst_url=os.path.join(args.train_url, 'ckpt')) +``` + +### 创建训练作业 + +可以参考[使用常用框架训练模型](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0238.html)来创建并启动训练作业。 + +创建训练作业的参考配置: + +- 算法来源:常用框架->Ascend-Powered-Engine->MindSpore +- 代码目录:选择上述新建的OBS桶中的lenet5目录 +- 启动文件:选择上述新建的OBS桶中的lenet5目录下的`main.py` +- 数据来源:数据存储位置->选择上述新建的OBS桶中的lenet5目录下的MNIST目录 +- 训练输出位置:选择上述新建的OBS桶中的lenet5目录并在其中创建output目录 +- 作业日志路径:同训练输出位置 +- 规格:Ascend:1*Ascend 910 +- 其他均为默认 + +启动并查看训练过程: + +1. 点击提交以开始训练; +2. 在训练作业列表里可以看到刚创建的训练作业,在训练作业页面可以看到版本管理; +3. 点击运行中的训练作业,在展开的窗口中可以查看作业配置信息,以及训练过程中的日志,日志会不断刷新,等训练作业完成后也可以下载日志到本地进行查看; +4. 参考实验步骤(Notebook),在日志中找到对应的打印信息,检查实验是否成功。 + +## 实验步骤(本地CPU/GPU/Ascend) + +MindSpore还支持在本地CPU/GPU/Ascend环境上运行,如Windows/Ubuntu x64笔记本,NVIDIA GPU服务器,以及Atlas Ascend服务器等。在本地环境运行实验前,需要先参考[安装教程](https://www.mindspore.cn/install/)配置环境。 + +在Windows/Ubuntu x64笔记本上运行实验: + +```shell script +vim main.py # 将第15行的context设置为`device_target='CPU'` +python main.py --data_url=D:\dataset\MNIST +``` + +在Ascend服务器上运行实验: + +```shell script +vim main.py # 将第15行的context设置为`device_target='Ascend'` +python main.py --data_url=/PATH/TO/MNIST +``` + +## 实验小结 + +本实验展示了如何使用MindSpore进行手写数字识别,以及开发和训练LeNet5模型。通过对LeNet5模型做几代的训练,然后使用训练后的LeNet5模型对手写数字进行识别,识别准确率大于95%。即LeNet5学习到了如何进行手写数字识别。 diff --git a/lenet5/images/mnist.png b/lenet5/images/mnist.png new file mode 100644 index 0000000000000000000000000000000000000000..8db03ee7fb2f7786c73bf3917790bcc0ce103ff5 GIT binary patch literal 7485 zcmaJ`2UwE(+kSDDtJ2g8x0tz5BS+4L11m?GrYVk8v{W<)Zslm1W?A40XQkFLS7ND% zDP?N7IF%#9ous)(Dn{x@r%s*!_x-Pr>*4`kUVhK}{;ubK?&rQqv9U52FUN8TNg6?ZwH`28jFB*2ioAw7>;~u(Nm>zRu!WAg&84YJaF>ZHh?~wf$ zk^0@U6{_2D#lZ^Nb=|X;Fgn<&muIy2<~DNKhLVZk#4vzL>XmI6(8%Eq*jD_>3p#UR zT763xFqWt36E@EP#(>6d@@YT{z}zEMAJJ`d2VlV9$z>62tU6DSKE?9r|oYzyL z<6l&}DbV1{ak#LfTLrd}xIDVd(d!dZfHUxgFHM?ss&W<>jRv-`6Za(6V(DF&+BG8}-Q=|qUF#g$LU;N8(s_mxn|>UO)nVz&8kCc2Vn(0>FrGe9G7c znj|fh06iv&S;u~DclAKp>C<%R_q|;y9?}*CkWXva%Rt`UNzKEyZ`AYwjhb6}mkh2iO+4=iI=F&9hj37CO`IFNJXfIdYF8fbL{SV}!z4 zqTnb$FwF56*JxaErKvb0>WEgvqhzLxlA8V=tW_yNWRLov-@}AG4Vx#{c!GfSf>JqC zG87o-<%yWYgwpJ|0nWQOdi|u;fx8~aD)Gt)xzp7>qIGnSQcyE5pa)4I>+L~6h3I@td(UhHgf z-k~~?LxmkEez}ESSuNXk)`qBgvFeY5^V~oYykK{u-MRvBTl}G;oI=T<)xEqb8s3SY zQU+25*$BE9;KSCe^5CX6C`CKhlLmO29Z;Zb<7Cp26s#Zlh3DGOW+S^(UNrkw+e{Do z%DehRr@2nfG1X8->drz$fZyegW_P&@u`cTFYnuS6 zXFP%6_Vy*d#MT|yrXCVk-}sOZm2!9oX1Gr`P)W?utrVyc(&D7jjdx1VW=&r4KmHXO zklAMTg+@@+(YfNMbVL+O1X?hpLz1P|KAR6)LO2_PD}670cK2Iqy?>>}<=c4pd`h{} z?U~TCr`wY@;wjyBE{&_?QYiAY?);^`$Iqa(a+AahXt{B2`|lylQ~teLjB4yiMm5xR zDV12A7dT(}4)JYxf3&>cn$+`ZIGz+gvAozauHhPgsB1DXFW#xXyJraMPv7()c{+uJ z+)sV4PYv9@;$^up#4S>~SF!&~&{DI3;Ye$RSdU{QExehADr{Qcwtmrh z|JP={V4-Gf%UbUDHpg~zLXGa%1|bIQ9H>D{qa+fkK5F6qYX?7B4nWIKyMd|RtVeI< zk3(mt!YI_x$n4V3+qJJ+X+|2-_T5h`wHj*AU#CCqC$uoG^?G-vO&H+l@XA}NW%9&~ ziKJ1V?TAz*+lY>UX0Nxm5Pr zr*^AzJ&_>a;#NKXMq=N-Hu2Rhjm@M;rmxTYP-B#-ZcldD)9SsJ#82o@Gon;;G_4he zY8Y*Q^1)92nb>yaw`dcSh6LXA+vr=dnw)H)Z5eA$4B9K_sO zVCP2HXU5m&#VbHfLWCHmozG4Ji~D|22X7as@+=x)(k)Q3k=z)4;EVzhMD87}ft#L1 zu@urt)v&kadUa&)l)m}+9i+;S%?FIg`rDrI$(w_8#jEk%?NZ>ZqLHeC+na+pSKDCy z{C=fJ#E8y;7T)(GMn1=L45J=872tVuc9v+gy+NHc)C@LM)54_j zUWb)Q<7hXj+7Nmv^onLfWoNtbfw+CrNdsC*h}6F5mpf~tjR?wXhb<0E8mQ;xMN$8{ zAg+_&+};|P7veU{aIEpUn#xdI0;-E7E`v=k+7 zhtwY6$qc%ld2Nc3E|>agZIHe_J3eCa$)vxrT24;W&RbE=6&|Wn8+9x{l{6qR?bZ2q z%0V=%H0rM)@w)jPBn1sYE`>}*J;G9k>6m`qn@CT3e_$~ClAr$0D?D_|aH}*ZB9C-C zUYYRFU_Qq6-L|0k?I=`P>STdrlQF#^Hu`KQ#Mbxor(4gzj+eP~Wva`felPI%)eFaV zqfBn+L+25Ft6>eNhwuPbX8ZAd|o;)cQ#45HrMc2 z7OR=hEOl$tSrgSH9xN~@si)886&2?~=$Bvae>+|%=gj7qrK`&9a6+!3`t!lB&;MLd zmv;^`w0oOnxakrz*aAivly3N4S|^UtQ(vwyG%4t(r=w2}IYYM)Rn0-wP55H@0KXRV zyGh>;<(c}wo%#})ThUyvClRRm^<})mn{DYmO9}t1KPA>v%No5JjgWZ0l3p>D%OAG# zuDNxw#TAB#Ss@2ShAz)t6yQ|ncp5KqALTuk#%kX$-?GbMt_>$)wY(oU18*{0{@@-d z$sW{nxQtotZZ6oDf0rf7pC2>R<*$l?fm}?3B``;UEIjQH(3wD$%QPX*Vb2_-pt%0p z_sSA70MeBra{~{#a@GFs0@WD=_*Rtf`_;Vg&PWB-l)Wd36X-rAPkLF|N{{h~z+_bJTe7DkJMy>SW2JmTuSgBV8xl3xNSfGMaAo3Pu2IQYwcBfT1 z_U?KTZJkXmwyKOfbbYItf%z3ZhO6pfBKS10&f(~T3^#Q!6moaiP32g097|OcmS)tt z9lk`1t_qP%*@6Y2o|AkMQUxqZVLmVIJqyw92)OMPth44+Ed2*#@THh}al|P+ zU906fvHp^mCXcT0bLc@iSDU{py^T~IFlO>LQ0j(}u;{P}*j1!|27mK{i$J%S74G?G zoA&gG8q;(y#WxRCj6RL)G~pUaRgAmz-=!-2=2pN^bxZ4;oe$574l^!0OQdmqbMq`v zJwFO$_eIb_G>d9P1&nU0X`-cuQBnPhUvdljH@1@Liv^l?7 z=z@ES?_bNchKZc)U)a@Iy=*EzGyk~xf_C7Qr4)YpwvI&N6KCgn#D~%uqlNJ-{pAnT!Id_&A~9#86XGVnvH9OhItu~AG#{?91I_>% zwj;SV z3e)NDLUTF_xgWq(uuD{iNmT1L`c=>dsVc&|*aN6{unu}RCGPkjC)o#R1dV=Ylx`#_ zLbdY5jaSEsbJ=Hx@BPX9a%}p+4 zi?N-fgSvbQwgT*%_^Ui3F)@8mXmsd7zKtI2WfxCo!)}ZFgHZ6Lm4-jriW6QsHRSkn zM-!Au0(V@FE!Wy+pDg|>rep3#`7FHYWUvaw73+N-=i>VEqNo5x-vV`TV-F?VV))$j zVT&se!#WV=hx*icy^?z4$}sy^99lr7T9_F=c}GV!0bqhs_In z<-yD2nX83$IUZ~>@!m5bO!`g??5m|{s7RdEvJx&D6w^@V*3~t9q~E@HFS`n~Ys6smOUOX)jrH_&w)Q)4Cn|gCtHio7Cs_sHky*7|k-=EN?Ul0u z=6=dBcFu=`3BvpVxN%XYJ(B&_Pxeodk>owBeYDjO+aJD}r$ zry+}*V$IVVGrR&8eS|61#JsCsqLeD!EH#4yRGzpDwj+?8#uH<|(WB>Z#N`tg*#g~9sx<;Wh?VazRL z?M?u0l-zq}`jW`_K6K5z`H&;d*7s$l^lX)q!Sf{FSm<6+5Oa$2akfu{=igPL1dm)v2O-_X9iX9 zhFWqXGNG5hX+fHZMOtR_6|;q7j50|axG5KTVlju3eJ95PQT$ekEj}mNBmc8IjR(lW zKp*f4@8nZ;Eae;*1AUmGHJyStd`j&FWfOIGRy=@w6Oxld*#4Xn(eyqa5|d13 zQgxP%!-@5C47NtFF**EOLdxzLazyiwoE|jXJSQH{%Heg1IlI7|yldpK=r6u2W=vIl z-K}|~)n&tWma!XKRQihBV&ZqSpiE;Ep$Sd3Wd5&nCFYi8L)jxd%X1#7O_1duaiTz9x1#en%8E6W{E zr=}KtY%Pff2f;hJ#H|l(K7-yKd8=aZR%E21W~jkK>p4udBqR6K>$&=XOSF&U)RBz} z^vtx|$mqDkxh9uD8v)H4Ghj!Y5Yz>0a_ST@-==zmJ>31MxAC+Sa>9xXv^V7$fqsfG z+;zaBpYwed>Le_e-yI4_bn103#c+$xP4`#KMzN-4$Fif}twvo?98=b-tLf_%`xLSn z@9$q%8OB<>D6Z&0+@i)xkjrHXxJRpB?lmu-U{`V`rARRqWUy=03+D`X1&s;;b`~f= zLyM_ud-nUDOrU0O9x3Zl3v%i>6`C~Sf{uIg&dvq10xHI|!vk)0ecc<` zE&=SLP(KC=N$WPTorLm^^|kwpn{@Sfrlt70OH)s4fTg9ABuxB2@?_aQh9tPa|+s!ge^MHZCnCiLii>_?UL(5I35@3sR{QWr>{SLXUxW4!FMSw5$K z&1(tSL+{V7yc91`G^DRWX!+m8(0A?J>q{gLZRC&LYMvZwwtFuLau9G|^3ZUzK80h-MahD%x zCqPM$gkY~wLoSWcUrV8=3UBjVpNf+65qU=T0g;k^8(WJk#&F+ceZ-Li%X*a)4+lf^ zMHZ?`!KoR^+I3$l}o0kXxl{q6f`FnEV{*$I}p4 z``es%MPJnb1`f>elvn28N7xJKLJqbWe%?YV?%13@Fb^#5rlwNCv{#LeaR-IZ@b3oK8uDy{_(HyokrP3FsRubGj=bx{k zC}FipN^BY($4_n(+$SSgFNui246_=^;3ezuYp`ki1;K3bj|VWW;{Q0n=VnROYLD_K z#;0D9Hqfw-Ps+6Nk|PG17D0@Ur_S!PI3l&>M`_r zymkCwYi(@6EF;;LP5M7P7Mtp!_oohDX|;zMZA@E2QFFeYSc5!+$uM!Z@R@5+RGyUU zTI7;G!a6CvY6>ynJMb*5JuOXJo$oiNOt*<0cMgfTULIC=kbz}B)iOUFi1bz<2-$(2*5@J=dr~Q z=bYG2IW!Yku45`HQtUiL>N~X9d?AC|OsMC@P=N?e@<#y4o!iYe2wbl1O-T%7U tgkENEn5}XY;>^#o`P-lS|HESQJL}2|lT${2vH$7>EKIFTDo?m4{13eiR+s<) literal 0 HcmV?d00001 diff --git a/lenet5/main.py b/lenet5/main.py new file mode 100644 index 0000000..e767ecd --- /dev/null +++ b/lenet5/main.py @@ -0,0 +1,85 @@ +# LeNet5 MNIST + +import os +# os.environ['DEVICE_ID'] = '0' + +import mindspore as ms +import mindspore.context as context +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset.transforms.vision.c_transforms as CV + +from mindspore import nn +from mindspore.train import Model +from mindspore.train.callback import LossMonitor + +context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') # Ascend, CPU, GPU + + +def create_dataset(data_dir, training=True, batch_size=32, resize=(32, 32), + rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64): + data_train = os.path.join(data_dir, 'train') # 训练集信息 + data_test = os.path.join(data_dir, 'test') # 测试集信息 + ds = ms.dataset.MnistDataset(data_train if training else data_test) + + ds = ds.map(input_columns=["image"], operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()]) + ds = ds.map(input_columns=["label"], operations=C.TypeCast(ms.int32)) + # When `dataset_sink_mode=True` on Ascend, append `ds = ds.repeat(num_epochs) to the end + ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True) + + return ds + + +class LeNet5(nn.Cell): + def __init__(self): + super(LeNet5, self).__init__() + self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid') + self.relu = nn.ReLU() + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + self.fc1 = nn.Dense(400, 120) + self.fc2 = nn.Dense(120, 84) + self.fc3 = nn.Dense(84, 10) + + def construct(self, x): + x = self.relu(self.conv1(x)) + x = self.pool(x) + x = self.relu(self.conv2(x)) + x = self.pool(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + + return x + + +def train(data_dir, lr=0.01, momentum=0.9, num_epochs=3): + ds_train = create_dataset(data_dir) + ds_eval = create_dataset(data_dir, training=False) + + net = LeNet5() + loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') + opt = nn.Momentum(net.trainable_params(), lr, momentum) + loss_cb = LossMonitor(per_print_times=ds_train.get_dataset_size()) + + model = Model(net, loss, opt, metrics={'acc', 'loss'}) + # dataset_sink_mode can be True when using Ascend + model.train(num_epochs, ds_train, callbacks=[loss_cb], dataset_sink_mode=False) + metrics = model.eval(ds_eval, dataset_sink_mode=False) + print('Metrics:', metrics) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--data_url', required=False, default='MNIST', help='Location of data.') + parser.add_argument('--train_url', required=False, default=None, help='Location of training outputs.') + args, unknown = parser.parse_known_args() + + if args.data_url.startswith('s3'): + import moxing + moxing.file.copy_parallel(src_url=args.data_url, dst_url='MNIST') + args.data_url = 'MNIST' + + train(args.data_url) -- GitLab