Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/modules/fastspeech/tts_modules.py
694 views
1
import logging
2
import math
3
4
import torch
5
import torch.nn as nn
6
from torch.nn import functional as F
7
8
from modules.commons.espnet_positional_embedding import RelPositionalEncoding
9
from modules.commons.common_layers import SinusoidalPositionalEmbedding, Linear, EncSALayer, DecSALayer, BatchNorm1dTBC
10
from utils.hparams import hparams
11
12
DEFAULT_MAX_SOURCE_POSITIONS = 2000
13
DEFAULT_MAX_TARGET_POSITIONS = 2000
14
15
16
class TransformerEncoderLayer(nn.Module):
17
def __init__(self, hidden_size, dropout, kernel_size=None, num_heads=2, norm='ln'):
18
super().__init__()
19
self.hidden_size = hidden_size
20
self.dropout = dropout
21
self.num_heads = num_heads
22
self.op = EncSALayer(
23
hidden_size, num_heads, dropout=dropout,
24
attention_dropout=0.0, relu_dropout=dropout,
25
kernel_size=kernel_size
26
if kernel_size is not None else hparams['enc_ffn_kernel_size'],
27
padding=hparams['ffn_padding'],
28
norm=norm, act=hparams['ffn_act'])
29
30
def forward(self, x, **kwargs):
31
return self.op(x, **kwargs)
32
33
34
######################
35
# fastspeech modules
36
######################
37
class LayerNorm(torch.nn.LayerNorm):
38
"""Layer normalization module.
39
:param int nout: output dim size
40
:param int dim: dimension to be normalized
41
"""
42
43
def __init__(self, nout, dim=-1):
44
"""Construct an LayerNorm object."""
45
super(LayerNorm, self).__init__(nout, eps=1e-12)
46
self.dim = dim
47
48
def forward(self, x):
49
"""Apply layer normalization.
50
:param torch.Tensor x: input tensor
51
:return: layer normalized tensor
52
:rtype torch.Tensor
53
"""
54
if self.dim == -1:
55
return super(LayerNorm, self).forward(x)
56
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
57
58
59
class DurationPredictor(torch.nn.Module):
60
"""Duration predictor module.
61
This is a module of duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
62
The duration predictor predicts a duration of each frame in log domain from the hidden embeddings of encoder.
63
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
64
https://arxiv.org/pdf/1905.09263.pdf
65
Note:
66
The calculation domain of outputs is different between in `forward` and in `inference`. In `forward`,
67
the outputs are calculated in log domain but in `inference`, those are calculated in linear domain.
68
"""
69
70
def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0, padding='SAME'):
71
"""Initilize duration predictor module.
72
Args:
73
idim (int): Input dimension.
74
n_layers (int, optional): Number of convolutional layers.
75
n_chans (int, optional): Number of channels of convolutional layers.
76
kernel_size (int, optional): Kernel size of convolutional layers.
77
dropout_rate (float, optional): Dropout rate.
78
offset (float, optional): Offset value to avoid nan in log domain.
79
"""
80
super(DurationPredictor, self).__init__()
81
self.offset = offset
82
self.conv = torch.nn.ModuleList()
83
self.kernel_size = kernel_size
84
self.padding = padding
85
for idx in range(n_layers):
86
in_chans = idim if idx == 0 else n_chans
87
self.conv += [torch.nn.Sequential(
88
torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
89
if padding == 'SAME'
90
else (kernel_size - 1, 0), 0),
91
torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
92
torch.nn.ReLU(),
93
LayerNorm(n_chans, dim=1),
94
torch.nn.Dropout(dropout_rate)
95
)]
96
if hparams['dur_loss'] in ['mse', 'huber']:
97
odims = 1
98
elif hparams['dur_loss'] == 'mog':
99
odims = 15
100
elif hparams['dur_loss'] == 'crf':
101
odims = 32
102
from torchcrf import CRF
103
self.crf = CRF(odims, batch_first=True)
104
self.linear = torch.nn.Linear(n_chans, odims)
105
106
def _forward(self, xs, x_masks=None, is_inference=False):
107
xs = xs.transpose(1, -1) # (B, idim, Tmax)
108
for f in self.conv:
109
xs = f(xs) # (B, C, Tmax)
110
if x_masks is not None:
111
xs = xs * (1 - x_masks.float())[:, None, :]
112
113
xs = self.linear(xs.transpose(1, -1)) # [B, T, C]
114
xs = xs * (1 - x_masks.float())[:, :, None] # (B, T, C)
115
if is_inference:
116
return self.out2dur(xs), xs
117
else:
118
if hparams['dur_loss'] in ['mse']:
119
xs = xs.squeeze(-1) # (B, Tmax)
120
return xs
121
122
def out2dur(self, xs):
123
if hparams['dur_loss'] in ['mse']:
124
# NOTE: calculate in log domain
125
xs = xs.squeeze(-1) # (B, Tmax)
126
dur = torch.clamp(torch.round(xs.exp() - self.offset), min=0).long() # avoid negative value
127
elif hparams['dur_loss'] == 'mog':
128
return NotImplementedError
129
elif hparams['dur_loss'] == 'crf':
130
dur = torch.LongTensor(self.crf.decode(xs)).cuda()
131
return dur
132
133
def forward(self, xs, x_masks=None):
134
"""Calculate forward propagation.
135
Args:
136
xs (Tensor): Batch of input sequences (B, Tmax, idim).
137
x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
138
Returns:
139
Tensor: Batch of predicted durations in log domain (B, Tmax).
140
"""
141
return self._forward(xs, x_masks, False)
142
143
def inference(self, xs, x_masks=None):
144
"""Inference duration.
145
Args:
146
xs (Tensor): Batch of input sequences (B, Tmax, idim).
147
x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
148
Returns:
149
LongTensor: Batch of predicted durations in linear domain (B, Tmax).
150
"""
151
return self._forward(xs, x_masks, True)
152
153
154
class LengthRegulator(torch.nn.Module):
155
def __init__(self, pad_value=0.0):
156
super(LengthRegulator, self).__init__()
157
self.pad_value = pad_value
158
159
def forward(self, dur, dur_padding=None, alpha=1.0):
160
"""
161
Example (no batch dim version):
162
1. dur = [2,2,3]
163
2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4]
164
3. token_mask = [[1,1,0,0,0,0,0],
165
[0,0,1,1,0,0,0],
166
[0,0,0,0,1,1,1]]
167
4. token_idx * token_mask = [[1,1,0,0,0,0,0],
168
[0,0,2,2,0,0,0],
169
[0,0,0,0,3,3,3]]
170
5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3]
171
172
:param dur: Batch of durations of each frame (B, T_txt)
173
:param dur_padding: Batch of padding of each frame (B, T_txt)
174
:param alpha: duration rescale coefficient
175
:return:
176
mel2ph (B, T_speech)
177
"""
178
assert alpha > 0
179
dur = torch.round(dur.float() * alpha).long()
180
if dur_padding is not None:
181
dur = dur * (1 - dur_padding.long())
182
token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device)
183
dur_cumsum = torch.cumsum(dur, 1)
184
dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0)
185
186
pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device)
187
token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
188
mel2ph = (token_idx * token_mask.long()).sum(1)
189
return mel2ph
190
191
192
class PitchPredictor(torch.nn.Module):
193
def __init__(self, idim, n_layers=5, n_chans=384, odim=2, kernel_size=5,
194
dropout_rate=0.1, padding='SAME'):
195
"""Initilize pitch predictor module.
196
Args:
197
idim (int): Input dimension.
198
n_layers (int, optional): Number of convolutional layers.
199
n_chans (int, optional): Number of channels of convolutional layers.
200
kernel_size (int, optional): Kernel size of convolutional layers.
201
dropout_rate (float, optional): Dropout rate.
202
"""
203
super(PitchPredictor, self).__init__()
204
self.conv = torch.nn.ModuleList()
205
self.kernel_size = kernel_size
206
self.padding = padding
207
for idx in range(n_layers):
208
in_chans = idim if idx == 0 else n_chans
209
self.conv += [torch.nn.Sequential(
210
torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
211
if padding == 'SAME'
212
else (kernel_size - 1, 0), 0),
213
torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
214
torch.nn.ReLU(),
215
LayerNorm(n_chans, dim=1),
216
torch.nn.Dropout(dropout_rate)
217
)]
218
self.linear = torch.nn.Linear(n_chans, odim)
219
self.embed_positions = SinusoidalPositionalEmbedding(idim, 0, init_size=4096)
220
self.pos_embed_alpha = nn.Parameter(torch.Tensor([1]))
221
222
def forward(self, xs):
223
"""
224
225
:param xs: [B, T, H]
226
:return: [B, T, H]
227
"""
228
positions = self.pos_embed_alpha * self.embed_positions(xs[..., 0])
229
xs = xs + positions
230
xs = xs.transpose(1, -1) # (B, idim, Tmax)
231
for f in self.conv:
232
xs = f(xs) # (B, C, Tmax)
233
# NOTE: calculate in log domain
234
xs = self.linear(xs.transpose(1, -1)) # (B, Tmax, H)
235
return xs
236
237
238
class EnergyPredictor(PitchPredictor):
239
pass
240
241
242
def mel2ph_to_dur(mel2ph, T_txt, max_dur=None):
243
B, _ = mel2ph.shape
244
dur = mel2ph.new_zeros(B, T_txt + 1).scatter_add(1, mel2ph, torch.ones_like(mel2ph))
245
dur = dur[:, 1:]
246
if max_dur is not None:
247
dur = dur.clamp(max=max_dur)
248
return dur
249
250
251
class FFTBlocks(nn.Module):
252
def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=None, num_heads=2,
253
use_pos_embed=True, use_last_norm=True, norm='ln', use_pos_embed_alpha=True):
254
super().__init__()
255
self.num_layers = num_layers
256
embed_dim = self.hidden_size = hidden_size
257
self.dropout = dropout if dropout is not None else hparams['dropout']
258
self.use_pos_embed = use_pos_embed
259
self.use_last_norm = use_last_norm
260
if use_pos_embed:
261
self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
262
self.padding_idx = 0
263
self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
264
self.embed_positions = SinusoidalPositionalEmbedding(
265
embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
266
)
267
268
self.layers = nn.ModuleList([])
269
self.layers.extend([
270
TransformerEncoderLayer(self.hidden_size, self.dropout,
271
kernel_size=ffn_kernel_size, num_heads=num_heads)
272
for _ in range(self.num_layers)
273
])
274
if self.use_last_norm:
275
if norm == 'ln':
276
self.layer_norm = nn.LayerNorm(embed_dim)
277
elif norm == 'bn':
278
self.layer_norm = BatchNorm1dTBC(embed_dim)
279
else:
280
self.layer_norm = None
281
282
def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
283
"""
284
:param x: [B, T, C]
285
:param padding_mask: [B, T]
286
:return: [B, T, C] or [L, B, T, C]
287
"""
288
# padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
289
padding_mask = x.abs().sum(-1).eq(0).detach() if padding_mask is None else padding_mask
290
nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
291
if self.use_pos_embed:
292
positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
293
x = x + positions
294
x = F.dropout(x, p=self.dropout, training=self.training)
295
# B x T x C -> T x B x C
296
x = x.transpose(0, 1) * nonpadding_mask_TB
297
hiddens = []
298
for layer in self.layers:
299
x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
300
hiddens.append(x)
301
if self.use_last_norm:
302
x = self.layer_norm(x) * nonpadding_mask_TB
303
if return_hiddens:
304
x = torch.stack(hiddens, 0) # [L, T, B, C]
305
x = x.transpose(1, 2) # [L, B, T, C]
306
else:
307
x = x.transpose(0, 1) # [B, T, C]
308
return x
309
310
311
class FastspeechEncoder(FFTBlocks):
312
'''
313
compared to FFTBlocks:
314
- input is [B, T, H], not [B, T, C]
315
- supports "relative" positional encoding
316
'''
317
def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=2):
318
hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size
319
kernel_size = hparams['enc_ffn_kernel_size'] if kernel_size is None else kernel_size
320
num_layers = hparams['dec_layers'] if num_layers is None else num_layers
321
super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
322
use_pos_embed=False) # use_pos_embed_alpha for compatibility
323
#self.embed_tokens = embed_tokens
324
self.embed_scale = math.sqrt(hidden_size)
325
self.padding_idx = 0
326
if hparams.get('rel_pos') is not None and hparams['rel_pos']:
327
self.embed_positions = RelPositionalEncoding(hidden_size, dropout_rate=0.0)
328
else:
329
self.embed_positions = SinusoidalPositionalEmbedding(
330
hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
331
)
332
333
def forward(self, hubert):
334
"""
335
336
:param hubert: [B, T, H ]
337
:return: {
338
'encoder_out': [T x B x C]
339
}
340
"""
341
# encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
342
encoder_padding_mask = (hubert==0).all(-1)
343
x = self.forward_embedding(hubert) # [B, T, H]
344
x = super(FastspeechEncoder, self).forward(x, encoder_padding_mask)
345
return x
346
347
def forward_embedding(self, hubert):
348
# embed tokens and positions
349
x = self.embed_scale * hubert
350
if hparams['use_pos_embed']:
351
positions = self.embed_positions(hubert)
352
x = x + positions
353
x = F.dropout(x, p=self.dropout, training=self.training)
354
return x
355
356
357
class FastspeechDecoder(FFTBlocks):
358
def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=None):
359
num_heads = hparams['num_heads'] if num_heads is None else num_heads
360
hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size
361
kernel_size = hparams['dec_ffn_kernel_size'] if kernel_size is None else kernel_size
362
num_layers = hparams['dec_layers'] if num_layers is None else num_layers
363
super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads)
364
365
366