提交 29458f1b 编写于 作者: H hypox64

batch-training, modify model

上级 9aca31da
......@@ -4,14 +4,20 @@ import datetime
import os
import random
def resize(img,size):
h, w = img.shape[:2]
if w >= h:
res = cv2.resize(img,(int(size*w/h), size))
else:
res = cv2.resize(img,(size, int(size*h/w)))
return res
import sys
sys.path.append("..")
from util import util
from util import image_processing as impro
image_dir = './datasets_img/v2im'
mask_dir = './datasets_img/v2im_mask'
util.makedirs(mask_dir)
files = os.listdir(image_dir)
files_new =files.copy()
print('find image:',len(files))
masks = os.listdir(mask_dir)
print('mask:',len(masks))
# mouse callback function
drawing = False # true if mouse is pressed
......@@ -46,11 +52,7 @@ def makemask(img):
# print('Cost time:',(endtime-starttime))
return mask
files = os.listdir('./origin_image')
files_new =files.copy()
print('find image:',len(files))
masks = os.listdir('./mask')
print('mask:',len(masks))
for i in range(len(masks)):
masks[i]=masks[i].replace('.png','.jpg')
for file in files:
......@@ -59,14 +61,14 @@ for file in files:
files = files_new
# files = list(set(files)) #Distinct
print('remain:',len(files))
random.shuffle (files)
random.shuffle(files)
# files.sort()
cnt = 0
for file in files:
cnt += 1
img = cv2.imread('./origin_image/'+file)
img = resize(img,512)
img = cv2.imread(os.path.join(image_dir,file))
img = impro.resize(img,512)
cv2.namedWindow('image')
cv2.setMouseCallback('image',draw_circle) #MouseCallback
while(1):
......@@ -74,10 +76,10 @@ for file in files:
cv2.imshow('image',img)
k = cv2.waitKey(1) & 0xFF
if k == ord(' '):
img = resize(img,256)
img = impro.resize(img,256)
mask = makemask(img)
cv2.imwrite('./mask/'+os.path.splitext(file)[0]+'.png',mask)
print('./mask/'+os.path.splitext(file)[0]+'.png')
cv2.imwrite(os.path.join(mask_dir,os.path.splitext(file)[0]+'.png'),mask)
print(os.path.join(mask_dir,os.path.splitext(file)[0]+'.png'))
# cv2.destroyAllWindows()
print('remain:',len(files)-cnt)
brushsize = 20
......
......@@ -9,9 +9,10 @@ sys.path.append("..")
from util import util,ffmpeg
from util import image_processing as impro
files = util.Traversal('/media/hypo/Media/download')
files = util.Traversal('./videos')
videos = util.is_videos(files)
output_dir = './dataset/v2im'
output_dir = './datasets_img/v2im'
util.makedirs(output_dir)
FPS = 1
util.makedirs(output_dir)
for video in videos:
......
......@@ -16,7 +16,7 @@ MOD = 'HD' #HD | pix2pix | mosaic
MASK = False # if True, output mask,too
BOUNDING = True # if true the mosaic size will be more big
suffix = ''
output_dir = os.path.join('./dataset_img',MOD)
output_dir = os.path.join('./datasets_img',MOD)
util.makedirs(output_dir)
if MOD == 'HD':
......
......@@ -97,8 +97,8 @@ class decoder_2d(nn.Module):
nn.Conv2d(ngf * mult, int(ngf * mult / 2),kernel_size=3, stride=1, padding=0),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
# model += [nn.ReflectionPad2d(3)]
# model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
# model += [nn.Tanh()]
# model += [nn.Sigmoid()]
......@@ -123,6 +123,20 @@ class conv_3d(nn.Module):
x = self.conv(x)
return x
class conv_2d(nn.Module):
def __init__(self,inchannel,outchannel,kernel_size=3,stride=1,padding=1):
super(conv_2d, self).__init__()
self.conv = nn.Sequential(
nn.ReflectionPad2d(padding),
nn.Conv2d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=0, bias=False),
nn.BatchNorm2d(outchannel),
nn.ReLU(inplace=True),
)
def forward(self, x):
x = self.conv(x)
return x
class encoder_3d(nn.Module):
def __init__(self,in_channel):
......@@ -131,21 +145,22 @@ class encoder_3d(nn.Module):
self.down2 = conv_3d(64, 128, 3, 2, 1)
self.down3 = conv_3d(128, 256, 3, 1, 1)
self.conver2d = nn.Sequential(
nn.Conv2d(int(in_channel/4), 1, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(1),
nn.Conv2d(256*int(in_channel/4), 256, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
)
def forward(self, x):
x = x.view(x.size(0),1,x.size(1),x.size(2),x.size(3))
x = self.down1(x)
x = self.down2(x)
x = self.down3(x)
x = x.view(x.size(1),x.size(2),x.size(3),x.size(4))
x = x.view(x.size(0),x.size(1)*x.size(2),x.size(3),x.size(4))
x = self.conver2d(x)
x = x.view(x.size(1),x.size(0),x.size(2),x.size(3))
return x
......@@ -158,30 +173,29 @@ class MosaicNet(nn.Module):
self.encoder_2d = encoder_2d(4,-1,64,n_blocks=9)
self.encoder_3d = encoder_3d(in_channel)
self.decoder_2d = decoder_2d(4,3,64,n_blocks=9)
self.merge1 = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(512, 256, 3, 1, 0, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
)
self.shortcut_cov = conv_2d(3,64,7,1,3)
self.merge1 = conv_2d(512,256,3,1,1)
self.merge2 = nn.Sequential(
conv_2d(128,64,3,1,1),
nn.ReflectionPad2d(3),
nn.Conv2d(6, out_channel, kernel_size=7, padding=0),
nn.Sigmoid()
nn.Conv2d(64, out_channel, kernel_size=7, padding=0),
nn.Tanh()
)
def forward(self, x):
N = int((x.size()[1])/3)
x_2d = torch.cat((x[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], x[:,N-1:N,:,:]), 1)
shortcat_2d = x[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]
shortcut_2d = x[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]
x_2d = self.encoder_2d(x_2d)
x_3d = self.encoder_3d(x)
x = torch.cat((x_2d,x_3d),1)
x = self.merge1(x)
x = self.decoder_2d(x)
x = torch.cat((x,shortcat_2d),1)
shortcut_2d = self.shortcut_cov(shortcut_2d)
x = torch.cat((x,shortcut_2d),1)
x = self.merge2(x)
return x
......
......@@ -18,12 +18,12 @@ import torch.backends.cudnn as cudnn
N = 25
ITER = 10000000
LR = 0.0002
LR = 0.001
beta1 = 0.5
use_gpu = True
use_gan = False
use_L2 = False
CONTINUE = False
use_L2 = True
CONTINUE = True
lambda_L1 = 1.0#100.0
lambda_gan = 1.0
......@@ -31,8 +31,9 @@ SAVE_FRE = 10000
start_iter = 0
finesize = 128
loadsize = int(finesize*1.1)
batchsize = 8
perload_num = 32
savename = 'MosaicNet_noL2'
savename = 'MosaicNet_test'
dir_checkpoint = 'checkpoints/'+savename
util.makedirs(dir_checkpoint)
......@@ -97,25 +98,32 @@ def loaddata():
ground_true = impro.resize(ground_true,loadsize)
input_img,ground_true = data.random_transform_video(input_img,ground_true,finesize,N)
input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False)
ground_true = data.im2tensor(ground_true,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False)
input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1=False)
ground_true = data.im2tensor(ground_true,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1=False)
return input_img,ground_true
print('preloading data, please wait 5s...')
input_imgs=[]
ground_trues=[]
# input_imgs=[]
# ground_trues=[]
input_imgs = torch.rand(batchsize,N*3+1,finesize,finesize).cuda()
ground_trues = torch.rand(batchsize,3,finesize,finesize).cuda()
load_cnt = 0
def preload():
global load_cnt
while 1:
try:
input_img,ground_true = loaddata()
input_imgs.append(input_img)
ground_trues.append(ground_true)
if len(input_imgs)>perload_num:
del(input_imgs[0])
del(ground_trues[0])
# input_img,ground_true = loaddata()
# input_imgs.append(input_img)
# ground_trues.append(ground_true)
ran = random.randint(0, batchsize-1)
input_imgs[ran],ground_trues[ran] = loaddata()
# if len(input_imgs)>perload_num:
# del(input_imgs[0])
# del(ground_trues[0])
load_cnt += 1
# time.sleep(0.1)
except Exception as e:
......@@ -125,7 +133,7 @@ import threading
t = threading.Thread(target=preload,args=()) #t为新创建的线程
t.daemon = True
t.start()
while load_cnt < perload_num:
while load_cnt < batchsize*2:
time.sleep(0.1)
netG.train()
......@@ -133,23 +141,26 @@ time_start=time.time()
print("Begin training...")
for iter in range(start_iter+1,ITER):
# input_img,ground_true = loaddata()
ran = random.randint(1, perload_num-2)
input_img = input_imgs[ran]
ground_true = ground_trues[ran]
# inputdata,target = loaddata()
# ran = random.randint(1, perload_num-2)
# inputdata = inputdatas[ran]
# target = targets[ran]
inputdata = input_imgs.clone()
target = ground_trues.clone()
pred = netG(input_img)
pred = netG(inputdata)
if use_gan:
netD.train()
# print(input_img[0,3*N,:,:].size())
# print((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]).size())
real_A = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], input_img[:,-1,:,:].reshape(-1,1,finesize,finesize)), 1)
# print(inputdata[0,3*N,:,:].size())
# print((inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]).size())
real_A = torch.cat((inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], inputdata[:,-1,:,:].reshape(-1,1,finesize,finesize)), 1)
fake_AB = torch.cat((real_A, pred), 1)
pred_fake = netD(fake_AB.detach())
loss_D_fake = criterionGAN(pred_fake, False)
real_AB = torch.cat((real_A, ground_true), 1)
real_AB = torch.cat((real_A, target), 1)
pred_real = netD(real_AB)
loss_D_real = criterionGAN(pred_real, True)
loss_D = (loss_D_fake + loss_D_real) * 0.5
......@@ -161,16 +172,16 @@ for iter in range(start_iter+1,ITER):
optimizer_D.step()
netD.eval()
# fake_AB = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], pred), 1)
real_A = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], input_img[:,-1,:,:].reshape(-1,1,finesize,finesize)), 1)
# fake_AB = torch.cat((inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], pred), 1)
real_A = torch.cat((inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], inputdata[:,-1,:,:].reshape(-1,1,finesize,finesize)), 1)
fake_AB = torch.cat((real_A, pred), 1)
pred_fake = netD(fake_AB)
loss_G_GAN = criterionGAN(pred_fake, True)*lambda_gan
# Second, G(A) = B
if use_L2:
loss_G_L1 = (criterion_L1(pred, ground_true)+criterion_L2(pred, ground_true)) * lambda_L1
loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * lambda_L1
else:
loss_G_L1 = criterion_L1(pred, ground_true) * lambda_L1
loss_G_L1 = criterion_L1(pred, target) * lambda_L1
# combine loss and calculate gradients
loss_G = loss_G_GAN + loss_G_L1
loss_sum[0] += loss_G_L1.item()
......@@ -182,9 +193,9 @@ for iter in range(start_iter+1,ITER):
else:
if use_L2:
loss_G_L1 = (criterion_L1(pred, ground_true)+criterion_L2(pred, ground_true)) * lambda_L1
loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * lambda_L1
else:
loss_G_L1 = criterion_L1(pred, ground_true) * lambda_L1
loss_G_L1 = criterion_L1(pred, target) * lambda_L1
loss_sum[0] += loss_G_L1.item()
optimizer_G.zero_grad()
......@@ -194,8 +205,8 @@ for iter in range(start_iter+1,ITER):
if (iter+1)%100 == 0:
try:
data.showresult(input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:],
ground_true, pred,os.path.join(dir_checkpoint,'result_train.png'))
data.showresult(inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:],
target, pred,os.path.join(dir_checkpoint,'result_train.png'))
except Exception as e:
print(e)
......@@ -249,28 +260,29 @@ for iter in range(start_iter+1,ITER):
#test
netG.eval()
result = np.zeros((finesize*2,finesize*4,3), dtype='uint8')
test_names = os.listdir('./test')
result = np.zeros((finesize*2,finesize*len(test_names),3), dtype='uint8')
for cnt,test_name in enumerate(test_names,0):
img_names = os.listdir(os.path.join('./test',test_name,'image'))
img_names.sort()
input_img = np.zeros((finesize,finesize,3*N+1), dtype='uint8')
inputdata = np.zeros((finesize,finesize,3*N+1), dtype='uint8')
img_names.sort()
for i in range(0,N):
img = impro.imread(os.path.join('./test',test_name,'image',img_names[i]))
img = impro.resize(img,finesize)
input_img[:,:,i*3:(i+1)*3] = img
inputdata[:,:,i*3:(i+1)*3] = img
mask = impro.imread(os.path.join('./test',test_name,'mask.png'),'gray')
mask = impro.resize(mask,finesize)
mask = impro.mask_threshold(mask,15,128)
input_img[:,:,-1] = mask
result[0:finesize,finesize*cnt:finesize*(cnt+1),:] = input_img[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3]
input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False)
pred = netG(input_img)
inputdata[:,:,-1] = mask
result[0:finesize,finesize*cnt:finesize*(cnt+1),:] = inputdata[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3]
inputdata = data.im2tensor(inputdata,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1 = False)
pred = netG(inputdata)
pred = data.tensor2im(pred,rgb2bgr = False, is0_1 = True)
pred = data.tensor2im(pred,rgb2bgr = False, is0_1 = False)
result[finesize:finesize*2,finesize*cnt:finesize*(cnt+1),:] = pred
cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.png'), result)
......
......@@ -3,6 +3,7 @@ import numpy as np
import torch
import torchvision.transforms as transforms
import cv2
from .image_processing import color_adjust
transform = transforms.Compose([
transforms.ToTensor(),
......@@ -29,7 +30,7 @@ def tensor2im(image_tensor, imtype=np.uint8, gray=False, rgb2bgr = True ,is0_1 =
return image_numpy.astype(imtype)
def im2tensor(image_numpy, imtype=np.uint8, gray=False,bgr2rgb = True, reshape = True, use_gpu = True, use_transform = True):
def im2tensor(image_numpy, imtype=np.uint8, gray=False,bgr2rgb = True, reshape = True, use_gpu = True, use_transform = True,is0_1 = True):
if gray:
h, w = image_numpy.shape
......@@ -44,7 +45,10 @@ def im2tensor(image_numpy, imtype=np.uint8, gray=False,bgr2rgb = True, reshape =
if use_transform:
image_tensor = transform(image_numpy)
else:
image_numpy = image_numpy/255.0
if is0_1:
image_numpy = image_numpy/255.0
else:
image_numpy = (image_numpy/255.0-0.5)/0.5
image_numpy = image_numpy.transpose((2, 0, 1))
image_tensor = torch.from_numpy(image_numpy).float()
if reshape:
......@@ -70,10 +74,19 @@ def random_transform_video(src,target,finesize,N):
target = target[:,::-1,:]
#random color
random_num = 15
bright = random.randint(-random_num*2,random_num*2)
for i in range(N*3): src[:,:,i]=np.clip(src[:,:,i].astype('int')+bright,0,255).astype('uint8')
for i in range(3): target[:,:,i]=np.clip(target[:,:,i].astype('int')+bright,0,255).astype('uint8')
alpha = random.uniform(-0.2,0.2)
beta = random.uniform(-0.2,0.2)
b = random.uniform(-0.1,0.1)
g = random.uniform(-0.1,0.1)
r = random.uniform(-0.1,0.1)
for i in range(N):
src[:,:,i*3:(i+1)*3] = color_adjust(src[:,:,i*3:(i+1)*3],alpha,beta,b,g,r)
target = color_adjust(target,alpha,beta,b,g,r)
# random_num = 15
# bright = random.randint(-random_num*2,random_num*2)
# for i in range(N*3): src[:,:,i]=np.clip(src[:,:,i].astype('int')+bright,0,255).astype('uint8')
# for i in range(3): target[:,:,i]=np.clip(target[:,:,i].astype('int')+bright,0,255).astype('uint8')
return src,target
......@@ -116,10 +129,11 @@ def random_transform_image(img,mask,finesize):
img,mask = img_crop,mask_crop
#random color
random_num = 15
for i in range(3): img[:,:,i]=np.clip(img[:,:,i].astype('int')+random.randint(-random_num,random_num),0,255).astype('uint8')
bright = random.randint(-random_num*2,random_num*2)
for i in range(3): img[:,:,i]=np.clip(img[:,:,i].astype('int')+bright,0,255).astype('uint8')
img = color_adjust(img,ran=True)
# random_num = 15
# for i in range(3): img[:,:,i]=np.clip(img[:,:,i].astype('int')+random.randint(-random_num,random_num),0,255).astype('uint8')
# bright = random.randint(-random_num*2,random_num*2)
# for i in range(3): img[:,:,i]=np.clip(img[:,:,i].astype('int')+bright,0,255).astype('uint8')
#random flip
if random.random()<0.5:
......@@ -134,7 +148,7 @@ def random_transform_image(img,mask,finesize):
def showresult(img1,img2,img3,name):
size = img1.shape[3]
showimg=np.zeros((size,size*3,3))
showimg[0:size,0:size] = tensor2im(img1,rgb2bgr = False, is0_1 = True)
showimg[0:size,size:size*2] = tensor2im(img2,rgb2bgr = False, is0_1 = True)
showimg[0:size,size*2:size*3] = tensor2im(img3,rgb2bgr = False, is0_1 = True)
showimg[0:size,0:size] = tensor2im(img1,rgb2bgr = False, is0_1 = False)
showimg[0:size,size:size*2] = tensor2im(img2,rgb2bgr = False, is0_1 = False)
showimg[0:size,size*2:size*3] = tensor2im(img3,rgb2bgr = False, is0_1 = False)
cv2.imwrite(name, showimg)
......@@ -45,4 +45,4 @@ def continuous_screenshot(videopath,savedir,fps):
fps: save how many images per second
'''
videoname = os.path.splitext(os.path.basename(videopath))[0]
os.system('ffmpeg -i '+videopath+' -vf fps='+str(fps)+' '+savedir+'/'+videoname+'%05d.jpg')
os.system('ffmpeg -i '+videopath+' -vf fps='+str(fps)+' '+savedir+'/'+videoname+'_%05d.jpg')
import cv2
import numpy as np
import random
def imread(file_path,mod = 'normal'):
'''
......@@ -37,6 +38,35 @@ def ch_one2three(img):
res = cv2.merge([img, img, img])
return res
def color_adjust(img,alpha=1,beta=0,b=0,g=0,r=0,ran = False):
'''
g(x) = (1+α)g(x)+255*β,
g(x) = g(x[:+b*255,:+g*255,:+r*255])
Args:
img : input image
alpha : contrast
beta : brightness
b : blue hue
g : green hue
r : red hue
ran : if True, randomly generated color correction parameters
Retuens:
img : output image
'''
img = img.astype('float')
if ran:
alpha = random.uniform(-0.2,0.2)
beta = random.uniform(-0.2,0.2)
b = random.uniform(-0.1,0.1)
g = random.uniform(-0.1,0.1)
r = random.uniform(-0.1,0.1)
img = (1+alpha)*img+255.0*beta
bgr = [b*255.0,g*255.0,r*255.0]
for i in range(3): img[:,:,i]=img[:,:,i]+bgr[i]
return (np.clip(img,0,255)).astype('uint8')
def makedataset(target_image,orgin_image):
target_image = resize(target_image,256)
orgin_image = resize(orgin_image,256)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册