Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/network/vocoders/nsf_hifigan.py
701 views
1
import os
2
import torch
3
from modules.nsf_hifigan.models import load_model, Generator
4
from modules.nsf_hifigan.nvSTFT import load_wav_to_torch, STFT
5
from utils.hparams import hparams
6
from network.vocoders.base_vocoder import BaseVocoder, register_vocoder
7
8
@register_vocoder
9
class NsfHifiGAN(BaseVocoder):
10
def __init__(self, device=None):
11
if device is None:
12
device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
self.device = device
14
model_path = hparams['vocoder_ckpt']
15
if os.path.exists(model_path):
16
print('| Load HifiGAN: ', model_path)
17
self.model, self.h = load_model(model_path, device=self.device)
18
else:
19
print('Error: HifiGAN model file is not found!')
20
21
def spec2wav_torch(self, mel, **kwargs): # mel: [B, T, bins]
22
if self.h.sampling_rate != hparams['audio_sample_rate']:
23
print('Mismatch parameters: hparams[\'audio_sample_rate\']=',hparams['audio_sample_rate'],'!=',self.h.sampling_rate,'(vocoder)')
24
if self.h.num_mels != hparams['audio_num_mel_bins']:
25
print('Mismatch parameters: hparams[\'audio_num_mel_bins\']=',hparams['audio_num_mel_bins'],'!=',self.h.num_mels,'(vocoder)')
26
if self.h.n_fft != hparams['fft_size']:
27
print('Mismatch parameters: hparams[\'fft_size\']=',hparams['fft_size'],'!=',self.h.n_fft,'(vocoder)')
28
if self.h.win_size != hparams['win_size']:
29
print('Mismatch parameters: hparams[\'win_size\']=',hparams['win_size'],'!=',self.h.win_size,'(vocoder)')
30
if self.h.hop_size != hparams['hop_size']:
31
print('Mismatch parameters: hparams[\'hop_size\']=',hparams['hop_size'],'!=',self.h.hop_size,'(vocoder)')
32
if self.h.fmin != hparams['fmin']:
33
print('Mismatch parameters: hparams[\'fmin\']=',hparams['fmin'],'!=',self.h.fmin,'(vocoder)')
34
if self.h.fmax != hparams['fmax']:
35
print('Mismatch parameters: hparams[\'fmax\']=',hparams['fmax'],'!=',self.h.fmax,'(vocoder)')
36
with torch.no_grad():
37
c = mel.transpose(2, 1) #[B, T, bins]
38
#log10 to log mel
39
c = 2.30259 * c
40
f0 = kwargs.get('f0') #[B, T]
41
if f0 is not None and hparams.get('use_nsf'):
42
y = self.model(c, f0).view(-1)
43
else:
44
y = self.model(c).view(-1)
45
return y
46
47
def spec2wav(self, mel, **kwargs):
48
if self.h.sampling_rate != hparams['audio_sample_rate']:
49
print('Mismatch parameters: hparams[\'audio_sample_rate\']=',hparams['audio_sample_rate'],'!=',self.h.sampling_rate,'(vocoder)')
50
if self.h.num_mels != hparams['audio_num_mel_bins']:
51
print('Mismatch parameters: hparams[\'audio_num_mel_bins\']=',hparams['audio_num_mel_bins'],'!=',self.h.num_mels,'(vocoder)')
52
if self.h.n_fft != hparams['fft_size']:
53
print('Mismatch parameters: hparams[\'fft_size\']=',hparams['fft_size'],'!=',self.h.n_fft,'(vocoder)')
54
if self.h.win_size != hparams['win_size']:
55
print('Mismatch parameters: hparams[\'win_size\']=',hparams['win_size'],'!=',self.h.win_size,'(vocoder)')
56
if self.h.hop_size != hparams['hop_size']:
57
print('Mismatch parameters: hparams[\'hop_size\']=',hparams['hop_size'],'!=',self.h.hop_size,'(vocoder)')
58
if self.h.fmin != hparams['fmin']:
59
print('Mismatch parameters: hparams[\'fmin\']=',hparams['fmin'],'!=',self.h.fmin,'(vocoder)')
60
if self.h.fmax != hparams['fmax']:
61
print('Mismatch parameters: hparams[\'fmax\']=',hparams['fmax'],'!=',self.h.fmax,'(vocoder)')
62
with torch.no_grad():
63
c = torch.FloatTensor(mel).unsqueeze(0).transpose(2, 1).to(self.device)
64
#log10 to log mel
65
c = 2.30259 * c
66
f0 = kwargs.get('f0')
67
if f0 is not None and hparams.get('use_nsf'):
68
f0 = torch.FloatTensor(f0[None, :]).to(self.device)
69
y = self.model(c, f0).view(-1)
70
else:
71
y = self.model(c).view(-1)
72
wav_out = y.cpu().numpy()
73
return wav_out
74
75
@staticmethod
76
def wav2spec(inp_path, device=None):
77
if device is None:
78
device = 'cuda' if torch.cuda.is_available() else 'cpu'
79
sampling_rate = hparams['audio_sample_rate']
80
num_mels = hparams['audio_num_mel_bins']
81
n_fft = hparams['fft_size']
82
win_size =hparams['win_size']
83
hop_size = hparams['hop_size']
84
fmin = hparams['fmin']
85
fmax = hparams['fmax']
86
stft = STFT(sampling_rate, num_mels, n_fft, win_size, hop_size, fmin, fmax)
87
with torch.no_grad():
88
wav_torch, _ = load_wav_to_torch(inp_path, target_sr=stft.target_sr)
89
mel_torch = stft.get_mel(wav_torch.unsqueeze(0).to(device)).squeeze(0).T
90
#log mel to log10 mel
91
mel_torch = 0.434294 * mel_torch
92
return wav_torch.cpu().numpy(), mel_torch.cpu().numpy()
93