Path: blob/main/modules/parallel_wavegan/layers/upsample.py
694 views
# -*- coding: utf-8 -*-12"""Upsampling module.34This code is modified from https://github.com/r9y9/wavenet_vocoder.56"""78import numpy as np9import torch10import torch.nn.functional as F1112from . import Conv1d131415class Stretch2d(torch.nn.Module):16"""Stretch2d module."""1718def __init__(self, x_scale, y_scale, mode="nearest"):19"""Initialize Stretch2d module.2021Args:22x_scale (int): X scaling factor (Time axis in spectrogram).23y_scale (int): Y scaling factor (Frequency axis in spectrogram).24mode (str): Interpolation mode.2526"""27super(Stretch2d, self).__init__()28self.x_scale = x_scale29self.y_scale = y_scale30self.mode = mode3132def forward(self, x):33"""Calculate forward propagation.3435Args:36x (Tensor): Input tensor (B, C, F, T).3738Returns:39Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale),4041"""42return F.interpolate(43x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode)444546class Conv2d(torch.nn.Conv2d):47"""Conv2d module with customized initialization."""4849def __init__(self, *args, **kwargs):50"""Initialize Conv2d module."""51super(Conv2d, self).__init__(*args, **kwargs)5253def reset_parameters(self):54"""Reset parameters."""55self.weight.data.fill_(1. / np.prod(self.kernel_size))56if self.bias is not None:57torch.nn.init.constant_(self.bias, 0.0)585960class UpsampleNetwork(torch.nn.Module):61"""Upsampling network module."""6263def __init__(self,64upsample_scales,65nonlinear_activation=None,66nonlinear_activation_params={},67interpolate_mode="nearest",68freq_axis_kernel_size=1,69use_causal_conv=False,70):71"""Initialize upsampling network module.7273Args:74upsample_scales (list): List of upsampling scales.75nonlinear_activation (str): Activation function name.76nonlinear_activation_params (dict): Arguments for specified activation function.77interpolate_mode (str): Interpolation mode.78freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.7980"""81super(UpsampleNetwork, self).__init__()82self.use_causal_conv = use_causal_conv83self.up_layers = torch.nn.ModuleList()84for scale in upsample_scales:85# interpolation layer86stretch = Stretch2d(scale, 1, interpolate_mode)87self.up_layers += [stretch]8889# conv layer90assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size."91freq_axis_padding = (freq_axis_kernel_size - 1) // 292kernel_size = (freq_axis_kernel_size, scale * 2 + 1)93if use_causal_conv:94padding = (freq_axis_padding, scale * 2)95else:96padding = (freq_axis_padding, scale)97conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)98self.up_layers += [conv]99100# nonlinear101if nonlinear_activation is not None:102nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)103self.up_layers += [nonlinear]104105def forward(self, c):106"""Calculate forward propagation.107108Args:109c : Input tensor (B, C, T).110111Returns:112Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales).113114"""115c = c.unsqueeze(1) # (B, 1, C, T)116for f in self.up_layers:117if self.use_causal_conv and isinstance(f, Conv2d):118c = f(c)[..., :c.size(-1)]119else:120c = f(c)121return c.squeeze(1) # (B, C, T')122123124class ConvInUpsampleNetwork(torch.nn.Module):125"""Convolution + upsampling network module."""126127def __init__(self,128upsample_scales,129nonlinear_activation=None,130nonlinear_activation_params={},131interpolate_mode="nearest",132freq_axis_kernel_size=1,133aux_channels=80,134aux_context_window=0,135use_causal_conv=False136):137"""Initialize convolution + upsampling network module.138139Args:140upsample_scales (list): List of upsampling scales.141nonlinear_activation (str): Activation function name.142nonlinear_activation_params (dict): Arguments for specified activation function.143mode (str): Interpolation mode.144freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.145aux_channels (int): Number of channels of pre-convolutional layer.146aux_context_window (int): Context window size of the pre-convolutional layer.147use_causal_conv (bool): Whether to use causal structure.148149"""150super(ConvInUpsampleNetwork, self).__init__()151self.aux_context_window = aux_context_window152self.use_causal_conv = use_causal_conv and aux_context_window > 0153# To capture wide-context information in conditional features154kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1155# NOTE(kan-bayashi): Here do not use padding because the input is already padded156self.conv_in = Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False)157self.upsample = UpsampleNetwork(158upsample_scales=upsample_scales,159nonlinear_activation=nonlinear_activation,160nonlinear_activation_params=nonlinear_activation_params,161interpolate_mode=interpolate_mode,162freq_axis_kernel_size=freq_axis_kernel_size,163use_causal_conv=use_causal_conv,164)165166def forward(self, c):167"""Calculate forward propagation.168169Args:170c : Input tensor (B, C, T').171172Returns:173Tensor: Upsampled tensor (B, C, T),174where T = (T' - aux_context_window * 2) * prod(upsample_scales).175176Note:177The length of inputs considers the context window size.178179"""180c_ = self.conv_in(c)181c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_182return self.upsample(c)183184185