Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/modules/encoder.py
694 views
1
import torch
2
from modules.commons.common_layers import *
3
from modules.commons.common_layers import Embedding
4
from modules.commons.common_layers import SinusoidalPositionalEmbedding
5
from utils.hparams import hparams
6
import numpy as np
7
import math
8
9
10
class LayerNorm(torch.nn.LayerNorm):
11
"""Layer normalization module.
12
:param int nout: output dim size
13
:param int dim: dimension to be normalized
14
"""
15
16
def __init__(self, nout, dim=-1):
17
"""Construct an LayerNorm object."""
18
super(LayerNorm, self).__init__(nout, eps=1e-12)
19
self.dim = dim
20
21
def forward(self, x):
22
"""Apply layer normalization.
23
:param torch.Tensor x: input tensor
24
:return: layer normalized tensor
25
:rtype torch.Tensor
26
"""
27
if self.dim == -1:
28
return super(LayerNorm, self).forward(x)
29
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
30
31
32
class PitchPredictor(torch.nn.Module):
33
def __init__(self, idim, n_layers=5, n_chans=384, odim=2, kernel_size=5,
34
dropout_rate=0.1, padding='SAME'):
35
super(PitchPredictor, self).__init__()
36
self.conv = torch.nn.ModuleList()
37
self.kernel_size = kernel_size
38
self.padding = padding
39
for idx in range(n_layers):
40
in_chans = idim if idx == 0 else n_chans
41
self.conv += [torch.nn.Sequential(
42
torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
43
if padding == 'SAME'
44
else (kernel_size - 1, 0), 0),
45
torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
46
torch.nn.ReLU(),
47
LayerNorm(n_chans, dim=1),
48
torch.nn.Dropout(dropout_rate)
49
)]
50
self.linear = torch.nn.Linear(n_chans, odim)
51
self.embed_positions = SinusoidalPositionalEmbedding(idim, 0, init_size=4096)
52
self.pos_embed_alpha = nn.Parameter(torch.Tensor([1]))
53
54
def forward(self, xs):
55
positions = self.pos_embed_alpha * self.embed_positions(xs[..., 0])
56
xs = xs + positions
57
xs = xs.transpose(1, -1) # (B, idim, Tmax)
58
for f in self.conv:
59
xs = f(xs) # (B, C, Tmax)
60
xs = self.linear(xs.transpose(1, -1)) # (B, Tmax, H)
61
return xs
62
63
64
class SvcEncoder(nn.Module):
65
def __init__(self, dictionary, out_dims=None):
66
super().__init__()
67
# self.dictionary = dictionary
68
self.padding_idx = 0
69
self.hidden_size = hparams['hidden_size']
70
self.out_dims = out_dims
71
if out_dims is None:
72
self.out_dims = hparams['audio_num_mel_bins']
73
self.mel_out = Linear(self.hidden_size, self.out_dims, bias=True)
74
predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
75
if hparams['use_pitch_embed']:
76
self.pitch_embed = Embedding(300, self.hidden_size, self.padding_idx)
77
self.pitch_predictor = PitchPredictor(
78
self.hidden_size,
79
n_chans=predictor_hidden,
80
n_layers=hparams['predictor_layers'],
81
dropout_rate=hparams['predictor_dropout'],
82
odim=2 if hparams['pitch_type'] == 'frame' else 1,
83
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
84
if hparams['use_energy_embed']:
85
self.energy_embed = Embedding(256, self.hidden_size, self.padding_idx)
86
if hparams['use_spk_id']:
87
self.spk_embed_proj = Embedding(hparams['num_spk'], self.hidden_size)
88
if hparams['use_split_spk_id']:
89
self.spk_embed_f0 = Embedding(hparams['num_spk'], self.hidden_size)
90
self.spk_embed_dur = Embedding(hparams['num_spk'], self.hidden_size)
91
elif hparams['use_spk_embed']:
92
self.spk_embed_proj = Linear(256, self.hidden_size, bias=True)
93
if hparams['pitch_norm'] == 'standard':
94
self.pitch_norm = True
95
else:
96
self.pitch_norm = False
97
self.f0_bin = hparams['f0_bin']
98
self.f0_max = hparams['f0_max']
99
self.f0_min = hparams['f0_min']
100
101
def forward(self, hubert, mel2ph=None, spk_embed=None, f0=None):
102
encoder_out = hubert
103
decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
104
mel2ph_ = mel2ph.unsqueeze(2).repeat([1, 1, encoder_out.shape[-1]])
105
decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
106
tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
107
rdecoder_inp, f0_denorm, pitch_pred = self.add_pitch(f0, mel2ph)
108
decoder_inp = decoder_inp + rdecoder_inp.cpu()
109
decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
110
return decoder_inp.transpose(1, 2), f0_denorm
111
112
def add_pitch(self, f0, mel2ph):
113
pitch_padding = (mel2ph == 0)
114
f0_denorm = self.denorm_f0(f0, pitch_padding=pitch_padding)
115
f0[pitch_padding] = 0
116
pitch = self.f0_to_coarse(f0_denorm)
117
pitch_pred = pitch.unsqueeze(-1)
118
pitch_embedding = self.pitch_embed(pitch).cuda()
119
return pitch_embedding, f0_denorm, pitch_pred
120
121
def denorm_f0(self, f0, pitch_padding=None):
122
f0 = 2 ** f0
123
f0[pitch_padding] = 0
124
return f0
125
126
def f0_to_coarse(self, f0):
127
f0_mel_min = 1127 * math.log(1 + self.f0_min / 700)
128
f0_mel_max = 1127 * math.log(1 + self.f0_max / 700)
129
f0_mel = 1127 * (1 + f0 / 700).log()
130
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (self.f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
131
132
f0_mel[f0_mel <= 1] = 1
133
f0_mel[f0_mel > self.f0_bin - 1] = self.f0_bin - 1
134
f0_coarse = (f0_mel + 0.5).long()
135
return f0_coarse
136
137