diff --git a/experiment_5/LeNet_MNIST_Windows.md b/experiment_5/LeNet_MNIST_Windows.md new file mode 100644 index 0000000000000000000000000000000000000000..002be9e61304091e44b55bbc096ac3ea3ad22007 --- /dev/null +++ b/experiment_5/LeNet_MNIST_Windows.md @@ -0,0 +1,171 @@ +# 在Windows上运行LeNet_MNIST + +## 实验介绍 + +LeNet5 + MINST被誉为深度学习领域的“Hello world”。本实验主要介绍使用MindSpore在Windows环境下MNIST数据集上开发和训练一个LeNet5模型,并验证模型精度。 + +## 实验目的 + +- 了解如何使用MindSpore进行简单卷积神经网络的开发。 +- 了解如何使用MindSpore进行简单图片分类任务的训练。 +- 了解如何使用MindSpore进行简单图片分类任务的验证。 + +## 预备知识 + +- 熟练使用Python,了解Shell及Linux操作系统基本知识。 +- 具备一定的深度学习理论知识,如卷积神经网络、损失函数、优化器,训练策略等。 +- 了解并熟悉MindSpore AI计算框架,MindSpore官网:[https://www.mindspore.cn](https://www.mindspore.cn/) + +## 实验环境 + +- Windows-x64版本MindSpore 0.3.0;安装命令可见官网: + + [https://www.mindspore.cn/install](https://www.mindspore.cn/install)(MindSpore版本会定期更新,本指导也会定期刷新,与版本配套)。 + +## 实验准备 + +### 创建目录 + +创建一个experiment文件夹,用于存放实验所需的文件代码等。 + +### 数据集准备 + +MNIST是一个手写数字数据集,训练集包含60000张手写数字,测试集包含10000张手写数字,共10类。MNIST数据集的官网:[THE MNIST DATABASE](http://yann.lecun.com/exdb/mnist/)。 + +从MNIST官网下载如下4个文件到本地并解压: + +``` +train-images-idx3-ubyte.gz: training set images (9912422 bytes) +train-labels-idx1-ubyte.gz: training set labels (28881 bytes) +t10k-images-idx3-ubyte.gz: test set images (1648877 bytes) +t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes) +``` + +### 脚本准备 + +从[课程gitee仓库](https://gitee.com/mindspore/course)上下载本实验相关脚本。 + +### 准备文件 + +将脚本和数据集放到到experiment文件夹中,组织为如下形式: + +``` +experiment +├── MNIST +│ ├── test +│ │ ├── t10k-images-idx3-ubyte +│ │ └── t10k-labels-idx1-ubyte +│ └── train +│ ├── train-images-idx3-ubyte +│ └── train-labels-idx1-ubyte +└── main.py +``` + +## 实验步骤 + +### 导入MindSpore模块和辅助模块 + +```python +import matplotlib.pyplot as plt + +import mindspore as ms +import mindspore.context as context +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset.transforms.vision.c_transforms as CV + +from mindspore import nn +from mindspore.model_zoo.lenet import LeNet5 +from mindspore.train import Model +from mindspore.train.callback import LossMonitor + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') +``` + +### 数据处理 + +在使用数据集训练网络前,首先需要对数据进行预处理,如下: + +```python +DATA_DIR_TRAIN = "MNIST/train" # 训练集信息 +DATA_DIR_TEST = "MNIST/test" # 测试集信息 + +def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32), rescale=1 / (255 * 0.3081), shift=-0.1307 / 0.3081, buffer_size=64): + ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST) + + ds = ds.map(input_columns="image", operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()]) + ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32)) + ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True).repeat(num_epoch) + + return ds +``` + +对其中几张图片进行可视化,可以看到图片中的手写数字,图片的大小为32x32。 + +```python +def show_dataset(): + ds = create_dataset(training=False) + data = ds.create_dict_iterator().get_next() + images = data['image'] + labels = data['label'] + + for i in range(1, 5): + plt.subplot(2, 2, i) + plt.imshow(images[i][0]) + plt.title('Number: %s' % labels[i]) + plt.xticks([]) + plt.show() +``` + +![img]() + +### 定义模型 + +MindSpore model_zoo中提供了多种常见的模型,可以直接使用。这里使用其中的LeNet5模型,模型结构如下图所示: + +![img](https://www.mindspore.cn/tutorial/zh-CN/master/_images/LeNet_5.jpg) + +图片来源于http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf + +### 训练 + +使用MNIST数据集对上述定义的LeNet5模型进行训练。训练策略如下表所示,可以调整训练策略并查看训练效果,要求验证精度大于95%。 + +| batch size | number of epochs | learning rate | optimizer | +| ---------: | ---------------: | ------------: | -----------: | +| 32 | 3 | 0.01 | Momentum 0.9 | + +```python +def test_train(lr=0.01, momentum=0.9, num_epoch=3, ckpt_name="a_lenet"): + ds_train = create_dataset(num_epoch=num_epoch) + ds_eval = create_dataset(training=False) + + net = LeNet5() + loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') + opt = nn.Momentum(net.trainable_params(), lr, momentum) + + loss_cb = LossMonitor(per_print_times=1) + + model = Model(net, loss, opt, metrics={'acc', 'loss'}) + model.train(num_epoch, ds_train, callbacks=[loss_cb], dataset_sink_mode=False) + metrics = model.eval(ds_eval, dataset_sink_mode=False) + print('Metrics:', metrics) +``` + +### 实验结果 + +1. 在训练日志中可以看到`epoch: 1 step: 1875, loss is 0.29772663`等字段,即训练过程的loss值; +2. 在训练日志中可以看到`Metrics: {'loss': 0.06830393138807267, 'acc': 0.9785657051282052}`字段,即训练完成后的验证精度。 + +```python + ... +>>> epoch: 1 step: 1875, loss is 0.29772663 + ... +>>> epoch: 2 step: 1875, loss is 0.049111396 + ... +>>> epoch: 3 step: 1875, loss is 0.08183163 +>>> Metrics: {'loss': 0.06830393138807267, 'acc': 0.9785657051282052} +``` + +## 实验小结 + +本实验展示了如何使用MindSpore进行手写数字识别,以及开发和训练LeNet5模型。通过对LeNet5模型做几代的训练,然后使用训练后的LeNet5模型对手写数字进行识别,识别准确率大于95%。即LeNet5学习到了如何进行手写数字识别。 \ No newline at end of file diff --git a/experiment_5/main.py b/experiment_5/main.py new file mode 100644 index 0000000000000000000000000000000000000000..2d4c691cb687a13776e6200198efebd1740fe159 --- /dev/null +++ b/experiment_5/main.py @@ -0,0 +1,62 @@ +# LeNet5 mnist +import matplotlib.pyplot as plt + +import mindspore as ms +import mindspore.context as context +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset.transforms.vision.c_transforms as CV + +from mindspore import nn +from mindspore.model_zoo.lenet import LeNet5 +from mindspore.train import Model +from mindspore.train.callback import LossMonitor + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + +DATA_DIR_TRAIN = "MNIST/train" # 训练集信息 +DATA_DIR_TEST = "MNIST/test" # 测试集信息 + + +def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32), + rescale=1 / (255 * 0.3081), shift=-0.1307 / 0.3081, buffer_size=64): + ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST) + + ds = ds.map(input_columns="image", operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()]) + ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32)) + ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True).repeat(num_epoch) + + return ds + + +def test_train(lr=0.01, momentum=0.9, num_epoch=3, ckpt_name="a_lenet"): + ds_train = create_dataset(num_epoch=num_epoch) + ds_eval = create_dataset(training=False) + + net = LeNet5() + loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') + opt = nn.Momentum(net.trainable_params(), lr, momentum) + + loss_cb = LossMonitor(per_print_times=1) + + model = Model(net, loss, opt, metrics={'acc', 'loss'}) + model.train(num_epoch, ds_train, callbacks=[loss_cb], dataset_sink_mode=False) + metrics = model.eval(ds_eval, dataset_sink_mode=False) + print('Metrics:', metrics) + +def show_dataset(): + ds = create_dataset(training=False) + data = ds.create_dict_iterator().get_next() + images = data['image'] + labels = data['label'] + + for i in range(1, 5): + plt.subplot(2, 2, i) + plt.imshow(images[i][0]) + plt.title('Number: %s' % labels[i]) + plt.xticks([]) + plt.show() + +if __name__ == "__main__": + show_dataset() + + test_train() \ No newline at end of file diff --git a/experiment_6/Save_And_Load_Model_Windows.md b/experiment_6/Save_And_Load_Model_Windows.md new file mode 100644 index 0000000000000000000000000000000000000000..1c1c805607ad4ef971186d4110ceb380139ad7c4 --- /dev/null +++ b/experiment_6/Save_And_Load_Model_Windows.md @@ -0,0 +1,346 @@ +# 在Windows上运行训练时模型的保存和加载 + +## 实验介绍 + +本实验主要介绍在Windows环境下使用MindSpore实现训练时模型的保存和加载。建议先阅读MindSpore官网教程中关于模型参数保存和加载的内容。 + +在模型训练过程中,可以添加检查点(CheckPoint)用于保存模型的参数,以便进行推理及中断后再训练使用。使用场景如下: + +- 训练后推理场景 + - 模型训练完毕后保存模型的参数,用于推理或预测操作。 + - 训练过程中,通过实时验证精度,把精度最高的模型参数保存下来,用于预测操作。 +- 再训练场景 + - 进行长时间训练任务时,保存训练过程中的CheckPoint文件,防止任务异常退出后从初始状态开始训练。 + - Fine-tuning(微调)场景,即训练一个模型并保存参数,基于该模型,面向第二个类似任务进行模型训练。 + +## 实验目的 + +- 了解如何使用MindSpore实现训练时模型的保存。 +- 了解如何使用MindSpore加载保存的模型文件并继续训练。 +- 了解如何MindSpore的Callback功能。 + +## 预备知识 + +- 熟练使用Python,了解Shell及Linux操作系统基本知识。 +- 具备一定的深度学习理论知识,如卷积神经网络、损失函数、优化器,训练策略、Checkpoint等。 +- 了解并熟悉MindSpore AI计算框架,MindSpore官网:https://www.mindspore.cn/ + +## 实验环境 + +- Windows-x64版本MindSpore 0.3.0;安装命令可见官网: + + [https://www.mindspore.cn/install](https://www.mindspore.cn/install)(MindSpore版本会定期更新,本指导也会定期刷新,与版本配套)。 + +## 实验准备 + +### 创建目录 + +创建一个experiment文件夹,用于存放实验所需的文件代码等。 + +### 数据集准备 + +MNIST是一个手写数字数据集,训练集包含60000张手写数字,测试集包含10000张手写数字,共10类。MNIST数据集的官网:[THE MNIST DATABASE](http://yann.lecun.com/exdb/mnist/)。 + +从MNIST官网下载如下4个文件到本地并解压: + +``` +train-images-idx3-ubyte.gz: training set images (9912422 bytes) +train-labels-idx1-ubyte.gz: training set labels (28881 bytes) +t10k-images-idx3-ubyte.gz: test set images (1648877 bytes) +t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes) +``` + +### 脚本准备 + +从[课程gitee仓库](https://gitee.com/mindspore/course)上下载本实验相关脚本。 + +### 准备文件 + +将脚本和数据集放到到experiment文件夹中,组织为如下形式: + +``` +experiment +├── MNIST +│ ├── test +│ │ ├── t10k-images-idx3-ubyte +│ │ └── t10k-labels-idx1-ubyte +│ └── train +│ ├── train-images-idx3-ubyte +│ └── train-labels-idx1-ubyte +└── main.py +``` + +## 实验步骤 + +### 导入MindSpore模块和辅助模块 + +```python +import matplotlib.pyplot as plt +import numpy as np + +import mindspore as ms +import mindspore.context as context +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset.transforms.vision.c_transforms as CV + +from mindspore import nn, Tensor +from mindspore.train import Model +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') +``` + +### 数据处理 + +在使用数据集训练网络前,首先需要对数据进行预处理,如下: + +```python +DATA_DIR_TRAIN = "MNIST/train" # 训练集信息 +DATA_DIR_TEST = "MNIST/test" # 测试集信息 + + +def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32), + rescale=1 / (255 * 0.3081), shift=-0.1307 / 0.3081, buffer_size=64): + ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST) + + # define map operations + resize_op = CV.Resize(resize) + rescale_op = CV.Rescale(rescale, shift) + hwc2chw_op = CV.HWC2CHW() + + # apply map operations on images + ds = ds.map(input_columns="image", operations=[resize_op, rescale_op, hwc2chw_op]) + ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32)) + + ds = ds.shuffle(buffer_size=buffer_size) + ds = ds.batch(batch_size, drop_remainder=True) + ds = ds.repeat(num_epoch) + + return ds +``` + +### 定义模型 + +定义LeNet5模型,模型结构如下图所示: + +![img](https://www.mindspore.cn/tutorial/zh-CN/master/_images/LeNet_5.jpg) + +图片来源于http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf + +```python +class LeNet5(nn.Cell): + def __init__(self): + super(LeNet5, self).__init__() + self.relu = nn.ReLU() + self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid') + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + self.fc1 = nn.Dense(400, 120) + self.fc2 = nn.Dense(120, 84) + self.fc3 = nn.Dense(84, 10) + + def construct(self, input_x): + output = self.conv1(input_x) + output = self.relu(output) + output = self.pool(output) + output = self.conv2(output) + output = self.relu(output) + output = self.pool(output) + output = self.flatten(output) + output = self.fc1(output) + output = self.fc2(output) + output = self.fc3(output) + + return output +``` + +### 保存模型Checkpoint + +MindSpore提供了Callback功能,可用于训练/测试过程中执行特定的任务。常用的Callback如下: + +- `ModelCheckpoint`:保存网络模型和参数,用于再训练或推理; +- `LossMonitor`:监控loss值,当loss值为Nan或Inf时停止训练; +- `SummaryStep`:把训练过程中的信息存储到文件中,用于后续查看或可视化展示。 + +`ModelCheckpoint`会生成模型(.meta)和Chekpoint(.ckpt)文件,如每个epoch结束时,都保存一次checkpoint。 + +```python +class CheckpointConfig: + """ + The config for model checkpoint. + + Args: + save_checkpoint_steps (int): Steps to save checkpoint. Default: 1. + save_checkpoint_seconds (int): Seconds to save checkpoint. Default: 0. + Can't be used with save_checkpoint_steps at the same time. + keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5. + keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0. + Can't be used with keep_checkpoint_max at the same time. + integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True. + Integrated save function is only supported in automatic parallel scene, not supported in manual parallel. + + Raises: + ValueError: If the input_param is None or 0. + """ + +class ModelCheckpoint(Callback): + """ + The checkpoint callback class. + + It is called to combine with train process and save the model and network parameters after traning. + + Args: + prefix (str): Checkpoint files names prefix. Default: "CKP". + directory (str): Lolder path into which checkpoint files will be saved. Default: None. + config (CheckpointConfig): Checkpoint strategy config. Default: None. + + Raises: + ValueError: If the prefix is invalid. + TypeError: If the config is not CheckpointConfig type. + """ +``` + +MindSpore提供了多种Metric评估指标,如`accuracy`、`loss`、`precision`、`recall`、`F1`。定义一个metrics字典/元组,里面包含多种指标,传递给`Model`,然后调用`model.eval`接口来计算这些指标。`model.eval`会返回一个字典,包含各个指标及其对应的值。 + +```python +def test_train(lr=0.01, momentum=0.9, num_epoch=2, check_point_name="b_lenet"): + ds_train = create_dataset(num_epoch=num_epoch) + ds_eval = create_dataset(training=False) + steps_per_epoch = ds_train.get_dataset_size() + + net = LeNet5() + loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') + opt = nn.Momentum(net.trainable_params(), lr, momentum) + + ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5) + ckpt_cb = ModelCheckpoint(prefix=check_point_name, config=ckpt_cfg) + loss_cb = LossMonitor(steps_per_epoch) + + model = Model(net, loss, opt, metrics={'acc', 'loss'}) + model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=False) + metrics = model.eval(ds_eval, dataset_sink_mode=False) + print('Metrics:', metrics) +``` + +### 加载Checkpoint继续训练 + +```python +def load_checkpoint(ckpoint_file_name, net=None): + """ + Loads checkpoint info from a specified file. + + Args: + ckpoint_file_name (str): Checkpoint file name. + net (Cell): Cell network. Default: None + + Returns: + Dict, key is parameter name, value is a Parameter. + + Raises: + ValueError: Checkpoint file is incorrect. + """ + +def load_param_into_net(net, parameter_dict): + """ + Loads parameters into network. + + Args: + net (Cell): Cell network. + parameter_dict (dict): Parameter dict. + + Raises: + TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dict. + """ +``` + +使用load_checkpoint接口加载数据时,需要把数据传入给原始网络,而不能传递给带有优化器和损失函数的训练网络。 + +```python +CKPT = 'b_lenet-2_1875.ckpt' + +def resume_train(lr=0.001, momentum=0.9, num_epoch=2, ckpt_name="b_lenet"): + ds_train = create_dataset(num_epoch=num_epoch) + ds_eval = create_dataset(training=False) + steps_per_epoch = ds_train.get_dataset_size() + + net = LeNet5() + loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') + opt = nn.Momentum(net.trainable_params(), lr, momentum) + + param_dict = load_checkpoint(CKPT) + load_param_into_net(net, param_dict) + load_param_into_net(opt, param_dict) + + ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5) + ckpt_cb = ModelCheckpoint(prefix=ckpt_name, config=ckpt_cfg) + loss_cb = LossMonitor(steps_per_epoch) + + model = Model(net, loss, opt, metrics={'acc', 'loss'}) + model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=False) + metrics = model.eval(ds_eval, dataset_sink_mode=False) + print('Metrics:', metrics) +``` + +### 加载Checkpoint进行推理 + +使用matplotlib定义一个将推理结果可视化的辅助函数,如下: + +```python +def plot_images(pred_fn, ds, net): + for i in range(1, 5): + pred, image, label = pred_fn(ds, net) + plt.subplot(2, 2, i) + plt.imshow(np.squeeze(image)) + color = 'blue' if pred == label else 'red' + plt.title("prediction: {}, truth: {}".format(pred, label), color=color) + plt.xticks([]) + plt.show() +``` + +使用训练后的LeNet5模型对手写数字进行识别,可以看到识别结果基本上是正确的。 + +```python +CKPT = 'b_lenet_1-2_1875.ckpt' + +def infer(ds, model): + data = ds.get_next() + images = data['image'] + labels = data['label'] + output = model.predict(Tensor(data['image'])) + pred = np.argmax(output.asnumpy(), axis=1) + return pred[0], images[0], labels[0] + +def test_infer(): + ds = create_dataset(training=False, batch_size=1).create_dict_iterator() + net = LeNet5() + param_dict = load_checkpoint(CKPT, net) + model = Model(net) + plot_images(infer, ds, model) +``` + +![img]() + +### 实验结果 + +1. 在训练日志中可以看到两阶段的loss值和验证精度打印,第一阶段为初始训练,第二阶段为加载Checkpoint继续训练; +2. 在训练目录里可以看到`b_lenet-graph.meta`、`b_lenet-2_1875.ckpt`等文件,即训练过程保存的Checkpoint。 + +```python +>>> epoch: 1 step: 1875, loss is 2.2984316 +>>> epoch: 2 step: 1875, loss is 0.06388051 +>>> Metrics: {'loss': 0.11160586341821517, 'acc': 0.9637419871794872} + +>>> epoch: 1 step: 1875, loss is 0.008898618 +>>> epoch: 2 step: 1875, loss is 0.05747048 +>>> Metrics: {'loss': 0.07453688951276351, 'acc': 0.9767628205128205} +``` + +## 实验小结 + +本实验展示了MindSpore的Checkpoint、断点继续训练等高级特性: + +1. 使用MindSpore的ModelCheckpoint接口每个epoch保存一次Checkpoint,训练2个epoch并终止。 +2. 使用MindSpore的load_checkpoint和load_param_into_net接口加载上一步保存的Checkpoint继续训练2个epoch。 +3. 观察训练过程中Loss的变化情况,加载Checkpoint继续训练后loss进一步下降。 \ No newline at end of file diff --git a/experiment_6/main.py b/experiment_6/main.py new file mode 100644 index 0000000000000000000000000000000000000000..0c74a752e51fd50684349dffbbb92324c77e9118 --- /dev/null +++ b/experiment_6/main.py @@ -0,0 +1,146 @@ +# Save and load model + +import matplotlib.pyplot as plt +import numpy as np + +import mindspore as ms +import mindspore.context as context +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset.transforms.vision.c_transforms as CV + +from mindspore import nn, Tensor +from mindspore.train import Model +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + +DATA_DIR_TRAIN = "MNIST/train" # 训练集信息 +DATA_DIR_TEST = "MNIST/test" # 测试集信息 + + +def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32), + rescale=1 / (255 * 0.3081), shift=-0.1307 / 0.3081, buffer_size=64): + ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST) + + # define map operations + resize_op = CV.Resize(resize) + rescale_op = CV.Rescale(rescale, shift) + hwc2chw_op = CV.HWC2CHW() + + # apply map operations on images + ds = ds.map(input_columns="image", operations=[resize_op, rescale_op, hwc2chw_op]) + ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32)) + + ds = ds.shuffle(buffer_size=buffer_size) + ds = ds.batch(batch_size, drop_remainder=True) + ds = ds.repeat(num_epoch) + + return ds + + +class LeNet5(nn.Cell): + def __init__(self): + super(LeNet5, self).__init__() + self.relu = nn.ReLU() + self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid') + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + self.fc1 = nn.Dense(400, 120) + self.fc2 = nn.Dense(120, 84) + self.fc3 = nn.Dense(84, 10) + + def construct(self, input_x): + output = self.conv1(input_x) + output = self.relu(output) + output = self.pool(output) + output = self.conv2(output) + output = self.relu(output) + output = self.pool(output) + output = self.flatten(output) + output = self.fc1(output) + output = self.fc2(output) + output = self.fc3(output) + + return output + + +def test_train(lr=0.01, momentum=0.9, num_epoch=2, check_point_name="b_lenet"): + ds_train = create_dataset(num_epoch=num_epoch) + ds_eval = create_dataset(training=False) + steps_per_epoch = ds_train.get_dataset_size() + + net = LeNet5() + loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') + opt = nn.Momentum(net.trainable_params(), lr, momentum) + + ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5) + ckpt_cb = ModelCheckpoint(prefix=check_point_name, config=ckpt_cfg) + loss_cb = LossMonitor(steps_per_epoch) + + model = Model(net, loss, opt, metrics={'acc', 'loss'}) + model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=False) + metrics = model.eval(ds_eval, dataset_sink_mode=False) + print('Metrics:', metrics) + + +CKPT = 'b_lenet-2_1875.ckpt' + + +def resume_train(lr=0.001, momentum=0.9, num_epoch=2, ckpt_name="b_lenet"): + ds_train = create_dataset(num_epoch=num_epoch) + ds_eval = create_dataset(training=False) + steps_per_epoch = ds_train.get_dataset_size() + + net = LeNet5() + loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') + opt = nn.Momentum(net.trainable_params(), lr, momentum) + + param_dict = load_checkpoint(CKPT) + load_param_into_net(net, param_dict) + load_param_into_net(opt, param_dict) + + ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5) + ckpt_cb = ModelCheckpoint(prefix=ckpt_name, config=ckpt_cfg) + loss_cb = LossMonitor(steps_per_epoch) + + model = Model(net, loss, opt, metrics={'acc', 'loss'}) + model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=False) + metrics = model.eval(ds_eval, dataset_sink_mode=False) + print('Metrics:', metrics) + +def plot_images(pred_fn, ds, net): + for i in range(1, 5): + pred, image, label = pred_fn(ds, net) + plt.subplot(2, 2, i) + plt.imshow(np.squeeze(image)) + color = 'blue' if pred == label else 'red' + plt.title("prediction: {}, truth: {}".format(pred, label), color=color) + plt.xticks([]) + plt.show() + +CKPT = 'b_lenet_1-2_1875.ckpt' + +def infer(ds, model): + data = ds.get_next() + images = data['image'] + labels = data['label'] + output = model.predict(Tensor(data['image'])) + pred = np.argmax(output.asnumpy(), axis=1) + return pred[0], images[0], labels[0] + +def test_infer(): + ds = create_dataset(training=False, batch_size=1).create_dict_iterator() + net = LeNet5() + param_dict = load_checkpoint(CKPT, net) + model = Model(net) + plot_images(infer, ds, model) + +if __name__ == "__main__": + + test_train() + + resume_train() + + test_infer() \ No newline at end of file