提交 1749be92 编写于 作者: H hypox64

Gan code finished!

上级 796b59d0
......@@ -283,9 +283,9 @@ def cleanmosaic_video_fusion(opt,netG,netM):
mosaic_input[:,:,k*3:(k+1)*3] = impro.resize(img_pool[k][y-size:y+size,x-size:x+size], INPUT_SIZE)
mask_input = impro.resize(mask,np.min(img_origin.shape[:2]))[y-size:y+size,x-size:x+size]
mosaic_input[:,:,-1] = impro.resize(mask_input, INPUT_SIZE)
mosaic_input_tensor = data.im2tensor(mosaic_input,bgr2rgb=False,gpu_id=opt.gpu_id,use_transform = False,is0_1 = False)
mosaic_input_tensor = data.im2tensor(mosaic_input,bgr2rgb=False,gpu_id=opt.gpu_id)
unmosaic_pred = netG(mosaic_input_tensor)
img_fake = data.tensor2im(unmosaic_pred,rgb2bgr = False ,is0_1 = False)
img_fake = data.tensor2im(unmosaic_pred,rgb2bgr = False)
img_result = impro.replace_mosaic(img_origin,img_fake,mask,x,y,size,opt.no_feather)
except Exception as e:
print('Warning:',e)
......
......@@ -94,22 +94,14 @@ class BVDNet(nn.Module):
def define_G(N=2, n_blocks=1, gpu_id='-1'):
netG = BVDNet(N = N, n_blocks=n_blocks)
if gpu_id != '-1' and len(gpu_id) == 1:
netG.cuda()
elif gpu_id != '-1' and len(gpu_id) > 1:
netG = nn.DataParallel(netG)
netG.cuda()
# netG.apply(model_util.init_weights)
netG = model_util.todevice(netG,gpu_id)
netG.apply(model_util.init_weights)
return netG
################################Discriminator################################
def define_D(input_nc=6, ndf=64, n_layers_D=3, use_sigmoid=False, num_D=4, gpu_id='-1'):
def define_D(input_nc=6, ndf=64, n_layers_D=1, use_sigmoid=False, num_D=3, gpu_id='-1'):
netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid, num_D)
if gpu_id != '-1' and len(gpu_id) == 1:
netD.cuda()
elif gpu_id != '-1' and len(gpu_id) > 1:
netD = nn.DataParallel(netD)
netD.cuda()
netD = model_util.todevice(netD,gpu_id)
netD.apply(model_util.init_weights)
return netD
......@@ -191,16 +183,16 @@ class GANLoss(nn.Module):
if self.mode == 'D':
loss = 0
for i in range(len(dis_fake)):
loss += self.lossf(dis_fake[i][0],dis_real[i][0])
loss += self.lossf(dis_fake[i][-1],dis_real[i][-1])
elif self.mode =='G':
loss = 0
weight = 2**len(dis_fake)
for i in range(len(dis_fake)):
weight = weight/2
loss += weight*self.lossf(dis_fake[i][0])
loss += weight*self.lossf(dis_fake[i][-1])
return loss
else:
if self.mode == 'D':
return self.lossf(dis_fake[0],dis_real[0])
return self.lossf(dis_fake[-1],dis_real[-1])
elif self.mode =='G':
return self.lossf(dis_fake[0])
return self.lossf(dis_fake[-1])
......@@ -2,7 +2,7 @@
import torch.nn as nn
import torch
import torch.nn.functional as F
from . import components
from . import model_util
import warnings
warnings.filterwarnings(action='ignore')
......@@ -43,7 +43,7 @@ class DiceLoss(nn.Module):
class resnet18(torch.nn.Module):
def __init__(self, pretrained=True):
super().__init__()
self.features = components.resnet18(pretrained=pretrained)
self.features = model_util.resnet18(pretrained=pretrained)
self.conv1 = self.features.conv1
self.bn1 = self.features.bn1
self.relu = self.features.relu
......@@ -70,7 +70,7 @@ class resnet18(torch.nn.Module):
class resnet101(torch.nn.Module):
def __init__(self, pretrained=True):
super().__init__()
self.features = components.resnet101(pretrained=pretrained)
self.features = model_util.resnet101(pretrained=pretrained)
self.conv1 = self.features.conv1
self.bn1 = self.features.bn1
self.relu = self.features.relu
......
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, planes)
self.bn1 = norm_layer(planes)
self.conv2 = conv3x3(planes, planes, stride)
self.bn2 = norm_layer(planes)
self.conv3 = conv1x1(planes, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self.inplanes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None):
if norm_layer is None:
norm_layer = nn.BatchNorm2d
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
return model
def resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
return model
def resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
def resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model
def resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model
\ No newline at end of file
import torch
from . import model_util
from .pix2pix_model import define_G
from .pix2pixHD_model import define_G as define_G_HD
from .unet_model import UNet
from .video_model import MosaicNet
from .videoHD_model import MosaicNet as MosaicNet_HD
from .BiSeNet_model import BiSeNet
......@@ -11,19 +11,6 @@ def show_paramsnumber(net,netname='net'):
parameters = round(parameters/1e6,2)
print(netname+' parameters: '+str(parameters)+'M')
def __patch_instance_norm_state_dict(state_dict, module, keys, i=0):
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'num_batches_tracked'):
state_dict.pop('.'.join(keys))
else:
__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
def pix2pix(opt):
# print(opt.model_path,opt.netG)
......@@ -33,9 +20,8 @@ def pix2pix(opt):
netG = define_G(3, 3, 64, opt.netG, norm='batch',use_dropout=True, init_type='normal', gpu_ids=[])
show_paramsnumber(netG,'netG')
netG.load_state_dict(torch.load(opt.model_path))
netG = model_util.todevice(netG,opt.gpu_id)
netG.eval()
if opt.gpu_id != -1:
netG.cuda()
return netG
......@@ -57,11 +43,11 @@ def style(opt):
# patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
__patch_instance_norm_state_dict(state_dict, netG, key.split('.'))
model_util.patch_instance_norm_state_dict(state_dict, netG, key.split('.'))
netG.load_state_dict(state_dict)
if opt.gpu_id != -1:
netG.cuda()
netG = model_util.todevice(netG,opt.gpu_id)
netG.eval()
return netG
def video(opt):
......@@ -71,9 +57,8 @@ def video(opt):
netG = MosaicNet(3*25+1, 3,norm = 'batch')
show_paramsnumber(netG,'netG')
netG.load_state_dict(torch.load(opt.model_path))
netG = model_util.todevice(netG,opt.gpu_id)
netG.eval()
if opt.gpu_id != -1:
netG.cuda()
return netG
def bisenet(opt,type='roi'):
......@@ -86,7 +71,6 @@ def bisenet(opt,type='roi'):
net.load_state_dict(torch.load(opt.model_path))
elif type == 'mosaic':
net.load_state_dict(torch.load(opt.mosaic_position_model_path))
net = model_util.todevice(net,opt.gpu_id)
net.eval()
if opt.gpu_id != -1:
net.cuda()
return net
......@@ -8,7 +8,9 @@ from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn.utils.spectral_norm as SpectralNorm
from torchvision import models
import torch.utils.model_zoo as model_zoo
################################## IO ##################################
def save(net,path,gpu_id):
if isinstance(net, nn.DataParallel):
torch.save(net.module.cpu().state_dict(),path)
......@@ -17,6 +19,29 @@ def save(net,path,gpu_id):
if gpu_id != '-1':
net.cuda()
def todevice(net,gpu_id):
if gpu_id != '-1' and len(gpu_id) == 1:
net.cuda()
elif gpu_id != '-1' and len(gpu_id) > 1:
net = nn.DataParallel(net)
net.cuda()
return net
# patch InstanceNorm checkpoints prior to 0.4
def patch_instance_norm_state_dict(state_dict, module, keys, i=0):
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'num_batches_tracked'):
state_dict.pop('.'.join(keys))
else:
patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
################################## initialization ##################################
def get_norm_layer(norm_type='instance',mod = '2d'):
if norm_type == 'batch':
......@@ -60,6 +85,7 @@ def init_weights(net, init_type='normal', gain=0.02):
net.apply(init_func)
################################## Network structure ##################################
################################## ResnetBlock ##################################
class ResnetBlockSpectralNorm(nn.Module):
def __init__(self, dim, padding_type, activation=nn.LeakyReLU(0.2), use_dropout=False):
super(ResnetBlockSpectralNorm, self).__init__()
......@@ -99,6 +125,193 @@ class ResnetBlockSpectralNorm(nn.Module):
out = x + self.conv_block(x)
return out
################################## Resnet ##################################
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, planes)
self.bn1 = norm_layer(planes)
self.conv2 = conv3x3(planes, planes, stride)
self.bn2 = norm_layer(planes)
self.conv3 = conv1x1(planes, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self.inplanes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None):
if norm_layer is None:
norm_layer = nn.BatchNorm2d
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
return model
def resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model
################################## Loss function ##################################
class HingeLossD(nn.Module):
def __init__(self):
......@@ -114,7 +327,7 @@ class HingeLossG(nn.Module):
super(HingeLossG, self).__init__()
def forward(self, dis_fake):
loss_fake = F.relu(-torch.mean(dis_fake))
loss_fake = -torch.mean(dis_fake)
return loss_fake
class VGGLoss(nn.Module):
......
......@@ -7,11 +7,11 @@ from util import data
import torch
import numpy as np
def run_segment(img,net,size = 360,gpu_id = 0):
def run_segment(img,net,size = 360,gpu_id = '-1'):
img = impro.resize(img,size)
img = data.im2tensor(img,gpu_id = gpu_id, bgr2rgb = False,use_transform = False , is0_1 = True)
img = data.im2tensor(img,gpu_id = gpu_id, bgr2rgb = False, is0_1 = True)
mask = net(img)
mask = data.tensor2im(mask, gray=True,rgb2bgr = False, is0_1 = True)
mask = data.tensor2im(mask, gray=True, is0_1 = True)
return mask
def run_pix2pix(img,net,opt):
......@@ -50,12 +50,12 @@ def run_styletransfer(opt, net, img):
else:
canny_low = opt.canny-int(opt.canny/2)
canny_high = opt.canny+int(opt.canny/2)
img = cv2.Canny(img,opt.canny-50,opt.canny+50)
img = cv2.Canny(img,canny_low,canny_high)
if opt.only_edges:
return img
img = data.im2tensor(img,gpu_id=opt.gpu_id,gray=True,use_transform = False,is0_1 = False)
img = data.im2tensor(img,gpu_id=opt.gpu_id,gray=True)
else:
img = data.im2tensor(img,gpu_id=opt.gpu_id,gray=False,use_transform = True)
img = data.im2tensor(img,gpu_id=opt.gpu_id)
img = net(img)
img = data.tensor2im(img)
return img
......
......@@ -31,15 +31,19 @@ opt.parser.add_argument('--beta2',type=float,default=0.999, help='')
opt.parser.add_argument('--finesize',type=int,default=256, help='')
opt.parser.add_argument('--loadsize',type=int,default=286, help='')
opt.parser.add_argument('--batchsize',type=int,default=1, help='')
opt.parser.add_argument('--no_gan', action='store_true', help='if specified, do not use gan')
opt.parser.add_argument('--n_layers_D',type=int,default=1, help='')
opt.parser.add_argument('--num_D',type=int,default=3, help='')
opt.parser.add_argument('--lambda_L2',type=float,default=100, help='')
opt.parser.add_argument('--lambda_VGG',type=float,default=1, help='')
opt.parser.add_argument('--lambda_GAN',type=float,default=1, help='')
opt.parser.add_argument('--lambda_D',type=float,default=1, help='')
opt.parser.add_argument('--load_thread',type=int,default=4, help='number of thread for loading data')
opt.parser.add_argument('--dataset',type=str,default='./datasets/face/', help='')
opt.parser.add_argument('--dataset_test',type=str,default='./datasets/face_test/', help='')
opt.parser.add_argument('--n_epoch',type=int,default=200, help='')
opt.parser.add_argument('--save_freq',type=int,default=100000, help='')
opt.parser.add_argument('--save_freq',type=int,default=10000, help='')
opt.parser.add_argument('--continue_train', action='store_true', help='')
opt.parser.add_argument('--savename',type=str,default='face', help='')
opt.parser.add_argument('--showresult_freq',type=int,default=1000, help='')
......@@ -84,16 +88,16 @@ TBGlobalWriter = SummaryWriter(tensorboard_savedir)
'''
if opt.gpu_id != '-1' and len(opt.gpu_id) == 1:
torch.backends.cudnn.benchmark = True
netG = BVDNet.define_G(opt.N,gpu_id=opt.gpu_id)
netD = BVDNet.define_D(gpu_id=opt.gpu_id)
netG = BVDNet.define_G(opt.N,gpu_id=opt.gpu_id)
optimizer_G = torch.optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
lossfun_L2 = nn.MSELoss()
lossfun_VGG = model_util.VGGLoss(opt.gpu_id)
lossfun_GAND = BVDNet.GANLoss('D')
lossfun_GANG = BVDNet.GANLoss('G')
if not opt.no_gan:
netD = BVDNet.define_D(n_layers_D=opt.n_layers_D,num_D=opt.num_D,gpu_id=opt.gpu_id)
optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
lossfun_GAND = BVDNet.GANLoss('D')
lossfun_GANG = BVDNet.GANLoss('G')
'''
--------------------------Init DataLoader--------------------------
......@@ -130,33 +134,42 @@ for train_iter in range(Videodataloader_train.n_iter):
# Fake Generator
out = netG(mosaic_stream,previous_frame)
# Discriminator
dis_real = netD(torch.cat((mosaic_stream[:,:,opt.N],ori_stream[:,:,opt.N].detach()),dim=1))
dis_fake_D = netD(torch.cat((mosaic_stream[:,:,opt.N],out.detach()),dim=1))
loss_D = lossfun_GAND(dis_fake_D,dis_real) * opt.lambda_GAN
if not opt.no_gan:
dis_real = netD(torch.cat((mosaic_stream[:,:,opt.N],ori_stream[:,:,opt.N].detach()),dim=1))
dis_fake_D = netD(torch.cat((mosaic_stream[:,:,opt.N],out.detach()),dim=1))
loss_D = lossfun_GAND(dis_fake_D,dis_real) * opt.lambda_GAN * opt.lambda_D
# Generator
dis_fake_G = netD(torch.cat((mosaic_stream[:,:,opt.N],out),dim=1))
loss_L2 = lossfun_L2(out,ori_stream[:,:,opt.N]) * opt.lambda_L2
loss_VGG = lossfun_VGG(out,ori_stream[:,:,opt.N]) * opt.lambda_VGG
loss_GANG = lossfun_GANG(dis_fake_G) * opt.lambda_GAN
loss_G = loss_L2+loss_VGG+loss_GANG
loss_G = loss_L2+loss_VGG
if not opt.no_gan:
dis_fake_G = netD(torch.cat((mosaic_stream[:,:,opt.N],out),dim=1))
loss_GANG = lossfun_GANG(dis_fake_G) * opt.lambda_GAN
loss_G = loss_G + loss_GANG
############### Backward Pass ####################
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
previous_predframe_tmp = out.detach().cpu().numpy()
if not opt.no_gan:
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
TBGlobalWriter.add_scalars('loss/train', {'L2':loss_L2.item(),'VGG':loss_VGG.item(),
'loss_D':loss_D.item(),'loss_G':loss_G.item()}, train_iter)
previous_predframe_tmp = out.detach().cpu().numpy()
if not opt.no_gan:
TBGlobalWriter.add_scalars('loss/train', {'L2':loss_L2.item(),'VGG':loss_VGG.item(),
'loss_D':loss_D.item(),'loss_G':loss_G.item()}, train_iter)
else:
TBGlobalWriter.add_scalars('loss/train', {'L2':loss_L2.item(),'VGG':loss_VGG.item()}, train_iter)
# save network
if train_iter%opt.save_freq == 0 and train_iter != 0:
model_util.save(netG, os.path.join('checkpoints',opt.savename,str(train_iter)+'_G.pth'), opt.gpu_id)
model_util.save(netD, os.path.join('checkpoints',opt.savename,str(train_iter)+'_D.pth'), opt.gpu_id)
if not opt.no_gan:
model_util.save(netD, os.path.join('checkpoints',opt.savename,str(train_iter)+'_D.pth'), opt.gpu_id)
# Image quality evaluation
if train_iter%(opt.showresult_freq//10) == 0:
......@@ -213,7 +226,7 @@ for train_iter in range(Videodataloader_train.n_iter):
mosaic_stream.append(_mosaic)
if step == 0:
previous = impro.imread(os.path.join(opt.dataset_test,video,'image',frames[opt.N*opt.S-1]),loadsize=opt.finesize,rgb=True)
previous = data.im2tensor(previous,bgr2rgb = False, gpu_id = opt.gpu_id,use_transform = False, is0_1 = False)
previous = data.im2tensor(previous,bgr2rgb = False, gpu_id = opt.gpu_id, is0_1 = False)
mosaic_stream = (np.array(mosaic_stream).astype(np.float32)/255.0-0.5)/0.5
mosaic_stream = mosaic_stream.reshape(1,opt.T,opt.finesize,opt.finesize,3).transpose((0,4,1,2,3))
mosaic_stream = data.to_tensor(mosaic_stream, opt.gpu_id)
......
......@@ -6,11 +6,6 @@ import torchvision.transforms as transforms
import cv2
from . import image_processing as impro
from . import degradater
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
]
)
def to_tensor(data,gpu_id):
data = torch.from_numpy(data)
......@@ -18,8 +13,7 @@ def to_tensor(data,gpu_id):
data = data.cuda()
return data
def tensor2im(image_tensor, imtype=np.uint8, gray=False, rgb2bgr = True ,is0_1 = False, batch_index=0):
def tensor2im(image_tensor, gray=False, rgb2bgr = True ,is0_1 = False, batch_index=0):
image_tensor =image_tensor.data
image_numpy = image_tensor[batch_index].cpu().float().numpy()
......@@ -31,7 +25,7 @@ def tensor2im(image_tensor, imtype=np.uint8, gray=False, rgb2bgr = True ,is0_1 =
if gray:
h, w = image_numpy.shape[1:]
image_numpy = image_numpy.reshape(h,w)
return image_numpy.astype(imtype)
return image_numpy.astype(np.uint8)
# output 3ch
if image_numpy.shape[0] == 1:
......@@ -39,11 +33,10 @@ def tensor2im(image_tensor, imtype=np.uint8, gray=False, rgb2bgr = True ,is0_1 =
image_numpy = image_numpy.transpose((1, 2, 0))
if rgb2bgr and not gray:
image_numpy = image_numpy[...,::-1]-np.zeros_like(image_numpy)
return image_numpy.astype(imtype)
return image_numpy.astype(np.uint8)
def im2tensor(image_numpy, imtype=np.uint8, gray=False,bgr2rgb = True, reshape = True, gpu_id = 0, use_transform = True,is0_1 = True):
def im2tensor(image_numpy, gray=False,bgr2rgb = True, reshape = True, gpu_id = '-1',is0_1 = False):
if gray:
h, w = image_numpy.shape
image_numpy = (image_numpy/255.0-0.5)/0.5
......@@ -54,15 +47,12 @@ def im2tensor(image_numpy, imtype=np.uint8, gray=False,bgr2rgb = True, reshape =
h, w ,ch = image_numpy.shape
if bgr2rgb:
image_numpy = image_numpy[...,::-1]-np.zeros_like(image_numpy)
if use_transform:
image_tensor = transform(image_numpy)
if is0_1:
image_numpy = image_numpy/255.0
else:
if is0_1:
image_numpy = image_numpy/255.0
else:
image_numpy = (image_numpy/255.0-0.5)/0.5
image_numpy = image_numpy.transpose((2, 0, 1))
image_tensor = torch.from_numpy(image_numpy).float()
image_numpy = (image_numpy/255.0-0.5)/0.5
image_numpy = image_numpy.transpose((2, 0, 1))
image_tensor = torch.from_numpy(image_numpy).float()
if reshape:
image_tensor = image_tensor.reshape(1,ch,h,w)
if gpu_id != '-1':
......@@ -75,7 +65,6 @@ def shuffledata(data,target):
np.random.set_state(state)
np.random.shuffle(target)
def random_transform_single_mask(img,out_shape):
out_h,out_w = out_shape
img = cv2.resize(img,(int(out_w*random.uniform(1.1, 1.5)),int(out_h*random.uniform(1.1, 1.5))))
......@@ -105,7 +94,7 @@ def get_transform_params():
color_rate = [np.random.uniform(-0.05,0.05),np.random.uniform(-0.05,0.05),np.random.uniform(-0.05,0.05),
np.random.uniform(-0.05,0.05),np.random.uniform(-0.05,0.05)]
flip_rate = np.random.random()
degradate_params = degradater.get_random_degenerate_params(mod='weaker_1')
degradate_params = degradater.get_random_degenerate_params(mod='weaker_2')
rate_dict = {'crop':crop_rate,'rotat':rotat_rate,'color':color_rate,'flip':flip_rate,'degradate':degradate_params}
return {'flag':flag_dict,'rate':rate_dict}
......@@ -113,6 +102,9 @@ def get_transform_params():
def random_transform_single_image(img,finesize,params=None,test_flag = False):
if params is None:
params = get_transform_params()
if params['flag']['degradate']:
img = degradater.degradate(img,params['rate']['degradate'])
if params['flag']['crop']:
h,w = img.shape[:2]
......@@ -135,9 +127,6 @@ def random_transform_single_image(img,finesize,params=None,test_flag = False):
if params['flag']['flip']:
img = img[:,::-1,:]
if params['flag']['degradate']:
img = degradater.degradate(img,params['rate']['degradate'])
#check shape
if img.shape[0]!= finesize or img.shape[1]!= finesize:
img = cv2.resize(img,(finesize,finesize))
......
......@@ -28,17 +28,11 @@ class VideoLoader(object):
feg_mask = impro.imread(os.path.join(video_dir,'mask','00001.png'),mod='gray',loadsize=self.opt.loadsize)
self.mosaic_size,self.mod,self.rect_rat,self.feather = mosaic.get_random_parameter(feg_ori,feg_mask)
self.startpos = [random.randint(0,self.mosaic_size),random.randint(0,self.mosaic_size)]
self.loadsize = self.opt.loadsize
#Init load pool
for i in range(self.opt.S*self.opt.T):
# random
if np.random.random()<0.05:
self.startpos = [random.randint(0,self.mosaic_size),random.randint(0,self.mosaic_size)]
if np.random.random()<0.02:
self.transform_params['rate']['crop'] = [np.random.random(),np.random.random()]
_ori_img = impro.imread(os.path.join(video_dir,'origin_image','%05d' % (i+1)+'.jpg'),loadsize=self.opt.loadsize,rgb=True)
_mask = impro.imread(os.path.join(video_dir,'mask','%05d' % (i+1)+'.png' ),mod='gray',loadsize=self.opt.loadsize)
_ori_img = impro.imread(os.path.join(video_dir,'origin_image','%05d' % (i+1)+'.jpg'),loadsize=self.loadsize,rgb=True)
_mask = impro.imread(os.path.join(video_dir,'mask','%05d' % (i+1)+'.png' ),mod='gray',loadsize=self.loadsize)
_mosaic_img = mosaic.addmosaic_base(_ori_img, _mask, self.mosaic_size,0, self.mod,self.rect_rat,self.feather,self.startpos)
_ori_img = data.random_transform_single_image(_ori_img,opt.finesize,self.transform_params)
_mosaic_img = data.random_transform_single_image(_mosaic_img,opt.finesize,self.transform_params)
......@@ -70,13 +64,21 @@ class VideoLoader(object):
return np.clip((data*0.5+0.5)*255,0,255).astype(np.uint8)
def next(self):
# random
if np.random.random()<0.05:
self.startpos = [random.randint(0,self.mosaic_size),random.randint(0,self.mosaic_size)]
if np.random.random()<0.02:
self.transform_params['rate']['crop'] = [np.random.random(),np.random.random()]
if np.random.random()<0.02:
self.loadsize = np.random.randint(self.opt.finesize,self.opt.loadsize)
if self.t != 0:
self.previous_pred = None
self.ori_load_pool [:self.opt.S*self.opt.T-1] = self.ori_load_pool [1:self.opt.S*self.opt.T]
self.mosaic_load_pool[:self.opt.S*self.opt.T-1] = self.mosaic_load_pool[1:self.opt.S*self.opt.T]
#print(os.path.join(self.video_dir,'origin_image','%05d' % (self.opt.S*self.opt.T+self.t)+'.jpg'))
_ori_img = impro.imread(os.path.join(self.video_dir,'origin_image','%05d' % (self.opt.S*self.opt.T+self.t)+'.jpg'),loadsize=self.opt.loadsize,rgb=True)
_mask = impro.imread(os.path.join(self.video_dir,'mask','%05d' % (self.opt.S*self.opt.T+self.t)+'.png' ),mod='gray',loadsize=self.opt.loadsize)
_ori_img = impro.imread(os.path.join(self.video_dir,'origin_image','%05d' % (self.opt.S*self.opt.T+self.t)+'.jpg'),loadsize=self.loadsize,rgb=True)
_mask = impro.imread(os.path.join(self.video_dir,'mask','%05d' % (self.opt.S*self.opt.T+self.t)+'.png' ),mod='gray',loadsize=self.loadsize)
_mosaic_img = mosaic.addmosaic_base(_ori_img, _mask, self.mosaic_size,0, self.mod,self.rect_rat,self.feather,self.startpos)
_ori_img = data.random_transform_single_image(_ori_img,self.opt.finesize,self.transform_params)
_mosaic_img = data.random_transform_single_image(_mosaic_img,self.opt.finesize,self.transform_params)
......
......@@ -98,7 +98,7 @@ def get_random_degenerate_params(mod='strong'):
return params
def degradate(img,params,jpeg_last = False):
def degradate(img,params,jpeg_last = True):
shape = img.shape
if not params:
params = get_random_degenerate_params('original')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册