Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/modules/fastspeech/pe.py
694 views
1
from modules.commons.common_layers import *
2
from utils.hparams import hparams
3
from modules.fastspeech.tts_modules import PitchPredictor
4
from utils.pitch_utils import denorm_f0
5
6
7
class Prenet(nn.Module):
8
def __init__(self, in_dim=80, out_dim=256, kernel=5, n_layers=3, strides=None):
9
super(Prenet, self).__init__()
10
padding = kernel // 2
11
self.layers = []
12
self.strides = strides if strides is not None else [1] * n_layers
13
for l in range(n_layers):
14
self.layers.append(nn.Sequential(
15
nn.Conv1d(in_dim, out_dim, kernel_size=kernel, padding=padding, stride=self.strides[l]),
16
nn.ReLU(),
17
nn.BatchNorm1d(out_dim)
18
))
19
in_dim = out_dim
20
self.layers = nn.ModuleList(self.layers)
21
self.out_proj = nn.Linear(out_dim, out_dim)
22
23
def forward(self, x):
24
"""
25
26
:param x: [B, T, 80]
27
:return: [L, B, T, H], [B, T, H]
28
"""
29
# padding_mask = x.abs().sum(-1).eq(0).data # [B, T]
30
padding_mask = x.abs().sum(-1).eq(0).detach()
31
nonpadding_mask_TB = 1 - padding_mask.float()[:, None, :] # [B, 1, T]
32
x = x.transpose(1, 2)
33
hiddens = []
34
for i, l in enumerate(self.layers):
35
nonpadding_mask_TB = nonpadding_mask_TB[:, :, ::self.strides[i]]
36
x = l(x) * nonpadding_mask_TB
37
hiddens.append(x)
38
hiddens = torch.stack(hiddens, 0) # [L, B, H, T]
39
hiddens = hiddens.transpose(2, 3) # [L, B, T, H]
40
x = self.out_proj(x.transpose(1, 2)) # [B, T, H]
41
x = x * nonpadding_mask_TB.transpose(1, 2)
42
return hiddens, x
43
44
45
class ConvBlock(nn.Module):
46
def __init__(self, idim=80, n_chans=256, kernel_size=3, stride=1, norm='gn', dropout=0):
47
super().__init__()
48
self.conv = ConvNorm(idim, n_chans, kernel_size, stride=stride)
49
self.norm = norm
50
if self.norm == 'bn':
51
self.norm = nn.BatchNorm1d(n_chans)
52
elif self.norm == 'in':
53
self.norm = nn.InstanceNorm1d(n_chans, affine=True)
54
elif self.norm == 'gn':
55
self.norm = nn.GroupNorm(n_chans // 16, n_chans)
56
elif self.norm == 'ln':
57
self.norm = LayerNorm(n_chans // 16, n_chans)
58
elif self.norm == 'wn':
59
self.conv = torch.nn.utils.weight_norm(self.conv.conv)
60
self.dropout = nn.Dropout(dropout)
61
self.relu = nn.ReLU()
62
63
def forward(self, x):
64
"""
65
66
:param x: [B, C, T]
67
:return: [B, C, T]
68
"""
69
x = self.conv(x)
70
if not isinstance(self.norm, str):
71
if self.norm == 'none':
72
pass
73
elif self.norm == 'ln':
74
x = self.norm(x.transpose(1, 2)).transpose(1, 2)
75
else:
76
x = self.norm(x)
77
x = self.relu(x)
78
x = self.dropout(x)
79
return x
80
81
82
class ConvStacks(nn.Module):
83
def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn',
84
dropout=0, strides=None, res=True):
85
super().__init__()
86
self.conv = torch.nn.ModuleList()
87
self.kernel_size = kernel_size
88
self.res = res
89
self.in_proj = Linear(idim, n_chans)
90
if strides is None:
91
strides = [1] * n_layers
92
else:
93
assert len(strides) == n_layers
94
for idx in range(n_layers):
95
self.conv.append(ConvBlock(
96
n_chans, n_chans, kernel_size, stride=strides[idx], norm=norm, dropout=dropout))
97
self.out_proj = Linear(n_chans, odim)
98
99
def forward(self, x, return_hiddens=False):
100
"""
101
102
:param x: [B, T, H]
103
:return: [B, T, H]
104
"""
105
x = self.in_proj(x)
106
x = x.transpose(1, -1) # (B, idim, Tmax)
107
hiddens = []
108
for f in self.conv:
109
x_ = f(x)
110
x = x + x_ if self.res else x_ # (B, C, Tmax)
111
hiddens.append(x)
112
x = x.transpose(1, -1)
113
x = self.out_proj(x) # (B, Tmax, H)
114
if return_hiddens:
115
hiddens = torch.stack(hiddens, 1) # [B, L, C, T]
116
return x, hiddens
117
return x
118
119
120
class PitchExtractor(nn.Module):
121
def __init__(self, n_mel_bins=80, conv_layers=2):
122
super().__init__()
123
self.hidden_size = hparams['hidden_size']
124
self.predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
125
self.conv_layers = conv_layers
126
127
self.mel_prenet = Prenet(n_mel_bins, self.hidden_size, strides=[1, 1, 1])
128
if self.conv_layers > 0:
129
self.mel_encoder = ConvStacks(
130
idim=self.hidden_size, n_chans=self.hidden_size, odim=self.hidden_size, n_layers=self.conv_layers)
131
self.pitch_predictor = PitchPredictor(
132
self.hidden_size, n_chans=self.predictor_hidden,
133
n_layers=5, dropout_rate=0.1, odim=2,
134
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
135
136
def forward(self, mel_input=None):
137
ret = {}
138
mel_hidden = self.mel_prenet(mel_input)[1]
139
if self.conv_layers > 0:
140
mel_hidden = self.mel_encoder(mel_hidden)
141
142
ret['pitch_pred'] = pitch_pred = self.pitch_predictor(mel_hidden)
143
144
pitch_padding = mel_input.abs().sum(-1) == 0
145
use_uv = hparams['pitch_type'] == 'frame' #and hparams['use_uv']
146
ret['f0_denorm_pred'] = denorm_f0(
147
pitch_pred[:, :, 0], (pitch_pred[:, :, 1] > 0) if use_uv else None,
148
hparams, pitch_padding=pitch_padding)
149
return ret
150