提交 5d31ba1c 编写于 作者: H hypox64

add train part

上级 36383527
import os
import numpy as np
import cv2
import random
import sys
sys.path.append("..")
from models import runmodel,loadmodel
from util import mosaic,util,ffmpeg,filt
from util import image_processing as impro
from options import Options
opt = Options().getparse()
util.file_init(opt)
videos = os.listdir('./video')
videos.sort()
opt.model_path = '../pretrained_models/add_youknow_128.pth'
opt.use_gpu = True
net = loadmodel.unet(opt)
for path in videos:
path = os.path.join('./video',path)
util.clean_tempfiles()
ffmpeg.video2voice(path,'./tmp/voice_tmp.mp3')
ffmpeg.video2image(path,'./tmp/video2image/output_%05d.'+opt.tempimage_type)
imagepaths=os.listdir('./tmp/video2image')
imagepaths.sort()
# get position
positions = []
img_ori_example = impro.imread(os.path.join('./tmp/video2image',imagepaths[0]))
mask_avg = np.zeros((impro.resize(img_ori_example, 128)).shape[:2])
for imagepath in imagepaths:
imagepath = os.path.join('./tmp/video2image',imagepath)
print('Find ROI location:',imagepath)
img = impro.imread(imagepath)
x,y,size,mask = runmodel.get_mosaic_position(img,net,opt,threshold = 64)
cv2.imwrite(os.path.join('./tmp/ROI_mask',
os.path.basename(imagepath)),mask)
positions.append([x,y,size])
mask_avg = mask_avg + mask
print('Optimize ROI locations...')
mask_index = filt.position_medfilt(np.array(positions), 13)
mask = np.clip(mask_avg/len(imagepaths),0,255).astype('uint8')
mask = impro.mask_threshold(mask,20,32)
x,y,size,area = impro.boundingSquare(mask,Ex_mul=1.5)
rat = min(img_ori_example.shape[:2])/128.0
x,y,size = int(rat*x),int(rat*y),int(rat*size)
cv2.imwrite(os.path.join('./tmp/ROI_mask_check',
'test_show.png'),mask)
if size !=0 :
mask_path = './dataset/'+os.path.splitext(os.path.basename(path))[0]+'/mask'
ori_path = './dataset/'+os.path.splitext(os.path.basename(path))[0]+'/ori'
mosaic_path = './dataset/'+os.path.splitext(os.path.basename(path))[0]+'/mosaic'
os.makedirs('./dataset/'+os.path.splitext(os.path.basename(path))[0]+'')
os.makedirs(mask_path)
os.makedirs(ori_path)
os.makedirs(mosaic_path)
print('Add mosaic to images...')
mosaic_size = mosaic.get_autosize(img_ori_example,mask,area_type = 'bounding')*random.uniform(1,2)
models = ['squa_avg','rect_avg','squa_mid']
mosaic_type = random.randint(0,len(models)-1)
rect_rat = random.uniform(1.2,1.6)
for i in range(len(imagepaths)):
mask = impro.imread(os.path.join('./tmp/ROI_mask',imagepaths[mask_index[i]]))
img_ori = impro.imread(os.path.join('./tmp/video2image',imagepaths[i]))
img_mosaic = mosaic.addmosaic_normal(img_ori,mask,mosaic_size,model = models[mosaic_type],rect_rat=rect_rat)
mask = impro.resize(mask, min(img_ori.shape[:2]))
img_ori_crop = impro.resize(img_ori[y-size:y+size,x-size:x+size],256)
img_mosaic_crop = impro.resize(img_mosaic[y-size:y+size,x-size:x+size],256)
mask_crop = impro.resize(mask[y-size:y+size,x-size:x+size],256)
cv2.imwrite(os.path.join(ori_path,os.path.basename(imagepaths[i])),img_ori_crop)
cv2.imwrite(os.path.join(mosaic_path,os.path.basename(imagepaths[i])),img_mosaic_crop)
cv2.imwrite(os.path.join(mask_path,os.path.basename(imagepaths[i])),mask_crop)
\ No newline at end of file
import numpy as np
import cv2
import os
from torchvision import transforms
from PIL import Image
import random
import sys
sys.path.append("..")
import util.image_processing as impro
from util import util,mosaic
import datetime
ir_mask_path = './Irregular_Holes_mask'
# img_path = 'D:/MyProject_new/face_512'
img_path ='/media/hypo/Hypoyun/Hypoyun/手机摄影/20190219'
output_dir = './datasets'
util.makedirs(output_dir)
HD = True #if false make dataset for pix2pix, if Ture for pix2pix_HD
MASK = True
if HD:
train_A_path = os.path.join(output_dir,'train_A')
train_B_path = os.path.join(output_dir,'train_B')
util.makedirs(train_A_path)
util.makedirs(train_B_path)
else:
train_path = os.path.join(output_dir,'train')
util.makedirs(train_path)
if MASK:
mask_path = os.path.join(output_dir,'mask')
util.makedirs(mask_path)
transform_mask = transforms.Compose([
transforms.RandomResizedCrop(size=512, scale=(0.5,1)),
transforms.RandomHorizontalFlip(),
])
transform_img = transforms.Compose([
transforms.Resize(512),
transforms.RandomCrop(512)
])
mask_names = os.listdir(ir_mask_path)
img_names = os.listdir(img_path)
print('Find images:',len(img_names))
for i,img_name in enumerate(img_names,1):
try:
img = Image.open(os.path.join(img_path,img_name))
img = transform_img(img)
img = np.array(img)
img = img[...,::-1]
mask = Image.open(os.path.join(ir_mask_path,random.choices(mask_names)[0]))
mask = transform_mask(mask)
mask = np.array(mask)
mosaic_img = mosaic.addmosaic_random(img, mask)
if HD:
cv2.imwrite(os.path.join(train_A_path,'%05d' % i+'.jpg'), mosaic_img)
cv2.imwrite(os.path.join(train_B_path,'%05d' % i+'.jpg'), img)
else:
merge_img = impro.makedataset(mosaic_img, img)
cv2.imwrite(os.path.join(train_path,'%05d' % i+'.jpg'), merge_img)
if MASK:
cv2.imwrite(os.path.join(mask_path,'%05d' % i+'.png'), mask)
print("Processing:",img_name," ","Remain:",len(img_names)-i)
except Exception as e:
print(img_name,e)
import os
import numpy as np
import cv2
import random
import torch
import torch.nn as nn
import time
import sys
sys.path.append("..")
from models import runmodel,loadmodel
from util import mosaic,util,ffmpeg,filt,data
from util import image_processing as impro
from cores import Options
from models import pix2pix_model
from matplotlib import pyplot as plt
import torch.backends.cudnn as cudnn
N = 25
ITER = 1000000
LR = 0.0002
use_gpu = True
CONTINUE = True
# BATCHSIZE = 4
dir_checkpoint = 'checkpoints/'
SAVE_FRE = 5000
start_iter = 0
SIZE = 256
lambda_L1 = 100.0
opt = Options().getparse()
opt.use_gpu=True
videos = os.listdir('./dataset')
videos.sort()
lengths = []
for video in videos:
video_images = os.listdir('./dataset/'+video+'/ori')
lengths.append(len(video_images))
netG = pix2pix_model.define_G(3*N+1, 3, 128, 'resnet_9blocks', norm='instance',use_dropout=True, init_type='normal', gpu_ids=[])
netD = pix2pix_model.define_D(3*2, 64, 'basic', n_layers_D=3, norm='instance', init_type='normal', init_gain=0.02, gpu_ids=[])
if CONTINUE:
netG.load_state_dict(torch.load(dir_checkpoint+'last_G.pth'))
netD.load_state_dict(torch.load(dir_checkpoint+'last_D.pth'))
f = open('./iter','r')
start_iter = int(f.read())
f.close()
if use_gpu:
netG.cuda()
netD.cuda()
cudnn.benchmark = True
optimizer_G = torch.optim.Adam(netG.parameters(), lr=LR)
optimizer_D = torch.optim.Adam(netG.parameters(), lr=LR)
criterion_L1 = nn.L1Loss()
criterion_L2 = nn.MSELoss()
criterionGAN = pix2pix_model.GANLoss('lsgan').cuda()
def showresult(img1,img2,img3,name):
img1 = (img1.cpu().detach().numpy()*255)
img2 = (img2.cpu().detach().numpy()*255)
img3 = (img3.cpu().detach().numpy()*255)
batchsize = img1.shape[0]
size = img1.shape[3]
ran =int(batchsize*random.random())
showimg=np.zeros((size,size*3,3))
showimg[0:size,0:size] =img1[ran].transpose((1, 2, 0))
showimg[0:size,size:size*2] = img2[ran].transpose((1, 2, 0))
showimg[0:size,size*2:size*3] = img3[ran].transpose((1, 2, 0))
cv2.imwrite(name, showimg)
def loaddata():
video_index = random.randint(0,len(videos)-1)
video = videos[video_index]
img_index = random.randint(N,lengths[video_index]- N)
input_img = np.zeros((SIZE,SIZE,3*N+1), dtype='uint8')
for i in range(0,N):
# print('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png')
img = cv2.imread('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png')
img = impro.resize(img,SIZE)
input_img[:,:,i*3:(i+1)*3] = img
mask = cv2.imread('./dataset/'+video+'/mask/output_'+'%05d'%(img_index)+'.png',0)
mask = impro.resize(mask,256)
mask = impro.mask_threshold(mask,15,128)
input_img[:,:,-1] = mask
input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False)
ground_true = cv2.imread('./dataset/'+video+'/ori/output_'+'%05d'%(img_index)+'.png')
ground_true = impro.resize(ground_true,SIZE)
# ground_true = im2tensor(ground_true,use_gpu)
ground_true = data.im2tensor(ground_true,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False)
return input_img,ground_true
input_imgs=[]
ground_trues=[]
def preload():
while 1:
input_img,ground_true = loaddata()
input_imgs.append(input_img)
ground_trues.append(ground_true)
if len(input_imgs)>10:
del(input_imgs[0])
del(ground_trues[0])
import threading
t=threading.Thread(target=preload,args=()) #t为新创建的线程
t.start()
time.sleep(3) #wait frist load
netG.train()
loss_sum = [0.,0.]
loss_plot = [[],[]]
item_plot = []
time_start=time.time()
print("Begin training...")
for iter in range(start_iter+1,ITER):
# input_img,ground_true = loaddata()
ran = random.randint(0, 9)
input_img = input_imgs[ran]
ground_true = ground_trues[ran]
pred = netG(input_img)
fake_AB = torch.cat((input_img[:,int((N+1)/2)*3:(int((N+1)/2)+1)*3,:,:], pred), 1)
pred_fake = netD(fake_AB.detach())
loss_D_fake = criterionGAN(pred_fake, False)
real_AB = torch.cat((input_img[:,int((N+1)/2)*3:(int((N+1)/2)+1)*3,:,:], ground_true), 1)
pred_real = netD(real_AB)
loss_D_real = criterionGAN(pred_real, True)
loss_D = (loss_D_fake + loss_D_real) * 0.5
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
fake_AB = torch.cat((input_img[:,int((N+1)/2)*3:(int((N+1)/2)+1)*3,:,:], pred), 1)
pred_fake = netD(fake_AB)
loss_G_GAN = criterionGAN(pred_fake, True)
# Second, G(A) = B
loss_G_L1 = criterion_L1(pred, ground_true) * lambda_L1
# combine loss and calculate gradients
loss_G = loss_G_GAN + loss_G_L1
loss_sum[0] += loss_G_L1.item()
loss_sum[1] += loss_G.item()
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()
# a = netD(ground_true)
# print(a.size())
# loss = criterion_L1(pred, ground_true)+criterion_L2(pred, ground_true)
# # loss = criterion_L2(pred, ground_true)
# loss_sum += loss.item()
# optimizer_G.zero_grad()
# loss.backward()
# optimizer_G.step()
if (iter+1)%100 == 0:
showresult(input_img[:,int((N+1)/2)*3:(int((N+1)/2)+1)*3,:,:], ground_true, pred,'./result_train.png')
if (iter+1)%100 == 0:
time_end=time.time()
print('iter:',iter+1,' L1_loss:', round(loss_sum[0]/100,4),'G_loss:', round(loss_sum[1]/100,4),'time:',round((time_end-time_start)/100,4))
if (iter+1)/100 >= 10:
loss_plot[0].append(loss_sum[0]/100)
loss_plot[1].append(loss_sum[1]/100)
item_plot.append(iter+1)
plt.plot(item_plot,loss_plot[0])
plt.plot(item_plot,loss_plot[1])
plt.savefig('./loss.png')
plt.close()
loss_sum = [0.,0.]
#show test result
# netG.eval()
# input_img = np.zeros((SIZE,SIZE,3*N), dtype='uint8')
# imgs = os.listdir('./test')
# for i in range(0,N):
# # print('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png')
# img = cv2.imread('./test/'+imgs[i])
# img = impro.resize(img,SIZE)
# input_img[:,:,i*3:(i+1)*3] = img
# input_img = im2tensor(input_img,use_gpu)
# ground_true = cv2.imread('./test/output_'+'%05d'%13+'.png')
# ground_true = impro.resize(ground_true,SIZE)
# ground_true = im2tensor(ground_true,use_gpu)
# pred = netG(input_img)
# showresult(input_img[:,int((N+1)/2)*3:(int((N+1)/2)+1)*3,:,:],pred,pred,'./result_test.png')
netG.train()
time_start=time.time()
if (iter+1)%SAVE_FRE == 0:
torch.save(netG.cpu().state_dict(),dir_checkpoint+'last_G.pth')
torch.save(netD.cpu().state_dict(),dir_checkpoint+'last_D.pth')
if use_gpu:
netG.cuda()
netD.cuda()
f = open('./iter','w+')
f.write(str(iter+1))
f.close()
# torch.save(netG.cpu().state_dict(),dir_checkpoint+'iter'+str(iter+1)+'.pth')
print('network saved.')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册