提交 cca42692 编写于 作者: H hypox64

use cyclegan to convert videos' style

上级 7e7145a9
......@@ -173,4 +173,7 @@ result/
*.avi
*.flv
*.mkv
*.rmvb
\ No newline at end of file
*.rmvb
*.JPG
*.MP4
*.JPEG
\ No newline at end of file
......@@ -6,6 +6,15 @@ from models import runmodel,loadmodel
from util import mosaic,util,ffmpeg,filt,data
from util import image_processing as impro
def video_init(opt,path):
util.clean_tempfiles()
fps = ffmpeg.get_video_infos(path)[0]
ffmpeg.video2voice(path,'./tmp/voice_tmp.mp3')
ffmpeg.video2image(path,'./tmp/video2image/output_%05d.'+opt.tempimage_type)
imagepaths=os.listdir('./tmp/video2image')
imagepaths.sort()
return fps,imagepaths
def addmosaic_img(opt,netS):
path = opt.media_path
print('Add Mosaic:',path)
......@@ -16,44 +25,63 @@ def addmosaic_img(opt,netS):
def addmosaic_video(opt,netS):
path = opt.media_path
util.clean_tempfiles()
fps = ffmpeg.get_video_infos(path)[0]
ffmpeg.video2voice(path,'./tmp/voice_tmp.mp3')
ffmpeg.video2image(path,'./tmp/video2image/output_%05d.'+opt.tempimage_type)
imagepaths=os.listdir('./tmp/video2image')
imagepaths.sort()
fps,imagepaths = video_init(opt,path)
# get position
positions = []
for imagepath in imagepaths:
print('Find ROI location:',imagepath)
for i,imagepath in enumerate(imagepaths,1):
img = impro.imread(os.path.join('./tmp/video2image',imagepath))
mask,x,y,area = runmodel.get_ROI_position(img,netS,opt)
positions.append([x,y,area])
cv2.imwrite(os.path.join('./tmp/ROI_mask',imagepath),mask)
print('Optimize ROI locations...')
print('\r','Find ROI location:'+str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=40),end='')
print('\nOptimize ROI locations...')
mask_index = filt.position_medfilt(np.array(positions), 7)
# add mosaic
print('Add mosaic to images...')
for i in range(len(imagepaths)):
mask = impro.imread(os.path.join('./tmp/ROI_mask',imagepaths[mask_index[i]]))
mask = impro.imread(os.path.join('./tmp/ROI_mask',imagepaths[mask_index[i]]),'gray')
img = impro.imread(os.path.join('./tmp/video2image',imagepaths[i]))
img = mosaic.addmosaic(img, mask, opt)
if impro.mask_area(mask)>100:
img = mosaic.addmosaic(img, mask, opt)
cv2.imwrite(os.path.join('./tmp/addmosaic_image',imagepaths[i]),img)
print('\r','Add Mosaic:'+str(i+1)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=40),end='')
print()
ffmpeg.image2video( fps,
'./tmp/addmosaic_image/output_%05d.'+opt.tempimage_type,
'./tmp/voice_tmp.mp3',
os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_add.mp4'))
def styletransfer_img(opt,netG):
print('Style Transfer_img:',opt.media_path)
img = impro.imread(opt.media_path)
img = runmodel.run_styletransfer(opt, netG, img)
suffix = os.path.basename(opt.model_path).replace('.pth','').replace('style_','')
cv2.imwrite(os.path.join(opt.result_dir,os.path.splitext(os.path.basename(opt.media_path))[0]+'_'+suffix+'.jpg'),img)
def styletransfer_video(opt,netG):
path = opt.media_path
positions = []
fps,imagepaths = video_init(opt,path)
for i,imagepath in enumerate(imagepaths,1):
img = impro.imread(os.path.join('./tmp/video2image',imagepath))
img = runmodel.run_styletransfer(opt, netG, img)
cv2.imwrite(os.path.join('./tmp/style_transfer',imagepath),img)
print('\r','Transfer:'+str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=40),end='')
print()
suffix = os.path.basename(opt.model_path).replace('.pth','').replace('style_','')
ffmpeg.image2video( fps,
'./tmp/style_transfer/output_%05d.'+opt.tempimage_type,
'./tmp/voice_tmp.mp3',
os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_'+suffix+'.mp4'))
def cleanmosaic_img(opt,netG,netM):
path = opt.media_path
print('Clean Mosaic:',path)
img_origin = impro.imread(path)
x,y,size,mask = runmodel.get_mosaic_position(img_origin,netM,opt)
cv2.imwrite('./mask/'+os.path.basename(path), mask)
#cv2.imwrite('./mask/'+os.path.basename(path), mask)
img_result = img_origin.copy()
if size != 0 :
img_mosaic = img_origin[y-size:y+size,x-size:x+size]
......@@ -65,21 +93,16 @@ def cleanmosaic_img(opt,netG,netM):
def cleanmosaic_video_byframe(opt,netG,netM):
path = opt.media_path
util.clean_tempfiles()
fps = ffmpeg.get_video_infos(path)[0]
ffmpeg.video2voice(path,'./tmp/voice_tmp.mp3')
ffmpeg.video2image(path,'./tmp/video2image/output_%05d.'+opt.tempimage_type)
fps,imagepaths = video_init(opt,path)
positions = []
imagepaths=os.listdir('./tmp/video2image')
imagepaths.sort()
# get position
for imagepath in imagepaths:
for i,imagepath in enumerate(imagepaths,1):
img_origin = impro.imread(os.path.join('./tmp/video2image',imagepath))
x,y,size = runmodel.get_mosaic_position(img_origin,netM,opt)[:3]
positions.append([x,y,size])
print('Find mosaic location:',imagepath)
print('Optimize mosaic locations...')
print('\r','Find mosaic location:'+str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=40),end='')
print('\nOptimize mosaic locations...')
positions =np.array(positions)
for i in range(3):positions[:,i] = filt.medfilt(positions[:,i],opt.medfilt_num)
......@@ -93,7 +116,8 @@ def cleanmosaic_video_byframe(opt,netG,netM):
img_fake = runmodel.run_pix2pix(img_mosaic,netG,opt)
img_result = impro.replace_mosaic(img_origin,img_fake,x,y,size,opt.no_feather)
cv2.imwrite(os.path.join('./tmp/replace_mosaic',imagepath),img_result)
print('Clean Mosaic:',imagepath)
print('\r','Clean Mosaic:'+str(i+1)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=40),end='')
print()
ffmpeg.image2video( fps,
'./tmp/replace_mosaic/output_%05d.'+opt.tempimage_type,
'./tmp/voice_tmp.mp3',
......@@ -103,31 +127,22 @@ def cleanmosaic_video_fusion(opt,netG,netM):
path = opt.media_path
N = 25
INPUT_SIZE = 128
util.clean_tempfiles()
fps = ffmpeg.get_video_infos(path)[0]
ffmpeg.video2voice(path,'./tmp/voice_tmp.mp3')
ffmpeg.video2image(path,'./tmp/video2image/output_%05d.'+opt.tempimage_type)
fps,imagepaths = video_init(opt,path)
positions = []
imagepaths=os.listdir('./tmp/video2image')
imagepaths.sort()
# get position
for imagepath in imagepaths:
for i,imagepath in enumerate(imagepaths,1):
img_origin = impro.imread(os.path.join('./tmp/video2image',imagepath))
# x,y,size = runmodel.get_mosaic_position(img_origin,net_mosaic_pos,opt)[:3]
x,y,size,mask = runmodel.get_mosaic_position(img_origin,netM,opt)
cv2.imwrite(os.path.join('./tmp/mosaic_mask',imagepath), mask)
positions.append([x,y,size])
print('Find mosaic location:',imagepath)
print('Optimize mosaic locations...')
print('\r','Find mosaic location:'+str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=40),end='')
print('\nOptimize mosaic locations...')
positions =np.array(positions)
for i in range(3):positions[:,i] = filt.medfilt(positions[:,i],opt.medfilt_num)
# clean mosaic
print('Clean mosaic...')
for i,imagepath in enumerate(imagepaths,0):
print('Clean mosaic:',imagepath)
x,y,size = positions[i][0],positions[i][1],positions[i][2]
img_origin = impro.imread(os.path.join('./tmp/video2image',imagepath))
mask = cv2.imread(os.path.join('./tmp/mosaic_mask',imagepath),0)
......@@ -153,7 +168,8 @@ def cleanmosaic_video_fusion(opt,netG,netM):
img_fake = data.tensor2im(unmosaic_pred,rgb2bgr = False ,is0_1 = False)
img_result = impro.replace_mosaic(img_origin,img_fake,x,y,size,opt.no_feather)
cv2.imwrite(os.path.join('./tmp/replace_mosaic',imagepath),img_result)
print('\r','Clean Mosaic:'+str(i+1)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=40),end='')
print()
ffmpeg.image2video( fps,
'./tmp/replace_mosaic/output_%05d.'+opt.tempimage_type,
'./tmp/voice_tmp.mp3',
......
......@@ -10,10 +10,10 @@ class Options():
def initialize(self):
#base
self.parser.add_argument('--use_gpu',type=int,default=1, help='if 0, do not use gpu')
self.parser.add_argument('--use_gpu',type=int,default=1, help='if 0 or -1, do not use gpu')
# self.parser.add_argument('--use_gpu', action='store_true', help='if input it, use gpu')
self.parser.add_argument('--media_path', type=str, default='./hands_test.mp4',help='your videos or images path')
self.parser.add_argument('--mode', type=str, default='auto',help='add or clean mosaic into your media auto | add | clean')
self.parser.add_argument('--mode', type=str, default='auto',help='add or clean mosaic into your media auto | add | clean | style')
self.parser.add_argument('--model_path', type=str, default='./pretrained_models/add_hands_128.pth',help='pretrained model path')
self.parser.add_argument('--result_dir', type=str, default='./result',help='output result will be saved here')
self.parser.add_argument('--tempimage_type', type=str, default='png',help='type of temp image, png | jpg, png is better but occupy more storage space')
......@@ -38,7 +38,7 @@ class Options():
self.initialize()
self.opt = self.parser.parse_args()
if torch.cuda.is_available() and self.opt.use_gpu:
if torch.cuda.is_available() and self.opt.use_gpu > 0:
self.opt.use_gpu = True
else:
self.opt.use_gpu = False
......@@ -49,17 +49,20 @@ class Options():
self.opt.mode = 'add'
elif 'clean' in self.opt.model_path:
self.opt.mode = 'clean'
elif 'style' in self.opt.model_path:
self.opt.mode = 'style'
else:
print('Please input running mode!')
if self.opt.netG == 'auto' and self.opt.mode =='clean':
if 'unet_128' in self.opt.model_path:
model_name = os.path.basename(self.opt.model_path)
if 'unet_128' in model_name:
self.opt.netG = 'unet_128'
elif 'resnet_9blocks' in self.opt.model_path:
elif 'resnet_9blocks' in model_name:
self.opt.netG = 'resnet_9blocks'
elif 'HD' in self.opt.model_path:
elif 'HD' in model_name:
self.opt.netG = 'HD'
elif 'video' in self.opt.model_path:
elif 'video' in model_name:
self.opt.netG = 'video'
else:
print('Type of Generator error!')
......
......@@ -20,7 +20,6 @@ def main():
core.addmosaic_img(opt,netS)
elif util.is_video(file):
core.addmosaic_video(opt,netS)
util.clean_tempfiles(tmp_init = False)
else:
print('This type of file is not supported')
......@@ -40,10 +39,22 @@ def main():
core.cleanmosaic_video_fusion(opt,netG,netM)
else:
core.cleanmosaic_video_byframe(opt,netG,netM)
util.clean_tempfiles(tmp_init = False)
else:
print('This type of file is not supported')
elif opt.mode == 'style':
netG = loadmodel.cyclegan(opt)
for file in files:
opt.media_path = file
if util.is_img(file):
core.styletransfer_img(opt,netG)
elif util.is_video(file):
core.styletransfer_video(opt,netG)
else:
print('This type of file is not supported')
util.clean_tempfiles(tmp_init = False)
main()
# if __name__ == '__main__':
......
......@@ -19,7 +19,7 @@ HD = True # if false make dataset for pix2pix, if Ture for pix2pix_HD
MASK = True # if True, output mask,too
OUT_SIZE = 256
FOLD_NUM = 2
Bounding = True
Bounding = False
if HD:
train_A_path = os.path.join(output_dir,'train_A')
......@@ -48,7 +48,7 @@ for fold in range(FOLD_NUM):
mask = impro.resize_like(mask, img)
x,y,size,area = impro.boundingSquare(mask, 1.5)
if area > 100:
if Bounding
if Bounding:
img = impro.resize(img[y-size:y+size,x-size:x+size],OUT_SIZE)
mask = impro.resize(mask[y-size:y+size,x-size:x+size],OUT_SIZE)
img_mosaic = mosaic.addmosaic_random(img, mask)
......
......@@ -14,7 +14,7 @@ ir_mask_path = './Irregular_Holes_mask'
img_dir ='/media/hypo/Hypoyun/Datasets/other/face512'
MOD = 'mosaic' #HD | pix2pix | mosaic
MASK = False # if True, output mask,too
BOUNDING = False # if true the mosaic size will be more big
BOUNDING = True # if true the mosaic size will be more big
suffix = '_1'
output_dir = os.path.join('./datasets_img',MOD)
util.makedirs(output_dir)
......
......@@ -9,6 +9,19 @@ def show_paramsnumber(net,netname='net'):
parameters = round(parameters/1e6,2)
print(netname+' parameters: '+str(parameters)+'M')
def __patch_instance_norm_state_dict(state_dict, module, keys, i=0):
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'num_batches_tracked'):
state_dict.pop('.'.join(keys))
else:
__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
def pix2pix(opt):
# print(opt.model_path,opt.netG)
......@@ -23,8 +36,31 @@ def pix2pix(opt):
netG.cuda()
return netG
def cyclegan(opt):
netG = define_G(3, 3, 64, 'resnet_9blocks', norm='instance',use_dropout=False, init_type='normal', gpu_ids=[])
#in other to load old pretrain model
#https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/models/base_model.py
if isinstance(netG, torch.nn.DataParallel):
netG = netG.module
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
state_dict = torch.load(opt.model_path, map_location='cpu')
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
# patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
__patch_instance_norm_state_dict(state_dict, netG, key.split('.'))
netG.load_state_dict(state_dict)
if opt.use_gpu:
netG.cuda()
return netG
def video(opt):
netG = MosaicNet(3*25+1, 3)
netG = MosaicNet(3*25+1, 3,norm = 'batch')
show_paramsnumber(netG,'netG')
netG.load_state_dict(torch.load(opt.model_path))
netG.eval()
......
......@@ -8,6 +8,21 @@ import functools
from torch.optim import lr_scheduler
def set_requires_grad(nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
###############################################################################
# Helper Functions
###############################################################################
......
......@@ -34,6 +34,15 @@ def run_pix2pix(img,net,opt):
img_fake = data.tensor2im(img_fake)
return img_fake
def run_styletransfer(opt, net, img, outsize = 720):
if min(img.shape[:2]) >= outsize:
img = impro.resize(img,outsize)
img = img[0:4*int(img.shape[0]/4),0:4*int(img.shape[1]/4),:]
img = data.im2tensor(img,use_gpu=opt.use_gpu)
img = net(img)
img = data.tensor2im(img)
return img
def get_ROI_position(img,net,opt):
mask = run_unet_rectim(img,net,use_gpu = opt.use_gpu)
mask = impro.mask_threshold(mask,opt.mask_extend,opt.mask_threshold)
......@@ -42,9 +51,10 @@ def get_ROI_position(img,net,opt):
def get_mosaic_position(img_origin,net_mosaic_pos,opt,threshold = 128 ):
mask = run_unet_rectim(img_origin,net_mosaic_pos,use_gpu = opt.use_gpu)
mask_1 = mask.copy()
mask = impro.mask_threshold(mask,20,threshold)
#mask_1 = mask.copy()
mask = impro.mask_threshold(mask,30,threshold)
mask = impro.find_best_ROI(mask)
x,y,size,area = impro.boundingSquare(mask,Ex_mul=1.5)
rat = min(img_origin.shape[:2])/224.0
x,y,size = int(rat*x),int(rat*y),int(rat*size)
return x,y,size,mask_1
\ No newline at end of file
return x,y,size,mask
\ No newline at end of file
......@@ -4,15 +4,6 @@ import torch.nn.functional as F
from .unet_parts import *
from .pix2pix_model import *
Norm = 'batch'
if Norm == 'instance':
NormLayer_2d = nn.InstanceNorm2d
NormLayer_3d = nn.InstanceNorm3d
use_bias = True
else:
NormLayer_2d = nn.BatchNorm2d
NormLayer_3d = nn.BatchNorm3d
use_bias = False
class encoder_2d(nn.Module):
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
......@@ -20,7 +11,7 @@ class encoder_2d(nn.Module):
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=NormLayer_2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
"""Construct a Resnet-based generator
Parameters:
......@@ -65,7 +56,7 @@ class decoder_2d(nn.Module):
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=NormLayer_2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
"""Construct a Resnet-based generator
Parameters:
......@@ -121,11 +112,11 @@ class decoder_2d(nn.Module):
class conv_3d(nn.Module):
def __init__(self,inchannel,outchannel,kernel_size=3,stride=2,padding=1):
def __init__(self,inchannel,outchannel,kernel_size=3,stride=2,padding=1,norm_layer_3d=nn.BatchNorm3d,use_bias=True):
super(conv_3d, self).__init__()
self.conv = nn.Sequential(
nn.Conv3d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=padding, bias=use_bias),
NormLayer_3d(outchannel),
norm_layer_3d(outchannel),
nn.ReLU(inplace=True),
)
......@@ -134,12 +125,12 @@ class conv_3d(nn.Module):
return x
class conv_2d(nn.Module):
def __init__(self,inchannel,outchannel,kernel_size=3,stride=1,padding=1):
def __init__(self,inchannel,outchannel,kernel_size=3,stride=1,padding=1,norm_layer_2d=nn.BatchNorm2d,use_bias=True):
super(conv_2d, self).__init__()
self.conv = nn.Sequential(
nn.ReflectionPad2d(padding),
nn.Conv2d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=0, bias=use_bias),
NormLayer_2d(outchannel),
norm_layer_2d(outchannel),
nn.ReLU(inplace=True),
)
......@@ -149,14 +140,14 @@ class conv_2d(nn.Module):
class encoder_3d(nn.Module):
def __init__(self,in_channel):
def __init__(self,in_channel,norm_layer_2d,norm_layer_3d,use_bias):
super(encoder_3d, self).__init__()
self.down1 = conv_3d(1, 64, 3, 2, 1)
self.down2 = conv_3d(64, 128, 3, 2, 1)
self.down3 = conv_3d(128, 256, 3, 1, 1)
self.down1 = conv_3d(1, 64, 3, 2, 1,norm_layer_3d,use_bias)
self.down2 = conv_3d(64, 128, 3, 2, 1,norm_layer_3d,use_bias)
self.down3 = conv_3d(128, 256, 3, 1, 1,norm_layer_3d,use_bias)
self.conver2d = nn.Sequential(
nn.Conv2d(256*int(in_channel/4), 256, kernel_size=3, stride=1, padding=1, bias=use_bias),
NormLayer_2d(256),
norm_layer_2d(256),
nn.ReLU(inplace=True),
)
......@@ -176,17 +167,17 @@ class encoder_3d(nn.Module):
class MosaicNet(nn.Module):
def __init__(self, in_channel, out_channel):
super(MosaicNet, self).__init__()
class ALL(nn.Module):
def __init__(self, in_channel, out_channel,norm_layer_2d,norm_layer_3d,use_bias):
super(ALL, self).__init__()
self.encoder_2d = encoder_2d(4,-1,64,n_blocks=9)
self.encoder_3d = encoder_3d(in_channel)
self.decoder_2d = decoder_2d(4,3,64,n_blocks=9)
self.shortcut_cov = conv_2d(3,64,7,1,3)
self.merge1 = conv_2d(512,256,3,1,1)
self.encoder_2d = encoder_2d(4,-1,64,norm_layer=norm_layer_2d,n_blocks=9)
self.encoder_3d = encoder_3d(in_channel,norm_layer_2d,norm_layer_3d,use_bias)
self.decoder_2d = decoder_2d(4,3,64,norm_layer=norm_layer_2d,n_blocks=9)
self.shortcut_cov = conv_2d(3,64,7,1,3,norm_layer_2d,use_bias)
self.merge1 = conv_2d(512,256,3,1,1,norm_layer_2d,use_bias)
self.merge2 = nn.Sequential(
conv_2d(128,64,3,1,1),
conv_2d(128,64,3,1,1,norm_layer_2d,use_bias),
nn.ReflectionPad2d(3),
nn.Conv2d(64, out_channel, kernel_size=7, padding=0),
nn.Tanh()
......@@ -210,3 +201,17 @@ class MosaicNet(nn.Module):
return x
def MosaicNet(in_channel, out_channel, norm='batch'):
if norm == 'batch':
# norm_layer_2d = nn.BatchNorm2d
# norm_layer_3d = nn.BatchNorm3d
norm_layer_2d = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
norm_layer_3d = functools.partial(nn.BatchNorm3d, affine=True, track_running_stats=True)
use_bias = False
elif norm == 'instance':
norm_layer_2d = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
norm_layer_3d = functools.partial(nn.InstanceNorm3d, affine=False, track_running_stats=False)
use_bias = True
return ALL(in_channel, out_channel, norm_layer_2d, norm_layer_3d, use_bias)
......@@ -28,7 +28,7 @@ FINESIZE = 224
CONTINUE = True
use_gpu = True
SAVE_FRE = 1
MAX_LOAD = 35000
MAX_LOAD = 30000
#cudnn.benchmark = True
......
......@@ -12,7 +12,7 @@ sys.path.append("../..")
from util import mosaic,util,ffmpeg,filt,data
from util import image_processing as impro
from cores import Options
from models import pix2pix_model,video_model,unet_model,loadmodel
from models import pix2pix_model,pix2pixHD_model,video_model,unet_model,loadmodel
from matplotlib import pyplot as plt
import torch.backends.cudnn as cudnn
......@@ -21,19 +21,20 @@ ITER = 10000000
LR = 0.0002
beta1 = 0.5
use_gpu = True
use_gan = True
use_gan = False
use_L2 = False
CONTINUE = True
lambda_L1 = 100.0
lambda_gan = 1
lambda_gan = 0.5
SAVE_FRE = 10000
start_iter = 0
finesize = 256
loadsize = int(finesize*1.2)
batchsize = 1
batchsize = 6
perload_num = 16
savename = 'MosaicNet_instance_gan_256_D5'
# savename = 'MosaicNet_instance_gan_256_hdD'
savename = 'MosaicNet_instance_test'
dir_checkpoint = 'checkpoints/'+savename
util.makedirs(dir_checkpoint)
......@@ -52,13 +53,14 @@ for video in videos:
#unet_128
#resnet_9blocks
#netG = pix2pix_model.define_G(3*N+1, 3, 128, 'resnet_6blocks', norm='instance',use_dropout=True, init_type='normal', gpu_ids=[])
netG = video_model.MosaicNet(3*N+1, 3)
netG = video_model.MosaicNet(3*N+1, 3, norm='instance')
loadmodel.show_paramsnumber(netG,'netG')
# netG = unet_model.UNet(3*N+1, 3)
if use_gan:
netD = pix2pixHD_model.define_D(6, 64, 3, norm='instance', use_sigmoid=False, num_D=2)
#netD = pix2pix_model.define_D(3*2+1, 64, 'pixel', norm='instance')
#netD = pix2pix_model.define_D(3*2+1, 64, 'basic', norm='instance')
netD = pix2pix_model.define_D(3*2+1, 64, 'n_layers', n_layers_D=5, norm='instance')
#netD = pix2pix_model.define_D(3*2, 64, 'basic', norm='instance')
#netD = pix2pix_model.define_D(3*2+1, 64, 'n_layers', n_layers_D=5, norm='instance')
if CONTINUE:
if not os.path.isfile(os.path.join(dir_checkpoint,'last_G.pth')):
......@@ -71,19 +73,22 @@ if CONTINUE:
f = open(os.path.join(dir_checkpoint,'iter'),'r')
start_iter = int(f.read())
f.close()
if use_gpu:
netG.cuda()
if use_gan:
netD.cuda()
cudnn.benchmark = True
optimizer_G = torch.optim.Adam(netG.parameters(), lr=LR,betas=(beta1, 0.999))
criterion_L1 = nn.L1Loss()
criterion_L2 = nn.MSELoss()
if use_gan:
optimizer_D = torch.optim.Adam(netG.parameters(), lr=LR,betas=(beta1, 0.999))
criterionGAN = pix2pix_model.GANLoss(gan_mode='lsgan').cuda()
# criterionGAN = pix2pix_model.GANLoss(gan_mode='lsgan').cuda()
criterionGAN = pix2pixHD_model.GANLoss(tensor=torch.cuda.FloatTensor)
netD.train()
if use_gpu:
netG.cuda()
if use_gan:
netD.cuda()
criterionGAN.cuda()
cudnn.benchmark = True
def loaddata():
video_index = random.randint(0,len(videos)-1)
......@@ -151,31 +156,34 @@ for iter in range(start_iter+1,ITER):
inputdata = input_imgs[ran:ran+batchsize].clone()
target = ground_trues[ran:ran+batchsize].clone()
pred = netG(inputdata)
if use_gan:
netD.train()
# print(inputdata[0,3*N,:,:].size())
# print((inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]).size())
real_A = torch.cat((inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], inputdata[:,-1,:,:].reshape(-1,1,finesize,finesize)), 1)
# compute fake images: G(A)
pred = netG(inputdata)
# update D
pix2pix_model.set_requires_grad(netD,True)
optimizer_D.zero_grad()
# Fake
real_A = inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]
fake_AB = torch.cat((real_A, pred), 1)
pred_fake = netD(fake_AB.detach())
loss_D_fake = criterionGAN(pred_fake, False)
# Real
real_AB = torch.cat((real_A, target), 1)
pred_real = netD(real_AB)
loss_D_real = criterionGAN(pred_real, True)
# combine loss and calculate gradients
loss_D = (loss_D_fake + loss_D_real) * 0.5
loss_sum[2] += loss_D_fake.item()
loss_sum[3] += loss_D_real.item()
optimizer_D.zero_grad()
# udpate D's weights
loss_D.backward()
optimizer_D.step()
netD.eval()
# fake_AB = torch.cat((inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], pred), 1)
real_A = torch.cat((inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], inputdata[:,-1,:,:].reshape(-1,1,finesize,finesize)), 1)
# update G
pix2pix_model.set_requires_grad(netD,False)
optimizer_G.zero_grad()
# First, G(A) should fake the discriminator
real_A = inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]
fake_AB = torch.cat((real_A, pred), 1)
pred_fake = netD(fake_AB)
loss_G_GAN = criterionGAN(pred_fake, True)*lambda_gan
......@@ -188,12 +196,12 @@ for iter in range(start_iter+1,ITER):
loss_G = loss_G_GAN + loss_G_L1
loss_sum[0] += loss_G_L1.item()
loss_sum[1] += loss_G_GAN.item()
optimizer_G.zero_grad()
# udpate G's weights
loss_G.backward()
optimizer_G.step()
else:
pred = netG(inputdata)
if use_L2:
loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * lambda_L1
else:
......@@ -204,7 +212,6 @@ for iter in range(start_iter+1,ITER):
loss_G_L1.backward()
optimizer_G.step()
if (iter+1)%100 == 0:
try:
data.showresult(inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:],
......
......@@ -9,7 +9,7 @@ def video2voice(videopath,voicepath):
os.system('ffmpeg -i "'+videopath+'" -f mp3 '+voicepath)
def image2video(fps,imagepath,voicepath,videopath):
os.system('ffmpeg -y -r '+str(fps)+' -i '+imagepath+' -vcodec libx264 '+'./tmp/video_tmp.mp4')
os.system('ffmpeg -y -r '+str(fps)+' -i '+imagepath+' -vcodec libx264 -b 12M '+'./tmp/video_tmp.mp4')
#os.system('ffmpeg -f image2 -i '+imagepath+' -vcodec libx264 -r '+str(fps)+' ./tmp/video_tmp.mp4')
os.system('ffmpeg -i ./tmp/video_tmp.mp4 -i "'+voicepath+'" -vcodec copy -acodec copy '+videopath)
......
......@@ -17,14 +17,14 @@ def imread(file_path,mod = 'normal'):
# cv_img = cv2.imdecode(np.fromfile(file_path,dtype=np.uint8),-1)
return cv_img
def resize(img,size):
def resize(img,size,interpolation=cv2.INTER_LINEAR):
h, w = img.shape[:2]
if np.min((w,h)) ==size:
return img
if w >= h:
res = cv2.resize(img,(int(size*w/h), size))
res = cv2.resize(img,(int(size*w/h), size),interpolation=interpolation)
else:
res = cv2.resize(img,(size, int(size*h/w)))
res = cv2.resize(img,(size, int(size*h/w)),interpolation=interpolation)
return res
def resize_like(img,img_like):
......@@ -111,6 +111,17 @@ def mergeimage(img1,img2,orgin_image,size = 128):
result_img = cv2.add(new_img1,new_img2)
return result_img
def find_best_ROI(mask):
contours,hierarchy=cv2.findContours(mask, cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE)
if len(contours)>0:
areas = []
for contour in contours:
areas.append(cv2.contourArea(contour))
index = areas.index(max(areas))
mask = np.zeros_like(mask)
mask = cv2.fillPoly(mask,[contours[index]],(255))
return mask
def boundingSquare(mask,Ex_mul):
# thresh = mask_threshold(mask,10,threshold)
area = mask_area(mask)
......@@ -152,7 +163,7 @@ def boundingSquare(mask,Ex_mul):
def mask_threshold(mask,blur,threshold):
mask = cv2.threshold(mask,threshold,255,cv2.THRESH_BINARY)[1]
mask = cv2.blur(mask, (blur, blur))
mask = cv2.threshold(mask,threshold/3,255,cv2.THRESH_BINARY)[1]
mask = cv2.threshold(mask,threshold/5,255,cv2.THRESH_BINARY)[1]
return mask
def mask_area(mask):
......
......@@ -64,6 +64,7 @@ def clean_tempfiles(tmp_init=True):
os.makedirs('./tmp/mosaic_mask')
os.makedirs('./tmp/ROI_mask')
os.makedirs('./tmp/ROI_mask_check')
os.makedirs('./tmp/style_transfer')
def file_init(opt):
if not os.path.isdir(opt.result_dir):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册