提交 d740acce 编写于 作者: W wjj19950828

fixed ToPILImage

上级 270c5dc0
......@@ -18,7 +18,8 @@ from paddle.vision.transforms import functional as F
class ToPILImage(BaseTransform):
def __init__(self, mode=None, keys=None):
super(ToTensor, self).__init__(keys)
super(ToPILImage, self).__init__(keys)
self.mode = mode
def _apply_image(self, pic):
"""
......@@ -53,7 +54,7 @@ class ToPILImage(BaseTransform):
npimg = pic
if isinstance(pic, paddle.Tensor) and "float" in str(pic.numpy(
).dtype) and mode != 'F':
).dtype) and self.mode != 'F':
pic = pic.mul(255).byte()
if isinstance(pic, paddle.Tensor):
npimg = np.transpose(pic.numpy(), (1, 2, 0))
......@@ -74,40 +75,40 @@ class ToPILImage(BaseTransform):
expected_mode = 'I'
elif npimg.dtype == np.float32:
expected_mode = 'F'
if mode is not None and mode != expected_mode:
if self.mode is not None and self.mode != expected_mode:
raise ValueError(
"Incorrect mode ({}) supplied for input type {}. Should be {}"
.format(mode, np.dtype, expected_mode))
mode = expected_mode
.format(self.mode, np.dtype, expected_mode))
self.mode = expected_mode
elif npimg.shape[2] == 2:
permitted_2_channel_modes = ['LA']
if mode is not None and mode not in permitted_2_channel_modes:
if self.mode is not None and self.mode not in permitted_2_channel_modes:
raise ValueError("Only modes {} are supported for 2D inputs".
format(permitted_2_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'LA'
if self.mode is None and npimg.dtype == np.uint8:
self.mode = 'LA'
elif npimg.shape[2] == 4:
permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX']
if mode is not None and mode not in permitted_4_channel_modes:
if self.mode is not None and self.mode not in permitted_4_channel_modes:
raise ValueError("Only modes {} are supported for 4D inputs".
format(permitted_4_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'RGBA'
if self.mode is None and npimg.dtype == np.uint8:
self.mode = 'RGBA'
else:
permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
if mode is not None and mode not in permitted_3_channel_modes:
if self.mode is not None and self.mode not in permitted_3_channel_modes:
raise ValueError("Only modes {} are supported for 3D inputs".
format(permitted_3_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'RGB'
if self.mode is None and npimg.dtype == np.uint8:
self.mode = 'RGB'
if mode is None:
if self.mode is None:
raise TypeError('Input type {} is not supported'.format(
npimg.dtype))
return Image.fromarray(npimg, mode=mode)
return Image.fromarray(npimg, mode=self.mode)
```
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册