import torch import torch.nn as nn from .pix2pixHD_model import * from .model_util import * from models import model_util class UpBlock(nn.Module): def __init__(self, in_channel, out_channel, kernel_size=3, padding=1): super().__init__() self.convup = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), nn.ReflectionPad2d(padding), # EqualConv2d(out_channel, out_channel, kernel_size, padding=padding), SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size)), nn.LeakyReLU(0.2), # Blur(out_channel), ) def forward(self, input): outup = self.convup(input) return outup class Encoder2d(nn.Module): def __init__(self, input_nc, ngf=64, n_downsampling=3, activation = nn.LeakyReLU(0.2)): super(Encoder2d, self).__init__() model = [nn.ReflectionPad2d(3), SpectralNorm(nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0)), activation] ### downsample for i in range(n_downsampling): mult = 2**i model += [ nn.ReflectionPad2d(1), SpectralNorm(nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=0)), activation] self.model = nn.Sequential(*model) def forward(self, input): return self.model(input) class Encoder3d(nn.Module): def __init__(self, input_nc, ngf=64, n_downsampling=3, activation = nn.LeakyReLU(0.2)): super(Encoder3d, self).__init__() model = [SpectralNorm(nn.Conv3d(input_nc, ngf, kernel_size=3, padding=1)), activation] ### downsample for i in range(n_downsampling): mult = 2**i model += [ SpectralNorm(nn.Conv3d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1)), activation] self.model = nn.Sequential(*model) def forward(self, input): return self.model(input) class BVDNet(nn.Module): def __init__(self, N=2, n_downsampling=3, n_blocks=4, input_nc=3, output_nc=3,activation=nn.LeakyReLU(0.2)): super(BVDNet, self).__init__() ngf = 64 padding_type = 'reflect' self.N = N ### encoder self.encoder3d = Encoder3d(input_nc,64,n_downsampling,activation) self.encoder2d = Encoder2d(input_nc,64,n_downsampling,activation) ### resnet blocks self.blocks = [] mult = 2**n_downsampling for i in range(n_blocks): self.blocks += [ResnetBlockSpectralNorm(ngf * mult, padding_type=padding_type, activation=activation)] self.blocks = nn.Sequential(*self.blocks) ### decoder self.decoder = [] for i in range(n_downsampling): mult = 2**(n_downsampling - i) self.decoder += [UpBlock(ngf * mult, int(ngf * mult / 2))] self.decoder += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] self.decoder = nn.Sequential(*self.decoder) self.limiter = nn.Tanh() def forward(self, stream, previous): this_shortcut = stream[:,:,self.N] stream = self.encoder3d(stream) stream = stream.reshape(stream.size(0),stream.size(1),stream.size(3),stream.size(4)) previous = self.encoder2d(previous) x = stream + previous x = self.blocks(x) x = self.decoder(x) x = x+this_shortcut x = self.limiter(x) return x def define_G(N=2, n_blocks=1, gpu_id='-1'): netG = BVDNet(N = N, n_blocks=n_blocks) netG = model_util.todevice(netG,gpu_id) netG.apply(model_util.init_weights) return netG ################################Discriminator################################ def define_D(input_nc=6, ndf=64, n_layers_D=1, use_sigmoid=False, num_D=3, gpu_id='-1'): netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid, num_D) netD = model_util.todevice(netD,gpu_id) netD.apply(model_util.init_weights) return netD class MultiscaleDiscriminator(nn.Module): def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, num_D=3): super(MultiscaleDiscriminator, self).__init__() self.num_D = num_D self.n_layers = n_layers for i in range(num_D): netD = NLayerDiscriminator(input_nc, ndf, n_layers, use_sigmoid) setattr(self, 'layer'+str(i), netD.model) self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) def singleD_forward(self, model, input): return [model(input)] def forward(self, input): num_D = self.num_D result = [] input_downsampled = input for i in range(num_D): model = getattr(self, 'layer'+str(num_D-1-i)) result.append(self.singleD_forward(model, input_downsampled)) if i != (num_D-1): input_downsampled = self.downsample(input_downsampled) return result # Defines the PatchGAN discriminator with the specified arguments. class NLayerDiscriminator(nn.Module): def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False): super(NLayerDiscriminator, self).__init__() self.n_layers = n_layers kw = 4 padw = int(np.ceil((kw-1.0)/2)) sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2)]] nf = ndf for n in range(1, n_layers): nf_prev = nf nf = min(nf * 2, 512) sequence += [[ SpectralNorm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw)), nn.LeakyReLU(0.2) ]] nf_prev = nf nf = min(nf * 2, 512) sequence += [[ SpectralNorm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw)), nn.LeakyReLU(0.2) ]] sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] if use_sigmoid: sequence += [[nn.Sigmoid()]] sequence_stream = [] for n in range(len(sequence)): sequence_stream += sequence[n] self.model = nn.Sequential(*sequence_stream) def forward(self, input): return self.model(input) class GANLoss(nn.Module): def __init__(self, mode='D'): super(GANLoss, self).__init__() if mode == 'D': self.lossf = model_util.HingeLossD() elif mode == 'G': self.lossf = model_util.HingeLossG() self.mode = mode def forward(self, dis_fake = None, dis_real = None): if isinstance(dis_fake, list): if self.mode == 'D': loss = 0 for i in range(len(dis_fake)): loss += self.lossf(dis_fake[i][-1],dis_real[i][-1]) elif self.mode =='G': loss = 0 weight = 2**len(dis_fake) for i in range(len(dis_fake)): weight = weight/2 loss += weight*self.lossf(dis_fake[i][-1]) return loss else: if self.mode == 'D': return self.lossf(dis_fake[-1],dis_real[-1]) elif self.mode =='G': return self.lossf(dis_fake[-1])