未验证 提交 23d24764 编写于 作者: G Guanghua Yu 提交者: GitHub

adapt infer reader (#1727)

上级 7a65af0c
......@@ -39,6 +39,8 @@ TestReader:
- NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- ResizeImage: {interp: 1, max_size: 1333, target_size: 800, use_cv2: true}
- Permute: {channel_first: true, to_bgr: false}
batch_transforms:
- PadBatch: {pad_to_stride: 32, use_padded_im_info: false, pad_gt: false}
batch_size: 1
shuffle: false
drop_last: false
......@@ -68,6 +68,23 @@ class DetDataset(Dataset):
return os.path.join(self.dataset_dir, self.anno_path)
def _is_valid_file(f, extensions=('.jpg', '.jpeg', '.png', '.bmp')):
return f.lower().endswith(extensions)
def _make_dataset(dir):
dir = os.path.expanduser(dir)
if not os.path.isdir(d):
raise ('{} should be a dir'.format(dir))
images = []
for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
images.append(path)
return images
@register
@serializable
class ImageFolder(DetDataset):
......@@ -76,11 +93,18 @@ class ImageFolder(DetDataset):
image_dir=None,
anno_path=None,
sample_num=-1,
use_default_label=None,
**kwargs):
super(ImageFolder, self).__init__(dataset_dir, image_dir, anno_path,
sample_num)
sample_num, use_default_label)
self._imid2path = {}
self.roidbs = None
def parse_dataset(self):
def parse_dataset(self, with_background=True):
if not self.roidbs:
self.roidbs = self._load_images()
def _parse(self):
image_dir = self.image_dir
if not isinstance(image_dir, Sequence):
image_dir = [image_dir]
......@@ -91,4 +115,27 @@ class ImageFolder(DetDataset):
images.extend(_make_dataset(im_dir))
elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
images.append(im_dir)
self.roidbs = images
return images
def _load_images(self):
images = self._parse()
ct = 0
records = []
for image in images:
assert image != '' and os.path.isfile(image), \
"Image {} not found".format(image)
if self.sample_num > 0 and ct >= self.sample_num:
break
rec = {'im_id': np.array([ct]), 'im_file': image}
self._imid2path[ct] = image
ct += 1
records.append(rec)
assert len(records) > 0, "No image file found"
return records
def get_imid2path(self):
return self._imid2path
def set_images(self, images):
self.image_dir = images
self.roidbs = self._load_images()
......@@ -33,7 +33,6 @@ from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.visualizer import visualize_results
from ppdet.utils.cli import ArgsParser
from ppdet.data.reader import create_reader
from ppdet.utils.checkpoint import load_weight
from ppdet.utils.eval_utils import get_infer_results
import logging
......@@ -120,22 +119,24 @@ def get_test_images(infer_dir, infer_img):
return images
def run(FLAGS, cfg):
def run(FLAGS, cfg, place):
# Model
main_arch = cfg.architecture
model = create(cfg.architecture)
dataset = cfg.TestReader['dataset']
# data
dataset = cfg.TestDataset
test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
dataset.set_images(test_images)
test_loader, _ = create('TestReader')(dataset, cfg['worker_num'], place)
# TODO: support other metrics
imid2path = dataset.get_imid2path()
from ppdet.utils.coco_eval import get_category_info
anno_file = dataset.get_anno()
with_background = dataset.with_background
with_background = cfg.with_background
use_default_label = dataset.use_default_label
clsid2catid, catid2name = get_category_info(anno_file, with_background,
use_default_label)
......@@ -143,11 +144,8 @@ def run(FLAGS, cfg):
# Init Model
load_weight(model, cfg.weights)
# Data Reader
test_reader = create_reader(cfg.TestDataset, cfg.TestReader)
# Run Infer
for iter_id, data in enumerate(test_reader()):
for iter_id, data in enumerate(test_loader):
# forward
model.eval()
outs = model(data, cfg.TestReader['inputs_def']['fields'], 'infer')
......@@ -208,7 +206,9 @@ def main():
check_gpu(cfg.use_gpu)
check_version()
run(FLAGS, cfg)
place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu'
place = paddle.set_device(place)
run(FLAGS, cfg, place)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册