未验证 提交 ded7b372 编写于 作者: B blue-fish 提交者: GitHub

Low memory inference fix (#536)

* For low_mem, use spawned workers instead of forked workers (resolves #36)
Used implementation from @lilydjwg: https://github.com/CorentinJ/Real-Time-Voice-Cloning/issues/36#issuecomment-529380190

* Different method of passing the seed for low_mem inference
Resolves #491, #529, #535
上级 8f71d678
...@@ -2,6 +2,7 @@ from synthesizer.tacotron2 import Tacotron2 ...@@ -2,6 +2,7 @@ from synthesizer.tacotron2 import Tacotron2
from synthesizer.hparams import hparams from synthesizer.hparams import hparams
from multiprocess.pool import Pool # You're free to use either one from multiprocess.pool import Pool # You're free to use either one
#from multiprocessing import Pool # #from multiprocessing import Pool #
from multiprocess.context import SpawnContext
from synthesizer import audio from synthesizer import audio
from pathlib import Path from pathlib import Path
from typing import Union, List from typing import Union, List
...@@ -97,16 +98,16 @@ class Synthesizer: ...@@ -97,16 +98,16 @@ class Synthesizer:
# Low memory inference mode: load the model upon every request. The model has to be # Low memory inference mode: load the model upon every request. The model has to be
# loaded in a separate process to be able to release GPU memory (a simple workaround # loaded in a separate process to be able to release GPU memory (a simple workaround
# to tensorflow's intricacies) # to tensorflow's intricacies)
specs, alignments = Pool(1).starmap(Synthesizer._one_shot_synthesize_spectrograms, specs, alignments = Pool(1, context=SpawnContext()).starmap(Synthesizer._one_shot_synthesize_spectrograms,
[(self.checkpoint_fpath, embeddings, texts)])[0] [(self.checkpoint_fpath, embeddings, texts, self._seed)])[0]
return (specs, alignments) if return_alignments else specs return (specs, alignments) if return_alignments else specs
@staticmethod @staticmethod
def _one_shot_synthesize_spectrograms(checkpoint_fpath, embeddings, texts): def _one_shot_synthesize_spectrograms(checkpoint_fpath, embeddings, texts, seed):
# Load the model and forward the inputs # Load the model and forward the inputs
tf.compat.v1.reset_default_graph() tf.compat.v1.reset_default_graph()
model = Tacotron2(checkpoint_fpath, hparams, seed=self._seed) model = Tacotron2(checkpoint_fpath, hparams, seed=seed)
specs, alignments = model.my_synthesize(embeddings, texts) specs, alignments = model.my_synthesize(embeddings, texts)
# Detach the outputs (not doing so will cause the process to hang) # Detach the outputs (not doing so will cause the process to hang)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册