提交 9de91414 编写于 作者: H hypox64

just commit, unstable

上级 c21505a2
......@@ -104,10 +104,11 @@ def cleanmosaic_video_byframe(opt):
os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_clean.mp4'))
def cleanmosaic_video_fusion(opt):
net = loadmodel.pix2pix(opt)
net = loadmodel.video(opt)
net_mosaic_pos = loadmodel.unet_clean(opt)
path = opt.media_path
N = 25
INPUT_SIZE = 128
util.clean_tempfiles()
fps = ffmpeg.get_video_infos(path)[0]
......@@ -140,15 +141,15 @@ def cleanmosaic_video_fusion(opt):
if size==0:
cv2.imwrite(os.path.join('./tmp/replace_mosaic',imagepath),img_origin)
else:
mosaic_input = np.zeros((256,256,3*N+1), dtype='uint8')
mosaic_input = np.zeros((INPUT_SIZE,INPUT_SIZE,3*N+1), dtype='uint8')
for j in range(0,N):
img = impro.imread(os.path.join('./tmp/video2image',imagepaths[np.clip(i+j-12,0,len(imagepaths)-1)]))
img = img[y-size:y+size,x-size:x+size]
img = impro.resize(img,256)
img = impro.resize(img,INPUT_SIZE)
mosaic_input[:,:,j*3:(j+1)*3] = img
mask = impro.resize(mask,np.min(img_origin.shape[:2]))
mask = mask[y-size:y+size,x-size:x+size]
mask = impro.resize(mask, 256)
mask = impro.resize(mask, INPUT_SIZE)
mosaic_input[:,:,-1] = mask
mosaic_input = data.im2tensor(mosaic_input,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False)
unmosaic_pred = net(mosaic_input)
......
......@@ -17,4 +17,24 @@
1pondo_070315_108_1080p.mp4,00:11:10,00:11:50,00:13:50,00:14:20,00:14:35,00:15:50,00:17:20,00:18:35,00:20:45,00:24:35,00:25:05,00:29:15,00:30:40,00:31:55,00:35:20,00:42:55,00:43:05,00:46:15,00:48:00,00:51:45,00:52:33,00:54:20,00:59:25,00:59:40,01:00:05
071114_842-1pon-whole1_hd.mp4,00:09:50,00:11:25,00:16:35,00:18:20,00:22:10,00:25:25,00:26:35,00:33:50,00:35:40,00:43:10
071715_116-1pon-1080p.mp4,00:10:50,00:11:30,00:12:50,00:15:10,00:16:45,00:17:05,00:25:20,00:26:45,00:28:30,00:30:20,00:32:55,00:34:30,00:37:40,00:38:40,00:40:20,00:41:20,00:44:10,00:47:15,00:55:00,00:59:40,00:59:50
071815_117-1pon-1080p.mp4,00:14:50,00:15:10,00:18:05,00:14:50,00:25:55,00:26:25,00:32:45,00:33:40,00:43:15,00:45:05,00:45:45,00:48:40,00:48:50,00:55:45,10:00:20,01:00:35,01:01:00,01:01:10
\ No newline at end of file
071815_117-1pon-1080p.mp4,00:14:50,00:15:10,00:18:05,00:14:50,00:25:55,00:26:25,00:32:45,00:33:40,00:43:15,00:45:05,00:45:45,00:48:40,00:48:50,00:55:45,10:00:20,01:00:35,01:01:00,01:01:10
080815_130-1pon-1080p,00:14:50,00:17:15,00:17:20,00:23:55,00:25:30,00:25:55,00:28:20,00:28:30,00:30:10,00:31:00,00:33:25,00:33:35,00:33:45,00:33:50,00:39:25,00:39:50,00:40:25,00:44:05,00:45:00,00:45:40,00:45:50,00:46:55,00:49:15,00:49:25,00:46:40,00:50:10,00:50:15,00:51:25,00:51:50,00:53:14,00:53:20,00:54:15,00:56:15,00:56:25,00:56:45,00:57:45,00:57:30,00:58:00,00:56:45,00:56:55,01:00:00,01:00:05,01:00:25,01:00:30
081514_863-1pon-whole1_hd.avi,00:10:30,00:26:00,00:30:00,00:38:21,00:40:15,00:40:30,00:49:10,00:50:05,00:57:10,00:59:00
090614_877-1pon-whole1_hd.mp4,00:04:45,00:05:15,00:12:25,00:12:40,00:15:00,00:15:15,00:16:25,00:20:50,00:21:45,00:26:10,00:33:35,00:35:55,00:37:50,00:37:55,00:38:12,00:39:55,00:41:50,00:44:27,00:44:37,00:46:30,00:47:35,00:47:40,00:48:20,00:59:50
091215_152-1pon-1080p.mp4,00:05:30,00:06:10,00:06:20,00:08:15,00:10:10,00:11:15,00:12:15,00:12:55,0:15:15,00:15:35,00:18:00,00:24:45,00:25:45,00:33:45,00:35:32,00:37:35,00:37:55,00:38:50,00:42:15,00:45:00,00:47:55,00:48:20,00:48:35,00:48:42,00:49:43,00:50:15,00:51:10,00:55:35,00:57:00,00:57:55,01:03:30,01:05:00
092813_670-1pon-whole1_hd.avi,00:16:32,00:19:00,00:22:10,00:23:20,00:23:40,00:30:20,00:32:00,00:35:00,00:36:50,00:41:40,00:44:50,00:52:45,00:54:00
103015_180-1pon-1080p.mp4,00:24:50,00:31:25,00:41:20,00:48:10,00:48:50,00:49:20,00:50:15,00:52:45,00:53:30,01:02:40,01:03:35,01:09:50,01:15:05,01:16:50
110615_185-1pon-1080p.mp4,00:15:00,00:15:40,00:34:15,00:34:50,00:35:30,00:37:05,00:39:35,00:40:30,00:41:40,00:47:35,00:50:15,00:51:01,00:51:35,00:54:15,00:55:40,00:55:50,00:57:20,00:59:35,01:00:00,01:00:25
120310_979-1pon-whole1_hd.avi,00:15:10,00:14:25,00:14:30,00:14:50,00:15:45,00:16:35,00:16:55,00:17:25,00:19:25,00:20:45,00:27:05,00:30:17,00:32:00,00:33:50,00:35:45,00:38:55,00:40:25,00:40:40,00:41:10,00:42:50,00:44:35,00:45:15,00:46:15,00:48:00,00:49:10,00:50:10,00:54:00,00:55:23,00:55:30,00:55:50
021315-806-carib-1080p.mp4,00:13:30,00:15:20,00:17:40,00:21:50,00:22:25,00:24:35,00:28:50,00:28:52,00:31:00,00:37:25,00:37:35,00:38:20,00:38:45,00:43:30,00:48:35,00:51:30,00:51:50,00:52:19,00:56:20,00:58:35
021715-809-carib-1080p.mp4,00:17:30,00:20:35,00:21:00,00:22:00,00:23:55,00:24:15,00:28:40,00:37:20,00:39:05,00:40:05,00:40:50,00:42:45,00:45:00,00:46:40,00:48:00,00:48:20,00:51:30,00:52:10,00:53:35,00:54:10,00:54:20,00:56:45,00:56:55,00:59:10,00:59:35,00:59:55
022715-817-carib-1080p.mp4,00:57:52,00:08:50,00:10:00,00:12:50,00:14:05,00:18:25,00:20:45,00:20:57,00:22:15,00:23:30,00:23:55,00:24:18,00:24:50,00:25:25,00:26:30,00:26:55,00:28:50,00:31:55,00:34:00,00:34:35,00:42:45,00:44:33
030914-558-carib-high_1.mp4,00:10:45,00:12:45,00:14:40,00:16:33,00:19:40,00:21:35,00:21:55,00:23:05,00:26:15,00:27:30,00:29:55,00:31:10,00:31:40,00:36:40,00:41:40,00:42:40,00:44:50,00:49:50,00:52:25,00:53:50,00:54:30,00:55:20,00:55:10,00:57:05,00:57:25,00:59:05,01:00:15,01:02:11,01:03:55,01:05:10
031815-830-carib-1080p.mp4,00:13:15,00:13:25,00:13:55,00:14:40,00:15:40,00:17:30,00:18:20,00:19:10,00:21:00,00:22:10,00:22:25,00:23:25,00:27:10,00:28:33,00:35:05,00:35:40,00:37:50,00:38:00,00:39:35,00:41:35,00:42:40,00:47:40,00:50:33,00:55:50,01:02:10,01:05:20,01:05:30
032016-121-carib-1080p.mp4,00:27:20,00:28:40,00:28:55,00:30:35,00:36:10,00:39:10,00:40:30,00:43:00,00:46:05,00:50:00,00:56:05,00:56:20,00:59:20
032913-301-carib-whole_hd1.wmv,00:06:00,00:09:40,00:11:00,00:13:00,00:15:05,00:16:40,00:18:05,00:20:00,00:39:31,00:34:35,00:44:50,00:47:25,00:49:50,00:51:20,00:54:58,00:56:55,00:59:50,01:00:50
032914-571-carib-high_1.mp4,00:13:30,00:13:55,00:16:40,00:15:25,00:20:40,00:26:45,00:32:05,00:33:15,00:36:40,00:38:55,00:39:00,00:39:25,00:47:30,00:49:20
042514-588-carib-high_1.mp4,00:10:30,00:11:15,00:19:15,00:20:00,00:20:30,00:22:05,00:22:45,00:22:53,00:24:15,00:30:50,00:32:25,00:34:15,00:34:45,00:34:55,0:36:05,00:37:20,00:37:40,00:38:30,00:39:35,00:41:00,00:43:30,00:43:40
052315-884-carib-1080p.mp4,00:09:35,00:14:10,00:14:30,00:14:40,00:17:10,00:17:50,00:19:00,00:20:20,01:21:55,00:22:40,00:23:05,00:24:00,00:26:00,00:27:15,00:30:25,00:32:50,00:37:55,0:39:35,00:40:10,00:41:40,00:43:15,00:43:40,00:47:55,00:49:30,00:49:55,00:58:55,01:00:40
053114-612-carib-high_1.mp4,00:08:35,00:13:35,00:15:25,00:16:40,00:20:35,00:22:25,00:26:10,00:29:10,00:32:55,00:34:10,00:37:05,00:37:40,00:39:40,00:40:52,00:42:08,00:42:15
062615-908-carib-1080p.mp4,00:13:45,00:14:40,00:15:45,00:16:11,00:17:00,00:22:10,00:23:40,00:26:10,00:27:15,00:27:50,00:31:30,00:35:00,00:40:20,00:43:10,00:44:35,00:47:17,00:50:25,00:51:15,00:52:20,00:54:10,00:55:30,01:00:20
\ No newline at end of file
......@@ -13,21 +13,20 @@ files = util.Traversal('/media/hypo/Media/download')
videos = util.is_videos(files)
video_times = []
useable_videos = []
video_dict = {}
reader = csv.reader(open('./csv/video_used_time.csv'))
for line in reader:
useable_videos.append(line[0])
video_times.append(line[1:])
video_dict[line[0]]=line[1:]
in_cnt = 0
out_cnt = 502
out_cnt = 1
for video in videos:
if os.path.basename(video) in useable_videos:
# print(video)
for i in range(len(video_times[in_cnt])):
for i in range(len(video_dict[os.path.basename(video)])):
ffmpeg.cut_video(video, video_dict[os.path.basename(video)][i], '00:00:05', './video/'+'%04d'%out_cnt+'.mp4')
out_cnt +=1
in_cnt += 1
......@@ -8,72 +8,77 @@ sys.path.append("..")
from models import runmodel,loadmodel
from util import mosaic,util,ffmpeg,filt
from util import image_processing as impro
from options import Options
from cores import options
opt = Options().getparse()
opt = options.Options().getparse()
util.file_init(opt)
videos = os.listdir('./video')
videos.sort()
opt.model_path = '../pretrained_models/add_youknow_128.pth'
opt.use_gpu = True
Ex = 1.4
Area_Type = 'normal'
suffix = ''
net = loadmodel.unet(opt)
for path in videos:
try:
path = os.path.join('./video',path)
util.clean_tempfiles()
ffmpeg.video2voice(path,'./tmp/voice_tmp.mp3')
ffmpeg.video2image(path,'./tmp/video2image/output_%05d.'+opt.tempimage_type)
imagepaths=os.listdir('./tmp/video2image')
imagepaths.sort()
path = os.path.join('./video',path)
util.clean_tempfiles()
ffmpeg.video2voice(path,'./tmp/voice_tmp.mp3')
ffmpeg.video2image(path,'./tmp/video2image/output_%05d.'+opt.tempimage_type)
imagepaths=os.listdir('./tmp/video2image')
imagepaths.sort()
# get position
positions = []
img_ori_example = impro.imread(os.path.join('./tmp/video2image',imagepaths[0]))
mask_avg = np.zeros((impro.resize(img_ori_example, 128)).shape[:2])
for imagepath in imagepaths:
imagepath = os.path.join('./tmp/video2image',imagepath)
print('Find ROI location:',imagepath)
img = impro.imread(imagepath)
x,y,size,mask = runmodel.get_mosaic_position(img,net,opt,threshold = 64)
cv2.imwrite(os.path.join('./tmp/ROI_mask',
os.path.basename(imagepath)),mask)
positions.append([x,y,size])
mask_avg = mask_avg + mask
print('Optimize ROI locations...')
mask_index = filt.position_medfilt(np.array(positions), 13)
# get position
positions = []
img_ori_example = impro.imread(os.path.join('./tmp/video2image',imagepaths[0]))
mask_avg = np.zeros((impro.resize(img_ori_example, 128)).shape[:2])
for imagepath in imagepaths:
imagepath = os.path.join('./tmp/video2image',imagepath)
print('Find ROI location:',imagepath)
img = impro.imread(imagepath)
x,y,size,mask = runmodel.get_mosaic_position(img,net,opt,threshold = 64)
cv2.imwrite(os.path.join('./tmp/ROI_mask',
os.path.basename(imagepath)),mask)
positions.append([x,y,size])
mask_avg = mask_avg + mask
print('Optimize ROI locations...')
mask_index = filt.position_medfilt(np.array(positions), 13)
mask = np.clip(mask_avg/len(imagepaths),0,255).astype('uint8')
mask = impro.mask_threshold(mask,20,32)
x,y,size,area = impro.boundingSquare(mask,Ex_mul=Ex)
rat = min(img_ori_example.shape[:2])/128.0
x,y,size = int(rat*x),int(rat*y),int(rat*size)
cv2.imwrite(os.path.join('./tmp/ROI_mask_check',
'test_show.png'),mask)
if size !=0 :
mask_path = './dataset/'+os.path.splitext(os.path.basename(path))[0]+suffix+'/mask'
ori_path = './dataset/'+os.path.splitext(os.path.basename(path))[0]+suffix+'/ori'
mosaic_path = './dataset/'+os.path.splitext(os.path.basename(path))[0]+suffix+'/mosaic'
os.makedirs('./dataset/'+os.path.splitext(os.path.basename(path))[0]+suffix)
os.makedirs(mask_path)
os.makedirs(ori_path)
os.makedirs(mosaic_path)
print('Add mosaic to images...')
mosaic_size = mosaic.get_autosize(img_ori_example,mask,area_type = Area_Type)*random.uniform(1,2)
models = ['squa_avg','rect_avg','squa_mid']
mosaic_type = random.randint(0,len(models)-1)
rect_rat = random.uniform(1.2,1.6)
for i in range(len(imagepaths)):
mask = impro.imread(os.path.join('./tmp/ROI_mask',imagepaths[mask_index[i]]),mod = 'gray')
img_ori = impro.imread(os.path.join('./tmp/video2image',imagepaths[i]))
img_mosaic = mosaic.addmosaic_normal(img_ori,mask,mosaic_size,model = models[mosaic_type],rect_rat=rect_rat)
mask = impro.resize(mask, min(img_ori.shape[:2]))
mask = np.clip(mask_avg/len(imagepaths),0,255).astype('uint8')
mask = impro.mask_threshold(mask,20,32)
x,y,size,area = impro.boundingSquare(mask,Ex_mul=1.5)
rat = min(img_ori_example.shape[:2])/128.0
x,y,size = int(rat*x),int(rat*y),int(rat*size)
cv2.imwrite(os.path.join('./tmp/ROI_mask_check',
'test_show.png'),mask)
if size !=0 :
mask_path = './dataset/'+os.path.splitext(os.path.basename(path))[0]+'/mask'
ori_path = './dataset/'+os.path.splitext(os.path.basename(path))[0]+'/ori'
mosaic_path = './dataset/'+os.path.splitext(os.path.basename(path))[0]+'/mosaic'
os.makedirs('./dataset/'+os.path.splitext(os.path.basename(path))[0]+'')
os.makedirs(mask_path)
os.makedirs(ori_path)
os.makedirs(mosaic_path)
print('Add mosaic to images...')
mosaic_size = mosaic.get_autosize(img_ori_example,mask,area_type = 'bounding')*random.uniform(1,2)
models = ['squa_avg','rect_avg','squa_mid']
mosaic_type = random.randint(0,len(models)-1)
rect_rat = random.uniform(1.2,1.6)
for i in range(len(imagepaths)):
mask = impro.imread(os.path.join('./tmp/ROI_mask',imagepaths[mask_index[i]]))
img_ori = impro.imread(os.path.join('./tmp/video2image',imagepaths[i]))
img_mosaic = mosaic.addmosaic_normal(img_ori,mask,mosaic_size,model = models[mosaic_type],rect_rat=rect_rat)
mask = impro.resize(mask, min(img_ori.shape[:2]))
img_ori_crop = impro.resize(img_ori[y-size:y+size,x-size:x+size],256)
img_mosaic_crop = impro.resize(img_mosaic[y-size:y+size,x-size:x+size],256)
mask_crop = impro.resize(mask[y-size:y+size,x-size:x+size],256)
img_ori_crop = impro.resize(img_ori[y-size:y+size,x-size:x+size],256)
img_mosaic_crop = impro.resize(img_mosaic[y-size:y+size,x-size:x+size],256)
mask_crop = impro.resize(mask[y-size:y+size,x-size:x+size],256)
cv2.imwrite(os.path.join(ori_path,os.path.basename(imagepaths[i])),img_ori_crop)
cv2.imwrite(os.path.join(mosaic_path,os.path.basename(imagepaths[i])),img_mosaic_crop)
cv2.imwrite(os.path.join(mask_path,os.path.basename(imagepaths[i])),mask_crop)
\ No newline at end of file
cv2.imwrite(os.path.join(ori_path,os.path.basename(imagepaths[i])),img_ori_crop)
cv2.imwrite(os.path.join(mosaic_path,os.path.basename(imagepaths[i])),img_mosaic_crop)
cv2.imwrite(os.path.join(mask_path,os.path.basename(imagepaths[i])),mask_crop)
except Exception as e:
print(e)
\ No newline at end of file
......@@ -2,13 +2,12 @@ import torch
from .pix2pix_model import define_G
from .pix2pixHD_model import define_G as define_G_HD
from .unet_model import UNet
from .video_model import HypoNet
def pix2pix(opt):
# print(opt.model_path,opt.netG)
if opt.netG == 'HD':
netG = define_G_HD(3, 3, 64, 'global' ,4)
elif opt.netG == 'video':
netG = define_G(3*25+1, 3, 128, 'unet_128', norm='instance',use_dropout=True, init_type='normal', gpu_ids=[])
else:
netG = define_G(3, 3, 64, opt.netG, norm='batch',use_dropout=True, init_type='normal', gpu_ids=[])
......@@ -18,6 +17,15 @@ def pix2pix(opt):
netG.cuda()
return netG
def video(opt):
netG = HypoNet(3*25+1, 3)
netG.load_state_dict(torch.load(opt.model_path))
netG.eval()
if opt.use_gpu:
netG.cuda()
return netG
def unet_clean(opt):
net = UNet(n_channels = 3, n_classes = 1)
net.load_state_dict(torch.load(opt.mosaic_position_model_path))
......
import torch
import torch.nn as nn
import torch.nn.functional as F
from .unet_parts import *
from .pix2pix_model import *
class encoder_2d(nn.Module):
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
"""Construct a Resnet-based generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
use_dropout (bool) -- if use dropout layers
n_blocks (int) -- the number of ResNet blocks
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
"""
assert(n_blocks >= 0)
super(encoder_2d, self).__init__()
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling): # add downsampling layers
mult = 2 ** i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
#torch.Size([1, 256, 32, 32])
# mult = 2 ** n_downsampling
# for i in range(n_blocks): # add ResNet blocks
# model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
#torch.Size([1, 256, 32, 32])
# for i in range(n_downsampling): # add upsampling layers
# mult = 2 ** (n_downsampling - i)
# model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
# kernel_size=3, stride=2,
# padding=1, output_padding=1,
# bias=use_bias),
# norm_layer(int(ngf * mult / 2)),
# nn.ReLU(True)]
# model += [nn.ReflectionPad2d(3)]
# model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
# model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
"""Standard forward"""
return self.model(input)
class decoder_2d(nn.Module):
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
"""Construct a Resnet-based generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
use_dropout (bool) -- if use dropout layers
n_blocks (int) -- the number of ResNet blocks
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
"""
super(decoder_2d, self).__init__()
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = []
n_downsampling = 2
mult = 2 ** n_downsampling
for i in range(n_blocks): # add ResNet blocks
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
#torch.Size([1, 256, 32, 32])
for i in range(n_downsampling): # add upsampling layers
mult = 2 ** (n_downsampling - i)
# model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
# kernel_size=3, stride=2,
# padding=1, output_padding=1,
# bias=use_bias),
# norm_layer(int(ngf * mult / 2)),
# nn.ReLU(True)]
#https://distill.pub/2016/deconv-checkerboard/
#https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/190
model += [ nn.Upsample(scale_factor = 2, mode='nearest'),
nn.ReflectionPad2d(1),
nn.Conv2d(ngf * mult, int(ngf * mult / 2),kernel_size=3, stride=1, padding=0),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
# model += [nn.Tanh()]
model += [nn.Sigmoid()]
self.model = nn.Sequential(*model)
def forward(self, input):
"""Standard forward"""
return self.model(input)
class conv_3d(nn.Module):
def __init__(self,inchannel,outchannel,kernel_size=3,stride=2,padding=1):
super(conv_3d, self).__init__()
self.conv = nn.Sequential(
nn.Conv3d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm3d(outchannel),
nn.ReLU(inplace=True),
)
def forward(self, x):
x = self.conv(x)
return x
class encoder_3d(nn.Module):
def __init__(self,in_channel):
super(encoder_3d, self).__init__()
self.down1 = conv_3d(1, 64, 3, 2, 1)
self.down2 = conv_3d(64, 128, 3, 2, 1)
self.down3 = conv_3d(128, 256, 3, 1, 1)
# self.down4 = conv_3d(256, 512, 3, 2, 1)
self.conver2d = nn.Sequential(
nn.Conv2d(int(in_channel/4), 1, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(1),
nn.ReLU(inplace=True),
)
def forward(self, x):
x = x.view(x.size(0),1,x.size(1),x.size(2),x.size(3))
x = self.down1(x)
x = self.down2(x)
x = self.down3(x)
# x = self.down4(x)
x = x.view(x.size(1),x.size(2),x.size(3),x.size(4))
x = self.conver2d(x)
x = x.view(x.size(1),x.size(0),x.size(2),x.size(3))
# print(x.size())
# x = self.avgpool(x)
return x
# input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'
class HypoNet(nn.Module):
def __init__(self, in_channel, out_channel):
super(HypoNet, self).__init__()
self.encoder_2d = encoder_2d(4,-1,64,n_blocks=9)
self.encoder_3d = encoder_3d(in_channel)
self.decoder_2d = decoder_2d(4,3,64,n_blocks=9)
self.merge = nn.Sequential(
nn.Conv2d(256, 256, 1, 1, 0, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
)
def forward(self, x):
N = int((x.size()[1])/3)
x_2d = torch.cat((x[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], x[:,N-1:N,:,:]), 1)
x_2d = self.encoder_2d(x_2d)
x_3d = self.encoder_3d(x)
x = x_2d + x_3d
x = self.merge(x)
# print(x.size())
x = self.decoder_2d(x)
return x
import torch
import torch.nn as nn
import torch.nn.functional as F
from .unet_parts import *
class conv_3d(nn.Module):
def __init__(self,inchannel,outchannel,kernel_size=3,stride=2,padding=1):
super(conv_3d, self).__init__()
self.conv = nn.Sequential(
nn.Conv3d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm3d(outchannel),
nn.ReLU(inplace=True),
)
def forward(self, x):
x = self.conv(x)
return x
class encoder_3d(nn.Module):
def __init__(self,in_channel):
super(encoder_3d, self).__init__()
self.down1 = conv_3d(1, 64, 3, 2, 1)
self.down2 = conv_3d(64, 128, 3, 2, 1)
self.down3 = conv_3d(128, 256, 3, 2, 1)
self.down4 = conv_3d(256, 512, 3, 2, 1)
self.conver2d = nn.Sequential(
nn.Conv2d(int(in_channel/16)+1, 1, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(1),
nn.ReLU(inplace=True),
)
def forward(self, x):
x = x.view(x.size(0),1,x.size(1),x.size(2),x.size(3))
x = self.down1(x)
x = self.down2(x)
x = self.down3(x)
x = self.down4(x)
x = x.view(x.size(1),x.size(2),x.size(3),x.size(4))
x = self.conver2d(x)
x = x.view(x.size(1),x.size(0),x.size(2),x.size(3))
# print(x.size())
# x = self.avgpool(x)
return x
class encoder_2d(nn.Module):
def __init__(self, in_channel):
super(encoder_2d, self).__init__()
self.inc = inconv(in_channel, 64)
self.down1 = down(64, 128)
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.down4 = down(512, 512)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
return x1,x2,x3,x4,x5
class decoder_2d(nn.Module):
def __init__(self, out_channel):
super(decoder_2d, self).__init__()
self.up1 = up(1024, 256,bilinear=False)
self.up2 = up(512, 128,bilinear=False)
self.up3 = up(256, 64,bilinear=False)
self.up4 = up(128, 64,bilinear=False)
self.outc = outconv(64, out_channel)
def forward(self,x5,x4,x3,x2,x1):
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
return x
class HypoNet(nn.Module):
def __init__(self, in_channel, out_channel):
super(HypoNet, self).__init__()
self.encoder_2d = encoder_2d(4)
self.encoder_3d = encoder_3d(in_channel)
self.decoder_2d = decoder_2d(out_channel)
def forward(self, x):
N = int((x.size()[1])/3)
x_2d = torch.cat((x[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], x[:,N-1:N,:,:]), 1)
# print(x_2d.size())
x_3d = self.encoder_3d(x)
x1,x2,x3,x4,x5 = self.encoder_2d(x_2d)
x5 = x5 + x_3d
x_2d = self.decoder_2d(x5,x4,x3,x2,x1)
return x_2d
import sys
import os
import random
import datetime
import numpy as np
import cv2
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from torch import optim
from unet import UNet
def resize(img,size):
h, w = img.shape[:2]
if w >= h:
res = cv2.resize(img,(int(size*w/h), size))
else:
res = cv2.resize(img,(size, int(size*h/w)))
return res
def Totensor(img,use_gpu=True):
size=img.shape[0]
img = torch.from_numpy(img).float()
if use_gpu:
img = img.cuda()
return img
def random_color(img,random_num):
for i in range(3): img[:,:,i]=np.clip(img[:,:,i].astype('int')+random.randint(-random_num,random_num),0,255).astype('uint8')
bright = random.randint(-random_num*2,random_num*2)
for i in range(3): img[:,:,i]=np.clip(img[:,:,i].astype('int')+bright,0,255).astype('uint8')
return img
def Toinputshape(imgs,masks,finesize):
batchsize = len(imgs)
result_imgs=[];result_masks=[]
for i in range(batchsize):
# print(imgs[i].shape,masks[i].shape)
img,mask = random_transform(imgs[i], masks[i], finesize)
# print(img.shape,mask.shape)
mask = mask[:,:,0].reshape(1,finesize,finesize)/255.0
img = img.transpose((2, 0, 1))/255.0
result_imgs.append(img)
result_masks.append(mask)
result_imgs = np.array(result_imgs)
result_masks = np.array(result_masks)
return result_imgs,result_masks
def random_transform(img,mask,finesize):
# randomsize = int(finesize*(1.2+0.2*random.random())+2)
h,w = img.shape[:2]
loadsize = min((h,w))
a = (float(h)/float(w))*random.uniform(0.9, 1.1)
if h<w:
mask = cv2.resize(mask, (int(loadsize/a),loadsize))
img = cv2.resize(img, (int(loadsize/a),loadsize))
else:
mask = cv2.resize(mask, (loadsize,int(loadsize*a)))
img = cv2.resize(img, (loadsize,int(loadsize*a)))
# mask = randomsize(mask,loadsize)
# img = randomsize(img,loadsize)
#random crop
h,w = img.shape[:2]
h_move = int((h-finesize)*random.random())
w_move = int((w-finesize)*random.random())
# print(h,w,h_move,w_move)
img_crop = img[h_move:h_move+finesize,w_move:w_move+finesize]
mask_crop = mask[h_move:h_move+finesize,w_move:w_move+finesize]
#random rotation
if random.random()<0.2:
h,w = img_crop.shape[:2]
M = cv2.getRotationMatrix2D((w/2,h/2),90*int(4*random.random()),1)
img = cv2.warpAffine(img_crop,M,(w,h))
mask = cv2.warpAffine(mask_crop,M,(w,h))
else:
img,mask = img_crop,mask_crop
#random color
img=random_color(img, 15)
#random flip
if random.random()<0.5:
if random.random()<0.5:
img = cv2.flip(img,0)
mask = cv2.flip(mask,0)
else:
img = cv2.flip(img,1)
mask = cv2.flip(mask,1)
return img,mask
def randomresize(img):
size = np.min(img.shape[:2])
img = resize(img, int(size*random.uniform(1,1.2)))
img = resize(img, size)
return img
def batch_generator(images,masks,batchsize):
dataset_images = []
dataset_masks = []
for i in range(int(len(images)/batchsize)):
dataset_images.append(images[i*batchsize:(i+1)*batchsize])
dataset_masks.append(masks[i*batchsize:(i+1)*batchsize])
if len(images)%batchsize != 0:
dataset_images.append(images[len(images)-len(images)%batchsize:])
dataset_masks.append(masks[len(images)-len(images)%batchsize:])
return dataset_images,dataset_masks
def loadimage(dir_img,dir_mask,loadsize,eval_p):
t1 = datetime.datetime.now()
imgnames = os.listdir(dir_img)
print('images num:',len(imgnames))
random.shuffle(imgnames)
imgnames = (f[:-4] for f in imgnames)
images = []
masks = []
for imgname in imgnames:
img = cv2.imread(dir_img+imgname+'.jpg')
mask = cv2.imread(dir_mask+imgname+'.png')
img = resize(img,loadsize)
mask = resize(mask,loadsize)
images.append(img)
masks.append(mask)
train_images,train_masks = images[0:int(len(masks)*(1-eval_p))],masks[0:int(len(masks)*(1-eval_p))]
eval_images,eval_masks = images[int(len(masks)*(1-eval_p)):len(masks)],masks[int(len(masks)*(1-eval_p)):len(masks)]
t2 = datetime.datetime.now()
print('load data cost time:',(t2 - t1).seconds,'s')
return train_images,train_masks,eval_images,eval_masks
def showresult(img,mask,mask_pred):
img = (img.cpu().detach().numpy()*255)
mask = (mask.cpu().detach().numpy()*255)
mask_pred = (mask_pred.cpu().detach().numpy()*255)
batchsize = img.shape[0]
size = img.shape[3]
ran =int(batchsize*random.random())
showimg=np.zeros((size,size*3,3))
showimg[0:size,0:size] =img[ran].transpose((1, 2, 0))
showimg[0:size,size:size*2,1] = mask[ran].reshape(size,size)
showimg[0:size,size*2:size*3,1] = mask_pred[ran].reshape(size,size)
# cv2.imshow("", showimg.astype('uint8'))
# key = cv2.waitKey(1)
# if key == ord('q'):
# exit()
cv2.imwrite('./result.jpg', showimg)
LR = 0.001
EPOCHS = 100
BATCHSIZE = 12
LOADSIZE = 144
FINESIZE = 128
CONTINUE = True
use_gpu = True
SAVE_FRE = 5
cudnn.benchmark = False
dir_img = './origin_image/'
dir_mask = './mask/'
dir_checkpoint = 'checkpoints/'
print('loading data......')
train_images,train_masks,eval_images,eval_masks = loadimage(dir_img,dir_mask,LOADSIZE,0.2)
dataset_eval_images,dataset_eval_masks = batch_generator(eval_images,eval_masks,BATCHSIZE)
dataset_train_images,dataset_train_masks = batch_generator(train_images,train_masks,BATCHSIZE)
net = UNet(n_channels = 3, n_classes = 1)
if CONTINUE:
net.load_state_dict(torch.load(dir_checkpoint+'last.pth'))
if use_gpu:
net.cuda()
# optimizer = optim.SGD(net.parameters(),
# lr=LR,
# momentum=0.9,
# weight_decay=0.0005)
optimizer = torch.optim.Adam(net.parameters(), lr=LR, betas=(0.9, 0.99))
criterion = nn.BCELoss()
# criterion = nn.L1Loss()
print('begin training......')
for epoch in range(EPOCHS):
starttime = datetime.datetime.now()
print('Epoch {}/{}.'.format(epoch + 1, EPOCHS))
net.train()
if use_gpu:
net.cuda()
epoch_loss = 0
for i,(img,mask) in enumerate(zip(dataset_train_images,dataset_train_masks)):
# print(epoch,i,img.shape,mask.shape)
img,mask = Toinputshape(img, mask, FINESIZE)
img = Totensor(img,use_gpu)
mask = Totensor(mask,use_gpu)
mask_pred = net(img)
loss = criterion(mask_pred, mask)
epoch_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i%10 == 0:
showresult(img,mask,mask_pred)
# torch.cuda.empty_cache()
# # net.eval()
epoch_loss_eval = 0
with torch.no_grad():
for i,(img,mask) in enumerate(zip(dataset_eval_images,dataset_eval_masks)):
# print(epoch,i,img.shape,mask.shape)
img,mask = Toinputshape(img, mask, FINESIZE)
img = Totensor(img,use_gpu)
mask = Totensor(mask,use_gpu)
mask_pred = net(img)
loss = criterion(mask_pred, mask)
epoch_loss_eval += loss.item()
# torch.cuda.empty_cache()
endtime = datetime.datetime.now()
print('--- Epoch train_loss: {0:.6f} eval_loss: {1:.6f} Cost time: {2:} s'.format(
epoch_loss/len(dataset_train_images),
epoch_loss_eval/len(dataset_eval_images),
(endtime - starttime).seconds)),
torch.save(net.cpu().state_dict(),dir_checkpoint+'last.pth')
# print('--- Epoch loss: {0:.6f}'.format(epoch_loss/i))
# print('Cost time: ',(endtime - starttime).seconds,'s')
if (epoch+1)%SAVE_FRE == 0:
torch.save(net.cpu().state_dict(),dir_checkpoint+'epoch'+str(epoch+1)+'.pth')
print('network saved.')
# torch.save(net.cpu().state_dict(),dir_checkpoint+'last.pth')
# print('network saved.')
import os
import numpy as np
import cv2
import random
import torch
import torch.nn as nn
import time
import sys
sys.path.append("..")
sys.path.append("../..")
from models import runmodel,loadmodel
from util import mosaic,util,ffmpeg,filt,data
from util import image_processing as impro
from cores import Options
from models import pix2pix_model,video_model,unet_model
from matplotlib import pyplot as plt
import torch.backends.cudnn as cudnn
N = 25
ITER = 10000000
LR = 0.0002
beta1 = 0.5
use_gpu = True
use_gan = False
use_L2 = False
CONTINUE = False
lambda_L1 = 100.0
lambda_gan = 1.0
SAVE_FRE = 10000
start_iter = 0
SIZE = 128
savename = 'MosaicNet'
dir_checkpoint = 'checkpoints/'+savename
util.makedirs(dir_checkpoint)
loss_sum = [0.,0.,0.,0.]
loss_plot = [[],[]]
item_plot = []
opt = Options().getparse()
videos = os.listdir('./dataset')
videos.sort()
lengths = []
for video in videos:
video_images = os.listdir('./dataset/'+video+'/ori')
lengths.append(len(video_images))
#unet_128
#resnet_9blocks
#netG = pix2pix_model.define_G(3*N+1, 3, 128, 'resnet_6blocks', norm='instance',use_dropout=True, init_type='normal', gpu_ids=[])
netG = video_model.HypoNet(3*N+1, 3)
# netG = unet_model.UNet(3*N+1, 3)
if use_gan:
netD = pix2pix_model.define_D(3*2+1, 64, 'basic', n_layers_D=3, norm='instance', init_type='normal', init_gain=0.02, gpu_ids=[])
#netD = pix2pix_model.define_D(3*2+1, 64, 'n_layers', n_layers_D=5, norm='instance', init_type='normal', init_gain=0.02, gpu_ids=[])
if CONTINUE:
netG.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_G.pth')))
if use_gan:
netD.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_D.pth')))
f = open(os.path.join(dir_checkpoint,'iter'),'r')
start_iter = int(f.read())
f.close()
if use_gpu:
netG.cuda()
if use_gan:
netD.cuda()
cudnn.benchmark = True
optimizer_G = torch.optim.Adam(netG.parameters(), lr=LR,betas=(beta1, 0.999))
criterion_L1 = nn.L1Loss()
criterion_L2 = nn.MSELoss()
if use_gan:
optimizer_D = torch.optim.Adam(netG.parameters(), lr=LR,betas=(beta1, 0.999))
criterionGAN = pix2pix_model.GANLoss(gan_mode='lsgan').cuda()
def random_transform(src,target):
#random flip
if random.random()<0.5:
src = src[:,::-1,:]
target = target[:,::-1,:]
#random color
random_num = 15
bright = random.randint(-random_num*2,random_num*2)
for i in range(N*3): src[:,:,i]=np.clip(src[:,:,i].astype('int')+bright,0,255).astype('uint8')
for i in range(3): target[:,:,i]=np.clip(target[:,:,i].astype('int')+bright,0,255).astype('uint8')
return src,target
def showresult(img1,img2,img3,name):
img1 = (img1.cpu().detach().numpy()*255)
img2 = (img2.cpu().detach().numpy()*255)
img3 = (img3.cpu().detach().numpy()*255)
batchsize = img1.shape[0]
size = img1.shape[3]
ran =int(batchsize*random.random())
showimg=np.zeros((size,size*3,3))
showimg[0:size,0:size] =img1[ran].transpose((1, 2, 0))
showimg[0:size,size:size*2] = img2[ran].transpose((1, 2, 0))
showimg[0:size,size*2:size*3] = img3[ran].transpose((1, 2, 0))
cv2.imwrite(os.path.join(dir_checkpoint,name), showimg)
def loaddata():
video_index = random.randint(0,len(videos)-1)
video = videos[video_index]
img_index = random.randint(N,lengths[video_index]- N)
input_img = np.zeros((SIZE,SIZE,3*N+1), dtype='uint8')
for i in range(0,N):
# print('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png')
img = cv2.imread('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png')
img = impro.resize(img,SIZE)
input_img[:,:,i*3:(i+1)*3] = img
mask = cv2.imread('./dataset/'+video+'/mask/output_'+'%05d'%(img_index)+'.png',0)
mask = impro.resize(mask,SIZE)
mask = impro.mask_threshold(mask,15,128)
input_img[:,:,-1] = mask
ground_true = cv2.imread('./dataset/'+video+'/ori/output_'+'%05d'%(img_index)+'.png')
ground_true = impro.resize(ground_true,SIZE)
input_img,ground_true = random_transform(input_img,ground_true)
input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False)
ground_true = data.im2tensor(ground_true,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False)
return input_img,ground_true
print('preloading data, please wait 5s...')
input_imgs=[]
ground_trues=[]
load_cnt = 0
def preload():
global load_cnt
load_cnt += 1
while 1:
try:
input_img,ground_true = loaddata()
input_imgs.append(input_img)
ground_trues.append(ground_true)
if len(input_imgs)>10:
del(input_imgs[0])
del(ground_trues[0])
# time.sleep(0.1)
except Exception as e:
print("error:",e)
import threading
t = threading.Thread(target=preload,args=()) #t为新创建的线程
t.daemon = True
t.start()
time.sleep(5) #wait frist load
netG.train()
time_start=time.time()
print("Begin training...")
for iter in range(start_iter+1,ITER):
# input_img,ground_true = loaddata()
ran = random.randint(1, 8)
input_img = input_imgs[ran]
ground_true = ground_trues[ran]
pred = netG(input_img)
if use_gan:
netD.train()
# print(input_img[0,3*N,:,:].size())
# print((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]).size())
real_A = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], input_img[:,-1,:,:].reshape(-1,1,SIZE,SIZE)), 1)
fake_AB = torch.cat((real_A, pred), 1)
pred_fake = netD(fake_AB.detach())
loss_D_fake = criterionGAN(pred_fake, False)
real_AB = torch.cat((real_A, ground_true), 1)
pred_real = netD(real_AB)
loss_D_real = criterionGAN(pred_real, True)
loss_D = (loss_D_fake + loss_D_real) * 0.5
loss_sum[2] += loss_D_fake.item()
loss_sum[3] += loss_D_real.item()
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
netD.eval()
# fake_AB = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], pred), 1)
real_A = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], input_img[:,-1,:,:].reshape(-1,1,SIZE,SIZE)), 1)
fake_AB = torch.cat((real_A, pred), 1)
pred_fake = netD(fake_AB)
loss_G_GAN = criterionGAN(pred_fake, True)*lambda_gan
# Second, G(A) = B
if use_L2:
loss_G_L1 = (criterion_L1(pred, ground_true)+criterion_L2(pred, ground_true)) * lambda_L1
else:
loss_G_L1 = criterion_L1(pred, ground_true) * lambda_L1
# combine loss and calculate gradients
loss_G = loss_G_GAN + loss_G_L1
loss_sum[0] += loss_G_L1.item()
loss_sum[1] += loss_G_GAN.item()
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()
else:
if use_L2:
loss_G_L1 = (criterion_L1(pred, ground_true)+criterion_L2(pred, ground_true)) * lambda_L1
else:
loss_G_L1 = criterion_L1(pred, ground_true) * lambda_L1
loss_sum[0] += loss_G_L1.item()
optimizer_G.zero_grad()
loss_G_L1.backward()
optimizer_G.step()
if (iter+1)%100 == 0:
try:
showresult(input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], ground_true, pred,'result_train.png')
except Exception as e:
print(e)
if (iter+1)%1000 == 0:
time_end = time.time()
if use_gan:
print('iter:',iter+1,' L1_loss:', round(loss_sum[0]/1000,3),' G_loss:', round(loss_sum[1]/1000,3),
' D_f:',round(loss_sum[2]/1000,3),' D_r:',round(loss_sum[3]/1000,3),' time:',round((time_end-time_start)/1000,2))
if (iter+1)/1000 >= 10:
loss_plot[0].append(loss_sum[0]/1000)
loss_plot[1].append(loss_sum[1]/1000)
item_plot.append(iter+1)
try:
plt.plot(item_plot,loss_plot[0])
plt.plot(item_plot,loss_plot[1])
plt.savefig(os.path.join(dir_checkpoint,'loss.png'))
plt.close()
except Exception as e:
print("error:",e)
else:
print('iter:',iter+1,' L1_loss:',round(loss_sum[0]/1000,3),' time:',round((time_end-time_start)/1000,2))
if (iter+1)/1000 >= 10:
loss_plot[0].append(loss_sum[0]/1000)
item_plot.append(iter+1)
try:
plt.plot(item_plot,loss_plot[0])
plt.savefig(os.path.join(dir_checkpoint,'loss.png'))
plt.close()
except Exception as e:
print("error:",e)
loss_sum = [0.,0.,0.,0.]
time_start=time.time()
if (iter+1)%SAVE_FRE == 0:
if iter+1 != SAVE_FRE:
os.rename(os.path.join(dir_checkpoint,'last_G.pth'),os.path.join(dir_checkpoint,str(iter+1-SAVE_FRE)+'G.pth'))
torch.save(netG.cpu().state_dict(),os.path.join(dir_checkpoint,'last_G.pth'))
if use_gan:
if iter+1 != SAVE_FRE:
os.rename(os.path.join(dir_checkpoint,'last_D.pth'),os.path.join(dir_checkpoint,str(iter+1-SAVE_FRE)+'D.pth'))
torch.save(netD.cpu().state_dict(),os.path.join(dir_checkpoint,'last_D.pth'))
if use_gpu:
netG.cuda()
if use_gan:
netD.cuda()
f = open(os.path.join(dir_checkpoint,'iter'),'w+')
f.write(str(iter+1))
f.close()
# torch.save(netG.cpu().state_dict(),dir_checkpoint+'iter'+str(iter+1)+'.pth')
print('network saved.')
#test
netG.eval()
result = np.zeros((SIZE*2,SIZE*4,3), dtype='uint8')
test_names = os.listdir('./test')
for cnt,test_name in enumerate(test_names,0):
img_names = os.listdir(os.path.join('./test',test_name,'image'))
input_img = np.zeros((SIZE,SIZE,3*N+1), dtype='uint8')
img_names.sort()
for i in range(0,N):
img = impro.imread(os.path.join('./test',test_name,'image',img_names[i]))
img = impro.resize(img,SIZE)
input_img[:,:,i*3:(i+1)*3] = img
mask = impro.imread(os.path.join('./test',test_name,'mask.png'),'gray')
mask = impro.resize(mask,SIZE)
mask = impro.mask_threshold(mask,15,128)
input_img[:,:,-1] = mask
result[0:SIZE,SIZE*cnt:SIZE*(cnt+1),:] = input_img[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3]
input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False)
pred = netG(input_img)
pred = (pred.cpu().detach().numpy()*255)[0].transpose((1, 2, 0))
result[SIZE:SIZE*2,SIZE*cnt:SIZE*(cnt+1),:] = pred
cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.png'), result)
netG.eval()
\ No newline at end of file
import os
import numpy as np
import cv2
import random
import torch
import torch.nn as nn
import time
import sys
sys.path.append("..")
from models import runmodel,loadmodel
from util import mosaic,util,ffmpeg,filt,data
from util import image_processing as impro
from cores import Options
from models import pix2pix_model
from matplotlib import pyplot as plt
import torch.backends.cudnn as cudnn
N = 25
ITER = 1000000
LR = 0.0002
use_gpu = True
CONTINUE = True
# BATCHSIZE = 4
dir_checkpoint = 'checkpoints/'
SAVE_FRE = 5000
start_iter = 0
SIZE = 256
lambda_L1 = 100.0
opt = Options().getparse()
opt.use_gpu=True
videos = os.listdir('./dataset')
videos.sort()
lengths = []
for video in videos:
video_images = os.listdir('./dataset/'+video+'/ori')
lengths.append(len(video_images))
netG = pix2pix_model.define_G(3*N+1, 3, 128, 'resnet_9blocks', norm='instance',use_dropout=True, init_type='normal', gpu_ids=[])
netD = pix2pix_model.define_D(3*2, 64, 'basic', n_layers_D=3, norm='instance', init_type='normal', init_gain=0.02, gpu_ids=[])
if CONTINUE:
netG.load_state_dict(torch.load(dir_checkpoint+'last_G.pth'))
netD.load_state_dict(torch.load(dir_checkpoint+'last_D.pth'))
f = open('./iter','r')
start_iter = int(f.read())
f.close()
if use_gpu:
netG.cuda()
netD.cuda()
cudnn.benchmark = True
optimizer_G = torch.optim.Adam(netG.parameters(), lr=LR)
optimizer_D = torch.optim.Adam(netG.parameters(), lr=LR)
criterion_L1 = nn.L1Loss()
criterion_L2 = nn.MSELoss()
criterionGAN = pix2pix_model.GANLoss('lsgan').cuda()
def showresult(img1,img2,img3,name):
img1 = (img1.cpu().detach().numpy()*255)
img2 = (img2.cpu().detach().numpy()*255)
img3 = (img3.cpu().detach().numpy()*255)
batchsize = img1.shape[0]
size = img1.shape[3]
ran =int(batchsize*random.random())
showimg=np.zeros((size,size*3,3))
showimg[0:size,0:size] =img1[ran].transpose((1, 2, 0))
showimg[0:size,size:size*2] = img2[ran].transpose((1, 2, 0))
showimg[0:size,size*2:size*3] = img3[ran].transpose((1, 2, 0))
cv2.imwrite(name, showimg)
def loaddata():
video_index = random.randint(0,len(videos)-1)
video = videos[video_index]
img_index = random.randint(N,lengths[video_index]- N)
input_img = np.zeros((SIZE,SIZE,3*N+1), dtype='uint8')
for i in range(0,N):
# print('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png')
img = cv2.imread('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png')
img = impro.resize(img,SIZE)
input_img[:,:,i*3:(i+1)*3] = img
mask = cv2.imread('./dataset/'+video+'/mask/output_'+'%05d'%(img_index)+'.png',0)
mask = impro.resize(mask,256)
mask = impro.mask_threshold(mask,15,128)
input_img[:,:,-1] = mask
input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False)
ground_true = cv2.imread('./dataset/'+video+'/ori/output_'+'%05d'%(img_index)+'.png')
ground_true = impro.resize(ground_true,SIZE)
# ground_true = im2tensor(ground_true,use_gpu)
ground_true = data.im2tensor(ground_true,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False)
return input_img,ground_true
input_imgs=[]
ground_trues=[]
def preload():
while 1:
input_img,ground_true = loaddata()
input_imgs.append(input_img)
ground_trues.append(ground_true)
if len(input_imgs)>10:
del(input_imgs[0])
del(ground_trues[0])
import threading
t=threading.Thread(target=preload,args=()) #t为新创建的线程
t.start()
time.sleep(3) #wait frist load
netG.train()
loss_sum = [0.,0.]
loss_plot = [[],[]]
item_plot = []
time_start=time.time()
print("Begin training...")
for iter in range(start_iter+1,ITER):
# input_img,ground_true = loaddata()
ran = random.randint(0, 9)
input_img = input_imgs[ran]
ground_true = ground_trues[ran]
pred = netG(input_img)
fake_AB = torch.cat((input_img[:,int((N+1)/2)*3:(int((N+1)/2)+1)*3,:,:], pred), 1)
pred_fake = netD(fake_AB.detach())
loss_D_fake = criterionGAN(pred_fake, False)
real_AB = torch.cat((input_img[:,int((N+1)/2)*3:(int((N+1)/2)+1)*3,:,:], ground_true), 1)
pred_real = netD(real_AB)
loss_D_real = criterionGAN(pred_real, True)
loss_D = (loss_D_fake + loss_D_real) * 0.5
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
fake_AB = torch.cat((input_img[:,int((N+1)/2)*3:(int((N+1)/2)+1)*3,:,:], pred), 1)
pred_fake = netD(fake_AB)
loss_G_GAN = criterionGAN(pred_fake, True)
# Second, G(A) = B
loss_G_L1 = criterion_L1(pred, ground_true) * lambda_L1
# combine loss and calculate gradients
loss_G = loss_G_GAN + loss_G_L1
loss_sum[0] += loss_G_L1.item()
loss_sum[1] += loss_G.item()
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()
# a = netD(ground_true)
# print(a.size())
# loss = criterion_L1(pred, ground_true)+criterion_L2(pred, ground_true)
# # loss = criterion_L2(pred, ground_true)
# loss_sum += loss.item()
# optimizer_G.zero_grad()
# loss.backward()
# optimizer_G.step()
if (iter+1)%100 == 0:
showresult(input_img[:,int((N+1)/2)*3:(int((N+1)/2)+1)*3,:,:], ground_true, pred,'./result_train.png')
if (iter+1)%100 == 0:
time_end=time.time()
print('iter:',iter+1,' L1_loss:', round(loss_sum[0]/100,4),'G_loss:', round(loss_sum[1]/100,4),'time:',round((time_end-time_start)/100,4))
if (iter+1)/100 >= 10:
loss_plot[0].append(loss_sum[0]/100)
loss_plot[1].append(loss_sum[1]/100)
item_plot.append(iter+1)
plt.plot(item_plot,loss_plot[0])
plt.plot(item_plot,loss_plot[1])
plt.savefig('./loss.png')
plt.close()
loss_sum = [0.,0.]
#show test result
# netG.eval()
# input_img = np.zeros((SIZE,SIZE,3*N), dtype='uint8')
# imgs = os.listdir('./test')
# for i in range(0,N):
# # print('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png')
# img = cv2.imread('./test/'+imgs[i])
# img = impro.resize(img,SIZE)
# input_img[:,:,i*3:(i+1)*3] = img
# input_img = im2tensor(input_img,use_gpu)
# ground_true = cv2.imread('./test/output_'+'%05d'%13+'.png')
# ground_true = impro.resize(ground_true,SIZE)
# ground_true = im2tensor(ground_true,use_gpu)
# pred = netG(input_img)
# showresult(input_img[:,int((N+1)/2)*3:(int((N+1)/2)+1)*3,:,:],pred,pred,'./result_test.png')
netG.train()
time_start=time.time()
if (iter+1)%SAVE_FRE == 0:
torch.save(netG.cpu().state_dict(),dir_checkpoint+'last_G.pth')
torch.save(netD.cpu().state_dict(),dir_checkpoint+'last_D.pth')
if use_gpu:
netG.cuda()
netD.cuda()
f = open('./iter','w+')
f.write(str(iter+1))
f.close()
# torch.save(netG.cpu().state_dict(),dir_checkpoint+'iter'+str(iter+1)+'.pth')
print('network saved.')
......@@ -2,7 +2,9 @@ import cv2
import numpy as np
def imread(file_path,mod = 'normal'):
'''
mod = 'normal' | 'gray' | 'all'
'''
if mod == 'normal':
cv_img = cv2.imread(file_path)
elif mod == 'gray':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册