提交 4ea05a54 编写于 作者: S ShusenTang

add code 9.6.0: prepare pikachu data

上级 e7179e4a
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 9.6.0 准备皮卡丘数据集"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from mxnet.gluon import utils as gutils # pip install mxnet\n",
"from mxnet import image\n",
"\n",
"data_dir = '../../data/pikachu'\n",
"os.makedirs(data_dir, exist_ok=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. 下载原始数据集\n",
"见http://zh.d2l.ai/chapter_computer-vision/object-detection-dataset.html"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def _download_pikachu(data_dir):\n",
" root_url = ('https://apache-mxnet.s3-accelerate.amazonaws.com/'\n",
" 'gluon/dataset/pikachu/')\n",
" dataset = {'train.rec': 'e6bcb6ffba1ac04ff8a9b1115e650af56ee969c8',\n",
" 'train.idx': 'dcf7318b2602c06428b9988470c731621716c393',\n",
" 'val.rec': 'd6c33f799b4d058e82f2cb5bd9a976f69d72d520'}\n",
" for k, v in dataset.items():\n",
" gutils.download(root_url + k, os.path.join(data_dir, k), sha1_hash=v)\n",
"\n",
"if not os.path.exists(os.path.join(data_dir, \"train.rec\")):\n",
" print(\"下载原始数据集到%s...\" % data_dir)\n",
" _download_pikachu(data_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. MXNet数据迭代器"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def load_data_pikachu(batch_size, edge_size=256): # edge_size:输出图像的宽和高\n",
" train_iter = image.ImageDetIter(\n",
" path_imgrec=os.path.join(data_dir, 'train.rec'),\n",
" path_imgidx=os.path.join(data_dir, 'train.idx'),\n",
" batch_size=batch_size,\n",
" data_shape=(3, edge_size, edge_size), # 输出图像的形状\n",
"# shuffle=False, # 以随机顺序读取数据集\n",
"# rand_crop=1, # 随机裁剪的概率为1\n",
" min_object_covered=0.95, max_attempts=200)\n",
" val_iter = image.ImageDetIter(\n",
" path_imgrec=os.path.join(data_dir, 'val.rec'), batch_size=batch_size,\n",
" data_shape=(3, edge_size, edge_size), shuffle=False)\n",
" return train_iter, val_iter"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((3, 256, 256), (1, 5))"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batch_size, edge_size = 1, 256\n",
"train_iter, val_iter = load_data_pikachu(batch_size, edge_size)\n",
"batch = train_iter.next()\n",
"batch.data[0][0].shape, batch.label[0][0].shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. 转换成PNG图片并保存"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def process(data_iter, save_dir):\n",
" \"\"\"batch size == 1\"\"\"\n",
" data_iter.reset() # 从头开始\n",
" all_label = dict()\n",
" id = 1\n",
" os.makedirs(os.path.join(save_dir, 'images'), exist_ok=True)\n",
" for sample in tqdm(data_iter):\n",
" x = sample.data[0][0].asnumpy().transpose((1,2,0))\n",
" plt.imsave(os.path.join(save_dir, 'images', str(id) + '.png'), x / 255.0)\n",
"\n",
" y = sample.label[0][0][0].asnumpy()\n",
"\n",
" label = {}\n",
" label[\"class\"] = int(y[0])\n",
" label[\"loc\"] = y[1:].tolist()\n",
"\n",
" all_label[str(id) + '.png'] = label.copy()\n",
"\n",
" id += 1\n",
"\n",
" with open(os.path.join(save_dir, 'label.json'), 'w') as f:\n",
" json.dump(all_label, f, indent=True)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"900it [00:40, 22.03it/s]\n"
]
}
],
"source": [
"process(data_iter = train_iter, save_dir = os.path.join(data_dir, \"train\"))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100it [00:04, 22.86it/s]\n"
]
}
],
"source": [
"process(data_iter = val_iter, save_dir = os.path.join(data_dir, \"val\"))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册