未验证 提交 43ffc949 编写于 作者: L lijianshe02 提交者: GitHub

add automatic weight download (#146)

* add automatic weight download
上级 edd62113
......@@ -10,17 +10,16 @@ parser = argparse.ArgumentParser(
parser.add_argument('--checkpoint_path',
type=str,
help='Name of saved checkpoint to load weights from',
required=True)
parser.add_argument('--face',
type=str,
help='Filepath of video/image that contains faces to use',
required=True)
default=None)
parser.add_argument(
'--audio',
type=str,
help='Filepath of video/audio file to use as raw audio source',
required=True)
parser.add_argument('--face',
type=str,
help='Filepath of video/image that contains faces to use',
required=True)
parser.add_argument('--outfile',
type=str,
help='Video path to save result. See default for an e.g.',
......
......@@ -6,11 +6,13 @@ import json, subprocess, random, string
from tqdm import tqdm
from glob import glob
import paddle
from paddle.utils.download import get_weights_path_from_url
from ppgan.faceutils import face_detection
from ppgan.utils import audio
from ppgan.models.generators.wav2lip import Wav2Lip
from .base_predictor import BasePredictor
WAV2LIP_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/wav2lip_hq.pdparams'
mel_step_size = 16
......@@ -216,7 +218,11 @@ class Wav2LipPredictor(BasePredictor):
gen = self.datagen(full_frames.copy(), mel_chunks)
model = Wav2Lip()
weights = paddle.load(self.args.checkpoint_path)
if self.args.checkpoint_path is None:
model_weights_path = get_weights_path_from_url(WAV2LIP_WEIGHT_URL)
weights = paddle.load(model_weights_path)
else:
weights = paddle.load(self.args.checkpoint_path)
model.load_dict(weights)
model.eval()
print("Model loaded")
......
......@@ -14,6 +14,7 @@
import paddle
import paddle.nn.functional as F
from paddle.utils.download import get_weights_path_from_url
from .base_model import BaseModel
from .builder import MODELS
......@@ -25,7 +26,7 @@ from .wav2lip_model import cosine_loss, get_sync_loss
from ..solver import build_optimizer
from ..modules.init import init_weights
lipsync_weight_path = '/workspace/PaddleGAN/lipsync_expert.pdparams'
SYNCNET_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/syncnet.pdparams'
@MODELS.register()
......@@ -65,7 +66,8 @@ class Wav2LipModelHq(BaseModel):
distribution='uniform')
if self.is_train:
self.nets['netDS'] = build_discriminator(discriminator_sync)
params = paddle.load(lipsync_weight_path)
weights_path = get_weights_path_from_url(SYNCNET_WEIGHT_URL)
params = paddle.load(weights_path)
self.nets['netDS'].load_dict(params)
self.nets['netDH'] = build_discriminator(discriminator_hq)
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import paddle
from paddle.utils.download import get_weights_path_from_url
from .base_model import BaseModel
from .builder import MODELS
......@@ -22,6 +23,7 @@ from .discriminators.builder import build_discriminator
from ..solver import build_optimizer
from ..modules.init import init_weights
SYNCNET_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/syncnet.pdparams'
syncnet_T = 5
syncnet_mel_step_size = 16
......@@ -74,7 +76,8 @@ class Wav2LipModel(BaseModel):
init_weights(self.nets['netG'], distribution='uniform')
if self.is_train:
self.nets['netD'] = build_discriminator(discriminator)
params = paddle.load(lipsync_weight_path)
weights_path = get_weights_path_from_url(SYNCNET_WEIGHT_URL)
params = paddle.load(weights_path)
self.nets['netD'].load_dict(params)
if self.is_train:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册