Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/modules/parallel_wavegan/layers/tf_layers.py
694 views
1
# -*- coding: utf-8 -*-
2
3
# Copyright 2020 MINH ANH (@dathudeptrai)
4
# MIT License (https://opensource.org/licenses/MIT)
5
6
"""Tensorflow Layer modules complatible with pytorch."""
7
8
import tensorflow as tf
9
10
11
class TFReflectionPad1d(tf.keras.layers.Layer):
12
"""Tensorflow ReflectionPad1d module."""
13
14
def __init__(self, padding_size):
15
"""Initialize TFReflectionPad1d module.
16
17
Args:
18
padding_size (int): Padding size.
19
20
"""
21
super(TFReflectionPad1d, self).__init__()
22
self.padding_size = padding_size
23
24
@tf.function
25
def call(self, x):
26
"""Calculate forward propagation.
27
28
Args:
29
x (Tensor): Input tensor (B, T, 1, C).
30
31
Returns:
32
Tensor: Padded tensor (B, T + 2 * padding_size, 1, C).
33
34
"""
35
return tf.pad(x, [[0, 0], [self.padding_size, self.padding_size], [0, 0], [0, 0]], "REFLECT")
36
37
38
class TFConvTranspose1d(tf.keras.layers.Layer):
39
"""Tensorflow ConvTranspose1d module."""
40
41
def __init__(self, channels, kernel_size, stride, padding):
42
"""Initialize TFConvTranspose1d( module.
43
44
Args:
45
channels (int): Number of channels.
46
kernel_size (int): kernel size.
47
strides (int): Stride width.
48
padding (str): Padding type ("same" or "valid").
49
50
"""
51
super(TFConvTranspose1d, self).__init__()
52
self.conv1d_transpose = tf.keras.layers.Conv2DTranspose(
53
filters=channels,
54
kernel_size=(kernel_size, 1),
55
strides=(stride, 1),
56
padding=padding,
57
)
58
59
@tf.function
60
def call(self, x):
61
"""Calculate forward propagation.
62
63
Args:
64
x (Tensor): Input tensor (B, T, 1, C).
65
66
Returns:
67
Tensors: Output tensor (B, T', 1, C').
68
69
"""
70
x = self.conv1d_transpose(x)
71
return x
72
73
74
class TFResidualStack(tf.keras.layers.Layer):
75
"""Tensorflow ResidualStack module."""
76
77
def __init__(self,
78
kernel_size,
79
channels,
80
dilation,
81
bias,
82
nonlinear_activation,
83
nonlinear_activation_params,
84
padding,
85
):
86
"""Initialize TFResidualStack module.
87
88
Args:
89
kernel_size (int): Kernel size.
90
channles (int): Number of channels.
91
dilation (int): Dilation ine.
92
bias (bool): Whether to add bias parameter in convolution layers.
93
nonlinear_activation (str): Activation function module name.
94
nonlinear_activation_params (dict): Hyperparameters for activation function.
95
padding (str): Padding type ("same" or "valid").
96
97
"""
98
super(TFResidualStack, self).__init__()
99
self.block = [
100
getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params),
101
TFReflectionPad1d(dilation),
102
tf.keras.layers.Conv2D(
103
filters=channels,
104
kernel_size=(kernel_size, 1),
105
dilation_rate=(dilation, 1),
106
use_bias=bias,
107
padding="valid",
108
),
109
getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params),
110
tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias)
111
]
112
self.shortcut = tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias)
113
114
@tf.function
115
def call(self, x):
116
"""Calculate forward propagation.
117
118
Args:
119
x (Tensor): Input tensor (B, T, 1, C).
120
121
Returns:
122
Tensor: Output tensor (B, T, 1, C).
123
124
"""
125
_x = tf.identity(x)
126
for i, layer in enumerate(self.block):
127
_x = layer(_x)
128
shortcut = self.shortcut(x)
129
return shortcut + _x
130
131