diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py index a3a1013f0eb629f29750f04be111416949196da4..9c239778abc905daec691220009c6f9dc1f91fec 100644 --- a/ppgan/models/generators/__init__.py +++ b/ppgan/models/generators/__init__.py @@ -18,3 +18,4 @@ from .rrdb_net import RRDBNet from .makeup import GeneratorPSGANAttention from .resnet_ugatit import ResnetUGATITGenerator from .dcgenerator import DCGenerator +from .wav2lip import Wav2Lip diff --git a/ppgan/models/generators/wav2lip.py b/ppgan/models/generators/wav2lip.py new file mode 100644 index 0000000000000000000000000000000000000000..d419d89eec52a2c7156fb28458f20ae86cc4f0e1 --- /dev/null +++ b/ppgan/models/generators/wav2lip.py @@ -0,0 +1,403 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import nn +from paddle.nn import functional as F + +from .builder import GENERATORS +from ...modules.conv import ConvBNRelu +from ...modules.conv import NonNormConv2d +from ...modules.conv import Conv2dTransposeRelu + + +@GENERATORS.register() +class Wav2Lip(nn.Layer): + def __init__(self): + super(Wav2Lip, self).__init__() + + self.face_encoder_blocks = [ + nn.Sequential(ConvBNRelu(6, 16, kernel_size=7, stride=1, + padding=3)), # 96,96 + nn.Sequential( + ConvBNRelu(16, 32, kernel_size=3, stride=2, padding=1), # 48,48 + ConvBNRelu(32, + 32, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(32, + 32, + kernel_size=3, + stride=1, + padding=1, + residual=True)), + nn.Sequential( + ConvBNRelu(32, 64, kernel_size=3, stride=2, padding=1), # 24,24 + ConvBNRelu(64, + 64, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(64, + 64, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(64, + 64, + kernel_size=3, + stride=1, + padding=1, + residual=True)), + nn.Sequential( + ConvBNRelu(64, 128, kernel_size=3, stride=2, + padding=1), # 12,12 + ConvBNRelu(128, + 128, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(128, + 128, + kernel_size=3, + stride=1, + padding=1, + residual=True)), + nn.Sequential( + ConvBNRelu(128, 256, kernel_size=3, stride=2, padding=1), # 6,6 + ConvBNRelu(256, + 256, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(256, + 256, + kernel_size=3, + stride=1, + padding=1, + residual=True)), + nn.Sequential( + ConvBNRelu(256, 512, kernel_size=3, stride=2, padding=1), # 3,3 + ConvBNRelu(512, + 512, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ), + nn.Sequential( + ConvBNRelu(512, 512, kernel_size=3, stride=1, + padding=0), # 1, 1 + ConvBNRelu(512, 512, kernel_size=1, stride=1, padding=0)), + ] + + self.audio_encoder = nn.Sequential( + ConvBNRelu(1, 32, kernel_size=3, stride=1, padding=1), + ConvBNRelu(32, + 32, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(32, + 32, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(32, 64, kernel_size=3, stride=(3, 1), padding=1), + ConvBNRelu(64, + 64, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(64, + 64, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(64, 128, kernel_size=3, stride=3, padding=1), + ConvBNRelu(128, + 128, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(128, + 128, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(128, 256, kernel_size=3, stride=(3, 2), padding=1), + ConvBNRelu(256, + 256, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(256, 512, kernel_size=3, stride=1, padding=0), + ConvBNRelu(512, 512, kernel_size=1, stride=1, padding=0), + ) + + self.face_decoder_blocks = [ + nn.Sequential( + ConvBNRelu(512, 512, kernel_size=1, stride=1, padding=0), ), + nn.Sequential( + Conv2dTransposeRelu(1024, + 512, + kernel_size=3, + stride=1, + padding=0), # 3,3 + ConvBNRelu(512, + 512, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ), + nn.Sequential( + Conv2dTransposeRelu(1024, + 512, + kernel_size=3, + stride=2, + padding=1, + output_padding=1), + ConvBNRelu(512, + 512, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(512, + 512, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ), # 6, 6 + nn.Sequential( + Conv2dTransposeRelu(768, + 384, + kernel_size=3, + stride=2, + padding=1, + output_padding=1), + ConvBNRelu(384, + 384, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(384, + 384, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ), # 12, 12 + nn.Sequential( + Conv2dTransposeRelu(512, + 256, + kernel_size=3, + stride=2, + padding=1, + output_padding=1), + ConvBNRelu(256, + 256, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(256, + 256, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ), # 24, 24 + nn.Sequential( + Conv2dTransposeRelu(320, + 128, + kernel_size=3, + stride=2, + padding=1, + output_padding=1), + ConvBNRelu(128, + 128, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(128, + 128, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ), # 48, 48 + nn.Sequential( + Conv2dTransposeRelu(160, + 64, + kernel_size=3, + stride=2, + padding=1, + output_padding=1), + ConvBNRelu(64, + 64, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(64, + 64, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ), + ] # 96,96 + + self.output_block = nn.Sequential( + ConvBNRelu(80, 32, kernel_size=3, stride=1, padding=1), + nn.Conv2D(32, 3, kernel_size=1, stride=1, padding=0), nn.Sigmoid()) + + def forward(self, audio_sequences, face_sequences): + # audio_sequences = (B, T, 1, 80, 16) + B = audio_sequences.shape[0] + + input_dim_size = len(face_sequences.shape) + if input_dim_size > 4: + audio_sequences = paddle.concat([ + audio_sequences[:, i] for i in range(audio_sequences.shape[1]) + ], + axis=0) + face_sequences = paddle.concat([ + face_sequences[:, :, i] for i in range(face_sequences.shape[2]) + ], + axis=0) + + audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 + + feats = [] + x = face_sequences + for f in self.face_encoder_blocks: + x = f(x) + feats.append(x) + + x = audio_embedding + for f in self.face_decoder_blocks: + x = f(x) + try: + x = paddle.concat((x, feats[-1]), axis=1) + except Exception as e: + print(x.shape) + print(feats[-1].shape) + raise e + + feats.pop() + + x = self.output_block(x) + + if input_dim_size > 4: + x = paddle.split(x, B, axis=0) # [(B, C, H, W)] + outputs = paddle.stack(x, axis=2) # (B, C, T, H, W) + + else: + outputs = x + + return outputs + + +class Wav2LipDiscQual(nn.Layer): + def __init__(self): + super(Wav2LipDiscQual, self).__init__() + + self.face_encoder_blocks = [ + nn.Sequential( + NonNormConv2d(3, 32, kernel_size=7, stride=1, + padding=3)), # 48,96 + nn.Sequential( + NonNormConv2d(32, 64, kernel_size=5, stride=(1, 2), + padding=2), # 48,48 + NonNormConv2d(64, 64, kernel_size=5, stride=1, padding=2)), + nn.Sequential( + NonNormConv2d(64, 128, kernel_size=5, stride=2, + padding=2), # 24,24 + NonNormConv2d(128, 128, kernel_size=5, stride=1, padding=2)), + nn.Sequential( + NonNormConv2d(128, 256, kernel_size=5, stride=2, + padding=2), # 12,12 + NonNormConv2d(256, 256, kernel_size=5, stride=1, padding=2)), + nn.Sequential( + NonNormConv2d(256, 512, kernel_size=3, stride=2, + padding=1), # 6,6 + NonNormConv2d(512, 512, kernel_size=3, stride=1, padding=1)), + nn.Sequential( + NonNormConv2d(512, 512, kernel_size=3, stride=2, + padding=1), # 3,3 + NonNormConv2d(512, 512, kernel_size=3, stride=1, padding=1), + ), + nn.Sequential( + NonNormConv2d(512, 512, kernel_size=3, stride=1, + padding=0), # 1, 1 + NonNormConv2d(512, 512, kernel_size=1, stride=1, padding=0)), + ] + + self.binary_pred = nn.Sequential( + nn.Conv2D(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid()) + self.label_noise = .0 + + def get_lower_half(self, face_sequences): + return face_sequences[:, :, face_sequences.shape[2] // 2:] + + def to_2d(self, face_sequences): + B = face_sequences.shape[0] + face_sequences = paddle.concat( + [face_sequences[:, :, i] for i in range(face_sequences.shape[2])], + axis=0) + return face_sequences + + def perceptual_forward(self, false_face_sequences): + false_face_sequences = self.to_2d(false_face_sequences) + false_face_sequences = self.get_lower_half(false_face_sequences) + + false_feats = false_face_sequences + for f in self.face_encoder_blocks: + false_feats = f(false_feats) + + false_pred_loss = F.binary_cross_entropy( + paddle.reshape(self.binary_pred(false_feats), + (len(false_feats), -1)), + paddle.ones((len(false_feats), 1))) + + return false_pred_loss + + def forward(self, face_sequences): + face_sequences = self.to_2d(face_sequences) + face_sequences = self.get_lower_half(face_sequences) + + x = face_sequences + for f in self.face_encoder_blocks: + x = f(x) + + return paddle.reshape(self.binary_pred(x), (len(x), -1)) diff --git a/ppgan/modules/conv.py b/ppgan/modules/conv.py index a7f5b2ce6dcb4b7861529c8995aaa097d2af5e65..b6018af5883072efad86bd3bc4411cf6da55f25e 100644 --- a/ppgan/modules/conv.py +++ b/ppgan/modules/conv.py @@ -47,7 +47,7 @@ class NonNormConv2d(nn.Layer): return self.act(out) -class Conv2dTranspseRelu(nn.Layer): +class Conv2dTransposeRelu(nn.Layer): def __init__(self, cin, cout,