Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
a26546593
dive-into-dl-pytorch
提交
b4c13914
D
dive-into-dl-pytorch
项目概览
a26546593
/
dive-into-dl-pytorch
与 Fork 源项目一致
从无法访问的项目Fork
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
dive-into-dl-pytorch
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
b4c13914
编写于
10月 31, 2019
作者:
S
ShusenTang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add code 9.6 in d2lzh
上级
73ccd30e
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
52 addition
and
0 deletion
+52
-0
code/d2lzh_pytorch/utils.py
code/d2lzh_pytorch/utils.py
+52
-0
未找到文件。
code/d2lzh_pytorch/utils.py
浏览文件 @
b4c13914
...
...
@@ -1028,6 +1028,58 @@ def MultiBoxDetection(cls_prob, loc_pred, anchor, nms_threshold = 0.5):
# ################################# 9.6 ############################
class
PikachuDetDataset
(
torch
.
utils
.
data
.
Dataset
):
"""皮卡丘检测数据集类"""
def
__init__
(
self
,
data_dir
,
part
,
image_size
=
(
256
,
256
)):
assert
part
in
[
"train"
,
"val"
]
self
.
image_size
=
image_size
self
.
image_dir
=
os
.
path
.
join
(
data_dir
,
part
,
"images"
)
with
open
(
os
.
path
.
join
(
data_dir
,
part
,
"label.json"
))
as
f
:
self
.
label
=
json
.
load
(
f
)
self
.
transform
=
torchvision
.
transforms
.
Compose
([
# 将 PIL 图片转换成位于[0.0, 1.0]的floatTensor, shape (C x H x W)
torchvision
.
transforms
.
ToTensor
()])
def
__len__
(
self
):
return
len
(
self
.
label
)
def
__getitem__
(
self
,
index
):
image_path
=
str
(
index
+
1
)
+
".png"
cls
=
self
.
label
[
image_path
][
"class"
]
label
=
np
.
array
([
cls
]
+
self
.
label
[
image_path
][
"loc"
],
dtype
=
"float32"
)[
None
,
:]
PIL_img
=
Image
.
open
(
os
.
path
.
join
(
self
.
image_dir
,
image_path
)
).
convert
(
'RGB'
).
resize
(
self
.
image_size
)
img
=
self
.
transform
(
PIL_img
)
sample
=
{
"label"
:
label
,
# shape: (1, 5) [class, xmin, ymin, xmax, ymax]
"image"
:
img
# shape: (3, *image_size)
}
return
sample
def
load_data_pikachu
(
batch_size
,
edge_size
=
256
,
data_dir
=
'../../data/pikachu'
):
"""edge_size:输出图像的宽和高"""
image_size
=
(
edge_size
,
edge_size
)
train_dataset
=
PikachuDetDataset
(
data_dir
,
'train'
,
image_size
)
val_dataset
=
PikachuDetDataset
(
data_dir
,
'val'
,
image_size
)
train_iter
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
batch_size
,
shuffle
=
True
,
num_workers
=
4
)
val_iter
=
torch
.
utils
.
data
.
DataLoader
(
val_dataset
,
batch_size
=
batch_size
,
shuffle
=
False
,
num_workers
=
4
)
return
train_iter
,
val_iter
# ############################# 10.7 ##########################
def
read_imdb
(
folder
=
'train'
,
data_root
=
"/S1/CSCL/tangss/Datasets/aclImdb"
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录