From 29458f1bb7e64517b84b8e4461a8d9df936535f2 Mon Sep 17 00:00:00 2001 From: hypox64 Date: Wed, 15 Jan 2020 23:22:25 +0800 Subject: [PATCH] batch-training, modify model --- make_datasets/draw_mask.py | 40 ++++---- make_datasets/get_image_from_video.py | 5 +- .../use_irregular_holes_mask_make_dataset.py | 2 +- models/video_model.py | 46 ++++++---- train/clean/train.py | 92 +++++++++++-------- util/data.py | 40 +++++--- util/ffmpeg.py | 2 +- util/image_processing.py | 30 ++++++ 8 files changed, 165 insertions(+), 92 deletions(-) diff --git a/make_datasets/draw_mask.py b/make_datasets/draw_mask.py index 69201ab..75a6950 100644 --- a/make_datasets/draw_mask.py +++ b/make_datasets/draw_mask.py @@ -4,14 +4,20 @@ import datetime import os import random -def resize(img,size): - h, w = img.shape[:2] - if w >= h: - res = cv2.resize(img,(int(size*w/h), size)) - else: - res = cv2.resize(img,(size, int(size*h/w))) - - return res +import sys +sys.path.append("..") +from util import util +from util import image_processing as impro + +image_dir = './datasets_img/v2im' +mask_dir = './datasets_img/v2im_mask' +util.makedirs(mask_dir) + +files = os.listdir(image_dir) +files_new =files.copy() +print('find image:',len(files)) +masks = os.listdir(mask_dir) +print('mask:',len(masks)) # mouse callback function drawing = False # true if mouse is pressed @@ -46,11 +52,7 @@ def makemask(img): # print('Cost time:',(endtime-starttime)) return mask -files = os.listdir('./origin_image') -files_new =files.copy() -print('find image:',len(files)) -masks = os.listdir('./mask') -print('mask:',len(masks)) + for i in range(len(masks)): masks[i]=masks[i].replace('.png','.jpg') for file in files: @@ -59,14 +61,14 @@ for file in files: files = files_new # files = list(set(files)) #Distinct print('remain:',len(files)) -random.shuffle (files) +random.shuffle(files) # files.sort() cnt = 0 for file in files: cnt += 1 - img = cv2.imread('./origin_image/'+file) - img = resize(img,512) + img = cv2.imread(os.path.join(image_dir,file)) + img = impro.resize(img,512) cv2.namedWindow('image') cv2.setMouseCallback('image',draw_circle) #MouseCallback while(1): @@ -74,10 +76,10 @@ for file in files: cv2.imshow('image',img) k = cv2.waitKey(1) & 0xFF if k == ord(' '): - img = resize(img,256) + img = impro.resize(img,256) mask = makemask(img) - cv2.imwrite('./mask/'+os.path.splitext(file)[0]+'.png',mask) - print('./mask/'+os.path.splitext(file)[0]+'.png') + cv2.imwrite(os.path.join(mask_dir,os.path.splitext(file)[0]+'.png'),mask) + print(os.path.join(mask_dir,os.path.splitext(file)[0]+'.png')) # cv2.destroyAllWindows() print('remain:',len(files)-cnt) brushsize = 20 diff --git a/make_datasets/get_image_from_video.py b/make_datasets/get_image_from_video.py index b7f96be..fffcde5 100644 --- a/make_datasets/get_image_from_video.py +++ b/make_datasets/get_image_from_video.py @@ -9,9 +9,10 @@ sys.path.append("..") from util import util,ffmpeg from util import image_processing as impro -files = util.Traversal('/media/hypo/Media/download') +files = util.Traversal('./videos') videos = util.is_videos(files) -output_dir = './dataset/v2im' +output_dir = './datasets_img/v2im' +util.makedirs(output_dir) FPS = 1 util.makedirs(output_dir) for video in videos: diff --git a/make_datasets/use_irregular_holes_mask_make_dataset.py b/make_datasets/use_irregular_holes_mask_make_dataset.py index 9321aef..37d92a4 100644 --- a/make_datasets/use_irregular_holes_mask_make_dataset.py +++ b/make_datasets/use_irregular_holes_mask_make_dataset.py @@ -16,7 +16,7 @@ MOD = 'HD' #HD | pix2pix | mosaic MASK = False # if True, output mask,too BOUNDING = True # if true the mosaic size will be more big suffix = '' -output_dir = os.path.join('./dataset_img',MOD) +output_dir = os.path.join('./datasets_img',MOD) util.makedirs(output_dir) if MOD == 'HD': diff --git a/models/video_model.py b/models/video_model.py index 6802e9b..8363aa7 100644 --- a/models/video_model.py +++ b/models/video_model.py @@ -97,8 +97,8 @@ class decoder_2d(nn.Module): nn.Conv2d(ngf * mult, int(ngf * mult / 2),kernel_size=3, stride=1, padding=0), norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] - model += [nn.ReflectionPad2d(3)] - model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + # model += [nn.ReflectionPad2d(3)] + # model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] # model += [nn.Tanh()] # model += [nn.Sigmoid()] @@ -123,6 +123,20 @@ class conv_3d(nn.Module): x = self.conv(x) return x +class conv_2d(nn.Module): + def __init__(self,inchannel,outchannel,kernel_size=3,stride=1,padding=1): + super(conv_2d, self).__init__() + self.conv = nn.Sequential( + nn.ReflectionPad2d(padding), + nn.Conv2d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=0, bias=False), + nn.BatchNorm2d(outchannel), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + x = self.conv(x) + return x + class encoder_3d(nn.Module): def __init__(self,in_channel): @@ -131,21 +145,22 @@ class encoder_3d(nn.Module): self.down2 = conv_3d(64, 128, 3, 2, 1) self.down3 = conv_3d(128, 256, 3, 1, 1) self.conver2d = nn.Sequential( - nn.Conv2d(int(in_channel/4), 1, kernel_size=3, stride=1, padding=1, bias=False), - nn.BatchNorm2d(1), + nn.Conv2d(256*int(in_channel/4), 256, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(256), nn.ReLU(inplace=True), ) def forward(self, x): + x = x.view(x.size(0),1,x.size(1),x.size(2),x.size(3)) x = self.down1(x) x = self.down2(x) x = self.down3(x) - x = x.view(x.size(1),x.size(2),x.size(3),x.size(4)) + x = x.view(x.size(0),x.size(1)*x.size(2),x.size(3),x.size(4)) + x = self.conver2d(x) - x = x.view(x.size(1),x.size(0),x.size(2),x.size(3)) return x @@ -158,30 +173,29 @@ class MosaicNet(nn.Module): self.encoder_2d = encoder_2d(4,-1,64,n_blocks=9) self.encoder_3d = encoder_3d(in_channel) self.decoder_2d = decoder_2d(4,3,64,n_blocks=9) - self.merge1 = nn.Sequential( - nn.ReflectionPad2d(1), - nn.Conv2d(512, 256, 3, 1, 0, bias=False), - nn.BatchNorm2d(256), - nn.ReLU(inplace=True), - ) + self.shortcut_cov = conv_2d(3,64,7,1,3) + self.merge1 = conv_2d(512,256,3,1,1) self.merge2 = nn.Sequential( + conv_2d(128,64,3,1,1), nn.ReflectionPad2d(3), - nn.Conv2d(6, out_channel, kernel_size=7, padding=0), - nn.Sigmoid() + nn.Conv2d(64, out_channel, kernel_size=7, padding=0), + nn.Tanh() ) def forward(self, x): N = int((x.size()[1])/3) x_2d = torch.cat((x[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], x[:,N-1:N,:,:]), 1) - shortcat_2d = x[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:] + shortcut_2d = x[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:] x_2d = self.encoder_2d(x_2d) + x_3d = self.encoder_3d(x) x = torch.cat((x_2d,x_3d),1) x = self.merge1(x) x = self.decoder_2d(x) - x = torch.cat((x,shortcat_2d),1) + shortcut_2d = self.shortcut_cov(shortcut_2d) + x = torch.cat((x,shortcut_2d),1) x = self.merge2(x) return x diff --git a/train/clean/train.py b/train/clean/train.py index 67bf51a..b251ba9 100644 --- a/train/clean/train.py +++ b/train/clean/train.py @@ -18,12 +18,12 @@ import torch.backends.cudnn as cudnn N = 25 ITER = 10000000 -LR = 0.0002 +LR = 0.001 beta1 = 0.5 use_gpu = True use_gan = False -use_L2 = False -CONTINUE = False +use_L2 = True +CONTINUE = True lambda_L1 = 1.0#100.0 lambda_gan = 1.0 @@ -31,8 +31,9 @@ SAVE_FRE = 10000 start_iter = 0 finesize = 128 loadsize = int(finesize*1.1) +batchsize = 8 perload_num = 32 -savename = 'MosaicNet_noL2' +savename = 'MosaicNet_test' dir_checkpoint = 'checkpoints/'+savename util.makedirs(dir_checkpoint) @@ -97,25 +98,32 @@ def loaddata(): ground_true = impro.resize(ground_true,loadsize) input_img,ground_true = data.random_transform_video(input_img,ground_true,finesize,N) - input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False) - ground_true = data.im2tensor(ground_true,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False) + input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1=False) + ground_true = data.im2tensor(ground_true,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1=False) return input_img,ground_true print('preloading data, please wait 5s...') -input_imgs=[] -ground_trues=[] +# input_imgs=[] +# ground_trues=[] +input_imgs = torch.rand(batchsize,N*3+1,finesize,finesize).cuda() +ground_trues = torch.rand(batchsize,3,finesize,finesize).cuda() load_cnt = 0 + def preload(): global load_cnt while 1: try: - input_img,ground_true = loaddata() - input_imgs.append(input_img) - ground_trues.append(ground_true) - if len(input_imgs)>perload_num: - del(input_imgs[0]) - del(ground_trues[0]) + # input_img,ground_true = loaddata() + # input_imgs.append(input_img) + # ground_trues.append(ground_true) + ran = random.randint(0, batchsize-1) + input_imgs[ran],ground_trues[ran] = loaddata() + + + # if len(input_imgs)>perload_num: + # del(input_imgs[0]) + # del(ground_trues[0]) load_cnt += 1 # time.sleep(0.1) except Exception as e: @@ -125,7 +133,7 @@ import threading t = threading.Thread(target=preload,args=()) #t为新创建的线程 t.daemon = True t.start() -while load_cnt < perload_num: +while load_cnt < batchsize*2: time.sleep(0.1) netG.train() @@ -133,23 +141,26 @@ time_start=time.time() print("Begin training...") for iter in range(start_iter+1,ITER): - # input_img,ground_true = loaddata() - ran = random.randint(1, perload_num-2) - input_img = input_imgs[ran] - ground_true = ground_trues[ran] + # inputdata,target = loaddata() + # ran = random.randint(1, perload_num-2) + # inputdata = inputdatas[ran] + # target = targets[ran] + + inputdata = input_imgs.clone() + target = ground_trues.clone() - pred = netG(input_img) + pred = netG(inputdata) if use_gan: netD.train() - # print(input_img[0,3*N,:,:].size()) - # print((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]).size()) - real_A = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], input_img[:,-1,:,:].reshape(-1,1,finesize,finesize)), 1) + # print(inputdata[0,3*N,:,:].size()) + # print((inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]).size()) + real_A = torch.cat((inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], inputdata[:,-1,:,:].reshape(-1,1,finesize,finesize)), 1) fake_AB = torch.cat((real_A, pred), 1) pred_fake = netD(fake_AB.detach()) loss_D_fake = criterionGAN(pred_fake, False) - real_AB = torch.cat((real_A, ground_true), 1) + real_AB = torch.cat((real_A, target), 1) pred_real = netD(real_AB) loss_D_real = criterionGAN(pred_real, True) loss_D = (loss_D_fake + loss_D_real) * 0.5 @@ -161,16 +172,16 @@ for iter in range(start_iter+1,ITER): optimizer_D.step() netD.eval() - # fake_AB = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], pred), 1) - real_A = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], input_img[:,-1,:,:].reshape(-1,1,finesize,finesize)), 1) + # fake_AB = torch.cat((inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], pred), 1) + real_A = torch.cat((inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], inputdata[:,-1,:,:].reshape(-1,1,finesize,finesize)), 1) fake_AB = torch.cat((real_A, pred), 1) pred_fake = netD(fake_AB) loss_G_GAN = criterionGAN(pred_fake, True)*lambda_gan # Second, G(A) = B if use_L2: - loss_G_L1 = (criterion_L1(pred, ground_true)+criterion_L2(pred, ground_true)) * lambda_L1 + loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * lambda_L1 else: - loss_G_L1 = criterion_L1(pred, ground_true) * lambda_L1 + loss_G_L1 = criterion_L1(pred, target) * lambda_L1 # combine loss and calculate gradients loss_G = loss_G_GAN + loss_G_L1 loss_sum[0] += loss_G_L1.item() @@ -182,9 +193,9 @@ for iter in range(start_iter+1,ITER): else: if use_L2: - loss_G_L1 = (criterion_L1(pred, ground_true)+criterion_L2(pred, ground_true)) * lambda_L1 + loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * lambda_L1 else: - loss_G_L1 = criterion_L1(pred, ground_true) * lambda_L1 + loss_G_L1 = criterion_L1(pred, target) * lambda_L1 loss_sum[0] += loss_G_L1.item() optimizer_G.zero_grad() @@ -194,8 +205,8 @@ for iter in range(start_iter+1,ITER): if (iter+1)%100 == 0: try: - data.showresult(input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], - ground_true, pred,os.path.join(dir_checkpoint,'result_train.png')) + data.showresult(inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], + target, pred,os.path.join(dir_checkpoint,'result_train.png')) except Exception as e: print(e) @@ -249,28 +260,29 @@ for iter in range(start_iter+1,ITER): #test netG.eval() - result = np.zeros((finesize*2,finesize*4,3), dtype='uint8') + test_names = os.listdir('./test') + result = np.zeros((finesize*2,finesize*len(test_names),3), dtype='uint8') for cnt,test_name in enumerate(test_names,0): img_names = os.listdir(os.path.join('./test',test_name,'image')) img_names.sort() - input_img = np.zeros((finesize,finesize,3*N+1), dtype='uint8') + inputdata = np.zeros((finesize,finesize,3*N+1), dtype='uint8') img_names.sort() for i in range(0,N): img = impro.imread(os.path.join('./test',test_name,'image',img_names[i])) img = impro.resize(img,finesize) - input_img[:,:,i*3:(i+1)*3] = img + inputdata[:,:,i*3:(i+1)*3] = img mask = impro.imread(os.path.join('./test',test_name,'mask.png'),'gray') mask = impro.resize(mask,finesize) mask = impro.mask_threshold(mask,15,128) - input_img[:,:,-1] = mask - result[0:finesize,finesize*cnt:finesize*(cnt+1),:] = input_img[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3] - input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False) - pred = netG(input_img) + inputdata[:,:,-1] = mask + result[0:finesize,finesize*cnt:finesize*(cnt+1),:] = inputdata[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3] + inputdata = data.im2tensor(inputdata,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1 = False) + pred = netG(inputdata) - pred = data.tensor2im(pred,rgb2bgr = False, is0_1 = True) + pred = data.tensor2im(pred,rgb2bgr = False, is0_1 = False) result[finesize:finesize*2,finesize*cnt:finesize*(cnt+1),:] = pred cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.png'), result) diff --git a/util/data.py b/util/data.py index 50aeab7..4b0402d 100755 --- a/util/data.py +++ b/util/data.py @@ -3,6 +3,7 @@ import numpy as np import torch import torchvision.transforms as transforms import cv2 +from .image_processing import color_adjust transform = transforms.Compose([ transforms.ToTensor(), @@ -29,7 +30,7 @@ def tensor2im(image_tensor, imtype=np.uint8, gray=False, rgb2bgr = True ,is0_1 = return image_numpy.astype(imtype) -def im2tensor(image_numpy, imtype=np.uint8, gray=False,bgr2rgb = True, reshape = True, use_gpu = True, use_transform = True): +def im2tensor(image_numpy, imtype=np.uint8, gray=False,bgr2rgb = True, reshape = True, use_gpu = True, use_transform = True,is0_1 = True): if gray: h, w = image_numpy.shape @@ -44,7 +45,10 @@ def im2tensor(image_numpy, imtype=np.uint8, gray=False,bgr2rgb = True, reshape = if use_transform: image_tensor = transform(image_numpy) else: - image_numpy = image_numpy/255.0 + if is0_1: + image_numpy = image_numpy/255.0 + else: + image_numpy = (image_numpy/255.0-0.5)/0.5 image_numpy = image_numpy.transpose((2, 0, 1)) image_tensor = torch.from_numpy(image_numpy).float() if reshape: @@ -70,10 +74,19 @@ def random_transform_video(src,target,finesize,N): target = target[:,::-1,:] #random color - random_num = 15 - bright = random.randint(-random_num*2,random_num*2) - for i in range(N*3): src[:,:,i]=np.clip(src[:,:,i].astype('int')+bright,0,255).astype('uint8') - for i in range(3): target[:,:,i]=np.clip(target[:,:,i].astype('int')+bright,0,255).astype('uint8') + alpha = random.uniform(-0.2,0.2) + beta = random.uniform(-0.2,0.2) + b = random.uniform(-0.1,0.1) + g = random.uniform(-0.1,0.1) + r = random.uniform(-0.1,0.1) + for i in range(N): + src[:,:,i*3:(i+1)*3] = color_adjust(src[:,:,i*3:(i+1)*3],alpha,beta,b,g,r) + target = color_adjust(target,alpha,beta,b,g,r) + + # random_num = 15 + # bright = random.randint(-random_num*2,random_num*2) + # for i in range(N*3): src[:,:,i]=np.clip(src[:,:,i].astype('int')+bright,0,255).astype('uint8') + # for i in range(3): target[:,:,i]=np.clip(target[:,:,i].astype('int')+bright,0,255).astype('uint8') return src,target @@ -116,10 +129,11 @@ def random_transform_image(img,mask,finesize): img,mask = img_crop,mask_crop #random color - random_num = 15 - for i in range(3): img[:,:,i]=np.clip(img[:,:,i].astype('int')+random.randint(-random_num,random_num),0,255).astype('uint8') - bright = random.randint(-random_num*2,random_num*2) - for i in range(3): img[:,:,i]=np.clip(img[:,:,i].astype('int')+bright,0,255).astype('uint8') + img = color_adjust(img,ran=True) + # random_num = 15 + # for i in range(3): img[:,:,i]=np.clip(img[:,:,i].astype('int')+random.randint(-random_num,random_num),0,255).astype('uint8') + # bright = random.randint(-random_num*2,random_num*2) + # for i in range(3): img[:,:,i]=np.clip(img[:,:,i].astype('int')+bright,0,255).astype('uint8') #random flip if random.random()<0.5: @@ -134,7 +148,7 @@ def random_transform_image(img,mask,finesize): def showresult(img1,img2,img3,name): size = img1.shape[3] showimg=np.zeros((size,size*3,3)) - showimg[0:size,0:size] = tensor2im(img1,rgb2bgr = False, is0_1 = True) - showimg[0:size,size:size*2] = tensor2im(img2,rgb2bgr = False, is0_1 = True) - showimg[0:size,size*2:size*3] = tensor2im(img3,rgb2bgr = False, is0_1 = True) + showimg[0:size,0:size] = tensor2im(img1,rgb2bgr = False, is0_1 = False) + showimg[0:size,size:size*2] = tensor2im(img2,rgb2bgr = False, is0_1 = False) + showimg[0:size,size*2:size*3] = tensor2im(img3,rgb2bgr = False, is0_1 = False) cv2.imwrite(name, showimg) diff --git a/util/ffmpeg.py b/util/ffmpeg.py index 76d254e..c4fd88a 100755 --- a/util/ffmpeg.py +++ b/util/ffmpeg.py @@ -45,4 +45,4 @@ def continuous_screenshot(videopath,savedir,fps): fps: save how many images per second ''' videoname = os.path.splitext(os.path.basename(videopath))[0] - os.system('ffmpeg -i '+videopath+' -vf fps='+str(fps)+' '+savedir+'/'+videoname+'%05d.jpg') + os.system('ffmpeg -i '+videopath+' -vf fps='+str(fps)+' '+savedir+'/'+videoname+'_%05d.jpg') diff --git a/util/image_processing.py b/util/image_processing.py index f945f7f..d03e771 100755 --- a/util/image_processing.py +++ b/util/image_processing.py @@ -1,5 +1,6 @@ import cv2 import numpy as np +import random def imread(file_path,mod = 'normal'): ''' @@ -37,6 +38,35 @@ def ch_one2three(img): res = cv2.merge([img, img, img]) return res +def color_adjust(img,alpha=1,beta=0,b=0,g=0,r=0,ran = False): + ''' + g(x) = (1+α)g(x)+255*β, + g(x) = g(x[:+b*255,:+g*255,:+r*255]) + + Args: + img : input image + alpha : contrast + beta : brightness + b : blue hue + g : green hue + r : red hue + ran : if True, randomly generated color correction parameters + Retuens: + img : output image + ''' + img = img.astype('float') + if ran: + alpha = random.uniform(-0.2,0.2) + beta = random.uniform(-0.2,0.2) + b = random.uniform(-0.1,0.1) + g = random.uniform(-0.1,0.1) + r = random.uniform(-0.1,0.1) + img = (1+alpha)*img+255.0*beta + bgr = [b*255.0,g*255.0,r*255.0] + for i in range(3): img[:,:,i]=img[:,:,i]+bgr[i] + + return (np.clip(img,0,255)).astype('uint8') + def makedataset(target_image,orgin_image): target_image = resize(target_image,256) orgin_image = resize(orgin_image,256) -- GitLab