Path: blob/master/tensorflow_tts/losses/stft.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""STFT-based loss modules."""1516import tensorflow as tf171819class TFSpectralConvergence(tf.keras.layers.Layer):20"""Spectral convergence loss."""2122def __init__(self):23"""Initialize."""24super().__init__()2526def call(self, y_mag, x_mag):27"""Calculate forward propagation.28Args:29y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).30x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).31Returns:32Tensor: Spectral convergence loss value.33"""34return tf.norm(y_mag - x_mag, ord="fro", axis=(-2, -1)) / tf.norm(35y_mag, ord="fro", axis=(-2, -1)36)373839class TFLogSTFTMagnitude(tf.keras.layers.Layer):40"""Log STFT magnitude loss module."""4142def __init__(self):43"""Initialize."""44super().__init__()4546def call(self, y_mag, x_mag):47"""Calculate forward propagation.48Args:49y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).50x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).51Returns:52Tensor: Spectral convergence loss value.53"""54return tf.abs(tf.math.log(y_mag) - tf.math.log(x_mag))555657class TFSTFT(tf.keras.layers.Layer):58"""STFT loss module."""5960def __init__(self, frame_length=600, frame_step=120, fft_length=1024):61"""Initialize."""62super().__init__()63self.frame_length = frame_length64self.frame_step = frame_step65self.fft_length = fft_length66self.spectral_convergenge_loss = TFSpectralConvergence()67self.log_stft_magnitude_loss = TFLogSTFTMagnitude()6869def call(self, y, x):70"""Calculate forward propagation.71Args:72y (Tensor): Groundtruth signal (B, T).73x (Tensor): Predicted signal (B, T).74Returns:75Tensor: Spectral convergence loss value (pre-reduce).76Tensor: Log STFT magnitude loss value (pre-reduce).77"""78x_mag = tf.abs(79tf.signal.stft(80signals=x,81frame_length=self.frame_length,82frame_step=self.frame_step,83fft_length=self.fft_length,84)85)86y_mag = tf.abs(87tf.signal.stft(88signals=y,89frame_length=self.frame_length,90frame_step=self.frame_step,91fft_length=self.fft_length,92)93)9495# add small number to prevent nan value.96# compatible with pytorch version.97x_mag = tf.clip_by_value(tf.math.sqrt(x_mag ** 2 + 1e-7), 1e-7, 1e3)98y_mag = tf.clip_by_value(tf.math.sqrt(y_mag ** 2 + 1e-7), 1e-7, 1e3)99100sc_loss = self.spectral_convergenge_loss(y_mag, x_mag)101mag_loss = self.log_stft_magnitude_loss(y_mag, x_mag)102103return sc_loss, mag_loss104105106class TFMultiResolutionSTFT(tf.keras.layers.Layer):107"""Multi resolution STFT loss module."""108109def __init__(110self,111fft_lengths=[1024, 2048, 512],112frame_lengths=[600, 1200, 240],113frame_steps=[120, 240, 50],114):115"""Initialize Multi resolution STFT loss module.116Args:117frame_lengths (list): List of FFT sizes.118frame_steps (list): List of hop sizes.119fft_lengths (list): List of window lengths.120"""121super().__init__()122assert len(frame_lengths) == len(frame_steps) == len(fft_lengths)123self.stft_losses = []124for frame_length, frame_step, fft_length in zip(125frame_lengths, frame_steps, fft_lengths126):127self.stft_losses.append(TFSTFT(frame_length, frame_step, fft_length))128129def call(self, y, x):130"""Calculate forward propagation.131Args:132y (Tensor): Groundtruth signal (B, T).133x (Tensor): Predicted signal (B, T).134Returns:135Tensor: Multi resolution spectral convergence loss value.136Tensor: Multi resolution log STFT magnitude loss value.137"""138sc_loss = 0.0139mag_loss = 0.0140for f in self.stft_losses:141sc_l, mag_l = f(y, x)142sc_loss += tf.reduce_mean(sc_l, axis=list(range(1, len(sc_l.shape))))143mag_loss += tf.reduce_mean(mag_l, axis=list(range(1, len(mag_l.shape))))144145sc_loss /= len(self.stft_losses)146mag_loss /= len(self.stft_losses)147148return sc_loss, mag_loss149150151