Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/losses/stft.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""STFT-based loss modules."""
16
17
import tensorflow as tf
18
19
20
class TFSpectralConvergence(tf.keras.layers.Layer):
21
"""Spectral convergence loss."""
22
23
def __init__(self):
24
"""Initialize."""
25
super().__init__()
26
27
def call(self, y_mag, x_mag):
28
"""Calculate forward propagation.
29
Args:
30
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
31
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
32
Returns:
33
Tensor: Spectral convergence loss value.
34
"""
35
return tf.norm(y_mag - x_mag, ord="fro", axis=(-2, -1)) / tf.norm(
36
y_mag, ord="fro", axis=(-2, -1)
37
)
38
39
40
class TFLogSTFTMagnitude(tf.keras.layers.Layer):
41
"""Log STFT magnitude loss module."""
42
43
def __init__(self):
44
"""Initialize."""
45
super().__init__()
46
47
def call(self, y_mag, x_mag):
48
"""Calculate forward propagation.
49
Args:
50
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
51
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
52
Returns:
53
Tensor: Spectral convergence loss value.
54
"""
55
return tf.abs(tf.math.log(y_mag) - tf.math.log(x_mag))
56
57
58
class TFSTFT(tf.keras.layers.Layer):
59
"""STFT loss module."""
60
61
def __init__(self, frame_length=600, frame_step=120, fft_length=1024):
62
"""Initialize."""
63
super().__init__()
64
self.frame_length = frame_length
65
self.frame_step = frame_step
66
self.fft_length = fft_length
67
self.spectral_convergenge_loss = TFSpectralConvergence()
68
self.log_stft_magnitude_loss = TFLogSTFTMagnitude()
69
70
def call(self, y, x):
71
"""Calculate forward propagation.
72
Args:
73
y (Tensor): Groundtruth signal (B, T).
74
x (Tensor): Predicted signal (B, T).
75
Returns:
76
Tensor: Spectral convergence loss value (pre-reduce).
77
Tensor: Log STFT magnitude loss value (pre-reduce).
78
"""
79
x_mag = tf.abs(
80
tf.signal.stft(
81
signals=x,
82
frame_length=self.frame_length,
83
frame_step=self.frame_step,
84
fft_length=self.fft_length,
85
)
86
)
87
y_mag = tf.abs(
88
tf.signal.stft(
89
signals=y,
90
frame_length=self.frame_length,
91
frame_step=self.frame_step,
92
fft_length=self.fft_length,
93
)
94
)
95
96
# add small number to prevent nan value.
97
# compatible with pytorch version.
98
x_mag = tf.clip_by_value(tf.math.sqrt(x_mag ** 2 + 1e-7), 1e-7, 1e3)
99
y_mag = tf.clip_by_value(tf.math.sqrt(y_mag ** 2 + 1e-7), 1e-7, 1e3)
100
101
sc_loss = self.spectral_convergenge_loss(y_mag, x_mag)
102
mag_loss = self.log_stft_magnitude_loss(y_mag, x_mag)
103
104
return sc_loss, mag_loss
105
106
107
class TFMultiResolutionSTFT(tf.keras.layers.Layer):
108
"""Multi resolution STFT loss module."""
109
110
def __init__(
111
self,
112
fft_lengths=[1024, 2048, 512],
113
frame_lengths=[600, 1200, 240],
114
frame_steps=[120, 240, 50],
115
):
116
"""Initialize Multi resolution STFT loss module.
117
Args:
118
frame_lengths (list): List of FFT sizes.
119
frame_steps (list): List of hop sizes.
120
fft_lengths (list): List of window lengths.
121
"""
122
super().__init__()
123
assert len(frame_lengths) == len(frame_steps) == len(fft_lengths)
124
self.stft_losses = []
125
for frame_length, frame_step, fft_length in zip(
126
frame_lengths, frame_steps, fft_lengths
127
):
128
self.stft_losses.append(TFSTFT(frame_length, frame_step, fft_length))
129
130
def call(self, y, x):
131
"""Calculate forward propagation.
132
Args:
133
y (Tensor): Groundtruth signal (B, T).
134
x (Tensor): Predicted signal (B, T).
135
Returns:
136
Tensor: Multi resolution spectral convergence loss value.
137
Tensor: Multi resolution log STFT magnitude loss value.
138
"""
139
sc_loss = 0.0
140
mag_loss = 0.0
141
for f in self.stft_losses:
142
sc_l, mag_l = f(y, x)
143
sc_loss += tf.reduce_mean(sc_l, axis=list(range(1, len(sc_l.shape))))
144
mag_loss += tf.reduce_mean(mag_l, axis=list(range(1, len(mag_l.shape))))
145
146
sc_loss /= len(self.stft_losses)
147
mag_loss /= len(self.stft_losses)
148
149
return sc_loss, mag_loss
150
151