提交 f51c8e2c 编写于 作者: H hypox64

Multiprocessing when loading train data

上级 84f6eb31
...@@ -182,4 +182,5 @@ nohup.out ...@@ -182,4 +182,5 @@ nohup.out
*.JPG *.JPG
*.MP4 *.MP4
*.JPEG *.JPEG
*.exe *.exe
\ No newline at end of file *.npy
\ No newline at end of file
import os import os
import time
import numpy as np import numpy as np
import cv2 import cv2
...@@ -15,7 +16,7 @@ def video_init(opt,path): ...@@ -15,7 +16,7 @@ def video_init(opt,path):
if opt.fps !=0: if opt.fps !=0:
fps = opt.fps fps = opt.fps
ffmpeg.video2voice(path,'./tmp/voice_tmp.mp3') ffmpeg.video2voice(path,'./tmp/voice_tmp.mp3')
ffmpeg.video2image(path,'./tmp/video2image/output_%05d.'+opt.tempimage_type,fps) ffmpeg.video2image(path,'./tmp/video2image/output_%06d.'+opt.tempimage_type,fps)
imagepaths=os.listdir('./tmp/video2image') imagepaths=os.listdir('./tmp/video2image')
imagepaths.sort() imagepaths.sort()
return fps,imagepaths,height,width return fps,imagepaths,height,width
...@@ -41,21 +42,26 @@ def addmosaic_video(opt,netS): ...@@ -41,21 +42,26 @@ def addmosaic_video(opt,netS):
mask,x,y,size,area = runmodel.get_ROI_position(img,netS,opt) mask,x,y,size,area = runmodel.get_ROI_position(img,netS,opt)
positions.append([x,y,area]) positions.append([x,y,area])
cv2.imwrite(os.path.join('./tmp/ROI_mask',imagepath),mask) cv2.imwrite(os.path.join('./tmp/ROI_mask',imagepath),mask)
print('\r','Find ROI location:'+str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='') print('Find ROI location:')
print('\r',str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='')
print('\nOptimize ROI locations...') print('\nOptimize ROI locations...')
mask_index = filt.position_medfilt(np.array(positions), 7) mask_index = filt.position_medfilt(np.array(positions), 7)
# add mosaic # add mosaic
print('Add Mosaic:')
for i in range(len(imagepaths)): for i in range(len(imagepaths)):
mask = impro.imread(os.path.join('./tmp/ROI_mask',imagepaths[mask_index[i]]),'gray') mask = impro.imread(os.path.join('./tmp/ROI_mask',imagepaths[mask_index[i]]),'gray')
img = impro.imread(os.path.join('./tmp/video2image',imagepaths[i])) img = impro.imread(os.path.join('./tmp/video2image',imagepaths[i]))
if impro.mask_area(mask)>100: if impro.mask_area(mask)>100:
img = mosaic.addmosaic(img, mask, opt) try:#Avoid unknown errors
img = mosaic.addmosaic(img, mask, opt)
except Exception as e:
print('Warning:',e)
cv2.imwrite(os.path.join('./tmp/addmosaic_image',imagepaths[i]),img) cv2.imwrite(os.path.join('./tmp/addmosaic_image',imagepaths[i]),img)
print('\r','Add Mosaic:'+str(i+1)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='') print('\r',str(i+1)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='')
print() print()
ffmpeg.image2video( fps, ffmpeg.image2video( fps,
'./tmp/addmosaic_image/output_%05d.'+opt.tempimage_type, './tmp/addmosaic_image/output_%06d.'+opt.tempimage_type,
'./tmp/voice_tmp.mp3', './tmp/voice_tmp.mp3',
os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_add.mp4')) os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_add.mp4'))
...@@ -73,16 +79,16 @@ def styletransfer_video(opt,netG): ...@@ -73,16 +79,16 @@ def styletransfer_video(opt,netG):
path = opt.media_path path = opt.media_path
positions = [] positions = []
fps,imagepaths = video_init(opt,path)[:2] fps,imagepaths = video_init(opt,path)[:2]
print('Transfer:')
for i,imagepath in enumerate(imagepaths,1): for i,imagepath in enumerate(imagepaths,1):
img = impro.imread(os.path.join('./tmp/video2image',imagepath)) img = impro.imread(os.path.join('./tmp/video2image',imagepath))
img = runmodel.run_styletransfer(opt, netG, img) img = runmodel.run_styletransfer(opt, netG, img)
cv2.imwrite(os.path.join('./tmp/style_transfer',imagepath),img) cv2.imwrite(os.path.join('./tmp/style_transfer',imagepath),img)
print('\r','Transfer:'+str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='') print('\r',str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='')
print() print()
suffix = os.path.basename(opt.model_path).replace('.pth','').replace('style_','') suffix = os.path.basename(opt.model_path).replace('.pth','').replace('style_','')
ffmpeg.image2video( fps, ffmpeg.image2video( fps,
'./tmp/style_transfer/output_%05d.'+opt.tempimage_type, './tmp/style_transfer/output_%06d.'+opt.tempimage_type,
'./tmp/voice_tmp.mp3', './tmp/voice_tmp.mp3',
os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_'+suffix+'.mp4')) os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_'+suffix+'.mp4'))
...@@ -92,16 +98,27 @@ def styletransfer_video(opt,netG): ...@@ -92,16 +98,27 @@ def styletransfer_video(opt,netG):
def get_mosaic_positions(opt,netM,imagepaths,savemask=True): def get_mosaic_positions(opt,netM,imagepaths,savemask=True):
# get mosaic position # get mosaic position
positions = [] positions = []
t1 = time.time()
if not opt.no_preview:
cv2.namedWindow('mosaic mask', cv2.WINDOW_NORMAL)
print('Find mosaic location:')
for i,imagepath in enumerate(imagepaths,1): for i,imagepath in enumerate(imagepaths,1):
img_origin = impro.imread(os.path.join('./tmp/video2image',imagepath)) img_origin = impro.imread(os.path.join('./tmp/video2image',imagepath))
x,y,size,mask = runmodel.get_mosaic_position(img_origin,netM,opt) x,y,size,mask = runmodel.get_mosaic_position(img_origin,netM,opt)
if not opt.no_preview:
cv2.imshow('mosaic mask',mask)
cv2.waitKey(1) & 0xFF
if savemask: if savemask:
cv2.imwrite(os.path.join('./tmp/mosaic_mask',imagepath), mask) cv2.imwrite(os.path.join('./tmp/mosaic_mask',imagepath), mask)
positions.append([x,y,size]) positions.append([x,y,size])
print('\r','Find mosaic location:'+str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='') t2 = time.time()
print('\r',str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),util.counttime(t1,t2,i,len(imagepaths)),end='')
if not opt.no_preview:
cv2.destroyAllWindows()
print('\nOptimize mosaic locations...') print('\nOptimize mosaic locations...')
positions =np.array(positions) positions =np.array(positions)
for i in range(3):positions[:,i] = filt.medfilt(positions[:,i],opt.medfilt_num) for i in range(3):positions[:,i] = filt.medfilt(positions[:,i],opt.medfilt_num)
np.save('./positions.npy', positions)
return positions return positions
def cleanmosaic_img(opt,netG,netM): def cleanmosaic_img(opt,netG,netM):
...@@ -112,7 +129,7 @@ def cleanmosaic_img(opt,netG,netM): ...@@ -112,7 +129,7 @@ def cleanmosaic_img(opt,netG,netM):
x,y,size,mask = runmodel.get_mosaic_position(img_origin,netM,opt) x,y,size,mask = runmodel.get_mosaic_position(img_origin,netM,opt)
cv2.imwrite('./mask/'+os.path.basename(path), mask) cv2.imwrite('./mask/'+os.path.basename(path), mask)
img_result = img_origin.copy() img_result = img_origin.copy()
if size != 0 : if size > 100 :
img_mosaic = img_origin[y-size:y+size,x-size:x+size] img_mosaic = img_origin[y-size:y+size,x-size:x+size]
if opt.traditional: if opt.traditional:
img_fake = runmodel.traditional_cleaner(img_mosaic,opt) img_fake = runmodel.traditional_cleaner(img_mosaic,opt)
...@@ -127,24 +144,40 @@ def cleanmosaic_video_byframe(opt,netG,netM): ...@@ -127,24 +144,40 @@ def cleanmosaic_video_byframe(opt,netG,netM):
path = opt.media_path path = opt.media_path
fps,imagepaths = video_init(opt,path)[:2] fps,imagepaths = video_init(opt,path)[:2]
positions = get_mosaic_positions(opt,netM,imagepaths,savemask=True) positions = get_mosaic_positions(opt,netM,imagepaths,savemask=True)
t1 = time.time()
if not opt.no_preview:
cv2.namedWindow('clean', cv2.WINDOW_NORMAL)
# clean mosaic # clean mosaic
print('Clean Mosaic:')
for i,imagepath in enumerate(imagepaths,0): for i,imagepath in enumerate(imagepaths,0):
x,y,size = positions[i][0],positions[i][1],positions[i][2] x,y,size = positions[i][0],positions[i][1],positions[i][2]
img_origin = impro.imread(os.path.join('./tmp/video2image',imagepath)) img_origin = impro.imread(os.path.join('./tmp/video2image',imagepath))
img_result = img_origin.copy() img_result = img_origin.copy()
if size != 0: if size > 100:
img_mosaic = img_origin[y-size:y+size,x-size:x+size] try:#Avoid unknown errors
if opt.traditional: img_mosaic = img_origin[y-size:y+size,x-size:x+size]
img_fake = runmodel.traditional_cleaner(img_mosaic,opt) if opt.traditional:
else: img_fake = runmodel.traditional_cleaner(img_mosaic,opt)
img_fake = runmodel.run_pix2pix(img_mosaic,netG,opt) else:
mask = cv2.imread(os.path.join('./tmp/mosaic_mask',imagepath),0) img_fake = runmodel.run_pix2pix(img_mosaic,netG,opt)
img_result = impro.replace_mosaic(img_origin,img_fake,mask,x,y,size,opt.no_feather) mask = cv2.imread(os.path.join('./tmp/mosaic_mask',imagepath),0)
img_result = impro.replace_mosaic(img_origin,img_fake,mask,x,y,size,opt.no_feather)
except Exception as e:
print('Warning:',e)
cv2.imwrite(os.path.join('./tmp/replace_mosaic',imagepath),img_result) cv2.imwrite(os.path.join('./tmp/replace_mosaic',imagepath),img_result)
print('\r','Clean Mosaic:'+str(i+1)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='') #preview result
if not opt.no_preview:
cv2.imshow('clean',img_result)
cv2.waitKey(1) & 0xFF
t2 = time.time()
print('\r',str(i+1)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),util.counttime(t1,t2,i+1,len(imagepaths)),end='')
print() print()
if not opt.no_preview:
cv2.destroyAllWindows()
# to video
ffmpeg.image2video( fps, ffmpeg.image2video( fps,
'./tmp/replace_mosaic/output_%05d.'+opt.tempimage_type, './tmp/replace_mosaic/output_%06d.'+opt.tempimage_type,
'./tmp/voice_tmp.mp3', './tmp/voice_tmp.mp3',
os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_clean.mp4')) os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_clean.mp4'))
...@@ -159,6 +192,7 @@ def cleanmosaic_video_fusion(opt,netG,netM): ...@@ -159,6 +192,7 @@ def cleanmosaic_video_fusion(opt,netG,netM):
positions = get_mosaic_positions(opt,netM,imagepaths,savemask=True) positions = get_mosaic_positions(opt,netM,imagepaths,savemask=True)
# clean mosaic # clean mosaic
print('Clean Mosaic:')
img_pool = np.zeros((height,width,3*N), dtype='uint8') img_pool = np.zeros((height,width,3*N), dtype='uint8')
for i,imagepath in enumerate(imagepaths,0): for i,imagepath in enumerate(imagepaths,0):
x,y,size = positions[i][0],positions[i][1],positions[i][2] x,y,size = positions[i][0],positions[i][1],positions[i][2]
...@@ -172,24 +206,26 @@ def cleanmosaic_video_fusion(opt,netG,netM): ...@@ -172,24 +206,26 @@ def cleanmosaic_video_fusion(opt,netG,netM):
img_pool[:,:,0:(N-1)*3] = img_pool[:,:,3:N*3] img_pool[:,:,0:(N-1)*3] = img_pool[:,:,3:N*3]
img_pool[:,:,(N-1)*3:] = impro.imread(os.path.join('./tmp/video2image',imagepaths[np.clip(i+12,0,len(imagepaths)-1)])) img_pool[:,:,(N-1)*3:] = impro.imread(os.path.join('./tmp/video2image',imagepaths[np.clip(i+12,0,len(imagepaths)-1)]))
img_origin = img_pool[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3] img_origin = img_pool[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3]
img_result = img_origin.copy()
if size==0: # can not find mosaic,
cv2.imwrite(os.path.join('./tmp/replace_mosaic',imagepath),img_origin)
else:
mosaic_input = np.zeros((INPUT_SIZE,INPUT_SIZE,3*N+1), dtype='uint8') if size>100:
mosaic_input[:,:,0:N*3] = impro.resize(img_pool[y-size:y+size,x-size:x+size,:], INPUT_SIZE) try:#Avoid unknown errors
mask_input = impro.resize(mask,np.min(img_origin.shape[:2]))[y-size:y+size,x-size:x+size] #reshape to network input shape
mosaic_input[:,:,-1] = impro.resize(mask_input, INPUT_SIZE) mosaic_input = np.zeros((INPUT_SIZE,INPUT_SIZE,3*N+1), dtype='uint8')
mosaic_input[:,:,0:N*3] = impro.resize(img_pool[y-size:y+size,x-size:x+size,:], INPUT_SIZE)
mask_input = impro.resize(mask,np.min(img_origin.shape[:2]))[y-size:y+size,x-size:x+size]
mosaic_input[:,:,-1] = impro.resize(mask_input, INPUT_SIZE)
mosaic_input = data.im2tensor(mosaic_input,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1 = False) mosaic_input = data.im2tensor(mosaic_input,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1 = False)
unmosaic_pred = netG(mosaic_input) unmosaic_pred = netG(mosaic_input)
img_fake = data.tensor2im(unmosaic_pred,rgb2bgr = False ,is0_1 = False) img_fake = data.tensor2im(unmosaic_pred,rgb2bgr = False ,is0_1 = False)
img_result = impro.replace_mosaic(img_origin,img_fake,mask,x,y,size,opt.no_feather) img_result = impro.replace_mosaic(img_origin,img_fake,mask,x,y,size,opt.no_feather)
cv2.imwrite(os.path.join('./tmp/replace_mosaic',imagepath),img_result) except Exception as e:
print('\r','Clean Mosaic:'+str(i+1)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='') print('Warning:',e)
cv2.imwrite(os.path.join('./tmp/replace_mosaic',imagepath),img_result)
print('\r',str(i+1)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='')
print() print()
ffmpeg.image2video( fps, ffmpeg.image2video( fps,
'./tmp/replace_mosaic/output_%05d.'+opt.tempimage_type, './tmp/replace_mosaic/output_%06d.'+opt.tempimage_type,
'./tmp/voice_tmp.mp3', './tmp/voice_tmp.mp3',
os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_clean.mp4')) os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_clean.mp4'))
\ No newline at end of file
...@@ -11,7 +11,6 @@ class Options(): ...@@ -11,7 +11,6 @@ class Options():
#base #base
self.parser.add_argument('--use_gpu',type=int,default=0, help='if -1, use cpu') self.parser.add_argument('--use_gpu',type=int,default=0, help='if -1, use cpu')
# self.parser.add_argument('--use_gpu', action='store_true', help='if input it, use gpu')
self.parser.add_argument('--media_path', type=str, default='./imgs/ruoruo.jpg',help='your videos or images path') self.parser.add_argument('--media_path', type=str, default='./imgs/ruoruo.jpg',help='your videos or images path')
self.parser.add_argument('--mode', type=str, default='auto',help='Program running mode. auto | add | clean | style') self.parser.add_argument('--mode', type=str, default='auto',help='Program running mode. auto | add | clean | style')
self.parser.add_argument('--model_path', type=str, default='./pretrained_models/mosaic/add_face.pth',help='pretrained model path') self.parser.add_argument('--model_path', type=str, default='./pretrained_models/mosaic/add_face.pth',help='pretrained model path')
...@@ -20,6 +19,7 @@ class Options(): ...@@ -20,6 +19,7 @@ class Options():
self.parser.add_argument('--netG', type=str, default='auto', self.parser.add_argument('--netG', type=str, default='auto',
help='select model to use for netG(Clean mosaic and Transfer style) -> auto | unet_128 | unet_256 | resnet_9blocks | HD | video') help='select model to use for netG(Clean mosaic and Transfer style) -> auto | unet_128 | unet_256 | resnet_9blocks | HD | video')
self.parser.add_argument('--fps', type=int, default=0,help='read and output fps, if 0-> origin') self.parser.add_argument('--fps', type=int, default=0,help='read and output fps, if 0-> origin')
self.parser.add_argument('--no_preview', action='store_true', help='if specified, do not preview images when processing video')
self.parser.add_argument('--output_size', type=int, default=0,help='size of output media, if 0 -> origin') self.parser.add_argument('--output_size', type=int, default=0,help='size of output media, if 0 -> origin')
self.parser.add_argument('--mask_threshold', type=int, default=64,help='threshold of recognize clean or add mosaic position 0~255') self.parser.add_argument('--mask_threshold', type=int, default=64,help='threshold of recognize clean or add mosaic position 0~255')
......
...@@ -60,7 +60,6 @@ def main(): ...@@ -60,7 +60,6 @@ def main():
util.clean_tempfiles(tmp_init = False) util.clean_tempfiles(tmp_init = False)
# main()
if __name__ == '__main__': if __name__ == '__main__':
try: try:
main() main()
......
...@@ -35,7 +35,7 @@ opt.parser.add_argument('--model',type=str,default='BiSeNet', help='BiSeNet or U ...@@ -35,7 +35,7 @@ opt.parser.add_argument('--model',type=str,default='BiSeNet', help='BiSeNet or U
opt.parser.add_argument('--maxepoch',type=int,default=100, help='') opt.parser.add_argument('--maxepoch',type=int,default=100, help='')
opt.parser.add_argument('--savefreq',type=int,default=5, help='') opt.parser.add_argument('--savefreq',type=int,default=5, help='')
opt.parser.add_argument('--maxload',type=int,default=1000000, help='') opt.parser.add_argument('--maxload',type=int,default=1000000, help='')
opt.parser.add_argument('--continuetrain', action='store_true', help='') opt.parser.add_argument('--continue_train', action='store_true', help='')
opt.parser.add_argument('--startepoch',type=int,default=0, help='') opt.parser.add_argument('--startepoch',type=int,default=0, help='')
opt.parser.add_argument('--dataset',type=str,default='./datasets/face/', help='') opt.parser.add_argument('--dataset',type=str,default='./datasets/face/', help='')
opt.parser.add_argument('--savename',type=str,default='face', help='') opt.parser.add_argument('--savename',type=str,default='face', help='')
...@@ -100,11 +100,11 @@ if opt.model =='UNet': ...@@ -100,11 +100,11 @@ if opt.model =='UNet':
elif opt.model =='BiSeNet': elif opt.model =='BiSeNet':
net = BiSeNet_model.BiSeNet(num_classes=1, context_path='resnet18') net = BiSeNet_model.BiSeNet(num_classes=1, context_path='resnet18')
if opt.continuetrain: if opt.continue_train:
if not os.path.isfile(os.path.join(dir_checkpoint,'last.pth')): if not os.path.isfile(os.path.join(dir_checkpoint,'last.pth')):
opt.continuetrain = False opt.continue_train = False
print('can not load last.pth, training on init weight.') print('can not load last.pth, training on init weight.')
if opt.continuetrain: if opt.continue_train:
net.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last.pth'))) net.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last.pth')))
f = open(os.path.join(dir_checkpoint,'epoch_log.txt'),'r') f = open(os.path.join(dir_checkpoint,'epoch_log.txt'),'r')
opt.startepoch = int(f.read()) opt.startepoch = int(f.read())
......
...@@ -11,6 +11,7 @@ import random ...@@ -11,6 +11,7 @@ import random
import torch import torch
import torch.nn as nn import torch.nn as nn
import time import time
from multiprocessing import Process, Queue
from util import mosaic,util,ffmpeg,filt,data from util import mosaic,util,ffmpeg,filt,data
from util import image_processing as impro from util import image_processing as impro
...@@ -32,17 +33,18 @@ opt.parser.add_argument('--lambda_gan',type=float,default=1, help='') ...@@ -32,17 +33,18 @@ opt.parser.add_argument('--lambda_gan',type=float,default=1, help='')
opt.parser.add_argument('--finesize',type=int,default=256, help='') opt.parser.add_argument('--finesize',type=int,default=256, help='')
opt.parser.add_argument('--loadsize',type=int,default=286, help='') opt.parser.add_argument('--loadsize',type=int,default=286, help='')
opt.parser.add_argument('--batchsize',type=int,default=1, help='') opt.parser.add_argument('--batchsize',type=int,default=1, help='')
opt.parser.add_argument('--perload_num',type=int,default=64, help='number of images pool')
opt.parser.add_argument('--norm',type=str,default='instance', help='') opt.parser.add_argument('--norm',type=str,default='instance', help='')
opt.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use') opt.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use')
opt.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') opt.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
opt.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') opt.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
opt.parser.add_argument('--image_pool',type=int,default=8, help='number of image load pool')
opt.parser.add_argument('--load_process',type=int,default=4, help='number of process for loading data')
opt.parser.add_argument('--dataset',type=str,default='./datasets/face/', help='') opt.parser.add_argument('--dataset',type=str,default='./datasets/face/', help='')
opt.parser.add_argument('--maxiter',type=int,default=10000000, help='') opt.parser.add_argument('--maxiter',type=int,default=10000000, help='')
opt.parser.add_argument('--savefreq',type=int,default=10000, help='') opt.parser.add_argument('--savefreq',type=int,default=10000, help='')
opt.parser.add_argument('--startiter',type=int,default=0, help='') opt.parser.add_argument('--startiter',type=int,default=0, help='')
opt.parser.add_argument('--continuetrain', action='store_true', help='') opt.parser.add_argument('--continue_train', action='store_true', help='')
opt.parser.add_argument('--savename',type=str,default='face', help='') opt.parser.add_argument('--savename',type=str,default='face', help='')
...@@ -89,13 +91,14 @@ if opt.gan: ...@@ -89,13 +91,14 @@ if opt.gan:
else: else:
netD = pix2pix_model.define_D(3*2, 64, 'basic', norm = opt.norm) netD = pix2pix_model.define_D(3*2, 64, 'basic', norm = opt.norm)
netD.cuda() netD.cuda()
netD.train()
#--------------------------continue train-------------------------- #--------------------------continue train--------------------------
if opt.continuetrain: if opt.continue_train:
if not os.path.isfile(os.path.join(dir_checkpoint,'last_G.pth')): if not os.path.isfile(os.path.join(dir_checkpoint,'last_G.pth')):
opt.continuetrain = False opt.continue_train = False
print('can not load last_G, training on init weight.') print('can not load last_G, training on init weight.')
if opt.continuetrain: if opt.continue_train:
netG.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_G.pth'))) netG.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_G.pth')))
if opt.gan: if opt.gan:
netD.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_D.pth'))) netD.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_D.pth')))
...@@ -111,7 +114,6 @@ if opt.gan: ...@@ -111,7 +114,6 @@ if opt.gan:
optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999)) optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999))
if opt.hd: if opt.hd:
criterionGAN = pix2pixHD_model.GANLoss(tensor=torch.cuda.FloatTensor).cuda() criterionGAN = pix2pixHD_model.GANLoss(tensor=torch.cuda.FloatTensor).cuda()
# criterionFeat = torch.nn.L1Loss()
criterionFeat = pix2pixHD_model.GAN_Feat_loss(opt) criterionFeat = pix2pixHD_model.GAN_Feat_loss(opt)
criterionVGG = pix2pixHD_model.VGGLoss([opt.use_gpu]) criterionVGG = pix2pixHD_model.VGGLoss([opt.use_gpu])
else: else:
...@@ -120,64 +122,27 @@ if opt.gan: ...@@ -120,64 +122,27 @@ if opt.gan:
''' '''
--------------------------preload data & data pool-------------------------- --------------------------preload data & data pool--------------------------
''' '''
# def loaddata(video_index):
# videoname = videonames[video_index]
# img_index = random.randint(int(N/2)+1,lengths[video_index]- int(N/2)-1)
# input_img = np.zeros((opt.loadsize,opt.loadsize,3*N+1), dtype='uint8')
# # this frame
# this_mask = impro.imread(os.path.join(opt.dataset,videoname,'mask','%05d'%(img_index)+'.png'),'gray',loadsize=opt.loadsize)
# input_img[:,:,-1] = this_mask
# #print(os.path.join(opt.dataset,videoname,'origin_image','%05d'%(img_index)+'.jpg'))
# ground_true = impro.imread(os.path.join(opt.dataset,videoname,'origin_image','%05d'%(img_index)+'.jpg'),loadsize=opt.loadsize)
# mosaic_size,mod,rect_rat,feather = mosaic.get_random_parameter(ground_true,this_mask)
# start_pos = mosaic.get_random_startpos(num=N,bisa_p=0.3,bisa_max=mosaic_size,bisa_max_part=3)
# # merge other frame
# for i in range(0,N):
# img = impro.imread(os.path.join(opt.dataset,videoname,'origin_image','%05d'%(img_index+i-int(N/2))+'.jpg'),loadsize=opt.loadsize)
# mask = impro.imread(os.path.join(opt.dataset,videoname,'mask','%05d'%(img_index+i-int(N/2))+'.png'),'gray',loadsize=opt.loadsize)
# img_mosaic = mosaic.addmosaic_base(img, mask, mosaic_size,model = mod,rect_rat=rect_rat,feather=feather,start_point=start_pos[i])
# input_img[:,:,i*3:(i+1)*3] = img_mosaic
# # to tensor
# input_img,ground_true = data.random_transform_video(input_img,ground_true,opt.finesize,N)
# input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=-1,use_transform = False,is0_1=False)
# ground_true = data.im2tensor(ground_true,bgr2rgb=False,use_gpu=-1,use_transform = False,is0_1=False)
# return input_img,ground_true
print('Preloading data, please wait...') print('Preloading data, please wait...')
def preload(pool):
if opt.perload_num <= opt.batchsize: cnt = 0
opt.perload_num = opt.batchsize*2 input_imgs = torch.rand(opt.batchsize,N*3+1,opt.finesize,opt.finesize)
#data pool ground_trues = torch.rand(opt.batchsize,3,opt.finesize,opt.finesize)
input_imgs = torch.rand(opt.perload_num,N*3+1,opt.finesize,opt.finesize)
ground_trues = torch.rand(opt.perload_num,3,opt.finesize,opt.finesize)
load_cnt = 0
def preload():
global load_cnt
while 1: while 1:
try: try:
video_index = random.randint(0,video_num-1) for i in range(opt.batchsize):
videoname = videonames[video_index] video_index = random.randint(0,video_num-1)
img_index = random.randint(int(N/2)+1,lengths[video_index]- int(N/2)-1) videoname = videonames[video_index]
input_imgs[load_cnt%opt.perload_num],ground_trues[load_cnt%opt.perload_num] = data.load_train_video(videoname,img_index,opt) img_index = random.randint(int(N/2)+1,lengths[video_index]- int(N/2)-1)
# input_imgs[load_cnt%opt.perload_num],ground_trues[load_cnt%opt.perload_num] = loaddata(video_index) input_imgs[i],ground_trues[i] = data.load_train_video(videoname,img_index,opt)
load_cnt += 1 cnt += 1
# time.sleep(0.1) pool.put([input_imgs,ground_trues])
except Exception as e: except Exception as e:
print("error:",e) print("Error:",videoname,e)
import threading pool = Queue(opt.image_pool)
t = threading.Thread(target=preload,args=()) for i in range(opt.load_process):
t.daemon = True p = Process(target=preload,args=(pool,))
t.start() p.daemon = True
time_start=time.time() p.start()
while load_cnt < opt.perload_num:
time.sleep(0.1)
time_end=time.time()
util.writelog(os.path.join(dir_checkpoint,'loss.txt'),
'load speed: '+str(round((time_end-time_start)/(opt.perload_num),3))+' s/it',True)
''' '''
--------------------------train-------------------------- --------------------------train--------------------------
...@@ -185,14 +150,12 @@ util.writelog(os.path.join(dir_checkpoint,'loss.txt'), ...@@ -185,14 +150,12 @@ util.writelog(os.path.join(dir_checkpoint,'loss.txt'),
util.copyfile('./train.py', os.path.join(dir_checkpoint,'train.py')) util.copyfile('./train.py', os.path.join(dir_checkpoint,'train.py'))
util.copyfile('../../models/videoHD_model.py', os.path.join(dir_checkpoint,'model.py')) util.copyfile('../../models/videoHD_model.py', os.path.join(dir_checkpoint,'model.py'))
netG.train() netG.train()
netD.train()
time_start=time.time() time_start=time.time()
print("Begin training...") print("Begin training...")
for iter in range(opt.startiter+1,opt.maxiter): for iter in range(opt.startiter+1,opt.maxiter):
ran = random.randint(0, opt.perload_num-opt.batchsize) inputdata,target = pool.get()
inputdata = (input_imgs[ran:ran+opt.batchsize].clone()).cuda() inputdata,target = inputdata.cuda(),target.cuda()
target = (ground_trues[ran:ran+opt.batchsize].clone()).cuda()
if opt.gan: if opt.gan:
# compute fake images: G(A) # compute fake images: G(A)
...@@ -226,17 +189,6 @@ for iter in range(opt.startiter+1,opt.maxiter): ...@@ -226,17 +189,6 @@ for iter in range(opt.startiter+1,opt.maxiter):
fake_AB = torch.cat((real_A, pred), 1) fake_AB = torch.cat((real_A, pred), 1)
pred_fake = netD(fake_AB) pred_fake = netD(fake_AB)
loss_G_GAN = criterionGAN(pred_fake, True)*opt.lambda_gan loss_G_GAN = criterionGAN(pred_fake, True)*opt.lambda_gan
# GAN feature matching loss
# if opt.hd:
# real_AB = torch.cat((real_A, target), 1)
# pred_real = netD(real_AB)
# loss_G_GAN_Feat=criterionFeat(pred_fake,pred_real)
# loss_G_GAN_Feat = 0
# feat_weights = 4.0 / (opt.n_layers_D + 1)
# D_weights = 1.0 / opt.num_D
# for i in range(opt.num_D):
# for j in range(len(pred_fake[i])-1):
# loss_G_GAN_Feat += D_weights * feat_weights * criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * opt.lambda_feat
# combine loss and calculate gradients # combine loss and calculate gradients
if opt.l2: if opt.l2:
...@@ -273,42 +225,33 @@ for iter in range(opt.startiter+1,opt.maxiter): ...@@ -273,42 +225,33 @@ for iter in range(opt.startiter+1,opt.maxiter):
loss_G_L1.backward() loss_G_L1.backward()
optimizer_G.step() optimizer_G.step()
# save eval result # save train result
if (iter+1)%1000 == 0: if (iter+1)%1000 == 0:
video_index = random.randint(0,video_num-1)
videoname = videonames[video_index]
img_index = random.randint(int(N/2)+1,lengths[video_index]- int(N/2)-1)
inputdata,target = data.load_train_video(videoname, img_index, opt)
# inputdata,target = loaddata(random.randint(0,video_num-1))
inputdata,target = inputdata.cuda(),target.cuda()
with torch.no_grad():
pred = netG(inputdata)
try: try:
data.showresult(inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], data.showresult(inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:],
target, pred, os.path.join(dir_checkpoint,'result_eval.jpg')) target, pred, os.path.join(dir_checkpoint,'result_train.jpg'))
except Exception as e: except Exception as e:
print(e) print(e)
# plot # plot
if (iter+1)%1000 == 0: if (iter+1)%1000 == 0:
time_end = time.time() time_end = time.time()
if opt.gan: #if opt.gan:
savestr ='iter:{0:d} L1_loss:{1:.3f} GAN_loss:{2:.3f} Feat:{3:.3f} VGG:{4:.3f} time:{5:.2f}'.format( savestr ='iter:{0:d} L1_loss:{1:.3f} GAN_loss:{2:.3f} Feat:{3:.3f} VGG:{4:.3f} time:{5:.2f}'.format(
iter+1,loss_sum[0]/1000,loss_sum[1]/1000,loss_sum[2]/1000,loss_sum[3]/1000,(time_end-time_start)/1000) iter+1,loss_sum[0]/1000,loss_sum[1]/1000,loss_sum[2]/1000,loss_sum[3]/1000,(time_end-time_start)/1000)
util.writelog(os.path.join(dir_checkpoint,'loss.txt'), savestr,True) util.writelog(os.path.join(dir_checkpoint,'loss.txt'), savestr,True)
if (iter+1)/1000 >= 10: if (iter+1)/1000 >= 10:
for i in range(4):loss_plot[i].append(loss_sum[i]/1000) for i in range(4):loss_plot[i].append(loss_sum[i]/1000)
item_plot.append(iter+1) item_plot.append(iter+1)
try: try:
labels = ['L1_loss','GAN_loss','GAN_Feat_loss','VGG_loss'] labels = ['L1_loss','GAN_loss','GAN_Feat_loss','VGG_loss']
for i in range(4):plt.plot(item_plot,loss_plot[i],label=labels[i]) for i in range(4):plt.plot(item_plot,loss_plot[i],label=labels[i])
plt.xlabel('iter') plt.xlabel('iter')
plt.legend(loc=1) plt.legend(loc=1)
plt.savefig(os.path.join(dir_checkpoint,'loss.jpg')) plt.savefig(os.path.join(dir_checkpoint,'loss.jpg'))
plt.close() plt.close()
except Exception as e: except Exception as e:
print("error:",e) print("error:",e)
loss_sum = [0.,0.,0.,0.,0.,0.] loss_sum = [0.,0.,0.,0.,0.,0.]
time_start=time.time() time_start=time.time()
...@@ -362,4 +305,4 @@ for iter in range(opt.startiter+1,opt.maxiter): ...@@ -362,4 +305,4 @@ for iter in range(opt.startiter+1,opt.maxiter):
result[opt.finesize:opt.finesize*2,opt.finesize*cnt:opt.finesize*(cnt+1),:] = pred result[opt.finesize:opt.finesize*2,opt.finesize*cnt:opt.finesize*(cnt+1),:] = pred
cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.jpg'), result) cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.jpg'), result)
netG.train() netG.train()
\ No newline at end of file
...@@ -237,8 +237,6 @@ def replace_mosaic(img_origin,img_fake,mask,x,y,size,no_feather): ...@@ -237,8 +237,6 @@ def replace_mosaic(img_origin,img_fake,mask,x,y,size,no_feather):
eclosion_num = int(size/5) eclosion_num = int(size/5)
entad = int(eclosion_num/2+2) entad = int(eclosion_num/2+2)
# mask = np.zeros(img_origin.shape, dtype='uint8')
# mask = cv2.rectangle(mask,(x-size+entad,y-size+entad),(x+size-entad,y+size-entad),(255,255,255),-1)
mask = cv2.resize(mask,(img_origin.shape[1],img_origin.shape[0])) mask = cv2.resize(mask,(img_origin.shape[1],img_origin.shape[0]))
mask = ch_one2three(mask) mask = ch_one2three(mask)
......
...@@ -62,7 +62,8 @@ def makedirs(path): ...@@ -62,7 +62,8 @@ def makedirs(path):
print('makedir:',path) print('makedir:',path)
def clean_tempfiles(tmp_init=True): def clean_tempfiles(tmp_init=True):
if os.path.isdir('./tmp'): if os.path.isdir('./tmp'):
print('Clean temp...')
shutil.rmtree('./tmp') shutil.rmtree('./tmp')
if tmp_init: if tmp_init:
os.makedirs('./tmp') os.makedirs('./tmp')
...@@ -86,9 +87,16 @@ def second2stamp(s): ...@@ -86,9 +87,16 @@ def second2stamp(s):
s = int(s%3600) s = int(s%3600)
m = int(s/60) m = int(s/60)
s = int(s%60) s = int(s%60)
return "%02d:%02d:%02d" % (h, m, s) return "%02d:%02d:%02d" % (h, m, s)
def counttime(t1,t2,now_num,all_num):
'''
t1,t2: time.time()
'''
used_time = int(t2-t1)
all_time = int(used_time/now_num*all_num)
return second2stamp(used_time)+'/'+second2stamp(all_time)
def get_bar(percent,num = 25): def get_bar(percent,num = 25):
bar = '[' bar = '['
for i in range(num): for i in range(num):
...@@ -97,7 +105,7 @@ def get_bar(percent,num = 25): ...@@ -97,7 +105,7 @@ def get_bar(percent,num = 25):
else: else:
bar += '-' bar += '-'
bar += ']' bar += ']'
return bar+' '+str(round(percent,2))+'%' return bar+' '+"%.2f"%percent+'%'
def copyfile(src,dst): def copyfile(src,dst):
try: try:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册