diff --git a/code/chapter09_computer-vision/9.6.0_prepare_pikachu.ipynb b/code/chapter09_computer-vision/9.6.0_prepare_pikachu.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..abc0d53e971c59b3b84abc28e22ed5d6dfc1773c --- /dev/null +++ b/code/chapter09_computer-vision/9.6.0_prepare_pikachu.ipynb @@ -0,0 +1,204 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 9.6.0 准备皮卡丘数据集" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "from tqdm import tqdm\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from mxnet.gluon import utils as gutils # pip install mxnet\n", + "from mxnet import image\n", + "\n", + "data_dir = '../../data/pikachu'\n", + "os.makedirs(data_dir, exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. 下载原始数据集\n", + "见http://zh.d2l.ai/chapter_computer-vision/object-detection-dataset.html" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def _download_pikachu(data_dir):\n", + " root_url = ('https://apache-mxnet.s3-accelerate.amazonaws.com/'\n", + " 'gluon/dataset/pikachu/')\n", + " dataset = {'train.rec': 'e6bcb6ffba1ac04ff8a9b1115e650af56ee969c8',\n", + " 'train.idx': 'dcf7318b2602c06428b9988470c731621716c393',\n", + " 'val.rec': 'd6c33f799b4d058e82f2cb5bd9a976f69d72d520'}\n", + " for k, v in dataset.items():\n", + " gutils.download(root_url + k, os.path.join(data_dir, k), sha1_hash=v)\n", + "\n", + "if not os.path.exists(os.path.join(data_dir, \"train.rec\")):\n", + " print(\"下载原始数据集到%s...\" % data_dir)\n", + " _download_pikachu(data_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. MXNet数据迭代器" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def load_data_pikachu(batch_size, edge_size=256): # edge_size:输出图像的宽和高\n", + " train_iter = image.ImageDetIter(\n", + " path_imgrec=os.path.join(data_dir, 'train.rec'),\n", + " path_imgidx=os.path.join(data_dir, 'train.idx'),\n", + " batch_size=batch_size,\n", + " data_shape=(3, edge_size, edge_size), # 输出图像的形状\n", + "# shuffle=False, # 以随机顺序读取数据集\n", + "# rand_crop=1, # 随机裁剪的概率为1\n", + " min_object_covered=0.95, max_attempts=200)\n", + " val_iter = image.ImageDetIter(\n", + " path_imgrec=os.path.join(data_dir, 'val.rec'), batch_size=batch_size,\n", + " data_shape=(3, edge_size, edge_size), shuffle=False)\n", + " return train_iter, val_iter" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((3, 256, 256), (1, 5))" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch_size, edge_size = 1, 256\n", + "train_iter, val_iter = load_data_pikachu(batch_size, edge_size)\n", + "batch = train_iter.next()\n", + "batch.data[0][0].shape, batch.label[0][0].shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. 转换成PNG图片并保存" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def process(data_iter, save_dir):\n", + " \"\"\"batch size == 1\"\"\"\n", + " data_iter.reset() # 从头开始\n", + " all_label = dict()\n", + " id = 1\n", + " os.makedirs(os.path.join(save_dir, 'images'), exist_ok=True)\n", + " for sample in tqdm(data_iter):\n", + " x = sample.data[0][0].asnumpy().transpose((1,2,0))\n", + " plt.imsave(os.path.join(save_dir, 'images', str(id) + '.png'), x / 255.0)\n", + "\n", + " y = sample.label[0][0][0].asnumpy()\n", + "\n", + " label = {}\n", + " label[\"class\"] = int(y[0])\n", + " label[\"loc\"] = y[1:].tolist()\n", + "\n", + " all_label[str(id) + '.png'] = label.copy()\n", + "\n", + " id += 1\n", + "\n", + " with open(os.path.join(save_dir, 'label.json'), 'w') as f:\n", + " json.dump(all_label, f, indent=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "900it [00:40, 22.03it/s]\n" + ] + } + ], + "source": [ + "process(data_iter = train_iter, save_dir = os.path.join(data_dir, \"train\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100it [00:04, 22.86it/s]\n" + ] + } + ], + "source": [ + "process(data_iter = val_iter, save_dir = os.path.join(data_dir, \"val\"))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}