Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/modules/parallel_wavegan/models/melgan.py
694 views
1
# -*- coding: utf-8 -*-
2
3
# Copyright 2020 Tomoki Hayashi
4
# MIT License (https://opensource.org/licenses/MIT)
5
6
"""MelGAN Modules."""
7
8
import logging
9
10
import numpy as np
11
import torch
12
13
from modules.parallel_wavegan.layers import CausalConv1d
14
from modules.parallel_wavegan.layers import CausalConvTranspose1d
15
from modules.parallel_wavegan.layers import ResidualStack
16
17
18
class MelGANGenerator(torch.nn.Module):
19
"""MelGAN generator module."""
20
21
def __init__(self,
22
in_channels=80,
23
out_channels=1,
24
kernel_size=7,
25
channels=512,
26
bias=True,
27
upsample_scales=[8, 8, 2, 2],
28
stack_kernel_size=3,
29
stacks=3,
30
nonlinear_activation="LeakyReLU",
31
nonlinear_activation_params={"negative_slope": 0.2},
32
pad="ReflectionPad1d",
33
pad_params={},
34
use_final_nonlinear_activation=True,
35
use_weight_norm=True,
36
use_causal_conv=False,
37
):
38
"""Initialize MelGANGenerator module.
39
40
Args:
41
in_channels (int): Number of input channels.
42
out_channels (int): Number of output channels.
43
kernel_size (int): Kernel size of initial and final conv layer.
44
channels (int): Initial number of channels for conv layer.
45
bias (bool): Whether to add bias parameter in convolution layers.
46
upsample_scales (list): List of upsampling scales.
47
stack_kernel_size (int): Kernel size of dilated conv layers in residual stack.
48
stacks (int): Number of stacks in a single residual stack.
49
nonlinear_activation (str): Activation function module name.
50
nonlinear_activation_params (dict): Hyperparameters for activation function.
51
pad (str): Padding function module name before dilated convolution layer.
52
pad_params (dict): Hyperparameters for padding function.
53
use_final_nonlinear_activation (torch.nn.Module): Activation function for the final layer.
54
use_weight_norm (bool): Whether to use weight norm.
55
If set to true, it will be applied to all of the conv layers.
56
use_causal_conv (bool): Whether to use causal convolution.
57
58
"""
59
super(MelGANGenerator, self).__init__()
60
61
# check hyper parameters is valid
62
assert channels >= np.prod(upsample_scales)
63
assert channels % (2 ** len(upsample_scales)) == 0
64
if not use_causal_conv:
65
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
66
67
# add initial layer
68
layers = []
69
if not use_causal_conv:
70
layers += [
71
getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params),
72
torch.nn.Conv1d(in_channels, channels, kernel_size, bias=bias),
73
]
74
else:
75
layers += [
76
CausalConv1d(in_channels, channels, kernel_size,
77
bias=bias, pad=pad, pad_params=pad_params),
78
]
79
80
for i, upsample_scale in enumerate(upsample_scales):
81
# add upsampling layer
82
layers += [getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)]
83
if not use_causal_conv:
84
layers += [
85
torch.nn.ConvTranspose1d(
86
channels // (2 ** i),
87
channels // (2 ** (i + 1)),
88
upsample_scale * 2,
89
stride=upsample_scale,
90
padding=upsample_scale // 2 + upsample_scale % 2,
91
output_padding=upsample_scale % 2,
92
bias=bias,
93
)
94
]
95
else:
96
layers += [
97
CausalConvTranspose1d(
98
channels // (2 ** i),
99
channels // (2 ** (i + 1)),
100
upsample_scale * 2,
101
stride=upsample_scale,
102
bias=bias,
103
)
104
]
105
106
# add residual stack
107
for j in range(stacks):
108
layers += [
109
ResidualStack(
110
kernel_size=stack_kernel_size,
111
channels=channels // (2 ** (i + 1)),
112
dilation=stack_kernel_size ** j,
113
bias=bias,
114
nonlinear_activation=nonlinear_activation,
115
nonlinear_activation_params=nonlinear_activation_params,
116
pad=pad,
117
pad_params=pad_params,
118
use_causal_conv=use_causal_conv,
119
)
120
]
121
122
# add final layer
123
layers += [getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)]
124
if not use_causal_conv:
125
layers += [
126
getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params),
127
torch.nn.Conv1d(channels // (2 ** (i + 1)), out_channels, kernel_size, bias=bias),
128
]
129
else:
130
layers += [
131
CausalConv1d(channels // (2 ** (i + 1)), out_channels, kernel_size,
132
bias=bias, pad=pad, pad_params=pad_params),
133
]
134
if use_final_nonlinear_activation:
135
layers += [torch.nn.Tanh()]
136
137
# define the model as a single function
138
self.melgan = torch.nn.Sequential(*layers)
139
140
# apply weight norm
141
if use_weight_norm:
142
self.apply_weight_norm()
143
144
# reset parameters
145
self.reset_parameters()
146
147
def forward(self, c):
148
"""Calculate forward propagation.
149
150
Args:
151
c (Tensor): Input tensor (B, channels, T).
152
153
Returns:
154
Tensor: Output tensor (B, 1, T ** prod(upsample_scales)).
155
156
"""
157
return self.melgan(c)
158
159
def remove_weight_norm(self):
160
"""Remove weight normalization module from all of the layers."""
161
def _remove_weight_norm(m):
162
try:
163
logging.debug(f"Weight norm is removed from {m}.")
164
torch.nn.utils.remove_weight_norm(m)
165
except ValueError: # this module didn't have weight norm
166
return
167
168
self.apply(_remove_weight_norm)
169
170
def apply_weight_norm(self):
171
"""Apply weight normalization module from all of the layers."""
172
def _apply_weight_norm(m):
173
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
174
torch.nn.utils.weight_norm(m)
175
logging.debug(f"Weight norm is applied to {m}.")
176
177
self.apply(_apply_weight_norm)
178
179
def reset_parameters(self):
180
"""Reset parameters.
181
182
This initialization follows official implementation manner.
183
https://github.com/descriptinc/melgan-neurips/blob/master/spec2wav/modules.py
184
185
"""
186
def _reset_parameters(m):
187
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
188
m.weight.data.normal_(0.0, 0.02)
189
logging.debug(f"Reset parameters in {m}.")
190
191
self.apply(_reset_parameters)
192
193
194
class MelGANDiscriminator(torch.nn.Module):
195
"""MelGAN discriminator module."""
196
197
def __init__(self,
198
in_channels=1,
199
out_channels=1,
200
kernel_sizes=[5, 3],
201
channels=16,
202
max_downsample_channels=1024,
203
bias=True,
204
downsample_scales=[4, 4, 4, 4],
205
nonlinear_activation="LeakyReLU",
206
nonlinear_activation_params={"negative_slope": 0.2},
207
pad="ReflectionPad1d",
208
pad_params={},
209
):
210
"""Initilize MelGAN discriminator module.
211
212
Args:
213
in_channels (int): Number of input channels.
214
out_channels (int): Number of output channels.
215
kernel_sizes (list): List of two kernel sizes. The prod will be used for the first conv layer,
216
and the first and the second kernel sizes will be used for the last two layers.
217
For example if kernel_sizes = [5, 3], the first layer kernel size will be 5 * 3 = 15,
218
the last two layers' kernel size will be 5 and 3, respectively.
219
channels (int): Initial number of channels for conv layer.
220
max_downsample_channels (int): Maximum number of channels for downsampling layers.
221
bias (bool): Whether to add bias parameter in convolution layers.
222
downsample_scales (list): List of downsampling scales.
223
nonlinear_activation (str): Activation function module name.
224
nonlinear_activation_params (dict): Hyperparameters for activation function.
225
pad (str): Padding function module name before dilated convolution layer.
226
pad_params (dict): Hyperparameters for padding function.
227
228
"""
229
super(MelGANDiscriminator, self).__init__()
230
self.layers = torch.nn.ModuleList()
231
232
# check kernel size is valid
233
assert len(kernel_sizes) == 2
234
assert kernel_sizes[0] % 2 == 1
235
assert kernel_sizes[1] % 2 == 1
236
237
# add first layer
238
self.layers += [
239
torch.nn.Sequential(
240
getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
241
torch.nn.Conv1d(in_channels, channels, np.prod(kernel_sizes), bias=bias),
242
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
243
)
244
]
245
246
# add downsample layers
247
in_chs = channels
248
for downsample_scale in downsample_scales:
249
out_chs = min(in_chs * downsample_scale, max_downsample_channels)
250
self.layers += [
251
torch.nn.Sequential(
252
torch.nn.Conv1d(
253
in_chs, out_chs,
254
kernel_size=downsample_scale * 10 + 1,
255
stride=downsample_scale,
256
padding=downsample_scale * 5,
257
groups=in_chs // 4,
258
bias=bias,
259
),
260
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
261
)
262
]
263
in_chs = out_chs
264
265
# add final layers
266
out_chs = min(in_chs * 2, max_downsample_channels)
267
self.layers += [
268
torch.nn.Sequential(
269
torch.nn.Conv1d(
270
in_chs, out_chs, kernel_sizes[0],
271
padding=(kernel_sizes[0] - 1) // 2,
272
bias=bias,
273
),
274
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
275
)
276
]
277
self.layers += [
278
torch.nn.Conv1d(
279
out_chs, out_channels, kernel_sizes[1],
280
padding=(kernel_sizes[1] - 1) // 2,
281
bias=bias,
282
),
283
]
284
285
def forward(self, x):
286
"""Calculate forward propagation.
287
288
Args:
289
x (Tensor): Input noise signal (B, 1, T).
290
291
Returns:
292
List: List of output tensors of each layer.
293
294
"""
295
outs = []
296
for f in self.layers:
297
x = f(x)
298
outs += [x]
299
300
return outs
301
302
303
class MelGANMultiScaleDiscriminator(torch.nn.Module):
304
"""MelGAN multi-scale discriminator module."""
305
306
def __init__(self,
307
in_channels=1,
308
out_channels=1,
309
scales=3,
310
downsample_pooling="AvgPool1d",
311
# follow the official implementation setting
312
downsample_pooling_params={
313
"kernel_size": 4,
314
"stride": 2,
315
"padding": 1,
316
"count_include_pad": False,
317
},
318
kernel_sizes=[5, 3],
319
channels=16,
320
max_downsample_channels=1024,
321
bias=True,
322
downsample_scales=[4, 4, 4, 4],
323
nonlinear_activation="LeakyReLU",
324
nonlinear_activation_params={"negative_slope": 0.2},
325
pad="ReflectionPad1d",
326
pad_params={},
327
use_weight_norm=True,
328
):
329
"""Initilize MelGAN multi-scale discriminator module.
330
331
Args:
332
in_channels (int): Number of input channels.
333
out_channels (int): Number of output channels.
334
downsample_pooling (str): Pooling module name for downsampling of the inputs.
335
downsample_pooling_params (dict): Parameters for the above pooling module.
336
kernel_sizes (list): List of two kernel sizes. The sum will be used for the first conv layer,
337
and the first and the second kernel sizes will be used for the last two layers.
338
channels (int): Initial number of channels for conv layer.
339
max_downsample_channels (int): Maximum number of channels for downsampling layers.
340
bias (bool): Whether to add bias parameter in convolution layers.
341
downsample_scales (list): List of downsampling scales.
342
nonlinear_activation (str): Activation function module name.
343
nonlinear_activation_params (dict): Hyperparameters for activation function.
344
pad (str): Padding function module name before dilated convolution layer.
345
pad_params (dict): Hyperparameters for padding function.
346
use_causal_conv (bool): Whether to use causal convolution.
347
348
"""
349
super(MelGANMultiScaleDiscriminator, self).__init__()
350
self.discriminators = torch.nn.ModuleList()
351
352
# add discriminators
353
for _ in range(scales):
354
self.discriminators += [
355
MelGANDiscriminator(
356
in_channels=in_channels,
357
out_channels=out_channels,
358
kernel_sizes=kernel_sizes,
359
channels=channels,
360
max_downsample_channels=max_downsample_channels,
361
bias=bias,
362
downsample_scales=downsample_scales,
363
nonlinear_activation=nonlinear_activation,
364
nonlinear_activation_params=nonlinear_activation_params,
365
pad=pad,
366
pad_params=pad_params,
367
)
368
]
369
self.pooling = getattr(torch.nn, downsample_pooling)(**downsample_pooling_params)
370
371
# apply weight norm
372
if use_weight_norm:
373
self.apply_weight_norm()
374
375
# reset parameters
376
self.reset_parameters()
377
378
def forward(self, x):
379
"""Calculate forward propagation.
380
381
Args:
382
x (Tensor): Input noise signal (B, 1, T).
383
384
Returns:
385
List: List of list of each discriminator outputs, which consists of each layer output tensors.
386
387
"""
388
outs = []
389
for f in self.discriminators:
390
outs += [f(x)]
391
x = self.pooling(x)
392
393
return outs
394
395
def remove_weight_norm(self):
396
"""Remove weight normalization module from all of the layers."""
397
def _remove_weight_norm(m):
398
try:
399
logging.debug(f"Weight norm is removed from {m}.")
400
torch.nn.utils.remove_weight_norm(m)
401
except ValueError: # this module didn't have weight norm
402
return
403
404
self.apply(_remove_weight_norm)
405
406
def apply_weight_norm(self):
407
"""Apply weight normalization module from all of the layers."""
408
def _apply_weight_norm(m):
409
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
410
torch.nn.utils.weight_norm(m)
411
logging.debug(f"Weight norm is applied to {m}.")
412
413
self.apply(_apply_weight_norm)
414
415
def reset_parameters(self):
416
"""Reset parameters.
417
418
This initialization follows official implementation manner.
419
https://github.com/descriptinc/melgan-neurips/blob/master/spec2wav/modules.py
420
421
"""
422
def _reset_parameters(m):
423
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
424
m.weight.data.normal_(0.0, 0.02)
425
logging.debug(f"Reset parameters in {m}.")
426
427
self.apply(_reset_parameters)
428
429