Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/network/vocoders/hifigan.py
701 views
1
import glob
2
import json
3
import os
4
import re
5
6
import librosa
7
import torch
8
9
import utils
10
from modules.hifigan.hifigan import HifiGanGenerator
11
from utils.hparams import hparams, set_hparams
12
from network.vocoders.base_vocoder import register_vocoder
13
from network.vocoders.pwg import PWG
14
from network.vocoders.vocoder_utils import denoise
15
16
17
def load_model(config_path, file_path):
18
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
ext = os.path.splitext(file_path)[-1]
20
if ext == '.pth':
21
if '.yaml' in config_path:
22
config = set_hparams(config_path, global_hparams=False)
23
elif '.json' in config_path:
24
config = json.load(open(config_path, 'r', encoding='utf-8'))
25
model = torch.load(file_path, map_location="cpu")
26
elif ext == '.ckpt':
27
ckpt_dict = torch.load(file_path, map_location="cpu")
28
if '.yaml' in config_path:
29
config = set_hparams(config_path, global_hparams=False)
30
state = ckpt_dict["state_dict"]["model_gen"]
31
elif '.json' in config_path:
32
config = json.load(open(config_path, 'r', encoding='utf-8'))
33
state = ckpt_dict["generator"]
34
model = HifiGanGenerator(config)
35
model.load_state_dict(state, strict=True)
36
model.remove_weight_norm()
37
model = model.eval().to(device)
38
print(f"| Loaded model parameters from {file_path}.")
39
print(f"| HifiGAN device: {device}.")
40
return model, config, device
41
42
43
total_time = 0
44
45
46
@register_vocoder
47
class HifiGAN(PWG):
48
def __init__(self):
49
base_dir = hparams['vocoder_ckpt']
50
config_path = f'{base_dir}/config.yaml'
51
if os.path.exists(config_path):
52
file_path = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.*'), key=
53
lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).*', x.replace('\\','/'))[0]))[-1]
54
print('| load HifiGAN: ', file_path)
55
self.model, self.config, self.device = load_model(config_path=config_path, file_path=file_path)
56
else:
57
config_path = f'{base_dir}/config.json'
58
ckpt = f'{base_dir}/generator_v1'
59
if os.path.exists(config_path):
60
self.model, self.config, self.device = load_model(config_path=config_path, file_path=file_path)
61
62
def spec2wav(self, mel, **kwargs):
63
device = self.device
64
with torch.no_grad():
65
c = torch.FloatTensor(mel).unsqueeze(0).transpose(2, 1).to(device)
66
with utils.Timer('hifigan', print_time=hparams['profile_infer']):
67
f0 = kwargs.get('f0')
68
if f0 is not None and hparams.get('use_nsf'):
69
f0 = torch.FloatTensor(f0[None, :]).to(device)
70
y = self.model(c, f0).view(-1)
71
else:
72
y = self.model(c).view(-1)
73
wav_out = y.cpu().numpy()
74
if hparams.get('vocoder_denoise_c', 0.0) > 0:
75
wav_out = denoise(wav_out, v=hparams['vocoder_denoise_c'])
76
return wav_out
77
78
# @staticmethod
79
# def wav2spec(wav_fn, **kwargs):
80
# wav, _ = librosa.core.load(wav_fn, sr=hparams['audio_sample_rate'])
81
# wav_torch = torch.FloatTensor(wav)[None, :]
82
# mel = mel_spectrogram(wav_torch, hparams).numpy()[0]
83
# return wav, mel.T
84
85