未验证 提交 f6fabeaf 编写于 作者: B blue-fish 提交者: GitHub

Enable CPU training for vocoder (#397)

上级 ba1a78d8
...@@ -118,8 +118,12 @@ class WaveRNN(nn.Module): ...@@ -118,8 +118,12 @@ class WaveRNN(nn.Module):
def forward(self, x, mels): def forward(self, x, mels):
self.step += 1 self.step += 1
bsize = x.size(0) bsize = x.size(0)
h1 = torch.zeros(1, bsize, self.rnn_dims).cuda() if torch.cuda.is_available():
h2 = torch.zeros(1, bsize, self.rnn_dims).cuda() h1 = torch.zeros(1, bsize, self.rnn_dims).cuda()
h2 = torch.zeros(1, bsize, self.rnn_dims).cuda()
else:
h1 = torch.zeros(1, bsize, self.rnn_dims).cpu()
h2 = torch.zeros(1, bsize, self.rnn_dims).cpu()
mels, aux = self.upsample(mels) mels, aux = self.upsample(mels)
aux_idx = [self.aux_dims * i for i in range(5)] aux_idx = [self.aux_dims * i for i in range(5)]
...@@ -209,8 +213,11 @@ class WaveRNN(nn.Module): ...@@ -209,8 +213,11 @@ class WaveRNN(nn.Module):
if self.mode == 'MOL': if self.mode == 'MOL':
sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2)) sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2))
output.append(sample.view(-1)) output.append(sample.view(-1))
# x = torch.FloatTensor([[sample]]).cuda() if torch.cuda.is_available():
x = sample.transpose(0, 1).cuda() # x = torch.FloatTensor([[sample]]).cuda()
x = sample.transpose(0, 1).cuda()
else:
x = sample.transpose(0, 1)
elif self.mode == 'RAW' : elif self.mode == 'RAW' :
posterior = F.softmax(logits, dim=1) posterior = F.softmax(logits, dim=1)
......
...@@ -10,6 +10,7 @@ import torch.nn.functional as F ...@@ -10,6 +10,7 @@ import torch.nn.functional as F
import vocoder.hparams as hp import vocoder.hparams as hp
import numpy as np import numpy as np
import time import time
import torch
def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_truth: bool, def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_truth: bool,
...@@ -32,8 +33,14 @@ def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_tr ...@@ -32,8 +33,14 @@ def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_tr
hop_length=hp.hop_length, hop_length=hp.hop_length,
sample_rate=hp.sample_rate, sample_rate=hp.sample_rate,
mode=hp.voc_mode mode=hp.voc_mode
).cuda() )
if torch.cuda.is_available():
model = model.cuda()
device = torch.device('cuda')
else:
device = torch.device('cpu')
# Initialize the optimizer # Initialize the optimizer
optimizer = optim.Adam(model.parameters()) optimizer = optim.Adam(model.parameters())
for p in optimizer.param_groups: for p in optimizer.param_groups:
...@@ -79,7 +86,8 @@ def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_tr ...@@ -79,7 +86,8 @@ def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_tr
running_loss = 0. running_loss = 0.
for i, (x, y, m) in enumerate(data_loader, 1): for i, (x, y, m) in enumerate(data_loader, 1):
x, m, y = x.cuda(), m.cuda(), y.cuda() if torch.cuda.is_available():
x, m, y = x.cuda(), m.cuda(), y.cuda()
# Forward pass # Forward pass
y_hat = model(x, m) y_hat = model(x, m)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册