提交 cececbbf 编写于 作者: D danleifeng

add reader data_format

上级 b85e4841
...@@ -693,7 +693,8 @@ class Entry(object): ...@@ -693,7 +693,8 @@ class Entry(object):
if self.predict_reader is None: if self.predict_reader is None:
predict_reader = paddle.batch(reader.arc_train(self.dataset_dir, predict_reader = paddle.batch(reader.arc_train(self.dataset_dir,
self.num_classes), self.num_classes,
data_format=self.data_format),
batch_size=self.train_batch_size) batch_size=self.train_batch_size)
else: else:
predict_reader = self.predict_reader predict_reader = self.predict_reader
...@@ -925,7 +926,7 @@ class Entry(object): ...@@ -925,7 +926,7 @@ class Entry(object):
if self.train_reader is None: if self.train_reader is None:
train_reader = paddle.batch(reader.arc_train( train_reader = paddle.batch(reader.arc_train(
self.dataset_dir, self.num_classes), self.dataset_dir, self.num_classes, data_format=self.data_format),
batch_size=self.train_batch_size) batch_size=self.train_batch_size)
else: else:
train_reader = self.train_reader train_reader = self.train_reader
......
...@@ -172,7 +172,8 @@ def process_image(sample, ...@@ -172,7 +172,8 @@ def process_image(sample,
color_jitter, color_jitter,
rotate, rotate,
rand_mirror, rand_mirror,
normalize): normalize,
data_format='NCHW'):
img_data = base64.b64decode(sample[0]) img_data = base64.b64decode(sample[0])
img = Image.open(StringIO(img_data)) img = Image.open(StringIO(img_data))
...@@ -198,6 +199,9 @@ def process_image(sample, ...@@ -198,6 +199,9 @@ def process_image(sample,
assert sample[1] < class_dim, \ assert sample[1] < class_dim, \
"label of train dataset should be less than the class_dim." "label of train dataset should be less than the class_dim."
if data_format == 'NHWC':
img = img.transpose((1, 2, 0))
return img, sample[1] return img, sample[1]
...@@ -208,7 +212,8 @@ def arc_iterator(data_dir, ...@@ -208,7 +212,8 @@ def arc_iterator(data_dir,
color_jitter=False, color_jitter=False,
rotate=False, rotate=False,
rand_mirror=False, rand_mirror=False,
normalize=False): normalize=False,
data_format='NCHW'):
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
num_trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) num_trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
...@@ -237,11 +242,12 @@ def arc_iterator(data_dir, ...@@ -237,11 +242,12 @@ def arc_iterator(data_dir,
color_jitter=color_jitter, color_jitter=color_jitter,
rotate=rotate, rotate=rotate,
rand_mirror=rand_mirror, rand_mirror=rand_mirror,
normalize=normalize) normalize=normalize,
data_format=data_format)
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE) return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def load_bin(path, image_size): def load_bin(path, image_size, data_format ='NCHW'):
if six.PY2: if six.PY2:
bins, issame_list = pickle.load(open(path, 'rb')) bins, issame_list = pickle.load(open(path, 'rb'))
else: else:
...@@ -267,6 +273,8 @@ def load_bin(path, image_size): ...@@ -267,6 +273,8 @@ def load_bin(path, image_size):
img = np.array(img).astype('float32').transpose((2, 0, 1)) img = np.array(img).astype('float32').transpose((2, 0, 1))
img -= img_mean img -= img_mean
img /= img_std img /= img_std
if data_format == 'NHWC':
img = img.transpose((1, 2, 0))
data_list[flip][i][:] = img data_list[flip][i][:] = img
if i % 1000 == 0: if i % 1000 == 0:
print('loading bin', i) print('loading bin', i)
...@@ -274,7 +282,7 @@ def load_bin(path, image_size): ...@@ -274,7 +282,7 @@ def load_bin(path, image_size):
return data_list, issame_list return data_list, issame_list
def train(data_dir, num_classes): def train(data_dir, num_classes, data_format ='NCHW'):
file_path = os.path.join(data_dir, 'file_list.txt') file_path = os.path.join(data_dir, 'file_list.txt')
return arc_iterator(data_dir, return arc_iterator(data_dir,
file_path, file_path,
...@@ -282,16 +290,17 @@ def train(data_dir, num_classes): ...@@ -282,16 +290,17 @@ def train(data_dir, num_classes):
color_jitter=False, color_jitter=False,
rotate=False, rotate=False,
rand_mirror=True, rand_mirror=True,
normalize=True) normalize=True,
data_format=data_format)
def test(data_dir, datasets): def test(data_dir, datasets, data_format ='NCHW'):
test_list = [] test_list = []
test_name_list = [] test_name_list = []
for name in datasets.split(','): for name in datasets.split(','):
path = os.path.join(data_dir, name+".bin") path = os.path.join(data_dir, name+".bin")
if os.path.exists(path): if os.path.exists(path):
data_set = load_bin(path, (DATA_DIM, DATA_DIM)) data_set = load_bin(path, (DATA_DIM, DATA_DIM), data_format=data_format)
test_list.append(data_set) test_list.append(data_set)
test_name_list.append(name) test_name_list.append(name)
print('test', name) print('test', name)
......
...@@ -184,7 +184,8 @@ def process_image_imagepath(sample, ...@@ -184,7 +184,8 @@ def process_image_imagepath(sample,
color_jitter, color_jitter,
rotate, rotate,
rand_mirror, rand_mirror,
normalize): normalize,
data_format='NCHW'):
imgpath = sample[0] imgpath = sample[0]
img = Image.open(imgpath) img = Image.open(imgpath)
...@@ -211,6 +212,9 @@ def process_image_imagepath(sample, ...@@ -211,6 +212,9 @@ def process_image_imagepath(sample,
assert sample[1] < class_dim, \ assert sample[1] < class_dim, \
"label of train dataset should be less than the class_dim." "label of train dataset should be less than the class_dim."
if data_format == 'NHWC':
img = img.transpose((1, 2, 0))
return img, sample[1] return img, sample[1]
...@@ -221,7 +225,8 @@ def arc_iterator(data, ...@@ -221,7 +225,8 @@ def arc_iterator(data,
color_jitter=False, color_jitter=False,
rotate=False, rotate=False,
rand_mirror=False, rand_mirror=False,
normalize=False): normalize=False,
data_format ='NCHW'):
def reader(): def reader():
if shuffle: if shuffle:
random.shuffle(data) random.shuffle(data)
...@@ -235,11 +240,12 @@ def arc_iterator(data, ...@@ -235,11 +240,12 @@ def arc_iterator(data,
color_jitter=color_jitter, color_jitter=color_jitter,
rotate=rotate, rotate=rotate,
rand_mirror=rand_mirror, rand_mirror=rand_mirror,
normalize=normalize) normalize=normalize,
data_format=data_format)
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE) return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def load_bin(path, image_size): def load_bin(path, image_size, data_format ='NCHW'):
if six.PY2: if six.PY2:
bins, issame_list = pickle.load(open(path, 'rb')) bins, issame_list = pickle.load(open(path, 'rb'))
else: else:
...@@ -265,6 +271,8 @@ def load_bin(path, image_size): ...@@ -265,6 +271,8 @@ def load_bin(path, image_size):
img = np.array(img).astype('float32').transpose((2, 0, 1)) img = np.array(img).astype('float32').transpose((2, 0, 1))
img -= img_mean img -= img_mean
img /= img_std img /= img_std
if data_format == 'NHWC':
img = img.transpose((1, 2, 0))
data_list[flip][i][:] = img data_list[flip][i][:] = img
if i % 1000 == 0: if i % 1000 == 0:
print('loading bin', i) print('loading bin', i)
...@@ -272,7 +280,7 @@ def load_bin(path, image_size): ...@@ -272,7 +280,7 @@ def load_bin(path, image_size):
return data_list, issame_list return data_list, issame_list
def arc_train(data_dir, class_dim): def arc_train(data_dir, class_dim, data_format ='NCHW'):
train_image_list = get_train_image_list(data_dir) train_image_list = get_train_image_list(data_dir)
return arc_iterator(train_image_list, return arc_iterator(train_image_list,
shuffle=True, shuffle=True,
...@@ -281,16 +289,17 @@ def arc_train(data_dir, class_dim): ...@@ -281,16 +289,17 @@ def arc_train(data_dir, class_dim):
color_jitter=False, color_jitter=False,
rotate=False, rotate=False,
rand_mirror=True, rand_mirror=True,
normalize=True) normalize=True,
data_format=data_format)
def test(data_dir, datasets): def test(data_dir, datasets, data_format ='NCHW'):
test_list = [] test_list = []
test_name_list = [] test_name_list = []
for name in datasets.split(','): for name in datasets.split(','):
path = os.path.join(data_dir, name+".bin") path = os.path.join(data_dir, name+".bin")
if os.path.exists(path): if os.path.exists(path):
data_set = load_bin(path, (DATA_DIM, DATA_DIM)) data_set = load_bin(path, (DATA_DIM, DATA_DIM), data_format=data_format)
test_list.append(data_set) test_list.append(data_set)
test_name_list.append(name) test_name_list.append(name)
print('test', name) print('test', name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册