Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/losses/spectrogram.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Spectrogram-based loss modules."""
16
17
import tensorflow as tf
18
19
20
class TFMelSpectrogram(tf.keras.layers.Layer):
21
"""Mel Spectrogram loss."""
22
23
def __init__(
24
self,
25
n_mels=80,
26
f_min=80.0,
27
f_max=7600,
28
frame_length=1024,
29
frame_step=256,
30
fft_length=1024,
31
sample_rate=16000,
32
**kwargs
33
):
34
"""Initialize."""
35
super().__init__(**kwargs)
36
self.frame_length = frame_length
37
self.frame_step = frame_step
38
self.fft_length = fft_length
39
40
self.linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
41
n_mels, fft_length // 2 + 1, sample_rate, f_min, f_max
42
)
43
44
def _calculate_log_mels_spectrogram(self, signals):
45
"""Calculate forward propagation.
46
Args:
47
signals (Tensor): signal (B, T).
48
Returns:
49
Tensor: Mel spectrogram (B, T', 80)
50
"""
51
stfts = tf.signal.stft(
52
signals,
53
frame_length=self.frame_length,
54
frame_step=self.frame_step,
55
fft_length=self.fft_length,
56
)
57
linear_spectrograms = tf.abs(stfts)
58
mel_spectrograms = tf.tensordot(
59
linear_spectrograms, self.linear_to_mel_weight_matrix, 1
60
)
61
mel_spectrograms.set_shape(
62
linear_spectrograms.shape[:-1].concatenate(
63
self.linear_to_mel_weight_matrix.shape[-1:]
64
)
65
)
66
log_mel_spectrograms = tf.math.log(mel_spectrograms + 1e-6) # prevent nan.
67
return log_mel_spectrograms
68
69
def call(self, y, x):
70
"""Calculate forward propagation.
71
Args:
72
y (Tensor): Groundtruth signal (B, T).
73
x (Tensor): Predicted signal (B, T).
74
Returns:
75
Tensor: Mean absolute Error Spectrogram Loss.
76
"""
77
y_mels = self._calculate_log_mels_spectrogram(y)
78
x_mels = self._calculate_log_mels_spectrogram(x)
79
return tf.reduce_mean(
80
tf.abs(y_mels - x_mels), axis=list(range(1, len(x_mels.shape)))
81
)
82
83