Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/modules/hifigan/hifigan.py
697 views
1
import torch
2
import torch.nn.functional as F
3
import torch.nn as nn
4
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
7
from modules.parallel_wavegan.layers import UpsampleNetwork, ConvInUpsampleNetwork
8
from modules.parallel_wavegan.models.source import SourceModuleHnNSF
9
import numpy as np
10
11
LRELU_SLOPE = 0.1
12
13
14
def init_weights(m, mean=0.0, std=0.01):
15
classname = m.__class__.__name__
16
if classname.find("Conv") != -1:
17
m.weight.data.normal_(mean, std)
18
19
20
def apply_weight_norm(m):
21
classname = m.__class__.__name__
22
if classname.find("Conv") != -1:
23
weight_norm(m)
24
25
26
def get_padding(kernel_size, dilation=1):
27
return int((kernel_size * dilation - dilation) / 2)
28
29
30
class ResBlock1(torch.nn.Module):
31
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
32
super(ResBlock1, self).__init__()
33
self.h = h
34
self.convs1 = nn.ModuleList([
35
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
36
padding=get_padding(kernel_size, dilation[0]))),
37
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
38
padding=get_padding(kernel_size, dilation[1]))),
39
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
40
padding=get_padding(kernel_size, dilation[2])))
41
])
42
self.convs1.apply(init_weights)
43
44
self.convs2 = nn.ModuleList([
45
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
46
padding=get_padding(kernel_size, 1))),
47
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
48
padding=get_padding(kernel_size, 1))),
49
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
50
padding=get_padding(kernel_size, 1)))
51
])
52
self.convs2.apply(init_weights)
53
54
def forward(self, x):
55
for c1, c2 in zip(self.convs1, self.convs2):
56
xt = F.leaky_relu(x, LRELU_SLOPE)
57
xt = c1(xt)
58
xt = F.leaky_relu(xt, LRELU_SLOPE)
59
xt = c2(xt)
60
x = xt + x
61
return x
62
63
def remove_weight_norm(self):
64
for l in self.convs1:
65
remove_weight_norm(l)
66
for l in self.convs2:
67
remove_weight_norm(l)
68
69
70
class ResBlock2(torch.nn.Module):
71
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
72
super(ResBlock2, self).__init__()
73
self.h = h
74
self.convs = nn.ModuleList([
75
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
76
padding=get_padding(kernel_size, dilation[0]))),
77
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
78
padding=get_padding(kernel_size, dilation[1])))
79
])
80
self.convs.apply(init_weights)
81
82
def forward(self, x):
83
for c in self.convs:
84
xt = F.leaky_relu(x, LRELU_SLOPE)
85
xt = c(xt)
86
x = xt + x
87
return x
88
89
def remove_weight_norm(self):
90
for l in self.convs:
91
remove_weight_norm(l)
92
93
94
class Conv1d1x1(Conv1d):
95
"""1x1 Conv1d with customized initialization."""
96
97
def __init__(self, in_channels, out_channels, bias):
98
"""Initialize 1x1 Conv1d module."""
99
super(Conv1d1x1, self).__init__(in_channels, out_channels,
100
kernel_size=1, padding=0,
101
dilation=1, bias=bias)
102
103
104
class HifiGanGenerator(torch.nn.Module):
105
def __init__(self, h, c_out=1):
106
super(HifiGanGenerator, self).__init__()
107
self.h = h
108
self.num_kernels = len(h['resblock_kernel_sizes'])
109
self.num_upsamples = len(h['upsample_rates'])
110
111
if h['use_pitch_embed']:
112
self.harmonic_num = 8
113
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h['upsample_rates']))
114
self.m_source = SourceModuleHnNSF(
115
sampling_rate=h['audio_sample_rate'],
116
harmonic_num=self.harmonic_num)
117
self.noise_convs = nn.ModuleList()
118
self.conv_pre = weight_norm(Conv1d(80, h['upsample_initial_channel'], 7, 1, padding=3))
119
resblock = ResBlock1 if h['resblock'] == '1' else ResBlock2
120
121
self.ups = nn.ModuleList()
122
for i, (u, k) in enumerate(zip(h['upsample_rates'], h['upsample_kernel_sizes'])):
123
c_cur = h['upsample_initial_channel'] // (2 ** (i + 1))
124
self.ups.append(weight_norm(
125
ConvTranspose1d(c_cur * 2, c_cur, k, u, padding=(k - u) // 2)))
126
if h['use_pitch_embed']:
127
if i + 1 < len(h['upsample_rates']):
128
stride_f0 = np.prod(h['upsample_rates'][i + 1:])
129
self.noise_convs.append(Conv1d(
130
1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
131
else:
132
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
133
134
self.resblocks = nn.ModuleList()
135
for i in range(len(self.ups)):
136
ch = h['upsample_initial_channel'] // (2 ** (i + 1))
137
for j, (k, d) in enumerate(zip(h['resblock_kernel_sizes'], h['resblock_dilation_sizes'])):
138
self.resblocks.append(resblock(h, ch, k, d))
139
140
self.conv_post = weight_norm(Conv1d(ch, c_out, 7, 1, padding=3))
141
self.ups.apply(init_weights)
142
self.conv_post.apply(init_weights)
143
144
def forward(self, x, f0=None):
145
if f0 is not None:
146
# harmonic-source signal, noise-source signal, uv flag
147
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)
148
har_source, noi_source, uv = self.m_source(f0)
149
har_source = har_source.transpose(1, 2)
150
151
x = self.conv_pre(x)
152
for i in range(self.num_upsamples):
153
x = F.leaky_relu(x, LRELU_SLOPE)
154
x = self.ups[i](x)
155
if f0 is not None:
156
x_source = self.noise_convs[i](har_source)
157
x = x + x_source
158
xs = None
159
for j in range(self.num_kernels):
160
if xs is None:
161
xs = self.resblocks[i * self.num_kernels + j](x)
162
else:
163
xs += self.resblocks[i * self.num_kernels + j](x)
164
x = xs / self.num_kernels
165
x = F.leaky_relu(x)
166
x = self.conv_post(x)
167
x = torch.tanh(x)
168
169
return x
170
171
def remove_weight_norm(self):
172
print('Removing weight norm...')
173
for l in self.ups:
174
remove_weight_norm(l)
175
for l in self.resblocks:
176
l.remove_weight_norm()
177
remove_weight_norm(self.conv_pre)
178
remove_weight_norm(self.conv_post)
179
180
181
class DiscriminatorP(torch.nn.Module):
182
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, use_cond=False, c_in=1):
183
super(DiscriminatorP, self).__init__()
184
self.use_cond = use_cond
185
if use_cond:
186
from utils.hparams import hparams
187
t = hparams['hop_size']
188
self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
189
c_in = 2
190
191
self.period = period
192
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
193
self.convs = nn.ModuleList([
194
norm_f(Conv2d(c_in, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
195
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
196
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
197
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
198
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
199
])
200
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
201
202
def forward(self, x, mel):
203
fmap = []
204
if self.use_cond:
205
x_mel = self.cond_net(mel)
206
x = torch.cat([x_mel, x], 1)
207
# 1d to 2d
208
b, c, t = x.shape
209
if t % self.period != 0: # pad first
210
n_pad = self.period - (t % self.period)
211
x = F.pad(x, (0, n_pad), "reflect")
212
t = t + n_pad
213
x = x.view(b, c, t // self.period, self.period)
214
215
for l in self.convs:
216
x = l(x)
217
x = F.leaky_relu(x, LRELU_SLOPE)
218
fmap.append(x)
219
x = self.conv_post(x)
220
fmap.append(x)
221
x = torch.flatten(x, 1, -1)
222
223
return x, fmap
224
225
226
class MultiPeriodDiscriminator(torch.nn.Module):
227
def __init__(self, use_cond=False, c_in=1):
228
super(MultiPeriodDiscriminator, self).__init__()
229
self.discriminators = nn.ModuleList([
230
DiscriminatorP(2, use_cond=use_cond, c_in=c_in),
231
DiscriminatorP(3, use_cond=use_cond, c_in=c_in),
232
DiscriminatorP(5, use_cond=use_cond, c_in=c_in),
233
DiscriminatorP(7, use_cond=use_cond, c_in=c_in),
234
DiscriminatorP(11, use_cond=use_cond, c_in=c_in),
235
])
236
237
def forward(self, y, y_hat, mel=None):
238
y_d_rs = []
239
y_d_gs = []
240
fmap_rs = []
241
fmap_gs = []
242
for i, d in enumerate(self.discriminators):
243
y_d_r, fmap_r = d(y, mel)
244
y_d_g, fmap_g = d(y_hat, mel)
245
y_d_rs.append(y_d_r)
246
fmap_rs.append(fmap_r)
247
y_d_gs.append(y_d_g)
248
fmap_gs.append(fmap_g)
249
250
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
251
252
253
class DiscriminatorS(torch.nn.Module):
254
def __init__(self, use_spectral_norm=False, use_cond=False, upsample_rates=None, c_in=1):
255
super(DiscriminatorS, self).__init__()
256
self.use_cond = use_cond
257
if use_cond:
258
t = np.prod(upsample_rates)
259
self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
260
c_in = 2
261
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
262
self.convs = nn.ModuleList([
263
norm_f(Conv1d(c_in, 128, 15, 1, padding=7)),
264
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
265
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
266
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
267
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
268
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
269
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
270
])
271
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
272
273
def forward(self, x, mel):
274
if self.use_cond:
275
x_mel = self.cond_net(mel)
276
x = torch.cat([x_mel, x], 1)
277
fmap = []
278
for l in self.convs:
279
x = l(x)
280
x = F.leaky_relu(x, LRELU_SLOPE)
281
fmap.append(x)
282
x = self.conv_post(x)
283
fmap.append(x)
284
x = torch.flatten(x, 1, -1)
285
286
return x, fmap
287
288
289
class MultiScaleDiscriminator(torch.nn.Module):
290
def __init__(self, use_cond=False, c_in=1):
291
super(MultiScaleDiscriminator, self).__init__()
292
from utils.hparams import hparams
293
self.discriminators = nn.ModuleList([
294
DiscriminatorS(use_spectral_norm=True, use_cond=use_cond,
295
upsample_rates=[4, 4, hparams['hop_size'] // 16],
296
c_in=c_in),
297
DiscriminatorS(use_cond=use_cond,
298
upsample_rates=[4, 4, hparams['hop_size'] // 32],
299
c_in=c_in),
300
DiscriminatorS(use_cond=use_cond,
301
upsample_rates=[4, 4, hparams['hop_size'] // 64],
302
c_in=c_in),
303
])
304
self.meanpools = nn.ModuleList([
305
AvgPool1d(4, 2, padding=1),
306
AvgPool1d(4, 2, padding=1)
307
])
308
309
def forward(self, y, y_hat, mel=None):
310
y_d_rs = []
311
y_d_gs = []
312
fmap_rs = []
313
fmap_gs = []
314
for i, d in enumerate(self.discriminators):
315
if i != 0:
316
y = self.meanpools[i - 1](y)
317
y_hat = self.meanpools[i - 1](y_hat)
318
y_d_r, fmap_r = d(y, mel)
319
y_d_g, fmap_g = d(y_hat, mel)
320
y_d_rs.append(y_d_r)
321
fmap_rs.append(fmap_r)
322
y_d_gs.append(y_d_g)
323
fmap_gs.append(fmap_g)
324
325
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
326
327
328
def feature_loss(fmap_r, fmap_g):
329
loss = 0
330
for dr, dg in zip(fmap_r, fmap_g):
331
for rl, gl in zip(dr, dg):
332
loss += torch.mean(torch.abs(rl - gl))
333
334
return loss * 2
335
336
337
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
338
r_losses = 0
339
g_losses = 0
340
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
341
r_loss = torch.mean((1 - dr) ** 2)
342
g_loss = torch.mean(dg ** 2)
343
r_losses += r_loss
344
g_losses += g_loss
345
r_losses = r_losses / len(disc_real_outputs)
346
g_losses = g_losses / len(disc_real_outputs)
347
return r_losses, g_losses
348
349
350
def cond_discriminator_loss(outputs):
351
loss = 0
352
for dg in outputs:
353
g_loss = torch.mean(dg ** 2)
354
loss += g_loss
355
loss = loss / len(outputs)
356
return loss
357
358
359
def generator_loss(disc_outputs):
360
loss = 0
361
for dg in disc_outputs:
362
l = torch.mean((1 - dg) ** 2)
363
loss += l
364
loss = loss / len(disc_outputs)
365
return loss
366
367