Path: blob/master/tensorflow_tts/losses/spectrogram.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Spectrogram-based loss modules."""1516import tensorflow as tf171819class TFMelSpectrogram(tf.keras.layers.Layer):20"""Mel Spectrogram loss."""2122def __init__(23self,24n_mels=80,25f_min=80.0,26f_max=7600,27frame_length=1024,28frame_step=256,29fft_length=1024,30sample_rate=16000,31**kwargs32):33"""Initialize."""34super().__init__(**kwargs)35self.frame_length = frame_length36self.frame_step = frame_step37self.fft_length = fft_length3839self.linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(40n_mels, fft_length // 2 + 1, sample_rate, f_min, f_max41)4243def _calculate_log_mels_spectrogram(self, signals):44"""Calculate forward propagation.45Args:46signals (Tensor): signal (B, T).47Returns:48Tensor: Mel spectrogram (B, T', 80)49"""50stfts = tf.signal.stft(51signals,52frame_length=self.frame_length,53frame_step=self.frame_step,54fft_length=self.fft_length,55)56linear_spectrograms = tf.abs(stfts)57mel_spectrograms = tf.tensordot(58linear_spectrograms, self.linear_to_mel_weight_matrix, 159)60mel_spectrograms.set_shape(61linear_spectrograms.shape[:-1].concatenate(62self.linear_to_mel_weight_matrix.shape[-1:]63)64)65log_mel_spectrograms = tf.math.log(mel_spectrograms + 1e-6) # prevent nan.66return log_mel_spectrograms6768def call(self, y, x):69"""Calculate forward propagation.70Args:71y (Tensor): Groundtruth signal (B, T).72x (Tensor): Predicted signal (B, T).73Returns:74Tensor: Mean absolute Error Spectrogram Loss.75"""76y_mels = self._calculate_log_mels_spectrogram(y)77x_mels = self._calculate_log_mels_spectrogram(x)78return tf.reduce_mean(79tf.abs(y_mels - x_mels), axis=list(range(1, len(x_mels.shape)))80)818283