Path: blob/master/labml_nn/transformers/models.py
4918 views
"""1---2title: Transformer Encoder and Decoder Models3summary: >4These are PyTorch implementations of Transformer based encoder and decoder models,5as well as other related modules.6---78# Transformer Encoder and Decoder Models910[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/basic/autoregressive_experiment.ipynb)11"""12import math1314import torch15import torch.nn as nn1617from labml_nn.utils import clone_module_list18from .feed_forward import FeedForward19from .mha import MultiHeadAttention20from .positional_encoding import get_positional_encoding212223class EmbeddingsWithPositionalEncoding(nn.Module):24"""25<a id="EmbeddingsWithPositionalEncoding"></a>2627## Embed tokens and add [fixed positional encoding](positional_encoding.html)28"""2930def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):31super().__init__()32self.linear = nn.Embedding(n_vocab, d_model)33self.d_model = d_model34self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))3536def forward(self, x: torch.Tensor):37pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)38return self.linear(x) * math.sqrt(self.d_model) + pe394041class EmbeddingsWithLearnedPositionalEncoding(nn.Module):42"""43<a id="EmbeddingsWithLearnedPositionalEncoding"></a>4445## Embed tokens and add parameterized positional encodings46"""4748def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):49super().__init__()50self.linear = nn.Embedding(n_vocab, d_model)51self.d_model = d_model52self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)5354def forward(self, x: torch.Tensor):55pe = self.positional_encodings[:x.shape[0]]56return self.linear(x) * math.sqrt(self.d_model) + pe575859class TransformerLayer(nn.Module):60"""61<a id="TransformerLayer"></a>6263## Transformer Layer6465This can act as an encoder layer or a decoder layer. We use pre-norm.66"""6768def __init__(self, *,69d_model: int,70self_attn: MultiHeadAttention,71src_attn: MultiHeadAttention = None,72feed_forward: FeedForward,73dropout_prob: float):74"""75* `d_model` is the token embedding size76* `self_attn` is the self attention module77* `src_attn` is the source attention module (when this is used in a decoder)78* `feed_forward` is the feed forward module79* `dropout_prob` is the probability of dropping out after self attention and FFN80"""81super().__init__()82self.size = d_model83self.self_attn = self_attn84self.src_attn = src_attn85self.feed_forward = feed_forward86self.dropout = nn.Dropout(dropout_prob)87self.norm_self_attn = nn.LayerNorm([d_model])88if self.src_attn is not None:89self.norm_src_attn = nn.LayerNorm([d_model])90self.norm_ff = nn.LayerNorm([d_model])91# Whether to save input to the feed forward layer92self.is_save_ff_input = False9394def forward(self, *,95x: torch.Tensor,96mask: torch.Tensor,97src: torch.Tensor = None,98src_mask: torch.Tensor = None):99# Normalize the vectors before doing self attention100z = self.norm_self_attn(x)101# Run through self attention, i.e. keys and values are from self102self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)103# Add the self attention results104x = x + self.dropout(self_attn)105106# If a source is provided, get results from attention to source.107# This is when you have a decoder layer that pays attention to108# encoder outputs109if src is not None:110# Normalize vectors111z = self.norm_src_attn(x)112# Attention to source. i.e. keys and values are from source113attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)114# Add the source attention results115x = x + self.dropout(attn_src)116117# Normalize for feed-forward118z = self.norm_ff(x)119# Save the input to the feed forward layer if specified120if self.is_save_ff_input:121self.ff_input = z.clone()122# Pass through the feed-forward network123ff = self.feed_forward(z)124# Add the feed-forward results back125x = x + self.dropout(ff)126127return x128129130class Encoder(nn.Module):131"""132<a id="Encoder"></a>133134## Transformer Encoder135"""136137def __init__(self, layer: TransformerLayer, n_layers: int):138super().__init__()139# Make copies of the transformer layer140self.layers = clone_module_list(layer, n_layers)141# Final normalization layer142self.norm = nn.LayerNorm([layer.size])143144def forward(self, x: torch.Tensor, mask: torch.Tensor):145# Run through each transformer layer146for layer in self.layers:147x = layer(x=x, mask=mask)148# Finally, normalize the vectors149return self.norm(x)150151152class Decoder(nn.Module):153"""154<a id="Decoder"></a>155156## Transformer Decoder157"""158159def __init__(self, layer: TransformerLayer, n_layers: int):160super().__init__()161# Make copies of the transformer layer162self.layers = clone_module_list(layer, n_layers)163# Final normalization layer164self.norm = nn.LayerNorm([layer.size])165166def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):167# Run through each transformer layer168for layer in self.layers:169x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)170# Finally, normalize the vectors171return self.norm(x)172173174class Generator(nn.Module):175"""176<a id="Generator"></a>177178## Generator179180This predicts the tokens and gives the lof softmax of those.181You don't need this if you are using `nn.CrossEntropyLoss`.182"""183184def __init__(self, n_vocab: int, d_model: int):185super().__init__()186self.projection = nn.Linear(d_model, n_vocab)187188def forward(self, x):189return self.projection(x)190191192class EncoderDecoder(nn.Module):193"""194<a id="EncoderDecoder"></a>195196## Combined Encoder-Decoder197"""198199def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: nn.Module, tgt_embed: nn.Module, generator: nn.Module):200super().__init__()201self.encoder = encoder202self.decoder = decoder203self.src_embed = src_embed204self.tgt_embed = tgt_embed205self.generator = generator206207# This was important from their code.208# Initialize parameters with Glorot / fan_avg.209for p in self.parameters():210if p.dim() > 1:211nn.init.xavier_uniform_(p)212213def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):214# Run the source through encoder215enc = self.encode(src, src_mask)216# Run encodings and targets through decoder217return self.decode(enc, src_mask, tgt, tgt_mask)218219def encode(self, src: torch.Tensor, src_mask: torch.Tensor):220return self.encoder(self.src_embed(src), src_mask)221222def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):223return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)224225226