From cececbbff223c4f52e25e929a3fd6d0b81dea3dc Mon Sep 17 00:00:00 2001 From: danleifeng Date: Thu, 6 Feb 2020 12:13:29 +0000 Subject: [PATCH] add reader data_format --- plsc/entry.py | 5 +++-- plsc/utils/base64_reader.py | 25 +++++++++++++++++-------- plsc/utils/jpeg_reader.py | 25 +++++++++++++++++-------- 3 files changed, 37 insertions(+), 18 deletions(-) diff --git a/plsc/entry.py b/plsc/entry.py index 4061867..a719467 100644 --- a/plsc/entry.py +++ b/plsc/entry.py @@ -693,7 +693,8 @@ class Entry(object): if self.predict_reader is None: predict_reader = paddle.batch(reader.arc_train(self.dataset_dir, - self.num_classes), + self.num_classes, + data_format=self.data_format), batch_size=self.train_batch_size) else: predict_reader = self.predict_reader @@ -925,7 +926,7 @@ class Entry(object): if self.train_reader is None: train_reader = paddle.batch(reader.arc_train( - self.dataset_dir, self.num_classes), + self.dataset_dir, self.num_classes, data_format=self.data_format), batch_size=self.train_batch_size) else: train_reader = self.train_reader diff --git a/plsc/utils/base64_reader.py b/plsc/utils/base64_reader.py index 820694e..7277a3b 100644 --- a/plsc/utils/base64_reader.py +++ b/plsc/utils/base64_reader.py @@ -172,7 +172,8 @@ def process_image(sample, color_jitter, rotate, rand_mirror, - normalize): + normalize, + data_format='NCHW'): img_data = base64.b64decode(sample[0]) img = Image.open(StringIO(img_data)) @@ -198,6 +199,9 @@ def process_image(sample, assert sample[1] < class_dim, \ "label of train dataset should be less than the class_dim." + + if data_format == 'NHWC': + img = img.transpose((1, 2, 0)) return img, sample[1] @@ -208,7 +212,8 @@ def arc_iterator(data_dir, color_jitter=False, rotate=False, rand_mirror=False, - normalize=False): + normalize=False, + data_format='NCHW'): trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) num_trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) @@ -237,11 +242,12 @@ def arc_iterator(data_dir, color_jitter=color_jitter, rotate=rotate, rand_mirror=rand_mirror, - normalize=normalize) + normalize=normalize, + data_format=data_format) return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE) -def load_bin(path, image_size): +def load_bin(path, image_size, data_format ='NCHW'): if six.PY2: bins, issame_list = pickle.load(open(path, 'rb')) else: @@ -267,6 +273,8 @@ def load_bin(path, image_size): img = np.array(img).astype('float32').transpose((2, 0, 1)) img -= img_mean img /= img_std + if data_format == 'NHWC': + img = img.transpose((1, 2, 0)) data_list[flip][i][:] = img if i % 1000 == 0: print('loading bin', i) @@ -274,7 +282,7 @@ def load_bin(path, image_size): return data_list, issame_list -def train(data_dir, num_classes): +def train(data_dir, num_classes, data_format ='NCHW'): file_path = os.path.join(data_dir, 'file_list.txt') return arc_iterator(data_dir, file_path, @@ -282,16 +290,17 @@ def train(data_dir, num_classes): color_jitter=False, rotate=False, rand_mirror=True, - normalize=True) + normalize=True, + data_format=data_format) -def test(data_dir, datasets): +def test(data_dir, datasets, data_format ='NCHW'): test_list = [] test_name_list = [] for name in datasets.split(','): path = os.path.join(data_dir, name+".bin") if os.path.exists(path): - data_set = load_bin(path, (DATA_DIM, DATA_DIM)) + data_set = load_bin(path, (DATA_DIM, DATA_DIM), data_format=data_format) test_list.append(data_set) test_name_list.append(name) print('test', name) diff --git a/plsc/utils/jpeg_reader.py b/plsc/utils/jpeg_reader.py index f8bb850..554c6e0 100644 --- a/plsc/utils/jpeg_reader.py +++ b/plsc/utils/jpeg_reader.py @@ -184,7 +184,8 @@ def process_image_imagepath(sample, color_jitter, rotate, rand_mirror, - normalize): + normalize, + data_format='NCHW'): imgpath = sample[0] img = Image.open(imgpath) @@ -211,6 +212,9 @@ def process_image_imagepath(sample, assert sample[1] < class_dim, \ "label of train dataset should be less than the class_dim." + if data_format == 'NHWC': + img = img.transpose((1, 2, 0)) + return img, sample[1] @@ -221,7 +225,8 @@ def arc_iterator(data, color_jitter=False, rotate=False, rand_mirror=False, - normalize=False): + normalize=False, + data_format ='NCHW'): def reader(): if shuffle: random.shuffle(data) @@ -235,11 +240,12 @@ def arc_iterator(data, color_jitter=color_jitter, rotate=rotate, rand_mirror=rand_mirror, - normalize=normalize) + normalize=normalize, + data_format=data_format) return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE) -def load_bin(path, image_size): +def load_bin(path, image_size, data_format ='NCHW'): if six.PY2: bins, issame_list = pickle.load(open(path, 'rb')) else: @@ -265,6 +271,8 @@ def load_bin(path, image_size): img = np.array(img).astype('float32').transpose((2, 0, 1)) img -= img_mean img /= img_std + if data_format == 'NHWC': + img = img.transpose((1, 2, 0)) data_list[flip][i][:] = img if i % 1000 == 0: print('loading bin', i) @@ -272,7 +280,7 @@ def load_bin(path, image_size): return data_list, issame_list -def arc_train(data_dir, class_dim): +def arc_train(data_dir, class_dim, data_format ='NCHW'): train_image_list = get_train_image_list(data_dir) return arc_iterator(train_image_list, shuffle=True, @@ -281,16 +289,17 @@ def arc_train(data_dir, class_dim): color_jitter=False, rotate=False, rand_mirror=True, - normalize=True) + normalize=True, + data_format=data_format) -def test(data_dir, datasets): +def test(data_dir, datasets, data_format ='NCHW'): test_list = [] test_name_list = [] for name in datasets.split(','): path = os.path.join(data_dir, name+".bin") if os.path.exists(path): - data_set = load_bin(path, (DATA_DIM, DATA_DIM)) + data_set = load_bin(path, (DATA_DIM, DATA_DIM), data_format=data_format) test_list.append(data_set) test_name_list.append(name) print('test', name) -- GitLab