Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/modules/parallel_wavegan/models/parallel_wavegan.py
694 views
1
# -*- coding: utf-8 -*-
2
3
# Copyright 2019 Tomoki Hayashi
4
# MIT License (https://opensource.org/licenses/MIT)
5
6
"""Parallel WaveGAN Modules."""
7
8
import logging
9
import math
10
11
import torch
12
from torch import nn
13
14
from modules.parallel_wavegan.layers import Conv1d
15
from modules.parallel_wavegan.layers import Conv1d1x1
16
from modules.parallel_wavegan.layers import ResidualBlock
17
from modules.parallel_wavegan.layers import upsample
18
from modules.parallel_wavegan import models
19
20
21
class ParallelWaveGANGenerator(torch.nn.Module):
22
"""Parallel WaveGAN Generator module."""
23
24
def __init__(self,
25
in_channels=1,
26
out_channels=1,
27
kernel_size=3,
28
layers=30,
29
stacks=3,
30
residual_channels=64,
31
gate_channels=128,
32
skip_channels=64,
33
aux_channels=80,
34
aux_context_window=2,
35
dropout=0.0,
36
bias=True,
37
use_weight_norm=True,
38
use_causal_conv=False,
39
upsample_conditional_features=True,
40
upsample_net="ConvInUpsampleNetwork",
41
upsample_params={"upsample_scales": [4, 4, 4, 4]},
42
use_pitch_embed=False,
43
):
44
"""Initialize Parallel WaveGAN Generator module.
45
46
Args:
47
in_channels (int): Number of input channels.
48
out_channels (int): Number of output channels.
49
kernel_size (int): Kernel size of dilated convolution.
50
layers (int): Number of residual block layers.
51
stacks (int): Number of stacks i.e., dilation cycles.
52
residual_channels (int): Number of channels in residual conv.
53
gate_channels (int): Number of channels in gated conv.
54
skip_channels (int): Number of channels in skip conv.
55
aux_channels (int): Number of channels for auxiliary feature conv.
56
aux_context_window (int): Context window size for auxiliary feature.
57
dropout (float): Dropout rate. 0.0 means no dropout applied.
58
bias (bool): Whether to use bias parameter in conv layer.
59
use_weight_norm (bool): Whether to use weight norm.
60
If set to true, it will be applied to all of the conv layers.
61
use_causal_conv (bool): Whether to use causal structure.
62
upsample_conditional_features (bool): Whether to use upsampling network.
63
upsample_net (str): Upsampling network architecture.
64
upsample_params (dict): Upsampling network parameters.
65
66
"""
67
super(ParallelWaveGANGenerator, self).__init__()
68
self.in_channels = in_channels
69
self.out_channels = out_channels
70
self.aux_channels = aux_channels
71
self.layers = layers
72
self.stacks = stacks
73
self.kernel_size = kernel_size
74
75
# check the number of layers and stacks
76
assert layers % stacks == 0
77
layers_per_stack = layers // stacks
78
79
# define first convolution
80
self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
81
82
# define conv + upsampling network
83
if upsample_conditional_features:
84
upsample_params.update({
85
"use_causal_conv": use_causal_conv,
86
})
87
if upsample_net == "MelGANGenerator":
88
assert aux_context_window == 0
89
upsample_params.update({
90
"use_weight_norm": False, # not to apply twice
91
"use_final_nonlinear_activation": False,
92
})
93
self.upsample_net = getattr(models, upsample_net)(**upsample_params)
94
else:
95
if upsample_net == "ConvInUpsampleNetwork":
96
upsample_params.update({
97
"aux_channels": aux_channels,
98
"aux_context_window": aux_context_window,
99
})
100
self.upsample_net = getattr(upsample, upsample_net)(**upsample_params)
101
else:
102
self.upsample_net = None
103
104
# define residual blocks
105
self.conv_layers = torch.nn.ModuleList()
106
for layer in range(layers):
107
dilation = 2 ** (layer % layers_per_stack)
108
conv = ResidualBlock(
109
kernel_size=kernel_size,
110
residual_channels=residual_channels,
111
gate_channels=gate_channels,
112
skip_channels=skip_channels,
113
aux_channels=aux_channels,
114
dilation=dilation,
115
dropout=dropout,
116
bias=bias,
117
use_causal_conv=use_causal_conv,
118
)
119
self.conv_layers += [conv]
120
121
# define output layers
122
self.last_conv_layers = torch.nn.ModuleList([
123
torch.nn.ReLU(inplace=True),
124
Conv1d1x1(skip_channels, skip_channels, bias=True),
125
torch.nn.ReLU(inplace=True),
126
Conv1d1x1(skip_channels, out_channels, bias=True),
127
])
128
129
self.use_pitch_embed = use_pitch_embed
130
if use_pitch_embed:
131
self.pitch_embed = nn.Embedding(300, aux_channels, 0)
132
self.c_proj = nn.Linear(2 * aux_channels, aux_channels)
133
134
# apply weight norm
135
if use_weight_norm:
136
self.apply_weight_norm()
137
138
def forward(self, x, c=None, pitch=None, **kwargs):
139
"""Calculate forward propagation.
140
141
Args:
142
x (Tensor): Input noise signal (B, C_in, T).
143
c (Tensor): Local conditioning auxiliary features (B, C ,T').
144
pitch (Tensor): Local conditioning pitch (B, T').
145
146
Returns:
147
Tensor: Output tensor (B, C_out, T)
148
149
"""
150
# perform upsampling
151
if c is not None and self.upsample_net is not None:
152
if self.use_pitch_embed:
153
p = self.pitch_embed(pitch)
154
c = self.c_proj(torch.cat([c.transpose(1, 2), p], -1)).transpose(1, 2)
155
c = self.upsample_net(c)
156
assert c.size(-1) == x.size(-1), (c.size(-1), x.size(-1))
157
158
# encode to hidden representation
159
x = self.first_conv(x)
160
skips = 0
161
for f in self.conv_layers:
162
x, h = f(x, c)
163
skips += h
164
skips *= math.sqrt(1.0 / len(self.conv_layers))
165
166
# apply final layers
167
x = skips
168
for f in self.last_conv_layers:
169
x = f(x)
170
171
return x
172
173
def remove_weight_norm(self):
174
"""Remove weight normalization module from all of the layers."""
175
def _remove_weight_norm(m):
176
try:
177
logging.debug(f"Weight norm is removed from {m}.")
178
torch.nn.utils.remove_weight_norm(m)
179
except ValueError: # this module didn't have weight norm
180
return
181
182
self.apply(_remove_weight_norm)
183
184
def apply_weight_norm(self):
185
"""Apply weight normalization module from all of the layers."""
186
def _apply_weight_norm(m):
187
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
188
torch.nn.utils.weight_norm(m)
189
logging.debug(f"Weight norm is applied to {m}.")
190
191
self.apply(_apply_weight_norm)
192
193
@staticmethod
194
def _get_receptive_field_size(layers, stacks, kernel_size,
195
dilation=lambda x: 2 ** x):
196
assert layers % stacks == 0
197
layers_per_cycle = layers // stacks
198
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
199
return (kernel_size - 1) * sum(dilations) + 1
200
201
@property
202
def receptive_field_size(self):
203
"""Return receptive field size."""
204
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
205
206
207
class ParallelWaveGANDiscriminator(torch.nn.Module):
208
"""Parallel WaveGAN Discriminator module."""
209
210
def __init__(self,
211
in_channels=1,
212
out_channels=1,
213
kernel_size=3,
214
layers=10,
215
conv_channels=64,
216
dilation_factor=1,
217
nonlinear_activation="LeakyReLU",
218
nonlinear_activation_params={"negative_slope": 0.2},
219
bias=True,
220
use_weight_norm=True,
221
):
222
"""Initialize Parallel WaveGAN Discriminator module.
223
224
Args:
225
in_channels (int): Number of input channels.
226
out_channels (int): Number of output channels.
227
kernel_size (int): Number of output channels.
228
layers (int): Number of conv layers.
229
conv_channels (int): Number of chnn layers.
230
dilation_factor (int): Dilation factor. For example, if dilation_factor = 2,
231
the dilation will be 2, 4, 8, ..., and so on.
232
nonlinear_activation (str): Nonlinear function after each conv.
233
nonlinear_activation_params (dict): Nonlinear function parameters
234
bias (bool): Whether to use bias parameter in conv.
235
use_weight_norm (bool) Whether to use weight norm.
236
If set to true, it will be applied to all of the conv layers.
237
238
"""
239
super(ParallelWaveGANDiscriminator, self).__init__()
240
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
241
assert dilation_factor > 0, "Dilation factor must be > 0."
242
self.conv_layers = torch.nn.ModuleList()
243
conv_in_channels = in_channels
244
for i in range(layers - 1):
245
if i == 0:
246
dilation = 1
247
else:
248
dilation = i if dilation_factor == 1 else dilation_factor ** i
249
conv_in_channels = conv_channels
250
padding = (kernel_size - 1) // 2 * dilation
251
conv_layer = [
252
Conv1d(conv_in_channels, conv_channels,
253
kernel_size=kernel_size, padding=padding,
254
dilation=dilation, bias=bias),
255
getattr(torch.nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params)
256
]
257
self.conv_layers += conv_layer
258
padding = (kernel_size - 1) // 2
259
last_conv_layer = Conv1d(
260
conv_in_channels, out_channels,
261
kernel_size=kernel_size, padding=padding, bias=bias)
262
self.conv_layers += [last_conv_layer]
263
264
# apply weight norm
265
if use_weight_norm:
266
self.apply_weight_norm()
267
268
def forward(self, x):
269
"""Calculate forward propagation.
270
271
Args:
272
x (Tensor): Input noise signal (B, 1, T).
273
274
Returns:
275
Tensor: Output tensor (B, 1, T)
276
277
"""
278
for f in self.conv_layers:
279
x = f(x)
280
return x
281
282
def apply_weight_norm(self):
283
"""Apply weight normalization module from all of the layers."""
284
def _apply_weight_norm(m):
285
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
286
torch.nn.utils.weight_norm(m)
287
logging.debug(f"Weight norm is applied to {m}.")
288
289
self.apply(_apply_weight_norm)
290
291
def remove_weight_norm(self):
292
"""Remove weight normalization module from all of the layers."""
293
def _remove_weight_norm(m):
294
try:
295
logging.debug(f"Weight norm is removed from {m}.")
296
torch.nn.utils.remove_weight_norm(m)
297
except ValueError: # this module didn't have weight norm
298
return
299
300
self.apply(_remove_weight_norm)
301
302
303
class ResidualParallelWaveGANDiscriminator(torch.nn.Module):
304
"""Parallel WaveGAN Discriminator module."""
305
306
def __init__(self,
307
in_channels=1,
308
out_channels=1,
309
kernel_size=3,
310
layers=30,
311
stacks=3,
312
residual_channels=64,
313
gate_channels=128,
314
skip_channels=64,
315
dropout=0.0,
316
bias=True,
317
use_weight_norm=True,
318
use_causal_conv=False,
319
nonlinear_activation="LeakyReLU",
320
nonlinear_activation_params={"negative_slope": 0.2},
321
):
322
"""Initialize Parallel WaveGAN Discriminator module.
323
324
Args:
325
in_channels (int): Number of input channels.
326
out_channels (int): Number of output channels.
327
kernel_size (int): Kernel size of dilated convolution.
328
layers (int): Number of residual block layers.
329
stacks (int): Number of stacks i.e., dilation cycles.
330
residual_channels (int): Number of channels in residual conv.
331
gate_channels (int): Number of channels in gated conv.
332
skip_channels (int): Number of channels in skip conv.
333
dropout (float): Dropout rate. 0.0 means no dropout applied.
334
bias (bool): Whether to use bias parameter in conv.
335
use_weight_norm (bool): Whether to use weight norm.
336
If set to true, it will be applied to all of the conv layers.
337
use_causal_conv (bool): Whether to use causal structure.
338
nonlinear_activation_params (dict): Nonlinear function parameters
339
340
"""
341
super(ResidualParallelWaveGANDiscriminator, self).__init__()
342
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
343
344
self.in_channels = in_channels
345
self.out_channels = out_channels
346
self.layers = layers
347
self.stacks = stacks
348
self.kernel_size = kernel_size
349
350
# check the number of layers and stacks
351
assert layers % stacks == 0
352
layers_per_stack = layers // stacks
353
354
# define first convolution
355
self.first_conv = torch.nn.Sequential(
356
Conv1d1x1(in_channels, residual_channels, bias=True),
357
getattr(torch.nn, nonlinear_activation)(
358
inplace=True, **nonlinear_activation_params),
359
)
360
361
# define residual blocks
362
self.conv_layers = torch.nn.ModuleList()
363
for layer in range(layers):
364
dilation = 2 ** (layer % layers_per_stack)
365
conv = ResidualBlock(
366
kernel_size=kernel_size,
367
residual_channels=residual_channels,
368
gate_channels=gate_channels,
369
skip_channels=skip_channels,
370
aux_channels=-1,
371
dilation=dilation,
372
dropout=dropout,
373
bias=bias,
374
use_causal_conv=use_causal_conv,
375
)
376
self.conv_layers += [conv]
377
378
# define output layers
379
self.last_conv_layers = torch.nn.ModuleList([
380
getattr(torch.nn, nonlinear_activation)(
381
inplace=True, **nonlinear_activation_params),
382
Conv1d1x1(skip_channels, skip_channels, bias=True),
383
getattr(torch.nn, nonlinear_activation)(
384
inplace=True, **nonlinear_activation_params),
385
Conv1d1x1(skip_channels, out_channels, bias=True),
386
])
387
388
# apply weight norm
389
if use_weight_norm:
390
self.apply_weight_norm()
391
392
def forward(self, x):
393
"""Calculate forward propagation.
394
395
Args:
396
x (Tensor): Input noise signal (B, 1, T).
397
398
Returns:
399
Tensor: Output tensor (B, 1, T)
400
401
"""
402
x = self.first_conv(x)
403
404
skips = 0
405
for f in self.conv_layers:
406
x, h = f(x, None)
407
skips += h
408
skips *= math.sqrt(1.0 / len(self.conv_layers))
409
410
# apply final layers
411
x = skips
412
for f in self.last_conv_layers:
413
x = f(x)
414
return x
415
416
def apply_weight_norm(self):
417
"""Apply weight normalization module from all of the layers."""
418
def _apply_weight_norm(m):
419
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
420
torch.nn.utils.weight_norm(m)
421
logging.debug(f"Weight norm is applied to {m}.")
422
423
self.apply(_apply_weight_norm)
424
425
def remove_weight_norm(self):
426
"""Remove weight normalization module from all of the layers."""
427
def _remove_weight_norm(m):
428
try:
429
logging.debug(f"Weight norm is removed from {m}.")
430
torch.nn.utils.remove_weight_norm(m)
431
except ValueError: # this module didn't have weight norm
432
return
433
434
self.apply(_remove_weight_norm)
435
436