From 43ffc94969756fd9c048b9dfba96587d285b4ffc Mon Sep 17 00:00:00 2001 From: lijianshe02 <48898730+lijianshe02@users.noreply.github.com> Date: Mon, 18 Jan 2021 15:04:10 +0800 Subject: [PATCH] add automatic weight download (#146) * add automatic weight download --- applications/tools/wav2lip.py | 11 +++++------ ppgan/apps/wav2lip_predictor.py | 8 +++++++- ppgan/models/wav2lip_hq_model.py | 6 ++++-- ppgan/models/wav2lip_model.py | 5 ++++- 4 files changed, 20 insertions(+), 10 deletions(-) diff --git a/applications/tools/wav2lip.py b/applications/tools/wav2lip.py index 97def54..8a708fc 100644 --- a/applications/tools/wav2lip.py +++ b/applications/tools/wav2lip.py @@ -10,17 +10,16 @@ parser = argparse.ArgumentParser( parser.add_argument('--checkpoint_path', type=str, help='Name of saved checkpoint to load weights from', - required=True) - -parser.add_argument('--face', - type=str, - help='Filepath of video/image that contains faces to use', - required=True) + default=None) parser.add_argument( '--audio', type=str, help='Filepath of video/audio file to use as raw audio source', required=True) +parser.add_argument('--face', + type=str, + help='Filepath of video/image that contains faces to use', + required=True) parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.', diff --git a/ppgan/apps/wav2lip_predictor.py b/ppgan/apps/wav2lip_predictor.py index 063c230..4ba30e2 100644 --- a/ppgan/apps/wav2lip_predictor.py +++ b/ppgan/apps/wav2lip_predictor.py @@ -6,11 +6,13 @@ import json, subprocess, random, string from tqdm import tqdm from glob import glob import paddle +from paddle.utils.download import get_weights_path_from_url from ppgan.faceutils import face_detection from ppgan.utils import audio from ppgan.models.generators.wav2lip import Wav2Lip from .base_predictor import BasePredictor +WAV2LIP_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/wav2lip_hq.pdparams' mel_step_size = 16 @@ -216,7 +218,11 @@ class Wav2LipPredictor(BasePredictor): gen = self.datagen(full_frames.copy(), mel_chunks) model = Wav2Lip() - weights = paddle.load(self.args.checkpoint_path) + if self.args.checkpoint_path is None: + model_weights_path = get_weights_path_from_url(WAV2LIP_WEIGHT_URL) + weights = paddle.load(model_weights_path) + else: + weights = paddle.load(self.args.checkpoint_path) model.load_dict(weights) model.eval() print("Model loaded") diff --git a/ppgan/models/wav2lip_hq_model.py b/ppgan/models/wav2lip_hq_model.py index 07ffdbc..034e81f 100644 --- a/ppgan/models/wav2lip_hq_model.py +++ b/ppgan/models/wav2lip_hq_model.py @@ -14,6 +14,7 @@ import paddle import paddle.nn.functional as F +from paddle.utils.download import get_weights_path_from_url from .base_model import BaseModel from .builder import MODELS @@ -25,7 +26,7 @@ from .wav2lip_model import cosine_loss, get_sync_loss from ..solver import build_optimizer from ..modules.init import init_weights -lipsync_weight_path = '/workspace/PaddleGAN/lipsync_expert.pdparams' +SYNCNET_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/syncnet.pdparams' @MODELS.register() @@ -65,7 +66,8 @@ class Wav2LipModelHq(BaseModel): distribution='uniform') if self.is_train: self.nets['netDS'] = build_discriminator(discriminator_sync) - params = paddle.load(lipsync_weight_path) + weights_path = get_weights_path_from_url(SYNCNET_WEIGHT_URL) + params = paddle.load(weights_path) self.nets['netDS'].load_dict(params) self.nets['netDH'] = build_discriminator(discriminator_hq) diff --git a/ppgan/models/wav2lip_model.py b/ppgan/models/wav2lip_model.py index d5a2c36..852d25b 100644 --- a/ppgan/models/wav2lip_model.py +++ b/ppgan/models/wav2lip_model.py @@ -13,6 +13,7 @@ # limitations under the License. import paddle +from paddle.utils.download import get_weights_path_from_url from .base_model import BaseModel from .builder import MODELS @@ -22,6 +23,7 @@ from .discriminators.builder import build_discriminator from ..solver import build_optimizer from ..modules.init import init_weights +SYNCNET_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/syncnet.pdparams' syncnet_T = 5 syncnet_mel_step_size = 16 @@ -74,7 +76,8 @@ class Wav2LipModel(BaseModel): init_weights(self.nets['netG'], distribution='uniform') if self.is_train: self.nets['netD'] = build_discriminator(discriminator) - params = paddle.load(lipsync_weight_path) + weights_path = get_weights_path_from_url(SYNCNET_WEIGHT_URL) + params = paddle.load(weights_path) self.nets['netD'].load_dict(params) if self.is_train: -- GitLab