提交 10710230 编写于 作者: R Rafael Valle

train.py: patching score_mask_value formerly inf, not concrete value, for...

train.py: patching score_mask_value formerly inf, not concrete value, for compatibility with pytorch
上级 cd851585
......@@ -2,6 +2,7 @@ import os
import time
import argparse
import math
from numpy import finfo
import torch
from distributed import DistributedDataParallel
......@@ -77,7 +78,9 @@ def prepare_directories_and_logger(output_directory, log_directory, rank):
def load_model(hparams):
model = Tacotron2(hparams).cuda()
model = batchnorm_to_float(model.half()) if hparams.fp16_run else model
if hparams.fp16_run:
model = batchnorm_to_float(model.half())
model.decoder.attention_layer.score_mask_value = float(finfo('float16').min)
if hparams.distributed_run:
model = DistributedDataParallel(model)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册