data.py 5.5 KB
Newer Older
H
HypoX64 已提交
1
import random
HypoX64's avatar
preview  
HypoX64 已提交
2 3 4
import numpy as np
import torch
import torchvision.transforms as transforms
H
HypoX64 已提交
5
import cv2
H
hypox64 已提交
6
from .image_processing import color_adjust
HypoX64's avatar
preview  
HypoX64 已提交
7 8 9 10 11 12 13

transform = transforms.Compose([  
    transforms.ToTensor(),  
    transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))  
    ]  
)  

H
HypoX64 已提交
14
def tensor2im(image_tensor, imtype=np.uint8, gray=False, rgb2bgr = True ,is0_1 = False):
HypoX64's avatar
preview  
HypoX64 已提交
15 16
    image_tensor =image_tensor.data
    image_numpy = image_tensor[0].cpu().float().numpy()
H
hypox64 已提交
17 18 19
    # if gray:
    #     image_numpy = (image_numpy+1.0)/2.0 * 255.0
    # else:
HypoX64's avatar
preview  
HypoX64 已提交
20 21
    if image_numpy.shape[0] == 1:
        image_numpy = np.tile(image_numpy, (3, 1, 1))
H
HypoX64 已提交
22 23 24 25 26 27

    image_numpy = image_numpy.transpose((1, 2, 0))

    if not is0_1:
        image_numpy = (image_numpy + 1)/2.0
    image_numpy = np.clip(image_numpy * 255.0,0,255)  
H
hypox64 已提交
28
    if rgb2bgr and not gray:
HypoX64's avatar
preview  
HypoX64 已提交
29 30 31 32
        image_numpy = image_numpy[...,::-1]-np.zeros_like(image_numpy)
    return image_numpy.astype(imtype)


H
hypox64 已提交
33
def im2tensor(image_numpy, imtype=np.uint8, gray=False,bgr2rgb = True, reshape = True, use_gpu = True,  use_transform = True,is0_1 = True):
H
hypox64 已提交
34 35 36 37
    
    if gray:
        h, w = image_numpy.shape
        image_numpy = (image_numpy/255.0-0.5)/0.5
H
hypox64 已提交
38
        image_tensor = torch.from_numpy(image_numpy).float()
H
hypox64 已提交
39
        if reshape:
H
hypox64 已提交
40
            image_tensor = image_tensor.reshape(1,1,h,w)
H
hypox64 已提交
41 42 43 44 45 46 47
    else:
        h, w ,ch = image_numpy.shape
        if bgr2rgb:
            image_numpy = image_numpy[...,::-1]-np.zeros_like(image_numpy)
        if use_transform:
            image_tensor = transform(image_numpy)
        else:
H
hypox64 已提交
48 49 50 51
            if is0_1:
                image_numpy = image_numpy/255.0
            else:
                image_numpy = (image_numpy/255.0-0.5)/0.5
H
hypox64 已提交
52 53 54
            image_numpy = image_numpy.transpose((2, 0, 1))
            image_tensor = torch.from_numpy(image_numpy).float()
        if reshape:
H
hypox64 已提交
55
            image_tensor = image_tensor.reshape(1,ch,h,w)
HypoX64's avatar
preview  
HypoX64 已提交
56 57
    if use_gpu:
        image_tensor = image_tensor.cuda()
H
hypox64 已提交
58 59
    return image_tensor

H
HypoX64 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76

def random_transform_video(src,target,finesize,N):

    #random crop
    h,w = target.shape[:2]
    h_move = int((h-finesize)*random.random())
    w_move = int((w-finesize)*random.random())
    # print(h,w,h_move,w_move)
    target = target[h_move:h_move+finesize,w_move:w_move+finesize,:]
    src = src[h_move:h_move+finesize,w_move:w_move+finesize,:]

    #random flip
    if random.random()<0.5:
        src = src[:,::-1,:]
        target = target[:,::-1,:]

    #random color
H
hypox64 已提交
77
    alpha = random.uniform(-0.3,0.3)
H
hypox64 已提交
78
    beta  = random.uniform(-0.2,0.2)
H
hypox64 已提交
79 80 81
    b     = random.uniform(-0.05,0.05)
    g     = random.uniform(-0.05,0.05)
    r     = random.uniform(-0.05,0.05)
H
hypox64 已提交
82 83 84 85 86 87 88 89
    for i in range(N):
        src[:,:,i*3:(i+1)*3] = color_adjust(src[:,:,i*3:(i+1)*3],alpha,beta,b,g,r)
    target = color_adjust(target,alpha,beta,b,g,r)

    # random_num = 15
    # bright = random.randint(-random_num*2,random_num*2)
    # for i in range(N*3): src[:,:,i]=np.clip(src[:,:,i].astype('int')+bright,0,255).astype('uint8')
    # for i in range(3): target[:,:,i]=np.clip(target[:,:,i].astype('int')+bright,0,255).astype('uint8')
H
HypoX64 已提交
90 91 92 93

    return src,target


H
hypox64 已提交
94
def random_transform_image(img,mask,finesize,test_flag = False):
H
HypoX64 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120

    # randomsize = int(finesize*(1.2+0.2*random.random())+2)

    h,w = img.shape[:2]
    loadsize = min((h,w))
    a = (float(h)/float(w))*random.uniform(0.9, 1.1)

    if h<w:
        mask = cv2.resize(mask, (int(loadsize/a),loadsize))
        img = cv2.resize(img, (int(loadsize/a),loadsize))
    else:
        mask = cv2.resize(mask, (loadsize,int(loadsize*a)))
        img = cv2.resize(img, (loadsize,int(loadsize*a)))

    # mask = randomsize(mask,loadsize)
    # img = randomsize(img,loadsize)


    #random crop
    h,w = img.shape[:2]

    h_move = int((h-finesize)*random.random())
    w_move = int((w-finesize)*random.random())
    # print(h,w,h_move,w_move)
    img_crop = img[h_move:h_move+finesize,w_move:w_move+finesize]
    mask_crop = mask[h_move:h_move+finesize,w_move:w_move+finesize]
H
hypox64 已提交
121 122 123

    if test_flag:
        return img_crop,mask_crop
H
HypoX64 已提交
124 125 126 127 128 129 130 131 132 133 134
    
    #random rotation
    if random.random()<0.2:
        h,w = img_crop.shape[:2]
        M = cv2.getRotationMatrix2D((w/2,h/2),90*int(4*random.random()),1)
        img = cv2.warpAffine(img_crop,M,(w,h))
        mask = cv2.warpAffine(mask_crop,M,(w,h))
    else:
        img,mask = img_crop,mask_crop

    #random color
H
hypox64 已提交
135 136 137 138 139
    img = color_adjust(img,ran=True)
    # random_num = 15
    # for i in range(3): img[:,:,i]=np.clip(img[:,:,i].astype('int')+random.randint(-random_num,random_num),0,255).astype('uint8')
    # bright = random.randint(-random_num*2,random_num*2)
    # for i in range(3): img[:,:,i]=np.clip(img[:,:,i].astype('int')+bright,0,255).astype('uint8')
H
HypoX64 已提交
140 141 142 143 144 145 146 147 148

    #random flip
    if random.random()<0.5:
        if random.random()<0.5:
            img = img[:,::-1,:]
            mask = mask[:,::-1]
        else:
            img = img[::-1,:,:]
            mask = mask[::-1,:]
H
hypox64 已提交
149 150 151 152 153 154 155

    #random blur
    if random.random()>0.5:
        size_ran = random.uniform(0.5,1.5)
        img = cv2.resize(img, (int(finesize*size_ran),int(finesize*size_ran)))
        img = cv2.resize(img, (finesize,finesize))
        #img = cv2.blur(img, (random.randint(1,3), random.randint(1,3)))
H
HypoX64 已提交
156 157
    return img,mask

H
hypox64 已提交
158
def showresult(img1,img2,img3,name,is0_1 = False):
H
HypoX64 已提交
159 160
    size = img1.shape[3]
    showimg=np.zeros((size,size*3,3))
H
hypox64 已提交
161 162 163
    showimg[0:size,0:size] = tensor2im(img1,rgb2bgr = False, is0_1 = is0_1)
    showimg[0:size,size:size*2] = tensor2im(img2,rgb2bgr = False, is0_1 = is0_1)
    showimg[0:size,size*2:size*3] = tensor2im(img3,rgb2bgr = False, is0_1 = is0_1)
H
HypoX64 已提交
164
    cv2.imwrite(name, showimg)