Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
da522568
MegEngine
项目概览
MegEngine 天元
/
MegEngine
9 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
da522568
编写于
4月 07, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/data): process mnist without generate new files
GitOrigin-RevId: 44a697c3fe9197bf6ae8889afc0df72d6095cf1f
上级
538d3de9
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
80 addition
and
129 deletion
+80
-129
python_module/megengine/data/dataset/vision/cifar.py
python_module/megengine/data/dataset/vision/cifar.py
+5
-5
python_module/megengine/data/dataset/vision/cityscapes.py
python_module/megengine/data/dataset/vision/cityscapes.py
+4
-4
python_module/megengine/data/dataset/vision/coco.py
python_module/megengine/data/dataset/vision/coco.py
+15
-11
python_module/megengine/data/dataset/vision/folder.py
python_module/megengine/data/dataset/vision/folder.py
+2
-1
python_module/megengine/data/dataset/vision/imagenet.py
python_module/megengine/data/dataset/vision/imagenet.py
+14
-12
python_module/megengine/data/dataset/vision/mnist.py
python_module/megengine/data/dataset/vision/mnist.py
+33
-89
python_module/megengine/data/dataset/vision/utils.py
python_module/megengine/data/dataset/vision/utils.py
+2
-2
python_module/megengine/data/dataset/vision/voc.py
python_module/megengine/data/dataset/vision/voc.py
+5
-5
未找到文件。
python_module/megengine/data/dataset/vision/cifar.py
浏览文件 @
da522568
...
@@ -78,7 +78,7 @@ class CIFAR10(VisionDataset):
...
@@ -78,7 +78,7 @@ class CIFAR10(VisionDataset):
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"dir does not contain target file
\
"dir does not contain target file
\
%s,please set download=True"
%s,
please set download=True"
%
(
self
.
target_file
)
%
(
self
.
target_file
)
)
)
...
@@ -108,7 +108,7 @@ class CIFAR10(VisionDataset):
...
@@ -108,7 +108,7 @@ class CIFAR10(VisionDataset):
def
untar
(
self
,
file_path
,
dirs
):
def
untar
(
self
,
file_path
,
dirs
):
assert
file_path
.
endswith
(
".tar.gz"
)
assert
file_path
.
endswith
(
".tar.gz"
)
logger
.
debug
(
"untar file %s to %s"
%
(
file_path
,
dirs
)
)
logger
.
debug
(
"untar file %s to %s"
,
file_path
,
dirs
)
t
=
tarfile
.
open
(
file_path
)
t
=
tarfile
.
open
(
file_path
)
t
.
extractall
(
path
=
dirs
)
t
.
extractall
(
path
=
dirs
)
...
@@ -117,13 +117,13 @@ class CIFAR10(VisionDataset):
...
@@ -117,13 +117,13 @@ class CIFAR10(VisionDataset):
label
=
[]
label
=
[]
for
filename
in
filenames
:
for
filename
in
filenames
:
path
=
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_dir
,
filename
)
path
=
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_dir
,
filename
)
logger
.
debug
(
"unpickle file %s"
%
path
)
logger
.
debug
(
"unpickle file %s"
,
path
)
with
open
(
path
,
"rb"
)
as
fo
:
with
open
(
path
,
"rb"
)
as
fo
:
dic
=
pickle
.
load
(
fo
,
encoding
=
"bytes"
)
dic
=
pickle
.
load
(
fo
,
encoding
=
"bytes"
)
batch_data
=
dic
[
b
"data"
].
reshape
(
-
1
,
3
,
32
,
32
).
transpose
((
0
,
2
,
3
,
1
))
batch_data
=
dic
[
b
"data"
].
reshape
(
-
1
,
3
,
32
,
32
).
transpose
((
0
,
2
,
3
,
1
))
data
.
extend
(
list
(
batch_data
[...,
[
2
,
1
,
0
]]))
data
.
extend
(
list
(
batch_data
[...,
[
2
,
1
,
0
]]))
label
.
extend
(
dic
[
b
"labels"
])
label
.
extend
(
dic
[
b
"labels"
])
label
=
np
.
array
(
label
)
label
=
np
.
array
(
label
,
dtype
=
np
.
int32
)
return
(
data
,
label
)
return
(
data
,
label
)
def
process
(
self
):
def
process
(
self
):
...
@@ -153,7 +153,7 @@ class CIFAR100(CIFAR10):
...
@@ -153,7 +153,7 @@ class CIFAR100(CIFAR10):
coarse_label
=
[]
coarse_label
=
[]
for
filename
in
filenames
:
for
filename
in
filenames
:
path
=
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_dir
,
filename
)
path
=
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_dir
,
filename
)
logger
.
debug
(
"unpickle file %s"
%
path
)
logger
.
debug
(
"unpickle file %s"
,
path
)
with
open
(
path
,
"rb"
)
as
fo
:
with
open
(
path
,
"rb"
)
as
fo
:
dic
=
pickle
.
load
(
fo
,
encoding
=
"bytes"
)
dic
=
pickle
.
load
(
fo
,
encoding
=
"bytes"
)
batch_data
=
dic
[
b
"data"
].
reshape
(
-
1
,
3
,
32
,
32
).
transpose
((
0
,
2
,
3
,
1
))
batch_data
=
dic
[
b
"data"
].
reshape
(
-
1
,
3
,
32
,
32
).
transpose
((
0
,
2
,
3
,
1
))
...
...
python_module/megengine/data/dataset/vision/cityscapes.py
浏览文件 @
da522568
...
@@ -71,7 +71,7 @@ class Cityscapes(VisionDataset):
...
@@ -71,7 +71,7 @@ class Cityscapes(VisionDataset):
elif
k
==
"mask"
:
elif
k
==
"mask"
:
mask
=
cv2
.
imread
(
self
.
masks
[
index
],
cv2
.
IMREAD_GRAYSCALE
)
mask
=
cv2
.
imread
(
self
.
masks
[
index
],
cv2
.
IMREAD_GRAYSCALE
)
mask
=
self
.
_trans_mask
(
mask
)
mask
=
self
.
_trans_mask
(
mask
)
mask
=
mask
[:,
:,
None
]
mask
=
mask
[:,
:,
np
.
newaxis
]
target
.
append
(
mask
)
target
.
append
(
mask
)
elif
k
==
"info"
:
elif
k
==
"info"
:
if
image
is
None
:
if
image
is
None
:
...
@@ -109,9 +109,9 @@ class Cityscapes(VisionDataset):
...
@@ -109,9 +109,9 @@ class Cityscapes(VisionDataset):
33
,
33
,
]
]
label
=
np
.
ones
(
mask
.
shape
)
*
255
label
=
np
.
ones
(
mask
.
shape
)
*
255
for
i
in
range
(
len
(
trans_labels
)
):
for
i
,
tl
in
enumerate
(
trans_labels
):
label
[
mask
==
t
rans_labels
[
i
]
]
=
i
label
[
mask
==
t
l
]
=
i
return
label
.
astype
(
"uint8"
)
return
label
.
astype
(
np
.
uint8
)
def
_get_target_suffix
(
self
,
mode
,
target_type
):
def
_get_target_suffix
(
self
,
mode
,
target_type
):
if
target_type
==
"instance"
:
if
target_type
==
"instance"
:
...
...
python_module/megengine/data/dataset/vision/coco.py
浏览文件 @
da522568
...
@@ -139,7 +139,7 @@ class COCO(VisionDataset):
...
@@ -139,7 +139,7 @@ class COCO(VisionDataset):
target
.
append
(
image
)
target
.
append
(
image
)
elif
k
==
"boxes"
:
elif
k
==
"boxes"
:
boxes
=
[
obj
[
"bbox"
]
for
obj
in
anno
]
boxes
=
[
obj
[
"bbox"
]
for
obj
in
anno
]
boxes
=
np
.
array
(
boxes
).
reshape
(
-
1
,
4
)
boxes
=
np
.
array
(
boxes
,
dtype
=
np
.
float32
).
reshape
(
-
1
,
4
)
# transfer boxes from xywh to xyxy
# transfer boxes from xywh to xyxy
boxes
[:,
2
:]
+=
boxes
[:,
:
2
]
boxes
[:,
2
:]
+=
boxes
[:,
:
2
]
target
.
append
(
boxes
)
target
.
append
(
boxes
)
...
@@ -148,17 +148,21 @@ class COCO(VisionDataset):
...
@@ -148,17 +148,21 @@ class COCO(VisionDataset):
boxes_category
=
[
boxes_category
=
[
self
.
json_category_id_to_contiguous_id
[
c
]
for
c
in
boxes_category
self
.
json_category_id_to_contiguous_id
[
c
]
for
c
in
boxes_category
]
]
boxes_category
=
np
.
array
(
boxes_category
)
boxes_category
=
np
.
array
(
boxes_category
,
dtype
=
np
.
int32
)
target
.
append
(
boxes_category
)
target
.
append
(
boxes_category
)
# TODO: need to check
elif
k
==
"keypoints"
:
# elif k == "keypoints":
keypoints
=
[
obj
[
"keypoints"
]
for
obj
in
anno
]
# keypoints = [obj["keypoints"] for obj in anno]
keypoints
=
np
.
array
(
keypoints
,
dtype
=
np
.
float32
).
reshape
(
# keypoints = np.array(keypoints).reshape(-1, len(self.keypoint_names), 3)
-
1
,
len
(
self
.
keypoint_names
),
3
# target.append(keypoints)
)
# elif k == "polygons":
target
.
append
(
keypoints
)
# polygons = [obj["segmentation"] for obj in anno]
elif
k
==
"polygons"
:
# polygons = [[np.array(p).reshape(-1, 2) for p in ps] for ps in polygons]
polygons
=
[
obj
[
"segmentation"
]
for
obj
in
anno
]
# target.append(polygons)
polygons
=
[
[
np
.
array
(
p
,
dtype
=
np
.
float32
).
reshape
(
-
1
,
2
)
for
p
in
ps
]
for
ps
in
polygons
]
target
.
append
(
polygons
)
elif
k
==
"info"
:
elif
k
==
"info"
:
info
=
self
.
imgs
[
img_id
]
info
=
self
.
imgs
[
img_id
]
info
=
[
info
[
"height"
],
info
[
"width"
],
info
[
"file_name"
]]
info
=
[
info
[
"height"
],
info
[
"width"
],
info
[
"file_name"
]]
...
...
python_module/megengine/data/dataset/vision/folder.py
浏览文件 @
da522568
...
@@ -19,6 +19,7 @@ import os
...
@@ -19,6 +19,7 @@ import os
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
import
cv2
import
cv2
import
numpy
as
np
from
.meta_vision
import
VisionDataset
from
.meta_vision
import
VisionDataset
from
.utils
import
is_img
from
.utils
import
is_img
...
@@ -78,7 +79,7 @@ class ImageFolder(VisionDataset):
...
@@ -78,7 +79,7 @@ class ImageFolder(VisionDataset):
def
collect_class
(
self
)
->
Dict
:
def
collect_class
(
self
)
->
Dict
:
classes
=
[
d
.
name
for
d
in
os
.
scandir
(
self
.
root
)
if
d
.
is_dir
()]
classes
=
[
d
.
name
for
d
in
os
.
scandir
(
self
.
root
)
if
d
.
is_dir
()]
classes
.
sort
()
classes
.
sort
()
return
{
classes
[
i
]:
i
for
i
in
range
(
len
(
classes
))}
return
{
classes
[
i
]:
np
.
int32
(
i
)
for
i
in
range
(
len
(
classes
))}
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
:
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
:
path
,
label
=
self
.
samples
[
index
]
path
,
label
=
self
.
samples
[
index
]
...
...
python_module/megengine/data/dataset/vision/imagenet.py
浏览文件 @
da522568
...
@@ -93,7 +93,7 @@ class ImageNet(ImageFolder):
...
@@ -93,7 +93,7 @@ class ImageNet(ImageFolder):
self
.
devkit_dir
=
os
.
path
.
join
(
self
.
root
,
self
.
default_devkit_dir
)
self
.
devkit_dir
=
os
.
path
.
join
(
self
.
root
,
self
.
default_devkit_dir
)
if
not
os
.
path
.
exists
(
self
.
devkit_dir
):
if
not
os
.
path
.
exists
(
self
.
devkit_dir
):
logger
.
warning
(
"devkit directory %s does not exists"
%
self
.
devkit_dir
)
logger
.
warning
(
"devkit directory %s does not exists"
,
self
.
devkit_dir
)
self
.
_prepare_devkit
()
self
.
_prepare_devkit
()
self
.
train
=
train
self
.
train
=
train
...
@@ -105,8 +105,8 @@ class ImageNet(ImageFolder):
...
@@ -105,8 +105,8 @@ class ImageNet(ImageFolder):
if
not
os
.
path
.
exists
(
self
.
target_folder
):
if
not
os
.
path
.
exists
(
self
.
target_folder
):
logger
.
warning
(
logger
.
warning
(
"expected image folder %s does not exist, try to load from raw file"
"expected image folder %s does not exist, try to load from raw file"
,
%
self
.
target_folder
self
.
target_folder
,
)
)
if
not
self
.
check_raw_file
():
if
not
self
.
check_raw_file
():
raise
FileNotFoundError
(
raise
FileNotFoundError
(
...
@@ -117,8 +117,10 @@ class ImageNet(ImageFolder):
...
@@ -117,8 +117,10 @@ class ImageNet(ImageFolder):
raise
RuntimeError
(
raise
RuntimeError
(
"extracting raw file shouldn't be done in distributed mode, use single process instead"
"extracting raw file shouldn't be done in distributed mode, use single process instead"
)
)
elif
train
:
self
.
_prepare_train
()
else
:
else
:
self
.
_prepare_
train
()
if
train
else
self
.
_prepare_
val
()
self
.
_prepare_val
()
super
().
__init__
(
self
.
target_folder
,
**
kwargs
)
super
().
__init__
(
self
.
target_folder
,
**
kwargs
)
...
@@ -145,12 +147,12 @@ class ImageNet(ImageFolder):
...
@@ -145,12 +147,12 @@ class ImageNet(ImageFolder):
try
:
try
:
return
load
(
os
.
path
.
join
(
self
.
devkit_dir
,
"meta.pkl"
))
return
load
(
os
.
path
.
join
(
self
.
devkit_dir
,
"meta.pkl"
))
except
FileNotFoundError
:
except
FileNotFoundError
:
import
scipy.io
as
sio
import
scipy.io
meta_path
=
os
.
path
.
join
(
self
.
devkit_dir
,
"data"
,
"meta.mat"
)
meta_path
=
os
.
path
.
join
(
self
.
devkit_dir
,
"data"
,
"meta.mat"
)
if
not
os
.
path
.
exists
(
meta_path
):
if
not
os
.
path
.
exists
(
meta_path
):
raise
FileNotFoundError
(
"meta file %s does not exist"
%
meta_path
)
raise
FileNotFoundError
(
"meta file %s does not exist"
%
meta_path
)
meta
=
sio
.
loadmat
(
meta_path
,
squeeze_me
=
True
)[
"synsets"
]
meta
=
s
cipy
.
io
.
loadmat
(
meta_path
,
squeeze_me
=
True
)[
"synsets"
]
nums_children
=
list
(
zip
(
*
meta
))[
4
]
nums_children
=
list
(
zip
(
*
meta
))[
4
]
meta
=
[
meta
=
[
meta
[
idx
]
meta
[
idx
]
...
@@ -159,8 +161,8 @@ class ImageNet(ImageFolder):
...
@@ -159,8 +161,8 @@ class ImageNet(ImageFolder):
]
]
idcs
,
wnids
,
classes
=
list
(
zip
(
*
meta
))[:
3
]
idcs
,
wnids
,
classes
=
list
(
zip
(
*
meta
))[:
3
]
classes
=
[
tuple
(
clss
.
split
(
", "
))
for
clss
in
classes
]
classes
=
[
tuple
(
clss
.
split
(
", "
))
for
clss
in
classes
]
idx_to_wnid
=
{
idx
:
wnid
for
idx
,
wnid
in
zip
(
idcs
,
wnids
)}
idx_to_wnid
=
dict
(
zip
(
idcs
,
wnids
))
wnid_to_classes
=
{
wnid
:
clss
for
wnid
,
clss
in
zip
(
wnids
,
classes
)}
wnid_to_classes
=
dict
(
zip
(
wnids
,
classes
))
logger
.
info
(
logger
.
info
(
"saving cached meta file to %s"
,
"saving cached meta file to %s"
,
os
.
path
.
join
(
self
.
devkit_dir
,
"meta.pkl"
),
os
.
path
.
join
(
self
.
devkit_dir
,
"meta.pkl"
),
...
@@ -208,11 +210,11 @@ class ImageNet(ImageFolder):
...
@@ -208,11 +210,11 @@ class ImageNet(ImageFolder):
assert
not
self
.
train
assert
not
self
.
train
raw_filename
,
checksum
=
self
.
raw_file_meta
[
"val"
]
raw_filename
,
checksum
=
self
.
raw_file_meta
[
"val"
]
raw_file
=
os
.
path
.
join
(
self
.
root
,
raw_filename
)
raw_file
=
os
.
path
.
join
(
self
.
root
,
raw_filename
)
logger
.
info
(
"checksum valid tar file
{} .."
.
format
(
raw_file
)
)
logger
.
info
(
"checksum valid tar file
%s ..."
,
raw_file
)
assert
(
assert
(
calculate_md5
(
raw_file
)
==
checksum
calculate_md5
(
raw_file
)
==
checksum
),
"checksum mismatch, {} may be damaged"
.
format
(
raw_file
)
),
"checksum mismatch, {} may be damaged"
.
format
(
raw_file
)
logger
.
info
(
"extract valid tar file.. this may take 10-20 minutes"
)
logger
.
info
(
"extract valid tar file..
.
this may take 10-20 minutes"
)
untar
(
os
.
path
.
join
(
self
.
root
,
raw_file
),
self
.
target_folder
)
untar
(
os
.
path
.
join
(
self
.
root
,
raw_file
),
self
.
target_folder
)
self
.
_organize_val_data
()
self
.
_organize_val_data
()
...
@@ -220,7 +222,7 @@ class ImageNet(ImageFolder):
...
@@ -220,7 +222,7 @@ class ImageNet(ImageFolder):
assert
self
.
train
assert
self
.
train
raw_filename
,
checksum
=
self
.
raw_file_meta
[
"train"
]
raw_filename
,
checksum
=
self
.
raw_file_meta
[
"train"
]
raw_file
=
os
.
path
.
join
(
self
.
root
,
raw_filename
)
raw_file
=
os
.
path
.
join
(
self
.
root
,
raw_filename
)
logger
.
info
(
"checksum train tar file
{} .."
.
format
(
raw_file
)
)
logger
.
info
(
"checksum train tar file
%s ..."
,
raw_file
)
assert
(
assert
(
calculate_md5
(
raw_file
)
==
checksum
calculate_md5
(
raw_file
)
==
checksum
),
"checksum mismatch, {} may be damaged"
.
format
(
raw_file
)
),
"checksum mismatch, {} may be damaged"
.
format
(
raw_file
)
...
@@ -238,7 +240,7 @@ class ImageNet(ImageFolder):
...
@@ -238,7 +240,7 @@ class ImageNet(ImageFolder):
def
_prepare_devkit
(
self
):
def
_prepare_devkit
(
self
):
raw_filename
,
checksum
=
self
.
raw_file_meta
[
"devkit"
]
raw_filename
,
checksum
=
self
.
raw_file_meta
[
"devkit"
]
raw_file
=
os
.
path
.
join
(
self
.
root
,
raw_filename
)
raw_file
=
os
.
path
.
join
(
self
.
root
,
raw_filename
)
logger
.
info
(
"checksum devkit tar file
{} .."
.
format
(
raw_file
)
)
logger
.
info
(
"checksum devkit tar file
%s ..."
,
raw_file
)
assert
(
assert
(
calculate_md5
(
raw_file
)
==
checksum
calculate_md5
(
raw_file
)
==
checksum
),
"checksum mismatch, {} may be damaged"
.
format
(
raw_file
)
),
"checksum mismatch, {} may be damaged"
.
format
(
raw_file
)
...
...
python_module/megengine/data/dataset/vision/mnist.py
浏览文件 @
da522568
...
@@ -8,7 +8,6 @@
...
@@ -8,7 +8,6 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
gzip
import
gzip
import
os
import
os
import
pickle
import
struct
import
struct
from
typing
import
Tuple
from
typing
import
Tuple
...
@@ -48,14 +47,6 @@ class MNIST(VisionDataset):
...
@@ -48,14 +47,6 @@ class MNIST(VisionDataset):
"""
"""
md5 for checking raw files
md5 for checking raw files
"""
"""
train_file
=
"train.pkl"
"""
default pickle file name of training set and its meta data
"""
test_file
=
"test.pkl"
"""
default pickle file name of test set and its meta data
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -65,30 +56,11 @@ class MNIST(VisionDataset):
...
@@ -65,30 +56,11 @@ class MNIST(VisionDataset):
timeout
:
int
=
500
,
timeout
:
int
=
500
,
):
):
r
"""
r
"""
initialization:
1. check root path and target file (train or test)
2. check target file exists
* if exists:
* load pickle file as meta-data and data in MNIST dataset
* else:
* if download:
a. load all raw datas (both train and test set) by url
b. process raw data ( idx3/idx1 -> dict (meta-data) ,numpy.array (data) )
c. save meta-data and data as pickle file
d. load pickle file as meta-data and data in MNIST dataset
:param root: path for mnist dataset downloading or loading, if ``None``,
:param root: path for mnist dataset downloading or loading, if ``None``,
set ``root`` to the ``_default_root``
set ``root`` to the ``_default_root``
:param train: if ``True``, loading trainingset, else loading test set
:param train: if ``True``, loading trainingset, else loading test set
:param download: after checking the target files existence, if target files do not
:param download: if raw files do not exists and download sets to ``True``,
exists and download sets to ``True``, download raw files and process,
download raw files and process, otherwise raise ValueError, default is True
then load, otherwise raise ValueError, default is True
"""
"""
super
().
__init__
(
root
,
order
=
(
"image"
,
"image_category"
))
super
().
__init__
(
root
,
order
=
(
"image"
,
"image_category"
))
...
@@ -105,28 +77,14 @@ class MNIST(VisionDataset):
...
@@ -105,28 +77,14 @@ class MNIST(VisionDataset):
if
not
os
.
path
.
exists
(
self
.
root
):
if
not
os
.
path
.
exists
(
self
.
root
):
raise
ValueError
(
"dir %s does not exist"
%
self
.
root
)
raise
ValueError
(
"dir %s does not exist"
%
self
.
root
)
# choose the target pickle file
if
self
.
_check_raw_files
():
if
train
:
self
.
process
(
train
)
self
.
target_file
=
os
.
path
.
join
(
self
.
root
,
self
.
train_file
)
elif
download
:
else
:
self
.
target_file
=
os
.
path
.
join
(
self
.
root
,
self
.
test_file
)
# check existence of target pickle file, if exists load the
# pickle file no matter what download is set
if
os
.
path
.
exists
(
self
.
target_file
):
self
.
_meta_data
,
self
.
arrays
=
self
.
_load_file
(
self
.
target_file
)
elif
self
.
_check_raw_files
():
self
.
process
()
self
.
_meta_data
,
self
.
arrays
=
self
.
_load_file
(
self
.
target_file
)
else
:
if
download
:
self
.
download
()
self
.
download
()
self
.
_meta_data
,
self
.
arrays
=
self
.
_load_file
(
self
.
target_file
)
self
.
process
(
train
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"dir does not contain target file
\
"root does not contain valid raw files, please set download=True"
%s,please set download=True"
%
(
self
.
target_file
)
)
)
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
:
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
:
...
@@ -143,10 +101,6 @@ class MNIST(VisionDataset):
...
@@ -143,10 +101,6 @@ class MNIST(VisionDataset):
def
meta
(
self
):
def
meta
(
self
):
return
self
.
_meta_data
return
self
.
_meta_data
def
_load_file
(
self
,
target_file
):
with
open
(
target_file
,
"rb"
)
as
f
:
return
pickle
.
load
(
f
)
def
_check_raw_files
(
self
):
def
_check_raw_files
(
self
):
return
all
(
return
all
(
[
[
...
@@ -159,45 +113,35 @@ class MNIST(VisionDataset):
...
@@ -159,45 +113,35 @@ class MNIST(VisionDataset):
for
file_name
,
md5
in
zip
(
self
.
raw_file_name
,
self
.
raw_file_md5
):
for
file_name
,
md5
in
zip
(
self
.
raw_file_name
,
self
.
raw_file_md5
):
url
=
self
.
url_path
+
file_name
url
=
self
.
url_path
+
file_name
load_raw_data_from_url
(
url
,
file_name
,
md5
,
self
.
root
,
self
.
timeout
)
load_raw_data_from_url
(
url
,
file_name
,
md5
,
self
.
root
,
self
.
timeout
)
self
.
process
()
def
process
(
self
):
def
process
(
self
,
train
):
# load raw files and transform them into meta data and datasets Tuple(np.array)
# load raw files and transform them into meta data and datasets Tuple(np.array)
logger
.
info
(
"process raw data ..."
)
logger
.
info
(
"process the raw files of %s set..."
,
"train"
if
train
else
"test"
)
meta_data_images_train
,
images_train
=
parse_idx3
(
if
train
:
meta_data_images
,
images
=
parse_idx3
(
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_name
[
0
])
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_name
[
0
])
)
)
meta_data_labels_train
,
labels_train
=
parse_idx1
(
meta_data_labels
,
labels
=
parse_idx1
(
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_name
[
1
])
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_name
[
1
])
)
)
meta_data_images_test
,
images_test
=
parse_idx3
(
else
:
meta_data_images
,
images
=
parse_idx3
(
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_name
[
2
])
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_name
[
2
])
)
)
meta_data_labels_test
,
labels_test
=
parse_idx1
(
meta_data_labels
,
labels
=
parse_idx1
(
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_name
[
3
])
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_name
[
3
])
)
)
meta_data_train
=
{
self
.
_meta_data
=
{
"images"
:
meta_data_images
_train
,
"images"
:
meta_data_images
,
"labels"
:
meta_data_labels
_train
,
"labels"
:
meta_data_labels
,
}
}
meta_data_test
=
{
self
.
arrays
=
(
images
,
labels
.
astype
(
np
.
int32
))
"images"
:
meta_data_images_test
,
"labels"
:
meta_data_labels_test
,
}
dataset_train
=
(
images_train
,
labels_train
)
dataset_test
=
(
images_test
,
labels_test
)
# save both training set and test set as pickle files
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
train_file
),
"wb"
)
as
f
:
pickle
.
dump
((
meta_data_train
,
dataset_train
),
f
,
pickle
.
HIGHEST_PROTOCOL
)
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
test_file
),
"wb"
)
as
f
:
pickle
.
dump
((
meta_data_test
,
dataset_test
),
f
,
pickle
.
HIGHEST_PROTOCOL
)
def
parse_idx3
(
idx3_file
):
def
parse_idx3
(
idx3_file
):
# parse idx3 file to meta data and data in numpy array (images)
# parse idx3 file to meta data and data in numpy array (images)
logger
.
debug
(
"parse idx3 file %s ..."
%
idx3_file
)
logger
.
debug
(
"parse idx3 file %s ..."
,
idx3_file
)
assert
idx3_file
.
endswith
(
".gz"
)
assert
idx3_file
.
endswith
(
".gz"
)
with
gzip
.
open
(
idx3_file
,
"rb"
)
as
f
:
with
gzip
.
open
(
idx3_file
,
"rb"
)
as
f
:
bin_data
=
f
.
read
()
bin_data
=
f
.
read
()
...
@@ -223,7 +167,7 @@ def parse_idx3(idx3_file):
...
@@ -223,7 +167,7 @@ def parse_idx3(idx3_file):
def
parse_idx1
(
idx1_file
):
def
parse_idx1
(
idx1_file
):
# parse idx1 file to meta data and data in numpy array (labels)
# parse idx1 file to meta data and data in numpy array (labels)
logger
.
debug
(
"parse idx1 file %s ..."
%
idx1_file
)
logger
.
debug
(
"parse idx1 file %s ..."
,
idx1_file
)
assert
idx1_file
.
endswith
(
".gz"
)
assert
idx1_file
.
endswith
(
".gz"
)
with
gzip
.
open
(
idx1_file
,
"rb"
)
as
f
:
with
gzip
.
open
(
idx1_file
,
"rb"
)
as
f
:
bin_data
=
f
.
read
()
bin_data
=
f
.
read
()
...
...
python_module/megengine/data/dataset/vision/utils.py
浏览文件 @
da522568
...
@@ -32,7 +32,7 @@ def load_raw_data_from_url(
...
@@ -32,7 +32,7 @@ def load_raw_data_from_url(
):
):
cached_file
=
os
.
path
.
join
(
raw_data_dir
,
filename
)
cached_file
=
os
.
path
.
join
(
raw_data_dir
,
filename
)
logger
.
debug
(
logger
.
debug
(
"load_raw_data_from_url: downloading to or using cached %s ..."
%
cached_file
"load_raw_data_from_url: downloading to or using cached %s ..."
,
cached_file
)
)
if
not
os
.
path
.
exists
(
cached_file
):
if
not
os
.
path
.
exists
(
cached_file
):
if
is_distributed
():
if
is_distributed
():
...
@@ -45,7 +45,7 @@ def load_raw_data_from_url(
...
@@ -45,7 +45,7 @@ def load_raw_data_from_url(
else
:
else
:
md5
=
calculate_md5
(
cached_file
)
md5
=
calculate_md5
(
cached_file
)
if
target_md5
==
md5
:
if
target_md5
==
md5
:
logger
.
debug
(
"%s exists with correct md5: %s"
%
(
filename
,
target_md5
)
)
logger
.
debug
(
"%s exists with correct md5: %s"
,
filename
,
target_md5
)
else
:
else
:
os
.
remove
(
cached_file
)
os
.
remove
(
cached_file
)
raise
RuntimeError
(
"{} exists but fail to match md5"
.
format
(
filename
))
raise
RuntimeError
(
"{} exists but fail to match md5"
.
format
(
filename
))
...
...
python_module/megengine/data/dataset/vision/voc.py
浏览文件 @
da522568
...
@@ -77,13 +77,13 @@ class PascalVOC(VisionDataset):
...
@@ -77,13 +77,13 @@ class PascalVOC(VisionDataset):
if
"aug"
in
self
.
image_set
:
if
"aug"
in
self
.
image_set
:
mask
=
cv2
.
imread
(
self
.
masks
[
index
],
cv2
.
IMREAD_GRAYSCALE
)
mask
=
cv2
.
imread
(
self
.
masks
[
index
],
cv2
.
IMREAD_GRAYSCALE
)
else
:
else
:
mask
=
np
.
array
(
cv2
.
imread
(
self
.
masks
[
index
],
cv2
.
IMREAD_COLOR
)
)
mask
=
cv2
.
imread
(
self
.
masks
[
index
],
cv2
.
IMREAD_COLOR
)
mask
=
self
.
_trans_mask
(
mask
)
mask
=
self
.
_trans_mask
(
mask
)
mask
=
mask
[:,
:,
np
.
newaxis
]
mask
=
mask
[:,
:,
np
.
newaxis
]
target
.
append
(
mask
)
target
.
append
(
mask
)
#
elif k == "boxes":
elif
k
==
"boxes"
:
#
boxes = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot())
boxes
=
self
.
parse_voc_xml
(
ET
.
parse
(
self
.
annotations
[
index
]).
getroot
())
#
target.append(boxes)
target
.
append
(
boxes
)
elif
k
==
"info"
:
elif
k
==
"info"
:
if
image
is
None
:
if
image
is
None
:
image
=
cv2
.
imread
(
self
.
images
[
index
],
cv2
.
IMREAD_COLOR
)
image
=
cv2
.
imread
(
self
.
images
[
index
],
cv2
.
IMREAD_COLOR
)
...
@@ -104,7 +104,7 @@ class PascalVOC(VisionDataset):
...
@@ -104,7 +104,7 @@ class PascalVOC(VisionDataset):
label
[
label
[
(
mask
[:,
:,
0
]
==
b
)
&
(
mask
[:,
:,
1
]
==
g
)
&
(
mask
[:,
:,
2
]
==
r
)
(
mask
[:,
:,
0
]
==
b
)
&
(
mask
[:,
:,
1
]
==
g
)
&
(
mask
[:,
:,
2
]
==
r
)
]
=
i
]
=
i
return
label
.
astype
(
"uint8"
)
return
label
.
astype
(
np
.
uint8
)
def
parse_voc_xml
(
self
,
node
):
def
parse_voc_xml
(
self
,
node
):
voc_dict
=
{}
voc_dict
=
{}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录