diff --git a/applications/tools/wav2lip.py b/applications/tools/wav2lip.py index 97def54a41fe9404858095f767c861ee20e1496b..8a708fc53bd555b4fd6244050614e54509ae8b64 100644 --- a/applications/tools/wav2lip.py +++ b/applications/tools/wav2lip.py @@ -10,17 +10,16 @@ parser = argparse.ArgumentParser( parser.add_argument('--checkpoint_path', type=str, help='Name of saved checkpoint to load weights from', - required=True) - -parser.add_argument('--face', - type=str, - help='Filepath of video/image that contains faces to use', - required=True) + default=None) parser.add_argument( '--audio', type=str, help='Filepath of video/audio file to use as raw audio source', required=True) +parser.add_argument('--face', + type=str, + help='Filepath of video/image that contains faces to use', + required=True) parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.', diff --git a/ppgan/apps/wav2lip_predictor.py b/ppgan/apps/wav2lip_predictor.py index 063c2306e55067c966e5e4a0daf20d44f285ff7a..4ba30e2464f1a75460741760138d9d7778fda532 100644 --- a/ppgan/apps/wav2lip_predictor.py +++ b/ppgan/apps/wav2lip_predictor.py @@ -6,11 +6,13 @@ import json, subprocess, random, string from tqdm import tqdm from glob import glob import paddle +from paddle.utils.download import get_weights_path_from_url from ppgan.faceutils import face_detection from ppgan.utils import audio from ppgan.models.generators.wav2lip import Wav2Lip from .base_predictor import BasePredictor +WAV2LIP_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/wav2lip_hq.pdparams' mel_step_size = 16 @@ -216,7 +218,11 @@ class Wav2LipPredictor(BasePredictor): gen = self.datagen(full_frames.copy(), mel_chunks) model = Wav2Lip() - weights = paddle.load(self.args.checkpoint_path) + if self.args.checkpoint_path is None: + model_weights_path = get_weights_path_from_url(WAV2LIP_WEIGHT_URL) + weights = paddle.load(model_weights_path) + else: + weights = paddle.load(self.args.checkpoint_path) model.load_dict(weights) model.eval() print("Model loaded") diff --git a/ppgan/models/wav2lip_hq_model.py b/ppgan/models/wav2lip_hq_model.py index 07ffdbc2d0d585cb9e262db69be58587254dc06e..034e81f9ffbd399fd8a716d8f8d8aa7f2905de0e 100644 --- a/ppgan/models/wav2lip_hq_model.py +++ b/ppgan/models/wav2lip_hq_model.py @@ -14,6 +14,7 @@ import paddle import paddle.nn.functional as F +from paddle.utils.download import get_weights_path_from_url from .base_model import BaseModel from .builder import MODELS @@ -25,7 +26,7 @@ from .wav2lip_model import cosine_loss, get_sync_loss from ..solver import build_optimizer from ..modules.init import init_weights -lipsync_weight_path = '/workspace/PaddleGAN/lipsync_expert.pdparams' +SYNCNET_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/syncnet.pdparams' @MODELS.register() @@ -65,7 +66,8 @@ class Wav2LipModelHq(BaseModel): distribution='uniform') if self.is_train: self.nets['netDS'] = build_discriminator(discriminator_sync) - params = paddle.load(lipsync_weight_path) + weights_path = get_weights_path_from_url(SYNCNET_WEIGHT_URL) + params = paddle.load(weights_path) self.nets['netDS'].load_dict(params) self.nets['netDH'] = build_discriminator(discriminator_hq) diff --git a/ppgan/models/wav2lip_model.py b/ppgan/models/wav2lip_model.py index d5a2c36daebeee31aebe7656ffd45fcf7b17c698..852d25be04cdb7a2ec5358152bbb3ac390fa9c33 100644 --- a/ppgan/models/wav2lip_model.py +++ b/ppgan/models/wav2lip_model.py @@ -13,6 +13,7 @@ # limitations under the License. import paddle +from paddle.utils.download import get_weights_path_from_url from .base_model import BaseModel from .builder import MODELS @@ -22,6 +23,7 @@ from .discriminators.builder import build_discriminator from ..solver import build_optimizer from ..modules.init import init_weights +SYNCNET_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/syncnet.pdparams' syncnet_T = 5 syncnet_mel_step_size = 16 @@ -74,7 +76,8 @@ class Wav2LipModel(BaseModel): init_weights(self.nets['netG'], distribution='uniform') if self.is_train: self.nets['netD'] = build_discriminator(discriminator) - params = paddle.load(lipsync_weight_path) + weights_path = get_weights_path_from_url(SYNCNET_WEIGHT_URL) + params = paddle.load(weights_path) self.nets['netD'].load_dict(params) if self.is_train: