提交 da522568 编写于 作者: M Megvii Engine Team

fix(mge/data): process mnist without generate new files

GitOrigin-RevId: 44a697c3fe9197bf6ae8889afc0df72d6095cf1f
上级 538d3de9
......@@ -78,7 +78,7 @@ class CIFAR10(VisionDataset):
else:
raise ValueError(
"dir does not contain target file\
%s,please set download=True"
%s, please set download=True"
% (self.target_file)
)
......@@ -108,7 +108,7 @@ class CIFAR10(VisionDataset):
def untar(self, file_path, dirs):
assert file_path.endswith(".tar.gz")
logger.debug("untar file %s to %s" % (file_path, dirs))
logger.debug("untar file %s to %s", file_path, dirs)
t = tarfile.open(file_path)
t.extractall(path=dirs)
......@@ -117,13 +117,13 @@ class CIFAR10(VisionDataset):
label = []
for filename in filenames:
path = os.path.join(self.root, self.raw_file_dir, filename)
logger.debug("unpickle file %s" % path)
logger.debug("unpickle file %s", path)
with open(path, "rb") as fo:
dic = pickle.load(fo, encoding="bytes")
batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
data.extend(list(batch_data[..., [2, 1, 0]]))
label.extend(dic[b"labels"])
label = np.array(label)
label = np.array(label, dtype=np.int32)
return (data, label)
def process(self):
......@@ -153,7 +153,7 @@ class CIFAR100(CIFAR10):
coarse_label = []
for filename in filenames:
path = os.path.join(self.root, self.raw_file_dir, filename)
logger.debug("unpickle file %s" % path)
logger.debug("unpickle file %s", path)
with open(path, "rb") as fo:
dic = pickle.load(fo, encoding="bytes")
batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
......
......@@ -71,7 +71,7 @@ class Cityscapes(VisionDataset):
elif k == "mask":
mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE)
mask = self._trans_mask(mask)
mask = mask[:, :, None]
mask = mask[:, :, np.newaxis]
target.append(mask)
elif k == "info":
if image is None:
......@@ -109,9 +109,9 @@ class Cityscapes(VisionDataset):
33,
]
label = np.ones(mask.shape) * 255
for i in range(len(trans_labels)):
label[mask == trans_labels[i]] = i
return label.astype("uint8")
for i, tl in enumerate(trans_labels):
label[mask == tl] = i
return label.astype(np.uint8)
def _get_target_suffix(self, mode, target_type):
if target_type == "instance":
......
......@@ -139,7 +139,7 @@ class COCO(VisionDataset):
target.append(image)
elif k == "boxes":
boxes = [obj["bbox"] for obj in anno]
boxes = np.array(boxes).reshape(-1, 4)
boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
# transfer boxes from xywh to xyxy
boxes[:, 2:] += boxes[:, :2]
target.append(boxes)
......@@ -148,17 +148,21 @@ class COCO(VisionDataset):
boxes_category = [
self.json_category_id_to_contiguous_id[c] for c in boxes_category
]
boxes_category = np.array(boxes_category)
boxes_category = np.array(boxes_category, dtype=np.int32)
target.append(boxes_category)
# TODO: need to check
# elif k == "keypoints":
# keypoints = [obj["keypoints"] for obj in anno]
# keypoints = np.array(keypoints).reshape(-1, len(self.keypoint_names), 3)
# target.append(keypoints)
# elif k == "polygons":
# polygons = [obj["segmentation"] for obj in anno]
# polygons = [[np.array(p).reshape(-1, 2) for p in ps] for ps in polygons]
# target.append(polygons)
elif k == "keypoints":
keypoints = [obj["keypoints"] for obj in anno]
keypoints = np.array(keypoints, dtype=np.float32).reshape(
-1, len(self.keypoint_names), 3
)
target.append(keypoints)
elif k == "polygons":
polygons = [obj["segmentation"] for obj in anno]
polygons = [
[np.array(p, dtype=np.float32).reshape(-1, 2) for p in ps]
for ps in polygons
]
target.append(polygons)
elif k == "info":
info = self.imgs[img_id]
info = [info["height"], info["width"], info["file_name"]]
......
......@@ -19,6 +19,7 @@ import os
from typing import Dict, List, Tuple
import cv2
import numpy as np
from .meta_vision import VisionDataset
from .utils import is_img
......@@ -78,7 +79,7 @@ class ImageFolder(VisionDataset):
def collect_class(self) -> Dict:
classes = [d.name for d in os.scandir(self.root) if d.is_dir()]
classes.sort()
return {classes[i]: i for i in range(len(classes))}
return {classes[i]: np.int32(i) for i in range(len(classes))}
def __getitem__(self, index: int) -> Tuple:
path, label = self.samples[index]
......
......@@ -93,7 +93,7 @@ class ImageNet(ImageFolder):
self.devkit_dir = os.path.join(self.root, self.default_devkit_dir)
if not os.path.exists(self.devkit_dir):
logger.warning("devkit directory %s does not exists" % self.devkit_dir)
logger.warning("devkit directory %s does not exists", self.devkit_dir)
self._prepare_devkit()
self.train = train
......@@ -105,8 +105,8 @@ class ImageNet(ImageFolder):
if not os.path.exists(self.target_folder):
logger.warning(
"expected image folder %s does not exist, try to load from raw file"
% self.target_folder
"expected image folder %s does not exist, try to load from raw file",
self.target_folder,
)
if not self.check_raw_file():
raise FileNotFoundError(
......@@ -117,8 +117,10 @@ class ImageNet(ImageFolder):
raise RuntimeError(
"extracting raw file shouldn't be done in distributed mode, use single process instead"
)
elif train:
self._prepare_train()
else:
self._prepare_train() if train else self._prepare_val()
self._prepare_val()
super().__init__(self.target_folder, **kwargs)
......@@ -145,12 +147,12 @@ class ImageNet(ImageFolder):
try:
return load(os.path.join(self.devkit_dir, "meta.pkl"))
except FileNotFoundError:
import scipy.io as sio
import scipy.io
meta_path = os.path.join(self.devkit_dir, "data", "meta.mat")
if not os.path.exists(meta_path):
raise FileNotFoundError("meta file %s does not exist" % meta_path)
meta = sio.loadmat(meta_path, squeeze_me=True)["synsets"]
meta = scipy.io.loadmat(meta_path, squeeze_me=True)["synsets"]
nums_children = list(zip(*meta))[4]
meta = [
meta[idx]
......@@ -159,8 +161,8 @@ class ImageNet(ImageFolder):
]
idcs, wnids, classes = list(zip(*meta))[:3]
classes = [tuple(clss.split(", ")) for clss in classes]
idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
idx_to_wnid = dict(zip(idcs, wnids))
wnid_to_classes = dict(zip(wnids, classes))
logger.info(
"saving cached meta file to %s",
os.path.join(self.devkit_dir, "meta.pkl"),
......@@ -208,11 +210,11 @@ class ImageNet(ImageFolder):
assert not self.train
raw_filename, checksum = self.raw_file_meta["val"]
raw_file = os.path.join(self.root, raw_filename)
logger.info("checksum valid tar file {} ..".format(raw_file))
logger.info("checksum valid tar file %s ...", raw_file)
assert (
calculate_md5(raw_file) == checksum
), "checksum mismatch, {} may be damaged".format(raw_file)
logger.info("extract valid tar file.. this may take 10-20 minutes")
logger.info("extract valid tar file... this may take 10-20 minutes")
untar(os.path.join(self.root, raw_file), self.target_folder)
self._organize_val_data()
......@@ -220,7 +222,7 @@ class ImageNet(ImageFolder):
assert self.train
raw_filename, checksum = self.raw_file_meta["train"]
raw_file = os.path.join(self.root, raw_filename)
logger.info("checksum train tar file {} ..".format(raw_file))
logger.info("checksum train tar file %s ...", raw_file)
assert (
calculate_md5(raw_file) == checksum
), "checksum mismatch, {} may be damaged".format(raw_file)
......@@ -238,7 +240,7 @@ class ImageNet(ImageFolder):
def _prepare_devkit(self):
raw_filename, checksum = self.raw_file_meta["devkit"]
raw_file = os.path.join(self.root, raw_filename)
logger.info("checksum devkit tar file {} ..".format(raw_file))
logger.info("checksum devkit tar file %s ...", raw_file)
assert (
calculate_md5(raw_file) == checksum
), "checksum mismatch, {} may be damaged".format(raw_file)
......
......@@ -8,7 +8,6 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import gzip
import os
import pickle
import struct
from typing import Tuple
......@@ -48,14 +47,6 @@ class MNIST(VisionDataset):
"""
md5 for checking raw files
"""
train_file = "train.pkl"
"""
default pickle file name of training set and its meta data
"""
test_file = "test.pkl"
"""
default pickle file name of test set and its meta data
"""
def __init__(
self,
......@@ -65,30 +56,11 @@ class MNIST(VisionDataset):
timeout: int = 500,
):
r"""
initialization:
1. check root path and target file (train or test)
2. check target file exists
* if exists:
* load pickle file as meta-data and data in MNIST dataset
* else:
* if download:
a. load all raw datas (both train and test set) by url
b. process raw data ( idx3/idx1 -> dict (meta-data) ,numpy.array (data) )
c. save meta-data and data as pickle file
d. load pickle file as meta-data and data in MNIST dataset
:param root: path for mnist dataset downloading or loading, if ``None``,
set ``root`` to the ``_default_root``
:param train: if ``True``, loading trainingset, else loading test set
:param download: after checking the target files existence, if target files do not
exists and download sets to ``True``, download raw files and process,
then load, otherwise raise ValueError, default is True
:param download: if raw files do not exists and download sets to ``True``,
download raw files and process, otherwise raise ValueError, default is True
"""
super().__init__(root, order=("image", "image_category"))
......@@ -105,29 +77,15 @@ class MNIST(VisionDataset):
if not os.path.exists(self.root):
raise ValueError("dir %s does not exist" % self.root)
# choose the target pickle file
if train:
self.target_file = os.path.join(self.root, self.train_file)
if self._check_raw_files():
self.process(train)
elif download:
self.download()
self.process(train)
else:
self.target_file = os.path.join(self.root, self.test_file)
# check existence of target pickle file, if exists load the
# pickle file no matter what download is set
if os.path.exists(self.target_file):
self._meta_data, self.arrays = self._load_file(self.target_file)
elif self._check_raw_files():
self.process()
self._meta_data, self.arrays = self._load_file(self.target_file)
else:
if download:
self.download()
self._meta_data, self.arrays = self._load_file(self.target_file)
else:
raise ValueError(
"dir does not contain target file\
%s,please set download=True"
% (self.target_file)
)
raise ValueError(
"root does not contain valid raw files, please set download=True"
)
def __getitem__(self, index: int) -> Tuple:
return tuple(array[index] for array in self.arrays)
......@@ -143,10 +101,6 @@ class MNIST(VisionDataset):
def meta(self):
return self._meta_data
def _load_file(self, target_file):
with open(target_file, "rb") as f:
return pickle.load(f)
def _check_raw_files(self):
return all(
[
......@@ -159,45 +113,35 @@ class MNIST(VisionDataset):
for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5):
url = self.url_path + file_name
load_raw_data_from_url(url, file_name, md5, self.root, self.timeout)
self.process()
def process(self):
def process(self, train):
# load raw files and transform them into meta data and datasets Tuple(np.array)
logger.info("process raw data ...")
meta_data_images_train, images_train = parse_idx3(
os.path.join(self.root, self.raw_file_name[0])
)
meta_data_labels_train, labels_train = parse_idx1(
os.path.join(self.root, self.raw_file_name[1])
)
meta_data_images_test, images_test = parse_idx3(
os.path.join(self.root, self.raw_file_name[2])
)
meta_data_labels_test, labels_test = parse_idx1(
os.path.join(self.root, self.raw_file_name[3])
)
meta_data_train = {
"images": meta_data_images_train,
"labels": meta_data_labels_train,
}
meta_data_test = {
"images": meta_data_images_test,
"labels": meta_data_labels_test,
logger.info("process the raw files of %s set...", "train" if train else "test")
if train:
meta_data_images, images = parse_idx3(
os.path.join(self.root, self.raw_file_name[0])
)
meta_data_labels, labels = parse_idx1(
os.path.join(self.root, self.raw_file_name[1])
)
else:
meta_data_images, images = parse_idx3(
os.path.join(self.root, self.raw_file_name[2])
)
meta_data_labels, labels = parse_idx1(
os.path.join(self.root, self.raw_file_name[3])
)
self._meta_data = {
"images": meta_data_images,
"labels": meta_data_labels,
}
dataset_train = (images_train, labels_train)
dataset_test = (images_test, labels_test)
# save both training set and test set as pickle files
with open(os.path.join(self.root, self.train_file), "wb") as f:
pickle.dump((meta_data_train, dataset_train), f, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(self.root, self.test_file), "wb") as f:
pickle.dump((meta_data_test, dataset_test), f, pickle.HIGHEST_PROTOCOL)
self.arrays = (images, labels.astype(np.int32))
def parse_idx3(idx3_file):
# parse idx3 file to meta data and data in numpy array (images)
logger.debug("parse idx3 file %s ..." % idx3_file)
logger.debug("parse idx3 file %s ...", idx3_file)
assert idx3_file.endswith(".gz")
with gzip.open(idx3_file, "rb") as f:
bin_data = f.read()
......@@ -223,7 +167,7 @@ def parse_idx3(idx3_file):
def parse_idx1(idx1_file):
# parse idx1 file to meta data and data in numpy array (labels)
logger.debug("parse idx1 file %s ..." % idx1_file)
logger.debug("parse idx1 file %s ...", idx1_file)
assert idx1_file.endswith(".gz")
with gzip.open(idx1_file, "rb") as f:
bin_data = f.read()
......
......@@ -32,7 +32,7 @@ def load_raw_data_from_url(
):
cached_file = os.path.join(raw_data_dir, filename)
logger.debug(
"load_raw_data_from_url: downloading to or using cached %s ..." % cached_file
"load_raw_data_from_url: downloading to or using cached %s ...", cached_file
)
if not os.path.exists(cached_file):
if is_distributed():
......@@ -45,7 +45,7 @@ def load_raw_data_from_url(
else:
md5 = calculate_md5(cached_file)
if target_md5 == md5:
logger.debug("%s exists with correct md5: %s" % (filename, target_md5))
logger.debug("%s exists with correct md5: %s", filename, target_md5)
else:
os.remove(cached_file)
raise RuntimeError("{} exists but fail to match md5".format(filename))
......
......@@ -77,13 +77,13 @@ class PascalVOC(VisionDataset):
if "aug" in self.image_set:
mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE)
else:
mask = np.array(cv2.imread(self.masks[index], cv2.IMREAD_COLOR))
mask = cv2.imread(self.masks[index], cv2.IMREAD_COLOR)
mask = self._trans_mask(mask)
mask = mask[:, :, np.newaxis]
target.append(mask)
# elif k == "boxes":
# boxes = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot())
# target.append(boxes)
elif k == "boxes":
boxes = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot())
target.append(boxes)
elif k == "info":
if image is None:
image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
......@@ -104,7 +104,7 @@ class PascalVOC(VisionDataset):
label[
(mask[:, :, 0] == b) & (mask[:, :, 1] == g) & (mask[:, :, 2] == r)
] = i
return label.astype("uint8")
return label.astype(np.uint8)
def parse_voc_xml(self, node):
voc_dict = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册