提交 da522568 编写于 作者: M Megvii Engine Team

fix(mge/data): process mnist without generate new files

GitOrigin-RevId: 44a697c3fe9197bf6ae8889afc0df72d6095cf1f
上级 538d3de9
...@@ -78,7 +78,7 @@ class CIFAR10(VisionDataset): ...@@ -78,7 +78,7 @@ class CIFAR10(VisionDataset):
else: else:
raise ValueError( raise ValueError(
"dir does not contain target file\ "dir does not contain target file\
%s,please set download=True" %s, please set download=True"
% (self.target_file) % (self.target_file)
) )
...@@ -108,7 +108,7 @@ class CIFAR10(VisionDataset): ...@@ -108,7 +108,7 @@ class CIFAR10(VisionDataset):
def untar(self, file_path, dirs): def untar(self, file_path, dirs):
assert file_path.endswith(".tar.gz") assert file_path.endswith(".tar.gz")
logger.debug("untar file %s to %s" % (file_path, dirs)) logger.debug("untar file %s to %s", file_path, dirs)
t = tarfile.open(file_path) t = tarfile.open(file_path)
t.extractall(path=dirs) t.extractall(path=dirs)
...@@ -117,13 +117,13 @@ class CIFAR10(VisionDataset): ...@@ -117,13 +117,13 @@ class CIFAR10(VisionDataset):
label = [] label = []
for filename in filenames: for filename in filenames:
path = os.path.join(self.root, self.raw_file_dir, filename) path = os.path.join(self.root, self.raw_file_dir, filename)
logger.debug("unpickle file %s" % path) logger.debug("unpickle file %s", path)
with open(path, "rb") as fo: with open(path, "rb") as fo:
dic = pickle.load(fo, encoding="bytes") dic = pickle.load(fo, encoding="bytes")
batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
data.extend(list(batch_data[..., [2, 1, 0]])) data.extend(list(batch_data[..., [2, 1, 0]]))
label.extend(dic[b"labels"]) label.extend(dic[b"labels"])
label = np.array(label) label = np.array(label, dtype=np.int32)
return (data, label) return (data, label)
def process(self): def process(self):
...@@ -153,7 +153,7 @@ class CIFAR100(CIFAR10): ...@@ -153,7 +153,7 @@ class CIFAR100(CIFAR10):
coarse_label = [] coarse_label = []
for filename in filenames: for filename in filenames:
path = os.path.join(self.root, self.raw_file_dir, filename) path = os.path.join(self.root, self.raw_file_dir, filename)
logger.debug("unpickle file %s" % path) logger.debug("unpickle file %s", path)
with open(path, "rb") as fo: with open(path, "rb") as fo:
dic = pickle.load(fo, encoding="bytes") dic = pickle.load(fo, encoding="bytes")
batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
......
...@@ -71,7 +71,7 @@ class Cityscapes(VisionDataset): ...@@ -71,7 +71,7 @@ class Cityscapes(VisionDataset):
elif k == "mask": elif k == "mask":
mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE) mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE)
mask = self._trans_mask(mask) mask = self._trans_mask(mask)
mask = mask[:, :, None] mask = mask[:, :, np.newaxis]
target.append(mask) target.append(mask)
elif k == "info": elif k == "info":
if image is None: if image is None:
...@@ -109,9 +109,9 @@ class Cityscapes(VisionDataset): ...@@ -109,9 +109,9 @@ class Cityscapes(VisionDataset):
33, 33,
] ]
label = np.ones(mask.shape) * 255 label = np.ones(mask.shape) * 255
for i in range(len(trans_labels)): for i, tl in enumerate(trans_labels):
label[mask == trans_labels[i]] = i label[mask == tl] = i
return label.astype("uint8") return label.astype(np.uint8)
def _get_target_suffix(self, mode, target_type): def _get_target_suffix(self, mode, target_type):
if target_type == "instance": if target_type == "instance":
......
...@@ -139,7 +139,7 @@ class COCO(VisionDataset): ...@@ -139,7 +139,7 @@ class COCO(VisionDataset):
target.append(image) target.append(image)
elif k == "boxes": elif k == "boxes":
boxes = [obj["bbox"] for obj in anno] boxes = [obj["bbox"] for obj in anno]
boxes = np.array(boxes).reshape(-1, 4) boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
# transfer boxes from xywh to xyxy # transfer boxes from xywh to xyxy
boxes[:, 2:] += boxes[:, :2] boxes[:, 2:] += boxes[:, :2]
target.append(boxes) target.append(boxes)
...@@ -148,17 +148,21 @@ class COCO(VisionDataset): ...@@ -148,17 +148,21 @@ class COCO(VisionDataset):
boxes_category = [ boxes_category = [
self.json_category_id_to_contiguous_id[c] for c in boxes_category self.json_category_id_to_contiguous_id[c] for c in boxes_category
] ]
boxes_category = np.array(boxes_category) boxes_category = np.array(boxes_category, dtype=np.int32)
target.append(boxes_category) target.append(boxes_category)
# TODO: need to check elif k == "keypoints":
# elif k == "keypoints": keypoints = [obj["keypoints"] for obj in anno]
# keypoints = [obj["keypoints"] for obj in anno] keypoints = np.array(keypoints, dtype=np.float32).reshape(
# keypoints = np.array(keypoints).reshape(-1, len(self.keypoint_names), 3) -1, len(self.keypoint_names), 3
# target.append(keypoints) )
# elif k == "polygons": target.append(keypoints)
# polygons = [obj["segmentation"] for obj in anno] elif k == "polygons":
# polygons = [[np.array(p).reshape(-1, 2) for p in ps] for ps in polygons] polygons = [obj["segmentation"] for obj in anno]
# target.append(polygons) polygons = [
[np.array(p, dtype=np.float32).reshape(-1, 2) for p in ps]
for ps in polygons
]
target.append(polygons)
elif k == "info": elif k == "info":
info = self.imgs[img_id] info = self.imgs[img_id]
info = [info["height"], info["width"], info["file_name"]] info = [info["height"], info["width"], info["file_name"]]
......
...@@ -19,6 +19,7 @@ import os ...@@ -19,6 +19,7 @@ import os
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import cv2 import cv2
import numpy as np
from .meta_vision import VisionDataset from .meta_vision import VisionDataset
from .utils import is_img from .utils import is_img
...@@ -78,7 +79,7 @@ class ImageFolder(VisionDataset): ...@@ -78,7 +79,7 @@ class ImageFolder(VisionDataset):
def collect_class(self) -> Dict: def collect_class(self) -> Dict:
classes = [d.name for d in os.scandir(self.root) if d.is_dir()] classes = [d.name for d in os.scandir(self.root) if d.is_dir()]
classes.sort() classes.sort()
return {classes[i]: i for i in range(len(classes))} return {classes[i]: np.int32(i) for i in range(len(classes))}
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> Tuple:
path, label = self.samples[index] path, label = self.samples[index]
......
...@@ -93,7 +93,7 @@ class ImageNet(ImageFolder): ...@@ -93,7 +93,7 @@ class ImageNet(ImageFolder):
self.devkit_dir = os.path.join(self.root, self.default_devkit_dir) self.devkit_dir = os.path.join(self.root, self.default_devkit_dir)
if not os.path.exists(self.devkit_dir): if not os.path.exists(self.devkit_dir):
logger.warning("devkit directory %s does not exists" % self.devkit_dir) logger.warning("devkit directory %s does not exists", self.devkit_dir)
self._prepare_devkit() self._prepare_devkit()
self.train = train self.train = train
...@@ -105,8 +105,8 @@ class ImageNet(ImageFolder): ...@@ -105,8 +105,8 @@ class ImageNet(ImageFolder):
if not os.path.exists(self.target_folder): if not os.path.exists(self.target_folder):
logger.warning( logger.warning(
"expected image folder %s does not exist, try to load from raw file" "expected image folder %s does not exist, try to load from raw file",
% self.target_folder self.target_folder,
) )
if not self.check_raw_file(): if not self.check_raw_file():
raise FileNotFoundError( raise FileNotFoundError(
...@@ -117,8 +117,10 @@ class ImageNet(ImageFolder): ...@@ -117,8 +117,10 @@ class ImageNet(ImageFolder):
raise RuntimeError( raise RuntimeError(
"extracting raw file shouldn't be done in distributed mode, use single process instead" "extracting raw file shouldn't be done in distributed mode, use single process instead"
) )
elif train:
self._prepare_train()
else: else:
self._prepare_train() if train else self._prepare_val() self._prepare_val()
super().__init__(self.target_folder, **kwargs) super().__init__(self.target_folder, **kwargs)
...@@ -145,12 +147,12 @@ class ImageNet(ImageFolder): ...@@ -145,12 +147,12 @@ class ImageNet(ImageFolder):
try: try:
return load(os.path.join(self.devkit_dir, "meta.pkl")) return load(os.path.join(self.devkit_dir, "meta.pkl"))
except FileNotFoundError: except FileNotFoundError:
import scipy.io as sio import scipy.io
meta_path = os.path.join(self.devkit_dir, "data", "meta.mat") meta_path = os.path.join(self.devkit_dir, "data", "meta.mat")
if not os.path.exists(meta_path): if not os.path.exists(meta_path):
raise FileNotFoundError("meta file %s does not exist" % meta_path) raise FileNotFoundError("meta file %s does not exist" % meta_path)
meta = sio.loadmat(meta_path, squeeze_me=True)["synsets"] meta = scipy.io.loadmat(meta_path, squeeze_me=True)["synsets"]
nums_children = list(zip(*meta))[4] nums_children = list(zip(*meta))[4]
meta = [ meta = [
meta[idx] meta[idx]
...@@ -159,8 +161,8 @@ class ImageNet(ImageFolder): ...@@ -159,8 +161,8 @@ class ImageNet(ImageFolder):
] ]
idcs, wnids, classes = list(zip(*meta))[:3] idcs, wnids, classes = list(zip(*meta))[:3]
classes = [tuple(clss.split(", ")) for clss in classes] classes = [tuple(clss.split(", ")) for clss in classes]
idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} idx_to_wnid = dict(zip(idcs, wnids))
wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} wnid_to_classes = dict(zip(wnids, classes))
logger.info( logger.info(
"saving cached meta file to %s", "saving cached meta file to %s",
os.path.join(self.devkit_dir, "meta.pkl"), os.path.join(self.devkit_dir, "meta.pkl"),
...@@ -208,11 +210,11 @@ class ImageNet(ImageFolder): ...@@ -208,11 +210,11 @@ class ImageNet(ImageFolder):
assert not self.train assert not self.train
raw_filename, checksum = self.raw_file_meta["val"] raw_filename, checksum = self.raw_file_meta["val"]
raw_file = os.path.join(self.root, raw_filename) raw_file = os.path.join(self.root, raw_filename)
logger.info("checksum valid tar file {} ..".format(raw_file)) logger.info("checksum valid tar file %s ...", raw_file)
assert ( assert (
calculate_md5(raw_file) == checksum calculate_md5(raw_file) == checksum
), "checksum mismatch, {} may be damaged".format(raw_file) ), "checksum mismatch, {} may be damaged".format(raw_file)
logger.info("extract valid tar file.. this may take 10-20 minutes") logger.info("extract valid tar file... this may take 10-20 minutes")
untar(os.path.join(self.root, raw_file), self.target_folder) untar(os.path.join(self.root, raw_file), self.target_folder)
self._organize_val_data() self._organize_val_data()
...@@ -220,7 +222,7 @@ class ImageNet(ImageFolder): ...@@ -220,7 +222,7 @@ class ImageNet(ImageFolder):
assert self.train assert self.train
raw_filename, checksum = self.raw_file_meta["train"] raw_filename, checksum = self.raw_file_meta["train"]
raw_file = os.path.join(self.root, raw_filename) raw_file = os.path.join(self.root, raw_filename)
logger.info("checksum train tar file {} ..".format(raw_file)) logger.info("checksum train tar file %s ...", raw_file)
assert ( assert (
calculate_md5(raw_file) == checksum calculate_md5(raw_file) == checksum
), "checksum mismatch, {} may be damaged".format(raw_file) ), "checksum mismatch, {} may be damaged".format(raw_file)
...@@ -238,7 +240,7 @@ class ImageNet(ImageFolder): ...@@ -238,7 +240,7 @@ class ImageNet(ImageFolder):
def _prepare_devkit(self): def _prepare_devkit(self):
raw_filename, checksum = self.raw_file_meta["devkit"] raw_filename, checksum = self.raw_file_meta["devkit"]
raw_file = os.path.join(self.root, raw_filename) raw_file = os.path.join(self.root, raw_filename)
logger.info("checksum devkit tar file {} ..".format(raw_file)) logger.info("checksum devkit tar file %s ...", raw_file)
assert ( assert (
calculate_md5(raw_file) == checksum calculate_md5(raw_file) == checksum
), "checksum mismatch, {} may be damaged".format(raw_file) ), "checksum mismatch, {} may be damaged".format(raw_file)
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import gzip import gzip
import os import os
import pickle
import struct import struct
from typing import Tuple from typing import Tuple
...@@ -48,14 +47,6 @@ class MNIST(VisionDataset): ...@@ -48,14 +47,6 @@ class MNIST(VisionDataset):
""" """
md5 for checking raw files md5 for checking raw files
""" """
train_file = "train.pkl"
"""
default pickle file name of training set and its meta data
"""
test_file = "test.pkl"
"""
default pickle file name of test set and its meta data
"""
def __init__( def __init__(
self, self,
...@@ -65,30 +56,11 @@ class MNIST(VisionDataset): ...@@ -65,30 +56,11 @@ class MNIST(VisionDataset):
timeout: int = 500, timeout: int = 500,
): ):
r""" r"""
initialization:
1. check root path and target file (train or test)
2. check target file exists
* if exists:
* load pickle file as meta-data and data in MNIST dataset
* else:
* if download:
a. load all raw datas (both train and test set) by url
b. process raw data ( idx3/idx1 -> dict (meta-data) ,numpy.array (data) )
c. save meta-data and data as pickle file
d. load pickle file as meta-data and data in MNIST dataset
:param root: path for mnist dataset downloading or loading, if ``None``, :param root: path for mnist dataset downloading or loading, if ``None``,
set ``root`` to the ``_default_root`` set ``root`` to the ``_default_root``
:param train: if ``True``, loading trainingset, else loading test set :param train: if ``True``, loading trainingset, else loading test set
:param download: after checking the target files existence, if target files do not :param download: if raw files do not exists and download sets to ``True``,
exists and download sets to ``True``, download raw files and process, download raw files and process, otherwise raise ValueError, default is True
then load, otherwise raise ValueError, default is True
""" """
super().__init__(root, order=("image", "image_category")) super().__init__(root, order=("image", "image_category"))
...@@ -105,28 +77,14 @@ class MNIST(VisionDataset): ...@@ -105,28 +77,14 @@ class MNIST(VisionDataset):
if not os.path.exists(self.root): if not os.path.exists(self.root):
raise ValueError("dir %s does not exist" % self.root) raise ValueError("dir %s does not exist" % self.root)
# choose the target pickle file if self._check_raw_files():
if train: self.process(train)
self.target_file = os.path.join(self.root, self.train_file) elif download:
else:
self.target_file = os.path.join(self.root, self.test_file)
# check existence of target pickle file, if exists load the
# pickle file no matter what download is set
if os.path.exists(self.target_file):
self._meta_data, self.arrays = self._load_file(self.target_file)
elif self._check_raw_files():
self.process()
self._meta_data, self.arrays = self._load_file(self.target_file)
else:
if download:
self.download() self.download()
self._meta_data, self.arrays = self._load_file(self.target_file) self.process(train)
else: else:
raise ValueError( raise ValueError(
"dir does not contain target file\ "root does not contain valid raw files, please set download=True"
%s,please set download=True"
% (self.target_file)
) )
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> Tuple:
...@@ -143,10 +101,6 @@ class MNIST(VisionDataset): ...@@ -143,10 +101,6 @@ class MNIST(VisionDataset):
def meta(self): def meta(self):
return self._meta_data return self._meta_data
def _load_file(self, target_file):
with open(target_file, "rb") as f:
return pickle.load(f)
def _check_raw_files(self): def _check_raw_files(self):
return all( return all(
[ [
...@@ -159,45 +113,35 @@ class MNIST(VisionDataset): ...@@ -159,45 +113,35 @@ class MNIST(VisionDataset):
for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5): for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5):
url = self.url_path + file_name url = self.url_path + file_name
load_raw_data_from_url(url, file_name, md5, self.root, self.timeout) load_raw_data_from_url(url, file_name, md5, self.root, self.timeout)
self.process()
def process(self): def process(self, train):
# load raw files and transform them into meta data and datasets Tuple(np.array) # load raw files and transform them into meta data and datasets Tuple(np.array)
logger.info("process raw data ...") logger.info("process the raw files of %s set...", "train" if train else "test")
meta_data_images_train, images_train = parse_idx3( if train:
meta_data_images, images = parse_idx3(
os.path.join(self.root, self.raw_file_name[0]) os.path.join(self.root, self.raw_file_name[0])
) )
meta_data_labels_train, labels_train = parse_idx1( meta_data_labels, labels = parse_idx1(
os.path.join(self.root, self.raw_file_name[1]) os.path.join(self.root, self.raw_file_name[1])
) )
meta_data_images_test, images_test = parse_idx3( else:
meta_data_images, images = parse_idx3(
os.path.join(self.root, self.raw_file_name[2]) os.path.join(self.root, self.raw_file_name[2])
) )
meta_data_labels_test, labels_test = parse_idx1( meta_data_labels, labels = parse_idx1(
os.path.join(self.root, self.raw_file_name[3]) os.path.join(self.root, self.raw_file_name[3])
) )
meta_data_train = { self._meta_data = {
"images": meta_data_images_train, "images": meta_data_images,
"labels": meta_data_labels_train, "labels": meta_data_labels,
} }
meta_data_test = { self.arrays = (images, labels.astype(np.int32))
"images": meta_data_images_test,
"labels": meta_data_labels_test,
}
dataset_train = (images_train, labels_train)
dataset_test = (images_test, labels_test)
# save both training set and test set as pickle files
with open(os.path.join(self.root, self.train_file), "wb") as f:
pickle.dump((meta_data_train, dataset_train), f, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(self.root, self.test_file), "wb") as f:
pickle.dump((meta_data_test, dataset_test), f, pickle.HIGHEST_PROTOCOL)
def parse_idx3(idx3_file): def parse_idx3(idx3_file):
# parse idx3 file to meta data and data in numpy array (images) # parse idx3 file to meta data and data in numpy array (images)
logger.debug("parse idx3 file %s ..." % idx3_file) logger.debug("parse idx3 file %s ...", idx3_file)
assert idx3_file.endswith(".gz") assert idx3_file.endswith(".gz")
with gzip.open(idx3_file, "rb") as f: with gzip.open(idx3_file, "rb") as f:
bin_data = f.read() bin_data = f.read()
...@@ -223,7 +167,7 @@ def parse_idx3(idx3_file): ...@@ -223,7 +167,7 @@ def parse_idx3(idx3_file):
def parse_idx1(idx1_file): def parse_idx1(idx1_file):
# parse idx1 file to meta data and data in numpy array (labels) # parse idx1 file to meta data and data in numpy array (labels)
logger.debug("parse idx1 file %s ..." % idx1_file) logger.debug("parse idx1 file %s ...", idx1_file)
assert idx1_file.endswith(".gz") assert idx1_file.endswith(".gz")
with gzip.open(idx1_file, "rb") as f: with gzip.open(idx1_file, "rb") as f:
bin_data = f.read() bin_data = f.read()
......
...@@ -32,7 +32,7 @@ def load_raw_data_from_url( ...@@ -32,7 +32,7 @@ def load_raw_data_from_url(
): ):
cached_file = os.path.join(raw_data_dir, filename) cached_file = os.path.join(raw_data_dir, filename)
logger.debug( logger.debug(
"load_raw_data_from_url: downloading to or using cached %s ..." % cached_file "load_raw_data_from_url: downloading to or using cached %s ...", cached_file
) )
if not os.path.exists(cached_file): if not os.path.exists(cached_file):
if is_distributed(): if is_distributed():
...@@ -45,7 +45,7 @@ def load_raw_data_from_url( ...@@ -45,7 +45,7 @@ def load_raw_data_from_url(
else: else:
md5 = calculate_md5(cached_file) md5 = calculate_md5(cached_file)
if target_md5 == md5: if target_md5 == md5:
logger.debug("%s exists with correct md5: %s" % (filename, target_md5)) logger.debug("%s exists with correct md5: %s", filename, target_md5)
else: else:
os.remove(cached_file) os.remove(cached_file)
raise RuntimeError("{} exists but fail to match md5".format(filename)) raise RuntimeError("{} exists but fail to match md5".format(filename))
......
...@@ -77,13 +77,13 @@ class PascalVOC(VisionDataset): ...@@ -77,13 +77,13 @@ class PascalVOC(VisionDataset):
if "aug" in self.image_set: if "aug" in self.image_set:
mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE) mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE)
else: else:
mask = np.array(cv2.imread(self.masks[index], cv2.IMREAD_COLOR)) mask = cv2.imread(self.masks[index], cv2.IMREAD_COLOR)
mask = self._trans_mask(mask) mask = self._trans_mask(mask)
mask = mask[:, :, np.newaxis] mask = mask[:, :, np.newaxis]
target.append(mask) target.append(mask)
# elif k == "boxes": elif k == "boxes":
# boxes = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot()) boxes = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot())
# target.append(boxes) target.append(boxes)
elif k == "info": elif k == "info":
if image is None: if image is None:
image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
...@@ -104,7 +104,7 @@ class PascalVOC(VisionDataset): ...@@ -104,7 +104,7 @@ class PascalVOC(VisionDataset):
label[ label[
(mask[:, :, 0] == b) & (mask[:, :, 1] == g) & (mask[:, :, 2] == r) (mask[:, :, 0] == b) & (mask[:, :, 1] == g) & (mask[:, :, 2] == r)
] = i ] = i
return label.astype("uint8") return label.astype(np.uint8)
def parse_voc_xml(self, node): def parse_voc_xml(self, node):
voc_dict = {} voc_dict = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册