Path: blob/main/modules/parallel_wavegan/layers/tf_layers.py
694 views
# -*- coding: utf-8 -*-12# Copyright 2020 MINH ANH (@dathudeptrai)3# MIT License (https://opensource.org/licenses/MIT)45"""Tensorflow Layer modules complatible with pytorch."""67import tensorflow as tf8910class TFReflectionPad1d(tf.keras.layers.Layer):11"""Tensorflow ReflectionPad1d module."""1213def __init__(self, padding_size):14"""Initialize TFReflectionPad1d module.1516Args:17padding_size (int): Padding size.1819"""20super(TFReflectionPad1d, self).__init__()21self.padding_size = padding_size2223@tf.function24def call(self, x):25"""Calculate forward propagation.2627Args:28x (Tensor): Input tensor (B, T, 1, C).2930Returns:31Tensor: Padded tensor (B, T + 2 * padding_size, 1, C).3233"""34return tf.pad(x, [[0, 0], [self.padding_size, self.padding_size], [0, 0], [0, 0]], "REFLECT")353637class TFConvTranspose1d(tf.keras.layers.Layer):38"""Tensorflow ConvTranspose1d module."""3940def __init__(self, channels, kernel_size, stride, padding):41"""Initialize TFConvTranspose1d( module.4243Args:44channels (int): Number of channels.45kernel_size (int): kernel size.46strides (int): Stride width.47padding (str): Padding type ("same" or "valid").4849"""50super(TFConvTranspose1d, self).__init__()51self.conv1d_transpose = tf.keras.layers.Conv2DTranspose(52filters=channels,53kernel_size=(kernel_size, 1),54strides=(stride, 1),55padding=padding,56)5758@tf.function59def call(self, x):60"""Calculate forward propagation.6162Args:63x (Tensor): Input tensor (B, T, 1, C).6465Returns:66Tensors: Output tensor (B, T', 1, C').6768"""69x = self.conv1d_transpose(x)70return x717273class TFResidualStack(tf.keras.layers.Layer):74"""Tensorflow ResidualStack module."""7576def __init__(self,77kernel_size,78channels,79dilation,80bias,81nonlinear_activation,82nonlinear_activation_params,83padding,84):85"""Initialize TFResidualStack module.8687Args:88kernel_size (int): Kernel size.89channles (int): Number of channels.90dilation (int): Dilation ine.91bias (bool): Whether to add bias parameter in convolution layers.92nonlinear_activation (str): Activation function module name.93nonlinear_activation_params (dict): Hyperparameters for activation function.94padding (str): Padding type ("same" or "valid").9596"""97super(TFResidualStack, self).__init__()98self.block = [99getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params),100TFReflectionPad1d(dilation),101tf.keras.layers.Conv2D(102filters=channels,103kernel_size=(kernel_size, 1),104dilation_rate=(dilation, 1),105use_bias=bias,106padding="valid",107),108getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params),109tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias)110]111self.shortcut = tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias)112113@tf.function114def call(self, x):115"""Calculate forward propagation.116117Args:118x (Tensor): Input tensor (B, T, 1, C).119120Returns:121Tensor: Output tensor (B, T, 1, C).122123"""124_x = tf.identity(x)125for i, layer in enumerate(self.block):126_x = layer(_x)127shortcut = self.shortcut(x)128return shortcut + _x129130131