Path: blob/main/modules/parallel_wavegan/models/parallel_wavegan.py
694 views
# -*- coding: utf-8 -*-12# Copyright 2019 Tomoki Hayashi3# MIT License (https://opensource.org/licenses/MIT)45"""Parallel WaveGAN Modules."""67import logging8import math910import torch11from torch import nn1213from modules.parallel_wavegan.layers import Conv1d14from modules.parallel_wavegan.layers import Conv1d1x115from modules.parallel_wavegan.layers import ResidualBlock16from modules.parallel_wavegan.layers import upsample17from modules.parallel_wavegan import models181920class ParallelWaveGANGenerator(torch.nn.Module):21"""Parallel WaveGAN Generator module."""2223def __init__(self,24in_channels=1,25out_channels=1,26kernel_size=3,27layers=30,28stacks=3,29residual_channels=64,30gate_channels=128,31skip_channels=64,32aux_channels=80,33aux_context_window=2,34dropout=0.0,35bias=True,36use_weight_norm=True,37use_causal_conv=False,38upsample_conditional_features=True,39upsample_net="ConvInUpsampleNetwork",40upsample_params={"upsample_scales": [4, 4, 4, 4]},41use_pitch_embed=False,42):43"""Initialize Parallel WaveGAN Generator module.4445Args:46in_channels (int): Number of input channels.47out_channels (int): Number of output channels.48kernel_size (int): Kernel size of dilated convolution.49layers (int): Number of residual block layers.50stacks (int): Number of stacks i.e., dilation cycles.51residual_channels (int): Number of channels in residual conv.52gate_channels (int): Number of channels in gated conv.53skip_channels (int): Number of channels in skip conv.54aux_channels (int): Number of channels for auxiliary feature conv.55aux_context_window (int): Context window size for auxiliary feature.56dropout (float): Dropout rate. 0.0 means no dropout applied.57bias (bool): Whether to use bias parameter in conv layer.58use_weight_norm (bool): Whether to use weight norm.59If set to true, it will be applied to all of the conv layers.60use_causal_conv (bool): Whether to use causal structure.61upsample_conditional_features (bool): Whether to use upsampling network.62upsample_net (str): Upsampling network architecture.63upsample_params (dict): Upsampling network parameters.6465"""66super(ParallelWaveGANGenerator, self).__init__()67self.in_channels = in_channels68self.out_channels = out_channels69self.aux_channels = aux_channels70self.layers = layers71self.stacks = stacks72self.kernel_size = kernel_size7374# check the number of layers and stacks75assert layers % stacks == 076layers_per_stack = layers // stacks7778# define first convolution79self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)8081# define conv + upsampling network82if upsample_conditional_features:83upsample_params.update({84"use_causal_conv": use_causal_conv,85})86if upsample_net == "MelGANGenerator":87assert aux_context_window == 088upsample_params.update({89"use_weight_norm": False, # not to apply twice90"use_final_nonlinear_activation": False,91})92self.upsample_net = getattr(models, upsample_net)(**upsample_params)93else:94if upsample_net == "ConvInUpsampleNetwork":95upsample_params.update({96"aux_channels": aux_channels,97"aux_context_window": aux_context_window,98})99self.upsample_net = getattr(upsample, upsample_net)(**upsample_params)100else:101self.upsample_net = None102103# define residual blocks104self.conv_layers = torch.nn.ModuleList()105for layer in range(layers):106dilation = 2 ** (layer % layers_per_stack)107conv = ResidualBlock(108kernel_size=kernel_size,109residual_channels=residual_channels,110gate_channels=gate_channels,111skip_channels=skip_channels,112aux_channels=aux_channels,113dilation=dilation,114dropout=dropout,115bias=bias,116use_causal_conv=use_causal_conv,117)118self.conv_layers += [conv]119120# define output layers121self.last_conv_layers = torch.nn.ModuleList([122torch.nn.ReLU(inplace=True),123Conv1d1x1(skip_channels, skip_channels, bias=True),124torch.nn.ReLU(inplace=True),125Conv1d1x1(skip_channels, out_channels, bias=True),126])127128self.use_pitch_embed = use_pitch_embed129if use_pitch_embed:130self.pitch_embed = nn.Embedding(300, aux_channels, 0)131self.c_proj = nn.Linear(2 * aux_channels, aux_channels)132133# apply weight norm134if use_weight_norm:135self.apply_weight_norm()136137def forward(self, x, c=None, pitch=None, **kwargs):138"""Calculate forward propagation.139140Args:141x (Tensor): Input noise signal (B, C_in, T).142c (Tensor): Local conditioning auxiliary features (B, C ,T').143pitch (Tensor): Local conditioning pitch (B, T').144145Returns:146Tensor: Output tensor (B, C_out, T)147148"""149# perform upsampling150if c is not None and self.upsample_net is not None:151if self.use_pitch_embed:152p = self.pitch_embed(pitch)153c = self.c_proj(torch.cat([c.transpose(1, 2), p], -1)).transpose(1, 2)154c = self.upsample_net(c)155assert c.size(-1) == x.size(-1), (c.size(-1), x.size(-1))156157# encode to hidden representation158x = self.first_conv(x)159skips = 0160for f in self.conv_layers:161x, h = f(x, c)162skips += h163skips *= math.sqrt(1.0 / len(self.conv_layers))164165# apply final layers166x = skips167for f in self.last_conv_layers:168x = f(x)169170return x171172def remove_weight_norm(self):173"""Remove weight normalization module from all of the layers."""174def _remove_weight_norm(m):175try:176logging.debug(f"Weight norm is removed from {m}.")177torch.nn.utils.remove_weight_norm(m)178except ValueError: # this module didn't have weight norm179return180181self.apply(_remove_weight_norm)182183def apply_weight_norm(self):184"""Apply weight normalization module from all of the layers."""185def _apply_weight_norm(m):186if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):187torch.nn.utils.weight_norm(m)188logging.debug(f"Weight norm is applied to {m}.")189190self.apply(_apply_weight_norm)191192@staticmethod193def _get_receptive_field_size(layers, stacks, kernel_size,194dilation=lambda x: 2 ** x):195assert layers % stacks == 0196layers_per_cycle = layers // stacks197dilations = [dilation(i % layers_per_cycle) for i in range(layers)]198return (kernel_size - 1) * sum(dilations) + 1199200@property201def receptive_field_size(self):202"""Return receptive field size."""203return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)204205206class ParallelWaveGANDiscriminator(torch.nn.Module):207"""Parallel WaveGAN Discriminator module."""208209def __init__(self,210in_channels=1,211out_channels=1,212kernel_size=3,213layers=10,214conv_channels=64,215dilation_factor=1,216nonlinear_activation="LeakyReLU",217nonlinear_activation_params={"negative_slope": 0.2},218bias=True,219use_weight_norm=True,220):221"""Initialize Parallel WaveGAN Discriminator module.222223Args:224in_channels (int): Number of input channels.225out_channels (int): Number of output channels.226kernel_size (int): Number of output channels.227layers (int): Number of conv layers.228conv_channels (int): Number of chnn layers.229dilation_factor (int): Dilation factor. For example, if dilation_factor = 2,230the dilation will be 2, 4, 8, ..., and so on.231nonlinear_activation (str): Nonlinear function after each conv.232nonlinear_activation_params (dict): Nonlinear function parameters233bias (bool): Whether to use bias parameter in conv.234use_weight_norm (bool) Whether to use weight norm.235If set to true, it will be applied to all of the conv layers.236237"""238super(ParallelWaveGANDiscriminator, self).__init__()239assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."240assert dilation_factor > 0, "Dilation factor must be > 0."241self.conv_layers = torch.nn.ModuleList()242conv_in_channels = in_channels243for i in range(layers - 1):244if i == 0:245dilation = 1246else:247dilation = i if dilation_factor == 1 else dilation_factor ** i248conv_in_channels = conv_channels249padding = (kernel_size - 1) // 2 * dilation250conv_layer = [251Conv1d(conv_in_channels, conv_channels,252kernel_size=kernel_size, padding=padding,253dilation=dilation, bias=bias),254getattr(torch.nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params)255]256self.conv_layers += conv_layer257padding = (kernel_size - 1) // 2258last_conv_layer = Conv1d(259conv_in_channels, out_channels,260kernel_size=kernel_size, padding=padding, bias=bias)261self.conv_layers += [last_conv_layer]262263# apply weight norm264if use_weight_norm:265self.apply_weight_norm()266267def forward(self, x):268"""Calculate forward propagation.269270Args:271x (Tensor): Input noise signal (B, 1, T).272273Returns:274Tensor: Output tensor (B, 1, T)275276"""277for f in self.conv_layers:278x = f(x)279return x280281def apply_weight_norm(self):282"""Apply weight normalization module from all of the layers."""283def _apply_weight_norm(m):284if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):285torch.nn.utils.weight_norm(m)286logging.debug(f"Weight norm is applied to {m}.")287288self.apply(_apply_weight_norm)289290def remove_weight_norm(self):291"""Remove weight normalization module from all of the layers."""292def _remove_weight_norm(m):293try:294logging.debug(f"Weight norm is removed from {m}.")295torch.nn.utils.remove_weight_norm(m)296except ValueError: # this module didn't have weight norm297return298299self.apply(_remove_weight_norm)300301302class ResidualParallelWaveGANDiscriminator(torch.nn.Module):303"""Parallel WaveGAN Discriminator module."""304305def __init__(self,306in_channels=1,307out_channels=1,308kernel_size=3,309layers=30,310stacks=3,311residual_channels=64,312gate_channels=128,313skip_channels=64,314dropout=0.0,315bias=True,316use_weight_norm=True,317use_causal_conv=False,318nonlinear_activation="LeakyReLU",319nonlinear_activation_params={"negative_slope": 0.2},320):321"""Initialize Parallel WaveGAN Discriminator module.322323Args:324in_channels (int): Number of input channels.325out_channels (int): Number of output channels.326kernel_size (int): Kernel size of dilated convolution.327layers (int): Number of residual block layers.328stacks (int): Number of stacks i.e., dilation cycles.329residual_channels (int): Number of channels in residual conv.330gate_channels (int): Number of channels in gated conv.331skip_channels (int): Number of channels in skip conv.332dropout (float): Dropout rate. 0.0 means no dropout applied.333bias (bool): Whether to use bias parameter in conv.334use_weight_norm (bool): Whether to use weight norm.335If set to true, it will be applied to all of the conv layers.336use_causal_conv (bool): Whether to use causal structure.337nonlinear_activation_params (dict): Nonlinear function parameters338339"""340super(ResidualParallelWaveGANDiscriminator, self).__init__()341assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."342343self.in_channels = in_channels344self.out_channels = out_channels345self.layers = layers346self.stacks = stacks347self.kernel_size = kernel_size348349# check the number of layers and stacks350assert layers % stacks == 0351layers_per_stack = layers // stacks352353# define first convolution354self.first_conv = torch.nn.Sequential(355Conv1d1x1(in_channels, residual_channels, bias=True),356getattr(torch.nn, nonlinear_activation)(357inplace=True, **nonlinear_activation_params),358)359360# define residual blocks361self.conv_layers = torch.nn.ModuleList()362for layer in range(layers):363dilation = 2 ** (layer % layers_per_stack)364conv = ResidualBlock(365kernel_size=kernel_size,366residual_channels=residual_channels,367gate_channels=gate_channels,368skip_channels=skip_channels,369aux_channels=-1,370dilation=dilation,371dropout=dropout,372bias=bias,373use_causal_conv=use_causal_conv,374)375self.conv_layers += [conv]376377# define output layers378self.last_conv_layers = torch.nn.ModuleList([379getattr(torch.nn, nonlinear_activation)(380inplace=True, **nonlinear_activation_params),381Conv1d1x1(skip_channels, skip_channels, bias=True),382getattr(torch.nn, nonlinear_activation)(383inplace=True, **nonlinear_activation_params),384Conv1d1x1(skip_channels, out_channels, bias=True),385])386387# apply weight norm388if use_weight_norm:389self.apply_weight_norm()390391def forward(self, x):392"""Calculate forward propagation.393394Args:395x (Tensor): Input noise signal (B, 1, T).396397Returns:398Tensor: Output tensor (B, 1, T)399400"""401x = self.first_conv(x)402403skips = 0404for f in self.conv_layers:405x, h = f(x, None)406skips += h407skips *= math.sqrt(1.0 / len(self.conv_layers))408409# apply final layers410x = skips411for f in self.last_conv_layers:412x = f(x)413return x414415def apply_weight_norm(self):416"""Apply weight normalization module from all of the layers."""417def _apply_weight_norm(m):418if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):419torch.nn.utils.weight_norm(m)420logging.debug(f"Weight norm is applied to {m}.")421422self.apply(_apply_weight_norm)423424def remove_weight_norm(self):425"""Remove weight normalization module from all of the layers."""426def _remove_weight_norm(m):427try:428logging.debug(f"Weight norm is removed from {m}.")429torch.nn.utils.remove_weight_norm(m)430except ValueError: # this module didn't have weight norm431return432433self.apply(_remove_weight_norm)434435436