未验证 提交 00e77dde 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #58 from heavengate/refine_compose

refine compose
...@@ -20,6 +20,7 @@ import argparse ...@@ -20,6 +20,7 @@ import argparse
import numpy as np import numpy as np
from hapi.model import Input, set_device from hapi.model import Input, set_device
from hapi.vision.transforms import Compose
from check import check_gpu, check_version from check import check_gpu, check_version
from modeling import tsm_resnet50 from modeling import tsm_resnet50
......
...@@ -24,6 +24,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv ...@@ -24,6 +24,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from hapi.model import Model, CrossEntropy, Input, set_device from hapi.model import Model, CrossEntropy, Input, set_device
from hapi.metrics import Accuracy from hapi.metrics import Accuracy
from hapi.vision.transforms import Compose
from modeling import tsm_resnet50 from modeling import tsm_resnet50
from check import check_gpu, check_version from check import check_gpu, check_version
......
...@@ -21,24 +21,7 @@ import logging ...@@ -21,24 +21,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = ['GroupScale', 'GroupMultiScaleCrop', 'GroupRandomCrop', __all__ = ['GroupScale', 'GroupMultiScaleCrop', 'GroupRandomCrop',
'GroupRandomFlip', 'GroupCenterCrop', 'NormalizeImage', 'GroupRandomFlip', 'GroupCenterCrop', 'NormalizeImage']
'Compose']
class Compose(object):
def __init__(self, transforms=[]):
self.transforms = transforms
def __call__(self, *data):
for f in self.transforms:
try:
data = f(*data)
except Exception as e:
stack_info = traceback.format_exc()
logger.info("fail to perform transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
raise e
return data
class GroupScale(object): class GroupScale(object):
......
...@@ -27,7 +27,7 @@ from paddle.io import DataLoader ...@@ -27,7 +27,7 @@ from paddle.io import DataLoader
from hapi.model import Model, Input, set_device from hapi.model import Model, Input, set_device
from hapi.distributed import DistributedBatchSampler from hapi.distributed import DistributedBatchSampler
from hapi.vision.transforms import BatchCompose from hapi.vision.transforms import Compose, BatchCompose
from modeling import yolov3_darknet53, YoloLoss from modeling import yolov3_darknet53, YoloLoss
from coco import COCODataset from coco import COCODataset
......
...@@ -20,7 +20,6 @@ import traceback ...@@ -20,7 +20,6 @@ import traceback
import numpy as np import numpy as np
__all__ = [ __all__ = [
"Compose",
'ColorDistort', 'ColorDistort',
'RandomExpand', 'RandomExpand',
'RandomCrop', 'RandomCrop',
...@@ -34,37 +33,6 @@ __all__ = [ ...@@ -34,37 +33,6 @@ __all__ = [
] ]
class Compose(object):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, *data):
for f in self.transforms:
try:
data = f(*data)
except Exception as e:
stack_info = traceback.format_exc()
print("fail to perform transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
raise e
return data
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class ColorDistort(object): class ColorDistort(object):
"""Random color distortion. """Random color distortion.
......
...@@ -121,7 +121,7 @@ class Flowers(Dataset): ...@@ -121,7 +121,7 @@ class Flowers(Dataset):
image = np.array(Image.open(io.BytesIO(image))) image = np.array(Image.open(io.BytesIO(image)))
if self.transform is not None: if self.transform is not None:
image, label = self.transform(image, label) image = self.transform(image)
return image, label return image, label
......
...@@ -149,7 +149,7 @@ class MNIST(Dataset): ...@@ -149,7 +149,7 @@ class MNIST(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
image, label = self.images[idx], self.labels[idx] image, label = self.images[idx], self.labels[idx]
if self.transform is not None: if self.transform is not None:
image, label = self.transform(image, label) image = self.transform(image)
return image, label return image, label
def __len__(self): def __len__(self):
......
...@@ -29,8 +29,10 @@ import traceback ...@@ -29,8 +29,10 @@ import traceback
from . import functional as F from . import functional as F
if sys.version_info < (3, 3): if sys.version_info < (3, 3):
Sequence = collections.Sequence
Iterable = collections.Iterable Iterable = collections.Iterable
else: else:
Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable Iterable = collections.abc.Iterable
__all__ = [ __all__ = [
...@@ -54,19 +56,44 @@ __all__ = [ ...@@ -54,19 +56,44 @@ __all__ = [
class Compose(object): class Compose(object):
"""Composes several transforms together. """
Composes several transforms together use for composing list of transforms
together for a dataset transform.
Args: Args:
transforms (list of ``Transform`` objects): list of transforms to compose. transforms (list of ``Transform`` objects): list of transforms to compose.
Returns:
A compose object which is callable, __call__ for this Compose
object will call each given :attr:`transforms` sequencely.
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, ColorJitter, Resize
transform = Compose([ColorJitter(), Resize(size=608)])
flowers = Flowers(mode='test', transform=transform)
for i in range(10):
sample = flowers[i]
print(sample[0].shape, sample[1])
""" """
def __init__(self, transforms): def __init__(self, transforms):
self.transforms = transforms self.transforms = transforms
def __call__(self, data): def __call__(self, *data):
for f in self.transforms: for f in self.transforms:
try: try:
# multi-fileds in a sample
if isinstance(data, Sequence):
data = f(*data)
# single field in a sample, call transform directly
else:
data = f(data) data = f(data)
except Exception as e: except Exception as e:
stack_info = traceback.format_exc() stack_info = traceback.format_exc()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册