Path: blob/main/modules/parallel_wavegan/stft_loss.py
694 views
# -*- coding: utf-8 -*-12# Copyright 2019 Tomoki Hayashi3# MIT License (https://opensource.org/licenses/MIT)45"""STFT-based Loss modules."""6import librosa7import torch89from modules.parallel_wavegan.losses import LogSTFTMagnitudeLoss, SpectralConvergengeLoss, stft101112class STFTLoss(torch.nn.Module):13"""STFT loss module."""1415def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window",16use_mel_loss=False):17"""Initialize STFT loss module."""18super(STFTLoss, self).__init__()19self.fft_size = fft_size20self.shift_size = shift_size21self.win_length = win_length22self.window = getattr(torch, window)(win_length)23self.spectral_convergenge_loss = SpectralConvergengeLoss()24self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()25self.use_mel_loss = use_mel_loss26self.mel_basis = None2728def forward(self, x, y):29"""Calculate forward propagation.3031Args:32x (Tensor): Predicted signal (B, T).33y (Tensor): Groundtruth signal (B, T).3435Returns:36Tensor: Spectral convergence loss value.37Tensor: Log STFT magnitude loss value.3839"""40x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)41y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)42if self.use_mel_loss:43if self.mel_basis is None:44self.mel_basis = torch.from_numpy(librosa.filters.mel(22050, self.fft_size, 80)).cuda().T45x_mag = x_mag @ self.mel_basis46y_mag = y_mag @ self.mel_basis4748sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)49mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)5051return sc_loss, mag_loss525354class MultiResolutionSTFTLoss(torch.nn.Module):55"""Multi resolution STFT loss module."""5657def __init__(self,58fft_sizes=[1024, 2048, 512],59hop_sizes=[120, 240, 50],60win_lengths=[600, 1200, 240],61window="hann_window",62use_mel_loss=False):63"""Initialize Multi resolution STFT loss module.6465Args:66fft_sizes (list): List of FFT sizes.67hop_sizes (list): List of hop sizes.68win_lengths (list): List of window lengths.69window (str): Window function type.7071"""72super(MultiResolutionSTFTLoss, self).__init__()73assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)74self.stft_losses = torch.nn.ModuleList()75for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):76self.stft_losses += [STFTLoss(fs, ss, wl, window, use_mel_loss)]7778def forward(self, x, y):79"""Calculate forward propagation.8081Args:82x (Tensor): Predicted signal (B, T).83y (Tensor): Groundtruth signal (B, T).8485Returns:86Tensor: Multi resolution spectral convergence loss value.87Tensor: Multi resolution log STFT magnitude loss value.8889"""90sc_loss = 0.091mag_loss = 0.092for f in self.stft_losses:93sc_l, mag_l = f(x, y)94sc_loss += sc_l95mag_loss += mag_l96sc_loss /= len(self.stft_losses)97mag_loss /= len(self.stft_losses)9899return sc_loss, mag_loss100101102