Path: blob/main/modules/parallel_wavegan/layers/pqmf.py
694 views
# -*- coding: utf-8 -*-12# Copyright 2020 Tomoki Hayashi3# MIT License (https://opensource.org/licenses/MIT)45"""Pseudo QMF modules."""67import numpy as np8import torch9import torch.nn.functional as F1011from scipy.signal import kaiser121314def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0):15"""Design prototype filter for PQMF.1617This method is based on `A Kaiser window approach for the design of prototype18filters of cosine modulated filterbanks`_.1920Args:21taps (int): The number of filter taps.22cutoff_ratio (float): Cut-off frequency ratio.23beta (float): Beta coefficient for kaiser window.2425Returns:26ndarray: Impluse response of prototype filter (taps + 1,).2728.. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:29https://ieeexplore.ieee.org/abstract/document/6814273031"""32# check the arguments are valid33assert taps % 2 == 0, "The number of taps mush be even number."34assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."3536# make initial filter37omega_c = np.pi * cutoff_ratio38with np.errstate(invalid='ignore'):39h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) \40/ (np.pi * (np.arange(taps + 1) - 0.5 * taps))41h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form4243# apply kaiser window44w = kaiser(taps + 1, beta)45h = h_i * w4647return h484950class PQMF(torch.nn.Module):51"""PQMF module.5253This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.5455.. _`Near-perfect-reconstruction pseudo-QMF banks`:56https://ieeexplore.ieee.org/document/2581225758"""5960def __init__(self, subbands=4, taps=62, cutoff_ratio=0.15, beta=9.0):61"""Initilize PQMF module.6263Args:64subbands (int): The number of subbands.65taps (int): The number of filter taps.66cutoff_ratio (float): Cut-off frequency ratio.67beta (float): Beta coefficient for kaiser window.6869"""70super(PQMF, self).__init__()7172# define filter coefficient73h_proto = design_prototype_filter(taps, cutoff_ratio, beta)74h_analysis = np.zeros((subbands, len(h_proto)))75h_synthesis = np.zeros((subbands, len(h_proto)))76for k in range(subbands):77h_analysis[k] = 2 * h_proto * np.cos(78(2 * k + 1) * (np.pi / (2 * subbands)) *79(np.arange(taps + 1) - ((taps - 1) / 2)) +80(-1) ** k * np.pi / 4)81h_synthesis[k] = 2 * h_proto * np.cos(82(2 * k + 1) * (np.pi / (2 * subbands)) *83(np.arange(taps + 1) - ((taps - 1) / 2)) -84(-1) ** k * np.pi / 4)8586# convert to tensor87analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)88synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)8990# register coefficients as beffer91self.register_buffer("analysis_filter", analysis_filter)92self.register_buffer("synthesis_filter", synthesis_filter)9394# filter for downsampling & upsampling95updown_filter = torch.zeros((subbands, subbands, subbands)).float()96for k in range(subbands):97updown_filter[k, k, 0] = 1.098self.register_buffer("updown_filter", updown_filter)99self.subbands = subbands100101# keep padding info102self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)103104def analysis(self, x):105"""Analysis with PQMF.106107Args:108x (Tensor): Input tensor (B, 1, T).109110Returns:111Tensor: Output tensor (B, subbands, T // subbands).112113"""114x = F.conv1d(self.pad_fn(x), self.analysis_filter)115return F.conv1d(x, self.updown_filter, stride=self.subbands)116117def synthesis(self, x):118"""Synthesis with PQMF.119120Args:121x (Tensor): Input tensor (B, subbands, T // subbands).122123Returns:124Tensor: Output tensor (B, 1, T).125126"""127x = F.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands)128return F.conv1d(self.pad_fn(x), self.synthesis_filter)129130131