diff --git a/.gitignore b/.gitignore index cf5d361388641f9e888bae9a4902dbb0acc22dff..822f94d6b34e4d0ae0f53420a3502105e9492763 100644 --- a/.gitignore +++ b/.gitignore @@ -182,4 +182,5 @@ nohup.out *.JPG *.MP4 *.JPEG -*.exe \ No newline at end of file +*.exe +*.npy \ No newline at end of file diff --git a/cores/core.py b/cores/core.py index d48ae5156681b668c7df7c9fdf480be1a5b324e9..ee5b99d649618fccd70931beb15f089f3dfcc38d 100644 --- a/cores/core.py +++ b/cores/core.py @@ -1,4 +1,5 @@ import os +import time import numpy as np import cv2 @@ -15,7 +16,7 @@ def video_init(opt,path): if opt.fps !=0: fps = opt.fps ffmpeg.video2voice(path,'./tmp/voice_tmp.mp3') - ffmpeg.video2image(path,'./tmp/video2image/output_%05d.'+opt.tempimage_type,fps) + ffmpeg.video2image(path,'./tmp/video2image/output_%06d.'+opt.tempimage_type,fps) imagepaths=os.listdir('./tmp/video2image') imagepaths.sort() return fps,imagepaths,height,width @@ -41,21 +42,26 @@ def addmosaic_video(opt,netS): mask,x,y,size,area = runmodel.get_ROI_position(img,netS,opt) positions.append([x,y,area]) cv2.imwrite(os.path.join('./tmp/ROI_mask',imagepath),mask) - print('\r','Find ROI location:'+str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='') + print('Find ROI location:') + print('\r',str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='') print('\nOptimize ROI locations...') mask_index = filt.position_medfilt(np.array(positions), 7) # add mosaic + print('Add Mosaic:') for i in range(len(imagepaths)): mask = impro.imread(os.path.join('./tmp/ROI_mask',imagepaths[mask_index[i]]),'gray') img = impro.imread(os.path.join('./tmp/video2image',imagepaths[i])) - if impro.mask_area(mask)>100: - img = mosaic.addmosaic(img, mask, opt) + if impro.mask_area(mask)>100: + try:#Avoid unknown errors + img = mosaic.addmosaic(img, mask, opt) + except Exception as e: + print('Warning:',e) cv2.imwrite(os.path.join('./tmp/addmosaic_image',imagepaths[i]),img) - print('\r','Add Mosaic:'+str(i+1)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='') + print('\r',str(i+1)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='') print() ffmpeg.image2video( fps, - './tmp/addmosaic_image/output_%05d.'+opt.tempimage_type, + './tmp/addmosaic_image/output_%06d.'+opt.tempimage_type, './tmp/voice_tmp.mp3', os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_add.mp4')) @@ -73,16 +79,16 @@ def styletransfer_video(opt,netG): path = opt.media_path positions = [] fps,imagepaths = video_init(opt,path)[:2] - + print('Transfer:') for i,imagepath in enumerate(imagepaths,1): img = impro.imread(os.path.join('./tmp/video2image',imagepath)) img = runmodel.run_styletransfer(opt, netG, img) cv2.imwrite(os.path.join('./tmp/style_transfer',imagepath),img) - print('\r','Transfer:'+str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='') + print('\r',str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='') print() suffix = os.path.basename(opt.model_path).replace('.pth','').replace('style_','') ffmpeg.image2video( fps, - './tmp/style_transfer/output_%05d.'+opt.tempimage_type, + './tmp/style_transfer/output_%06d.'+opt.tempimage_type, './tmp/voice_tmp.mp3', os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_'+suffix+'.mp4')) @@ -92,16 +98,27 @@ def styletransfer_video(opt,netG): def get_mosaic_positions(opt,netM,imagepaths,savemask=True): # get mosaic position positions = [] + t1 = time.time() + if not opt.no_preview: + cv2.namedWindow('mosaic mask', cv2.WINDOW_NORMAL) + print('Find mosaic location:') for i,imagepath in enumerate(imagepaths,1): img_origin = impro.imread(os.path.join('./tmp/video2image',imagepath)) x,y,size,mask = runmodel.get_mosaic_position(img_origin,netM,opt) + if not opt.no_preview: + cv2.imshow('mosaic mask',mask) + cv2.waitKey(1) & 0xFF if savemask: cv2.imwrite(os.path.join('./tmp/mosaic_mask',imagepath), mask) positions.append([x,y,size]) - print('\r','Find mosaic location:'+str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='') + t2 = time.time() + print('\r',str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),util.counttime(t1,t2,i,len(imagepaths)),end='') + if not opt.no_preview: + cv2.destroyAllWindows() print('\nOptimize mosaic locations...') positions =np.array(positions) for i in range(3):positions[:,i] = filt.medfilt(positions[:,i],opt.medfilt_num) + np.save('./positions.npy', positions) return positions def cleanmosaic_img(opt,netG,netM): @@ -112,7 +129,7 @@ def cleanmosaic_img(opt,netG,netM): x,y,size,mask = runmodel.get_mosaic_position(img_origin,netM,opt) cv2.imwrite('./mask/'+os.path.basename(path), mask) img_result = img_origin.copy() - if size != 0 : + if size > 100 : img_mosaic = img_origin[y-size:y+size,x-size:x+size] if opt.traditional: img_fake = runmodel.traditional_cleaner(img_mosaic,opt) @@ -127,24 +144,40 @@ def cleanmosaic_video_byframe(opt,netG,netM): path = opt.media_path fps,imagepaths = video_init(opt,path)[:2] positions = get_mosaic_positions(opt,netM,imagepaths,savemask=True) + t1 = time.time() + if not opt.no_preview: + cv2.namedWindow('clean', cv2.WINDOW_NORMAL) + # clean mosaic + print('Clean Mosaic:') for i,imagepath in enumerate(imagepaths,0): x,y,size = positions[i][0],positions[i][1],positions[i][2] img_origin = impro.imread(os.path.join('./tmp/video2image',imagepath)) img_result = img_origin.copy() - if size != 0: - img_mosaic = img_origin[y-size:y+size,x-size:x+size] - if opt.traditional: - img_fake = runmodel.traditional_cleaner(img_mosaic,opt) - else: - img_fake = runmodel.run_pix2pix(img_mosaic,netG,opt) - mask = cv2.imread(os.path.join('./tmp/mosaic_mask',imagepath),0) - img_result = impro.replace_mosaic(img_origin,img_fake,mask,x,y,size,opt.no_feather) + if size > 100: + try:#Avoid unknown errors + img_mosaic = img_origin[y-size:y+size,x-size:x+size] + if opt.traditional: + img_fake = runmodel.traditional_cleaner(img_mosaic,opt) + else: + img_fake = runmodel.run_pix2pix(img_mosaic,netG,opt) + mask = cv2.imread(os.path.join('./tmp/mosaic_mask',imagepath),0) + img_result = impro.replace_mosaic(img_origin,img_fake,mask,x,y,size,opt.no_feather) + except Exception as e: + print('Warning:',e) cv2.imwrite(os.path.join('./tmp/replace_mosaic',imagepath),img_result) - print('\r','Clean Mosaic:'+str(i+1)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='') + #preview result + if not opt.no_preview: + cv2.imshow('clean',img_result) + cv2.waitKey(1) & 0xFF + t2 = time.time() + print('\r',str(i+1)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),util.counttime(t1,t2,i+1,len(imagepaths)),end='') print() + if not opt.no_preview: + cv2.destroyAllWindows() + # to video ffmpeg.image2video( fps, - './tmp/replace_mosaic/output_%05d.'+opt.tempimage_type, + './tmp/replace_mosaic/output_%06d.'+opt.tempimage_type, './tmp/voice_tmp.mp3', os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_clean.mp4')) @@ -159,6 +192,7 @@ def cleanmosaic_video_fusion(opt,netG,netM): positions = get_mosaic_positions(opt,netM,imagepaths,savemask=True) # clean mosaic + print('Clean Mosaic:') img_pool = np.zeros((height,width,3*N), dtype='uint8') for i,imagepath in enumerate(imagepaths,0): x,y,size = positions[i][0],positions[i][1],positions[i][2] @@ -172,24 +206,26 @@ def cleanmosaic_video_fusion(opt,netG,netM): img_pool[:,:,0:(N-1)*3] = img_pool[:,:,3:N*3] img_pool[:,:,(N-1)*3:] = impro.imread(os.path.join('./tmp/video2image',imagepaths[np.clip(i+12,0,len(imagepaths)-1)])) img_origin = img_pool[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3] - - if size==0: # can not find mosaic, - cv2.imwrite(os.path.join('./tmp/replace_mosaic',imagepath),img_origin) - else: + img_result = img_origin.copy() - mosaic_input = np.zeros((INPUT_SIZE,INPUT_SIZE,3*N+1), dtype='uint8') - mosaic_input[:,:,0:N*3] = impro.resize(img_pool[y-size:y+size,x-size:x+size,:], INPUT_SIZE) - mask_input = impro.resize(mask,np.min(img_origin.shape[:2]))[y-size:y+size,x-size:x+size] - mosaic_input[:,:,-1] = impro.resize(mask_input, INPUT_SIZE) + if size>100: + try:#Avoid unknown errors + #reshape to network input shape + mosaic_input = np.zeros((INPUT_SIZE,INPUT_SIZE,3*N+1), dtype='uint8') + mosaic_input[:,:,0:N*3] = impro.resize(img_pool[y-size:y+size,x-size:x+size,:], INPUT_SIZE) + mask_input = impro.resize(mask,np.min(img_origin.shape[:2]))[y-size:y+size,x-size:x+size] + mosaic_input[:,:,-1] = impro.resize(mask_input, INPUT_SIZE) - mosaic_input = data.im2tensor(mosaic_input,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1 = False) - unmosaic_pred = netG(mosaic_input) - img_fake = data.tensor2im(unmosaic_pred,rgb2bgr = False ,is0_1 = False) - img_result = impro.replace_mosaic(img_origin,img_fake,mask,x,y,size,opt.no_feather) - cv2.imwrite(os.path.join('./tmp/replace_mosaic',imagepath),img_result) - print('\r','Clean Mosaic:'+str(i+1)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='') + mosaic_input = data.im2tensor(mosaic_input,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1 = False) + unmosaic_pred = netG(mosaic_input) + img_fake = data.tensor2im(unmosaic_pred,rgb2bgr = False ,is0_1 = False) + img_result = impro.replace_mosaic(img_origin,img_fake,mask,x,y,size,opt.no_feather) + except Exception as e: + print('Warning:',e) + cv2.imwrite(os.path.join('./tmp/replace_mosaic',imagepath),img_result) + print('\r',str(i+1)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='') print() ffmpeg.image2video( fps, - './tmp/replace_mosaic/output_%05d.'+opt.tempimage_type, + './tmp/replace_mosaic/output_%06d.'+opt.tempimage_type, './tmp/voice_tmp.mp3', os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_clean.mp4')) \ No newline at end of file diff --git a/cores/options.py b/cores/options.py index 05188dfdb5a122f1dfa2f072af9e8788f8990b88..356ffec890a25dc539a6012b06ebc6e26e1d88fa 100644 --- a/cores/options.py +++ b/cores/options.py @@ -11,7 +11,6 @@ class Options(): #base self.parser.add_argument('--use_gpu',type=int,default=0, help='if -1, use cpu') - # self.parser.add_argument('--use_gpu', action='store_true', help='if input it, use gpu') self.parser.add_argument('--media_path', type=str, default='./imgs/ruoruo.jpg',help='your videos or images path') self.parser.add_argument('--mode', type=str, default='auto',help='Program running mode. auto | add | clean | style') self.parser.add_argument('--model_path', type=str, default='./pretrained_models/mosaic/add_face.pth',help='pretrained model path') @@ -20,6 +19,7 @@ class Options(): self.parser.add_argument('--netG', type=str, default='auto', help='select model to use for netG(Clean mosaic and Transfer style) -> auto | unet_128 | unet_256 | resnet_9blocks | HD | video') self.parser.add_argument('--fps', type=int, default=0,help='read and output fps, if 0-> origin') + self.parser.add_argument('--no_preview', action='store_true', help='if specified, do not preview images when processing video') self.parser.add_argument('--output_size', type=int, default=0,help='size of output media, if 0 -> origin') self.parser.add_argument('--mask_threshold', type=int, default=64,help='threshold of recognize clean or add mosaic position 0~255') diff --git a/deepmosaic.py b/deepmosaic.py index 3f91413e4b44c7b946c74b6e9d4866f779d72d8f..f258c6f809c14ce350df23d67dafd377028cb3d0 100644 --- a/deepmosaic.py +++ b/deepmosaic.py @@ -60,7 +60,6 @@ def main(): util.clean_tempfiles(tmp_init = False) -# main() if __name__ == '__main__': try: main() diff --git a/train/add/train.py b/train/add/train.py index f8c23bdedcf14ec275722d688fef722412668a6a..385b149c72047242be287a49b39dd476f5295f8a 100644 --- a/train/add/train.py +++ b/train/add/train.py @@ -35,7 +35,7 @@ opt.parser.add_argument('--model',type=str,default='BiSeNet', help='BiSeNet or U opt.parser.add_argument('--maxepoch',type=int,default=100, help='') opt.parser.add_argument('--savefreq',type=int,default=5, help='') opt.parser.add_argument('--maxload',type=int,default=1000000, help='') -opt.parser.add_argument('--continuetrain', action='store_true', help='') +opt.parser.add_argument('--continue_train', action='store_true', help='') opt.parser.add_argument('--startepoch',type=int,default=0, help='') opt.parser.add_argument('--dataset',type=str,default='./datasets/face/', help='') opt.parser.add_argument('--savename',type=str,default='face', help='') @@ -100,11 +100,11 @@ if opt.model =='UNet': elif opt.model =='BiSeNet': net = BiSeNet_model.BiSeNet(num_classes=1, context_path='resnet18') -if opt.continuetrain: +if opt.continue_train: if not os.path.isfile(os.path.join(dir_checkpoint,'last.pth')): - opt.continuetrain = False + opt.continue_train = False print('can not load last.pth, training on init weight.') -if opt.continuetrain: +if opt.continue_train: net.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last.pth'))) f = open(os.path.join(dir_checkpoint,'epoch_log.txt'),'r') opt.startepoch = int(f.read()) diff --git a/train/clean/train.py b/train/clean/train.py index b865246f59da75c66ed66184e819c8906ba181a0..7c179e7a8ba46b4838974608677b0d38031ea73b 100644 --- a/train/clean/train.py +++ b/train/clean/train.py @@ -11,6 +11,7 @@ import random import torch import torch.nn as nn import time +from multiprocessing import Process, Queue from util import mosaic,util,ffmpeg,filt,data from util import image_processing as impro @@ -32,17 +33,18 @@ opt.parser.add_argument('--lambda_gan',type=float,default=1, help='') opt.parser.add_argument('--finesize',type=int,default=256, help='') opt.parser.add_argument('--loadsize',type=int,default=286, help='') opt.parser.add_argument('--batchsize',type=int,default=1, help='') -opt.parser.add_argument('--perload_num',type=int,default=64, help='number of images pool') opt.parser.add_argument('--norm',type=str,default='instance', help='') opt.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use') opt.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') opt.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') +opt.parser.add_argument('--image_pool',type=int,default=8, help='number of image load pool') +opt.parser.add_argument('--load_process',type=int,default=4, help='number of process for loading data') opt.parser.add_argument('--dataset',type=str,default='./datasets/face/', help='') opt.parser.add_argument('--maxiter',type=int,default=10000000, help='') opt.parser.add_argument('--savefreq',type=int,default=10000, help='') opt.parser.add_argument('--startiter',type=int,default=0, help='') -opt.parser.add_argument('--continuetrain', action='store_true', help='') +opt.parser.add_argument('--continue_train', action='store_true', help='') opt.parser.add_argument('--savename',type=str,default='face', help='') @@ -89,13 +91,14 @@ if opt.gan: else: netD = pix2pix_model.define_D(3*2, 64, 'basic', norm = opt.norm) netD.cuda() + netD.train() #--------------------------continue train-------------------------- -if opt.continuetrain: +if opt.continue_train: if not os.path.isfile(os.path.join(dir_checkpoint,'last_G.pth')): - opt.continuetrain = False + opt.continue_train = False print('can not load last_G, training on init weight.') -if opt.continuetrain: +if opt.continue_train: netG.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_G.pth'))) if opt.gan: netD.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_D.pth'))) @@ -111,7 +114,6 @@ if opt.gan: optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999)) if opt.hd: criterionGAN = pix2pixHD_model.GANLoss(tensor=torch.cuda.FloatTensor).cuda() - # criterionFeat = torch.nn.L1Loss() criterionFeat = pix2pixHD_model.GAN_Feat_loss(opt) criterionVGG = pix2pixHD_model.VGGLoss([opt.use_gpu]) else: @@ -120,64 +122,27 @@ if opt.gan: ''' --------------------------preload data & data pool-------------------------- ''' -# def loaddata(video_index): - -# videoname = videonames[video_index] -# img_index = random.randint(int(N/2)+1,lengths[video_index]- int(N/2)-1) - -# input_img = np.zeros((opt.loadsize,opt.loadsize,3*N+1), dtype='uint8') -# # this frame -# this_mask = impro.imread(os.path.join(opt.dataset,videoname,'mask','%05d'%(img_index)+'.png'),'gray',loadsize=opt.loadsize) -# input_img[:,:,-1] = this_mask -# #print(os.path.join(opt.dataset,videoname,'origin_image','%05d'%(img_index)+'.jpg')) -# ground_true = impro.imread(os.path.join(opt.dataset,videoname,'origin_image','%05d'%(img_index)+'.jpg'),loadsize=opt.loadsize) -# mosaic_size,mod,rect_rat,feather = mosaic.get_random_parameter(ground_true,this_mask) -# start_pos = mosaic.get_random_startpos(num=N,bisa_p=0.3,bisa_max=mosaic_size,bisa_max_part=3) -# # merge other frame -# for i in range(0,N): -# img = impro.imread(os.path.join(opt.dataset,videoname,'origin_image','%05d'%(img_index+i-int(N/2))+'.jpg'),loadsize=opt.loadsize) -# mask = impro.imread(os.path.join(opt.dataset,videoname,'mask','%05d'%(img_index+i-int(N/2))+'.png'),'gray',loadsize=opt.loadsize) -# img_mosaic = mosaic.addmosaic_base(img, mask, mosaic_size,model = mod,rect_rat=rect_rat,feather=feather,start_point=start_pos[i]) -# input_img[:,:,i*3:(i+1)*3] = img_mosaic -# # to tensor -# input_img,ground_true = data.random_transform_video(input_img,ground_true,opt.finesize,N) -# input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=-1,use_transform = False,is0_1=False) -# ground_true = data.im2tensor(ground_true,bgr2rgb=False,use_gpu=-1,use_transform = False,is0_1=False) - -# return input_img,ground_true - print('Preloading data, please wait...') - -if opt.perload_num <= opt.batchsize: - opt.perload_num = opt.batchsize*2 -#data pool -input_imgs = torch.rand(opt.perload_num,N*3+1,opt.finesize,opt.finesize) -ground_trues = torch.rand(opt.perload_num,3,opt.finesize,opt.finesize) -load_cnt = 0 - -def preload(): - global load_cnt +def preload(pool): + cnt = 0 + input_imgs = torch.rand(opt.batchsize,N*3+1,opt.finesize,opt.finesize) + ground_trues = torch.rand(opt.batchsize,3,opt.finesize,opt.finesize) while 1: try: - video_index = random.randint(0,video_num-1) - videoname = videonames[video_index] - img_index = random.randint(int(N/2)+1,lengths[video_index]- int(N/2)-1) - input_imgs[load_cnt%opt.perload_num],ground_trues[load_cnt%opt.perload_num] = data.load_train_video(videoname,img_index,opt) - # input_imgs[load_cnt%opt.perload_num],ground_trues[load_cnt%opt.perload_num] = loaddata(video_index) - load_cnt += 1 - # time.sleep(0.1) + for i in range(opt.batchsize): + video_index = random.randint(0,video_num-1) + videoname = videonames[video_index] + img_index = random.randint(int(N/2)+1,lengths[video_index]- int(N/2)-1) + input_imgs[i],ground_trues[i] = data.load_train_video(videoname,img_index,opt) + cnt += 1 + pool.put([input_imgs,ground_trues]) except Exception as e: - print("error:",e) -import threading -t = threading.Thread(target=preload,args=()) -t.daemon = True -t.start() -time_start=time.time() -while load_cnt < opt.perload_num: - time.sleep(0.1) -time_end=time.time() -util.writelog(os.path.join(dir_checkpoint,'loss.txt'), - 'load speed: '+str(round((time_end-time_start)/(opt.perload_num),3))+' s/it',True) + print("Error:",videoname,e) +pool = Queue(opt.image_pool) +for i in range(opt.load_process): + p = Process(target=preload,args=(pool,)) + p.daemon = True + p.start() ''' --------------------------train-------------------------- @@ -185,14 +150,12 @@ util.writelog(os.path.join(dir_checkpoint,'loss.txt'), util.copyfile('./train.py', os.path.join(dir_checkpoint,'train.py')) util.copyfile('../../models/videoHD_model.py', os.path.join(dir_checkpoint,'model.py')) netG.train() -netD.train() time_start=time.time() print("Begin training...") for iter in range(opt.startiter+1,opt.maxiter): - ran = random.randint(0, opt.perload_num-opt.batchsize) - inputdata = (input_imgs[ran:ran+opt.batchsize].clone()).cuda() - target = (ground_trues[ran:ran+opt.batchsize].clone()).cuda() + inputdata,target = pool.get() + inputdata,target = inputdata.cuda(),target.cuda() if opt.gan: # compute fake images: G(A) @@ -226,17 +189,6 @@ for iter in range(opt.startiter+1,opt.maxiter): fake_AB = torch.cat((real_A, pred), 1) pred_fake = netD(fake_AB) loss_G_GAN = criterionGAN(pred_fake, True)*opt.lambda_gan - # GAN feature matching loss - # if opt.hd: - # real_AB = torch.cat((real_A, target), 1) - # pred_real = netD(real_AB) - # loss_G_GAN_Feat=criterionFeat(pred_fake,pred_real) - # loss_G_GAN_Feat = 0 - # feat_weights = 4.0 / (opt.n_layers_D + 1) - # D_weights = 1.0 / opt.num_D - # for i in range(opt.num_D): - # for j in range(len(pred_fake[i])-1): - # loss_G_GAN_Feat += D_weights * feat_weights * criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * opt.lambda_feat # combine loss and calculate gradients if opt.l2: @@ -273,42 +225,33 @@ for iter in range(opt.startiter+1,opt.maxiter): loss_G_L1.backward() optimizer_G.step() - # save eval result + # save train result if (iter+1)%1000 == 0: - video_index = random.randint(0,video_num-1) - videoname = videonames[video_index] - img_index = random.randint(int(N/2)+1,lengths[video_index]- int(N/2)-1) - inputdata,target = data.load_train_video(videoname, img_index, opt) - - # inputdata,target = loaddata(random.randint(0,video_num-1)) - inputdata,target = inputdata.cuda(),target.cuda() - with torch.no_grad(): - pred = netG(inputdata) try: data.showresult(inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], - target, pred, os.path.join(dir_checkpoint,'result_eval.jpg')) + target, pred, os.path.join(dir_checkpoint,'result_train.jpg')) except Exception as e: print(e) # plot if (iter+1)%1000 == 0: time_end = time.time() - if opt.gan: - savestr ='iter:{0:d} L1_loss:{1:.3f} GAN_loss:{2:.3f} Feat:{3:.3f} VGG:{4:.3f} time:{5:.2f}'.format( - iter+1,loss_sum[0]/1000,loss_sum[1]/1000,loss_sum[2]/1000,loss_sum[3]/1000,(time_end-time_start)/1000) - util.writelog(os.path.join(dir_checkpoint,'loss.txt'), savestr,True) - if (iter+1)/1000 >= 10: - for i in range(4):loss_plot[i].append(loss_sum[i]/1000) - item_plot.append(iter+1) - try: - labels = ['L1_loss','GAN_loss','GAN_Feat_loss','VGG_loss'] - for i in range(4):plt.plot(item_plot,loss_plot[i],label=labels[i]) - plt.xlabel('iter') - plt.legend(loc=1) - plt.savefig(os.path.join(dir_checkpoint,'loss.jpg')) - plt.close() - except Exception as e: - print("error:",e) + #if opt.gan: + savestr ='iter:{0:d} L1_loss:{1:.3f} GAN_loss:{2:.3f} Feat:{3:.3f} VGG:{4:.3f} time:{5:.2f}'.format( + iter+1,loss_sum[0]/1000,loss_sum[1]/1000,loss_sum[2]/1000,loss_sum[3]/1000,(time_end-time_start)/1000) + util.writelog(os.path.join(dir_checkpoint,'loss.txt'), savestr,True) + if (iter+1)/1000 >= 10: + for i in range(4):loss_plot[i].append(loss_sum[i]/1000) + item_plot.append(iter+1) + try: + labels = ['L1_loss','GAN_loss','GAN_Feat_loss','VGG_loss'] + for i in range(4):plt.plot(item_plot,loss_plot[i],label=labels[i]) + plt.xlabel('iter') + plt.legend(loc=1) + plt.savefig(os.path.join(dir_checkpoint,'loss.jpg')) + plt.close() + except Exception as e: + print("error:",e) loss_sum = [0.,0.,0.,0.,0.,0.] time_start=time.time() @@ -362,4 +305,4 @@ for iter in range(opt.startiter+1,opt.maxiter): result[opt.finesize:opt.finesize*2,opt.finesize*cnt:opt.finesize*(cnt+1),:] = pred cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.jpg'), result) - netG.train() + netG.train() \ No newline at end of file diff --git a/util/image_processing.py b/util/image_processing.py index 95783b2056ba5d37b69c5a1cfa92df8e1c5e0042..d19d64319e402f86ffb0dac189ffa8ffcfdea8c9 100755 --- a/util/image_processing.py +++ b/util/image_processing.py @@ -237,8 +237,6 @@ def replace_mosaic(img_origin,img_fake,mask,x,y,size,no_feather): eclosion_num = int(size/5) entad = int(eclosion_num/2+2) - # mask = np.zeros(img_origin.shape, dtype='uint8') - # mask = cv2.rectangle(mask,(x-size+entad,y-size+entad),(x+size-entad,y+size-entad),(255,255,255),-1) mask = cv2.resize(mask,(img_origin.shape[1],img_origin.shape[0])) mask = ch_one2three(mask) diff --git a/util/util.py b/util/util.py index 669610e45fa7e8c88cfc388f50e1a9d3c75f313b..a276bbdf3818e43907851afb84d4c1dabd48110d 100755 --- a/util/util.py +++ b/util/util.py @@ -62,7 +62,8 @@ def makedirs(path): print('makedir:',path) def clean_tempfiles(tmp_init=True): - if os.path.isdir('./tmp'): + if os.path.isdir('./tmp'): + print('Clean temp...') shutil.rmtree('./tmp') if tmp_init: os.makedirs('./tmp') @@ -86,9 +87,16 @@ def second2stamp(s): s = int(s%3600) m = int(s/60) s = int(s%60) - return "%02d:%02d:%02d" % (h, m, s) +def counttime(t1,t2,now_num,all_num): + ''' + t1,t2: time.time() + ''' + used_time = int(t2-t1) + all_time = int(used_time/now_num*all_num) + return second2stamp(used_time)+'/'+second2stamp(all_time) + def get_bar(percent,num = 25): bar = '[' for i in range(num): @@ -97,7 +105,7 @@ def get_bar(percent,num = 25): else: bar += '-' bar += ']' - return bar+' '+str(round(percent,2))+'%' + return bar+' '+"%.2f"%percent+'%' def copyfile(src,dst): try: