Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/modules/parallel_wavegan/layers/pqmf.py
694 views
1
# -*- coding: utf-8 -*-
2
3
# Copyright 2020 Tomoki Hayashi
4
# MIT License (https://opensource.org/licenses/MIT)
5
6
"""Pseudo QMF modules."""
7
8
import numpy as np
9
import torch
10
import torch.nn.functional as F
11
12
from scipy.signal import kaiser
13
14
15
def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0):
16
"""Design prototype filter for PQMF.
17
18
This method is based on `A Kaiser window approach for the design of prototype
19
filters of cosine modulated filterbanks`_.
20
21
Args:
22
taps (int): The number of filter taps.
23
cutoff_ratio (float): Cut-off frequency ratio.
24
beta (float): Beta coefficient for kaiser window.
25
26
Returns:
27
ndarray: Impluse response of prototype filter (taps + 1,).
28
29
.. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
30
https://ieeexplore.ieee.org/abstract/document/681427
31
32
"""
33
# check the arguments are valid
34
assert taps % 2 == 0, "The number of taps mush be even number."
35
assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
36
37
# make initial filter
38
omega_c = np.pi * cutoff_ratio
39
with np.errstate(invalid='ignore'):
40
h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) \
41
/ (np.pi * (np.arange(taps + 1) - 0.5 * taps))
42
h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
43
44
# apply kaiser window
45
w = kaiser(taps + 1, beta)
46
h = h_i * w
47
48
return h
49
50
51
class PQMF(torch.nn.Module):
52
"""PQMF module.
53
54
This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
55
56
.. _`Near-perfect-reconstruction pseudo-QMF banks`:
57
https://ieeexplore.ieee.org/document/258122
58
59
"""
60
61
def __init__(self, subbands=4, taps=62, cutoff_ratio=0.15, beta=9.0):
62
"""Initilize PQMF module.
63
64
Args:
65
subbands (int): The number of subbands.
66
taps (int): The number of filter taps.
67
cutoff_ratio (float): Cut-off frequency ratio.
68
beta (float): Beta coefficient for kaiser window.
69
70
"""
71
super(PQMF, self).__init__()
72
73
# define filter coefficient
74
h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
75
h_analysis = np.zeros((subbands, len(h_proto)))
76
h_synthesis = np.zeros((subbands, len(h_proto)))
77
for k in range(subbands):
78
h_analysis[k] = 2 * h_proto * np.cos(
79
(2 * k + 1) * (np.pi / (2 * subbands)) *
80
(np.arange(taps + 1) - ((taps - 1) / 2)) +
81
(-1) ** k * np.pi / 4)
82
h_synthesis[k] = 2 * h_proto * np.cos(
83
(2 * k + 1) * (np.pi / (2 * subbands)) *
84
(np.arange(taps + 1) - ((taps - 1) / 2)) -
85
(-1) ** k * np.pi / 4)
86
87
# convert to tensor
88
analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
89
synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
90
91
# register coefficients as beffer
92
self.register_buffer("analysis_filter", analysis_filter)
93
self.register_buffer("synthesis_filter", synthesis_filter)
94
95
# filter for downsampling & upsampling
96
updown_filter = torch.zeros((subbands, subbands, subbands)).float()
97
for k in range(subbands):
98
updown_filter[k, k, 0] = 1.0
99
self.register_buffer("updown_filter", updown_filter)
100
self.subbands = subbands
101
102
# keep padding info
103
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
104
105
def analysis(self, x):
106
"""Analysis with PQMF.
107
108
Args:
109
x (Tensor): Input tensor (B, 1, T).
110
111
Returns:
112
Tensor: Output tensor (B, subbands, T // subbands).
113
114
"""
115
x = F.conv1d(self.pad_fn(x), self.analysis_filter)
116
return F.conv1d(x, self.updown_filter, stride=self.subbands)
117
118
def synthesis(self, x):
119
"""Synthesis with PQMF.
120
121
Args:
122
x (Tensor): Input tensor (B, subbands, T // subbands).
123
124
Returns:
125
Tensor: Output tensor (B, 1, T).
126
127
"""
128
x = F.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands)
129
return F.conv1d(self.pad_fn(x), self.synthesis_filter)
130
131