From f6fabeaf3effdf2f023db98b37bdabd1ec3d81e5 Mon Sep 17 00:00:00 2001 From: blue-fish <67130644+blue-fish@users.noreply.github.com> Date: Sat, 4 Jul 2020 07:00:22 -0700 Subject: [PATCH] Enable CPU training for vocoder (#397) --- vocoder/models/fatchord_version.py | 15 +++++++++++---- vocoder/train.py | 14 +++++++++++--- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/vocoder/models/fatchord_version.py b/vocoder/models/fatchord_version.py index 429572b..70ef1e3 100644 --- a/vocoder/models/fatchord_version.py +++ b/vocoder/models/fatchord_version.py @@ -118,8 +118,12 @@ class WaveRNN(nn.Module): def forward(self, x, mels): self.step += 1 bsize = x.size(0) - h1 = torch.zeros(1, bsize, self.rnn_dims).cuda() - h2 = torch.zeros(1, bsize, self.rnn_dims).cuda() + if torch.cuda.is_available(): + h1 = torch.zeros(1, bsize, self.rnn_dims).cuda() + h2 = torch.zeros(1, bsize, self.rnn_dims).cuda() + else: + h1 = torch.zeros(1, bsize, self.rnn_dims).cpu() + h2 = torch.zeros(1, bsize, self.rnn_dims).cpu() mels, aux = self.upsample(mels) aux_idx = [self.aux_dims * i for i in range(5)] @@ -209,8 +213,11 @@ class WaveRNN(nn.Module): if self.mode == 'MOL': sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2)) output.append(sample.view(-1)) - # x = torch.FloatTensor([[sample]]).cuda() - x = sample.transpose(0, 1).cuda() + if torch.cuda.is_available(): + # x = torch.FloatTensor([[sample]]).cuda() + x = sample.transpose(0, 1).cuda() + else: + x = sample.transpose(0, 1) elif self.mode == 'RAW' : posterior = F.softmax(logits, dim=1) diff --git a/vocoder/train.py b/vocoder/train.py index 8782380..4912469 100644 --- a/vocoder/train.py +++ b/vocoder/train.py @@ -10,6 +10,7 @@ import torch.nn.functional as F import vocoder.hparams as hp import numpy as np import time +import torch def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_truth: bool, @@ -32,8 +33,14 @@ def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_tr hop_length=hp.hop_length, sample_rate=hp.sample_rate, mode=hp.voc_mode - ).cuda() - + ) + + if torch.cuda.is_available(): + model = model.cuda() + device = torch.device('cuda') + else: + device = torch.device('cpu') + # Initialize the optimizer optimizer = optim.Adam(model.parameters()) for p in optimizer.param_groups: @@ -79,7 +86,8 @@ def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_tr running_loss = 0. for i, (x, y, m) in enumerate(data_loader, 1): - x, m, y = x.cuda(), m.cuda(), y.cuda() + if torch.cuda.is_available(): + x, m, y = x.cuda(), m.cuda(), y.cuda() # Forward pass y_hat = model(x, m) -- GitLab