Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/network/vocoders/pwg.py
701 views
1
import glob
2
import re
3
import librosa
4
import torch
5
import yaml
6
from sklearn.preprocessing import StandardScaler
7
from torch import nn
8
from modules.parallel_wavegan.models import ParallelWaveGANGenerator
9
from modules.parallel_wavegan.utils import read_hdf5
10
from utils.hparams import hparams
11
from utils.pitch_utils import f0_to_coarse
12
from network.vocoders.base_vocoder import BaseVocoder, register_vocoder
13
import numpy as np
14
15
16
def load_pwg_model(config_path, checkpoint_path, stats_path):
17
# load config
18
with open(config_path, encoding='utf-8') as f:
19
config = yaml.load(f, Loader=yaml.Loader)
20
21
# setup
22
if torch.cuda.is_available():
23
device = torch.device("cuda")
24
else:
25
device = torch.device("cpu")
26
model = ParallelWaveGANGenerator(**config["generator_params"])
27
28
ckpt_dict = torch.load(checkpoint_path, map_location="cpu")
29
if 'state_dict' not in ckpt_dict: # official vocoder
30
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model"]["generator"])
31
scaler = StandardScaler()
32
if config["format"] == "hdf5":
33
scaler.mean_ = read_hdf5(stats_path, "mean")
34
scaler.scale_ = read_hdf5(stats_path, "scale")
35
elif config["format"] == "npy":
36
scaler.mean_ = np.load(stats_path)[0]
37
scaler.scale_ = np.load(stats_path)[1]
38
else:
39
raise ValueError("support only hdf5 or npy format.")
40
else: # custom PWG vocoder
41
fake_task = nn.Module()
42
fake_task.model_gen = model
43
fake_task.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["state_dict"], strict=False)
44
scaler = None
45
46
model.remove_weight_norm()
47
model = model.eval().to(device)
48
print(f"| Loaded model parameters from {checkpoint_path}.")
49
print(f"| PWG device: {device}.")
50
return model, scaler, config, device
51
52
53
@register_vocoder
54
class PWG(BaseVocoder):
55
def __init__(self):
56
if hparams['vocoder_ckpt'] == '': # load LJSpeech PWG pretrained model
57
base_dir = 'wavegan_pretrained'
58
ckpts = glob.glob(f'{base_dir}/checkpoint-*steps.pkl')
59
ckpt = sorted(ckpts, key=
60
lambda x: int(re.findall(f'{base_dir}/checkpoint-(\d+)steps.pkl', x)[0]))[-1]
61
config_path = f'{base_dir}/config.yaml'
62
print('| load PWG: ', ckpt)
63
self.model, self.scaler, self.config, self.device = load_pwg_model(
64
config_path=config_path,
65
checkpoint_path=ckpt,
66
stats_path=f'{base_dir}/stats.h5',
67
)
68
else:
69
base_dir = hparams['vocoder_ckpt']
70
print(base_dir)
71
config_path = f'{base_dir}/config.yaml'
72
ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
73
lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1]
74
print('| load PWG: ', ckpt)
75
self.scaler = None
76
self.model, _, self.config, self.device = load_pwg_model(
77
config_path=config_path,
78
checkpoint_path=ckpt,
79
stats_path=f'{base_dir}/stats.h5',
80
)
81
82
def spec2wav(self, mel, **kwargs):
83
# start generation
84
config = self.config
85
device = self.device
86
pad_size = (config["generator_params"]["aux_context_window"],
87
config["generator_params"]["aux_context_window"])
88
c = mel
89
if self.scaler is not None:
90
c = self.scaler.transform(c)
91
92
with torch.no_grad():
93
z = torch.randn(1, 1, c.shape[0] * config["hop_size"]).to(device)
94
c = np.pad(c, (pad_size, (0, 0)), "edge")
95
c = torch.FloatTensor(c).unsqueeze(0).transpose(2, 1).to(device)
96
p = kwargs.get('f0')
97
if p is not None:
98
p = f0_to_coarse(p)
99
p = np.pad(p, (pad_size,), "edge")
100
p = torch.LongTensor(p[None, :]).to(device)
101
y = self.model(z, c, p).view(-1)
102
wav_out = y.cpu().numpy()
103
return wav_out
104
105
@staticmethod
106
def wav2spec(wav_fn, return_linear=False):
107
from preprocessing.data_gen_utils import process_utterance
108
res = process_utterance(
109
wav_fn, fft_size=hparams['fft_size'],
110
hop_size=hparams['hop_size'],
111
win_length=hparams['win_size'],
112
num_mels=hparams['audio_num_mel_bins'],
113
fmin=hparams['fmin'],
114
fmax=hparams['fmax'],
115
sample_rate=hparams['audio_sample_rate'],
116
loud_norm=hparams['loud_norm'],
117
min_level_db=hparams['min_level_db'],
118
return_linear=return_linear, vocoder='pwg', eps=float(hparams.get('wav2spec_eps', 1e-10)))
119
if return_linear:
120
return res[0], res[1].T, res[2].T # [T, 80], [T, n_fft]
121
else:
122
return res[0], res[1].T
123
124
@staticmethod
125
def wav2mfcc(wav_fn):
126
fft_size = hparams['fft_size']
127
hop_size = hparams['hop_size']
128
win_length = hparams['win_size']
129
sample_rate = hparams['audio_sample_rate']
130
wav, _ = librosa.core.load(wav_fn, sr=sample_rate)
131
mfcc = librosa.feature.mfcc(y=wav, sr=sample_rate, n_mfcc=13,
132
n_fft=fft_size, hop_length=hop_size,
133
win_length=win_length, pad_mode="constant", power=1.0)
134
mfcc_delta = librosa.feature.delta(mfcc, order=1)
135
mfcc_delta_delta = librosa.feature.delta(mfcc, order=2)
136
mfcc = np.concatenate([mfcc, mfcc_delta, mfcc_delta_delta]).T
137
return mfcc
138
139