Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/modules/parallel_wavegan/layers/causal_conv.py
694 views
1
# -*- coding: utf-8 -*-
2
3
# Copyright 2020 Tomoki Hayashi
4
# MIT License (https://opensource.org/licenses/MIT)
5
6
"""Causal convolusion layer modules."""
7
8
9
import torch
10
11
12
class CausalConv1d(torch.nn.Module):
13
"""CausalConv1d module with customized initialization."""
14
15
def __init__(self, in_channels, out_channels, kernel_size,
16
dilation=1, bias=True, pad="ConstantPad1d", pad_params={"value": 0.0}):
17
"""Initialize CausalConv1d module."""
18
super(CausalConv1d, self).__init__()
19
self.pad = getattr(torch.nn, pad)((kernel_size - 1) * dilation, **pad_params)
20
self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size,
21
dilation=dilation, bias=bias)
22
23
def forward(self, x):
24
"""Calculate forward propagation.
25
26
Args:
27
x (Tensor): Input tensor (B, in_channels, T).
28
29
Returns:
30
Tensor: Output tensor (B, out_channels, T).
31
32
"""
33
return self.conv(self.pad(x))[:, :, :x.size(2)]
34
35
36
class CausalConvTranspose1d(torch.nn.Module):
37
"""CausalConvTranspose1d module with customized initialization."""
38
39
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True):
40
"""Initialize CausalConvTranspose1d module."""
41
super(CausalConvTranspose1d, self).__init__()
42
self.deconv = torch.nn.ConvTranspose1d(
43
in_channels, out_channels, kernel_size, stride, bias=bias)
44
self.stride = stride
45
46
def forward(self, x):
47
"""Calculate forward propagation.
48
49
Args:
50
x (Tensor): Input tensor (B, in_channels, T_in).
51
52
Returns:
53
Tensor: Output tensor (B, out_channels, T_out).
54
55
"""
56
return self.deconv(x)[:, :, :-self.stride]
57
58