Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/modules/parallel_wavegan/layers/upsample.py
694 views
1
# -*- coding: utf-8 -*-
2
3
"""Upsampling module.
4
5
This code is modified from https://github.com/r9y9/wavenet_vocoder.
6
7
"""
8
9
import numpy as np
10
import torch
11
import torch.nn.functional as F
12
13
from . import Conv1d
14
15
16
class Stretch2d(torch.nn.Module):
17
"""Stretch2d module."""
18
19
def __init__(self, x_scale, y_scale, mode="nearest"):
20
"""Initialize Stretch2d module.
21
22
Args:
23
x_scale (int): X scaling factor (Time axis in spectrogram).
24
y_scale (int): Y scaling factor (Frequency axis in spectrogram).
25
mode (str): Interpolation mode.
26
27
"""
28
super(Stretch2d, self).__init__()
29
self.x_scale = x_scale
30
self.y_scale = y_scale
31
self.mode = mode
32
33
def forward(self, x):
34
"""Calculate forward propagation.
35
36
Args:
37
x (Tensor): Input tensor (B, C, F, T).
38
39
Returns:
40
Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale),
41
42
"""
43
return F.interpolate(
44
x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode)
45
46
47
class Conv2d(torch.nn.Conv2d):
48
"""Conv2d module with customized initialization."""
49
50
def __init__(self, *args, **kwargs):
51
"""Initialize Conv2d module."""
52
super(Conv2d, self).__init__(*args, **kwargs)
53
54
def reset_parameters(self):
55
"""Reset parameters."""
56
self.weight.data.fill_(1. / np.prod(self.kernel_size))
57
if self.bias is not None:
58
torch.nn.init.constant_(self.bias, 0.0)
59
60
61
class UpsampleNetwork(torch.nn.Module):
62
"""Upsampling network module."""
63
64
def __init__(self,
65
upsample_scales,
66
nonlinear_activation=None,
67
nonlinear_activation_params={},
68
interpolate_mode="nearest",
69
freq_axis_kernel_size=1,
70
use_causal_conv=False,
71
):
72
"""Initialize upsampling network module.
73
74
Args:
75
upsample_scales (list): List of upsampling scales.
76
nonlinear_activation (str): Activation function name.
77
nonlinear_activation_params (dict): Arguments for specified activation function.
78
interpolate_mode (str): Interpolation mode.
79
freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
80
81
"""
82
super(UpsampleNetwork, self).__init__()
83
self.use_causal_conv = use_causal_conv
84
self.up_layers = torch.nn.ModuleList()
85
for scale in upsample_scales:
86
# interpolation layer
87
stretch = Stretch2d(scale, 1, interpolate_mode)
88
self.up_layers += [stretch]
89
90
# conv layer
91
assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size."
92
freq_axis_padding = (freq_axis_kernel_size - 1) // 2
93
kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
94
if use_causal_conv:
95
padding = (freq_axis_padding, scale * 2)
96
else:
97
padding = (freq_axis_padding, scale)
98
conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
99
self.up_layers += [conv]
100
101
# nonlinear
102
if nonlinear_activation is not None:
103
nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)
104
self.up_layers += [nonlinear]
105
106
def forward(self, c):
107
"""Calculate forward propagation.
108
109
Args:
110
c : Input tensor (B, C, T).
111
112
Returns:
113
Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales).
114
115
"""
116
c = c.unsqueeze(1) # (B, 1, C, T)
117
for f in self.up_layers:
118
if self.use_causal_conv and isinstance(f, Conv2d):
119
c = f(c)[..., :c.size(-1)]
120
else:
121
c = f(c)
122
return c.squeeze(1) # (B, C, T')
123
124
125
class ConvInUpsampleNetwork(torch.nn.Module):
126
"""Convolution + upsampling network module."""
127
128
def __init__(self,
129
upsample_scales,
130
nonlinear_activation=None,
131
nonlinear_activation_params={},
132
interpolate_mode="nearest",
133
freq_axis_kernel_size=1,
134
aux_channels=80,
135
aux_context_window=0,
136
use_causal_conv=False
137
):
138
"""Initialize convolution + upsampling network module.
139
140
Args:
141
upsample_scales (list): List of upsampling scales.
142
nonlinear_activation (str): Activation function name.
143
nonlinear_activation_params (dict): Arguments for specified activation function.
144
mode (str): Interpolation mode.
145
freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
146
aux_channels (int): Number of channels of pre-convolutional layer.
147
aux_context_window (int): Context window size of the pre-convolutional layer.
148
use_causal_conv (bool): Whether to use causal structure.
149
150
"""
151
super(ConvInUpsampleNetwork, self).__init__()
152
self.aux_context_window = aux_context_window
153
self.use_causal_conv = use_causal_conv and aux_context_window > 0
154
# To capture wide-context information in conditional features
155
kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
156
# NOTE(kan-bayashi): Here do not use padding because the input is already padded
157
self.conv_in = Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False)
158
self.upsample = UpsampleNetwork(
159
upsample_scales=upsample_scales,
160
nonlinear_activation=nonlinear_activation,
161
nonlinear_activation_params=nonlinear_activation_params,
162
interpolate_mode=interpolate_mode,
163
freq_axis_kernel_size=freq_axis_kernel_size,
164
use_causal_conv=use_causal_conv,
165
)
166
167
def forward(self, c):
168
"""Calculate forward propagation.
169
170
Args:
171
c : Input tensor (B, C, T').
172
173
Returns:
174
Tensor: Upsampled tensor (B, C, T),
175
where T = (T' - aux_context_window * 2) * prod(upsample_scales).
176
177
Note:
178
The length of inputs considers the context window size.
179
180
"""
181
c_ = self.conv_in(c)
182
c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_
183
return self.upsample(c)
184
185