Path: blob/main/modules/commons/espnet_positional_embedding.py
694 views
import math1import torch234class PositionalEncoding(torch.nn.Module):5"""Positional encoding.6Args:7d_model (int): Embedding dimension.8dropout_rate (float): Dropout rate.9max_len (int): Maximum input length.10reverse (bool): Whether to reverse the input position.11"""1213def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):14"""Construct an PositionalEncoding object."""15super(PositionalEncoding, self).__init__()16self.d_model = d_model17self.reverse = reverse18self.xscale = math.sqrt(self.d_model)19self.dropout = torch.nn.Dropout(p=dropout_rate)20self.pe = None21self.extend_pe(torch.tensor(0.0).expand(1, max_len))2223def extend_pe(self, x):24"""Reset the positional encodings."""25if self.pe is not None:26if self.pe.size(1) >= x.size(1):27if self.pe.dtype != x.dtype or self.pe.device != x.device:28self.pe = self.pe.to(dtype=x.dtype, device=x.device)29return30pe = torch.zeros(x.size(1), self.d_model)31if self.reverse:32position = torch.arange(33x.size(1) - 1, -1, -1.0, dtype=torch.float3234).unsqueeze(1)35else:36position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)37div_term = torch.exp(38torch.arange(0, self.d_model, 2, dtype=torch.float32)39* -(math.log(10000.0) / self.d_model)40)41pe[:, 0::2] = torch.sin(position * div_term)42pe[:, 1::2] = torch.cos(position * div_term)43pe = pe.unsqueeze(0)44self.pe = pe.to(device=x.device, dtype=x.dtype)4546def forward(self, x: torch.Tensor):47"""Add positional encoding.48Args:49x (torch.Tensor): Input tensor (batch, time, `*`).50Returns:51torch.Tensor: Encoded tensor (batch, time, `*`).52"""53self.extend_pe(x)54x = x * self.xscale + self.pe[:, : x.size(1)]55return self.dropout(x)565758class ScaledPositionalEncoding(PositionalEncoding):59"""Scaled positional encoding module.60See Sec. 3.2 https://arxiv.org/abs/1809.0889561Args:62d_model (int): Embedding dimension.63dropout_rate (float): Dropout rate.64max_len (int): Maximum input length.65"""6667def __init__(self, d_model, dropout_rate, max_len=5000):68"""Initialize class."""69super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)70self.alpha = torch.nn.Parameter(torch.tensor(1.0))7172def reset_parameters(self):73"""Reset parameters."""74self.alpha.data = torch.tensor(1.0)7576def forward(self, x):77"""Add positional encoding.78Args:79x (torch.Tensor): Input tensor (batch, time, `*`).80Returns:81torch.Tensor: Encoded tensor (batch, time, `*`).82"""83self.extend_pe(x)84x = x + self.alpha * self.pe[:, : x.size(1)]85return self.dropout(x)868788class RelPositionalEncoding(PositionalEncoding):89"""Relative positional encoding module.90See : Appendix B in https://arxiv.org/abs/1901.0286091Args:92d_model (int): Embedding dimension.93dropout_rate (float): Dropout rate.94max_len (int): Maximum input length.95"""9697def __init__(self, d_model, dropout_rate, max_len=5000):98"""Initialize class."""99super().__init__(d_model, dropout_rate, max_len, reverse=True)100101def forward(self, x):102"""Compute positional encoding.103Args:104x (torch.Tensor): Input tensor (batch, time, `*`).105Returns:106torch.Tensor: Encoded tensor (batch, time, `*`).107torch.Tensor: Positional embedding tensor (1, time, `*`).108"""109self.extend_pe(x)110x = x * self.xscale111pos_emb = self.pe[:, : x.size(1)]112return self.dropout(x) + self.dropout(pos_emb)113114