未验证 提交 f6fabeaf 编写于 作者: B blue-fish 提交者: GitHub

Enable CPU training for vocoder (#397)

上级 ba1a78d8
......@@ -118,8 +118,12 @@ class WaveRNN(nn.Module):
def forward(self, x, mels):
self.step += 1
bsize = x.size(0)
h1 = torch.zeros(1, bsize, self.rnn_dims).cuda()
h2 = torch.zeros(1, bsize, self.rnn_dims).cuda()
if torch.cuda.is_available():
h1 = torch.zeros(1, bsize, self.rnn_dims).cuda()
h2 = torch.zeros(1, bsize, self.rnn_dims).cuda()
else:
h1 = torch.zeros(1, bsize, self.rnn_dims).cpu()
h2 = torch.zeros(1, bsize, self.rnn_dims).cpu()
mels, aux = self.upsample(mels)
aux_idx = [self.aux_dims * i for i in range(5)]
......@@ -209,8 +213,11 @@ class WaveRNN(nn.Module):
if self.mode == 'MOL':
sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2))
output.append(sample.view(-1))
# x = torch.FloatTensor([[sample]]).cuda()
x = sample.transpose(0, 1).cuda()
if torch.cuda.is_available():
# x = torch.FloatTensor([[sample]]).cuda()
x = sample.transpose(0, 1).cuda()
else:
x = sample.transpose(0, 1)
elif self.mode == 'RAW' :
posterior = F.softmax(logits, dim=1)
......
......@@ -10,6 +10,7 @@ import torch.nn.functional as F
import vocoder.hparams as hp
import numpy as np
import time
import torch
def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_truth: bool,
......@@ -32,8 +33,14 @@ def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_tr
hop_length=hp.hop_length,
sample_rate=hp.sample_rate,
mode=hp.voc_mode
).cuda()
)
if torch.cuda.is_available():
model = model.cuda()
device = torch.device('cuda')
else:
device = torch.device('cpu')
# Initialize the optimizer
optimizer = optim.Adam(model.parameters())
for p in optimizer.param_groups:
......@@ -79,7 +86,8 @@ def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_tr
running_loss = 0.
for i, (x, y, m) in enumerate(data_loader, 1):
x, m, y = x.cuda(), m.cuda(), y.cuda()
if torch.cuda.is_available():
x, m, y = x.cuda(), m.cuda(), y.cuda()
# Forward pass
y_hat = model(x, m)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册