Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/modules/commons/espnet_positional_embedding.py
694 views
1
import math
2
import torch
3
4
5
class PositionalEncoding(torch.nn.Module):
6
"""Positional encoding.
7
Args:
8
d_model (int): Embedding dimension.
9
dropout_rate (float): Dropout rate.
10
max_len (int): Maximum input length.
11
reverse (bool): Whether to reverse the input position.
12
"""
13
14
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
15
"""Construct an PositionalEncoding object."""
16
super(PositionalEncoding, self).__init__()
17
self.d_model = d_model
18
self.reverse = reverse
19
self.xscale = math.sqrt(self.d_model)
20
self.dropout = torch.nn.Dropout(p=dropout_rate)
21
self.pe = None
22
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
23
24
def extend_pe(self, x):
25
"""Reset the positional encodings."""
26
if self.pe is not None:
27
if self.pe.size(1) >= x.size(1):
28
if self.pe.dtype != x.dtype or self.pe.device != x.device:
29
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
30
return
31
pe = torch.zeros(x.size(1), self.d_model)
32
if self.reverse:
33
position = torch.arange(
34
x.size(1) - 1, -1, -1.0, dtype=torch.float32
35
).unsqueeze(1)
36
else:
37
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
38
div_term = torch.exp(
39
torch.arange(0, self.d_model, 2, dtype=torch.float32)
40
* -(math.log(10000.0) / self.d_model)
41
)
42
pe[:, 0::2] = torch.sin(position * div_term)
43
pe[:, 1::2] = torch.cos(position * div_term)
44
pe = pe.unsqueeze(0)
45
self.pe = pe.to(device=x.device, dtype=x.dtype)
46
47
def forward(self, x: torch.Tensor):
48
"""Add positional encoding.
49
Args:
50
x (torch.Tensor): Input tensor (batch, time, `*`).
51
Returns:
52
torch.Tensor: Encoded tensor (batch, time, `*`).
53
"""
54
self.extend_pe(x)
55
x = x * self.xscale + self.pe[:, : x.size(1)]
56
return self.dropout(x)
57
58
59
class ScaledPositionalEncoding(PositionalEncoding):
60
"""Scaled positional encoding module.
61
See Sec. 3.2 https://arxiv.org/abs/1809.08895
62
Args:
63
d_model (int): Embedding dimension.
64
dropout_rate (float): Dropout rate.
65
max_len (int): Maximum input length.
66
"""
67
68
def __init__(self, d_model, dropout_rate, max_len=5000):
69
"""Initialize class."""
70
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
71
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
72
73
def reset_parameters(self):
74
"""Reset parameters."""
75
self.alpha.data = torch.tensor(1.0)
76
77
def forward(self, x):
78
"""Add positional encoding.
79
Args:
80
x (torch.Tensor): Input tensor (batch, time, `*`).
81
Returns:
82
torch.Tensor: Encoded tensor (batch, time, `*`).
83
"""
84
self.extend_pe(x)
85
x = x + self.alpha * self.pe[:, : x.size(1)]
86
return self.dropout(x)
87
88
89
class RelPositionalEncoding(PositionalEncoding):
90
"""Relative positional encoding module.
91
See : Appendix B in https://arxiv.org/abs/1901.02860
92
Args:
93
d_model (int): Embedding dimension.
94
dropout_rate (float): Dropout rate.
95
max_len (int): Maximum input length.
96
"""
97
98
def __init__(self, d_model, dropout_rate, max_len=5000):
99
"""Initialize class."""
100
super().__init__(d_model, dropout_rate, max_len, reverse=True)
101
102
def forward(self, x):
103
"""Compute positional encoding.
104
Args:
105
x (torch.Tensor): Input tensor (batch, time, `*`).
106
Returns:
107
torch.Tensor: Encoded tensor (batch, time, `*`).
108
torch.Tensor: Positional embedding tensor (1, time, `*`).
109
"""
110
self.extend_pe(x)
111
x = x * self.xscale
112
pos_emb = self.pe[:, : x.size(1)]
113
return self.dropout(x) + self.dropout(pos_emb)
114