Path: blob/master/labml_nn/transformers/mha.py
4910 views
"""1---2title: Multi-Headed Attention (MHA)3summary: >4This implements the Multi-Headed Attention used in transformers5using PyTorch with explanations.6---78# Multi-Headed Attention (MHA)910[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/basic/autoregressive_experiment.ipynb)1112This is a tutorial/implementation of multi-headed attention13from paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762)14in [PyTorch](https://pytorch.org/).15The implementation is inspired from [Annotated Transformer](https://nlp.seas.harvard.edu/2018/04/03/attention.html).1617Here is the [training code](basic/autoregressive_experiment.html) that uses a basic transformer18with MHA for NLP auto-regression.1920[Here is an experiment implementation](basic/autoregressive_experiment.html) that trains a simple transformer.21"""2223import math24from typing import Optional, List2526import torch27from torch import nn2829from labml import tracker303132class PrepareForMultiHeadAttention(nn.Module):33"""34<a id="PrepareMHA"></a>3536## Prepare for multi-head attention3738This module does a linear transformation and splits the vector into given39number of heads for multi-head attention.40This is used to transform **key**, **query**, and **value** vectors.41"""4243def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):44super().__init__()45# Linear layer for linear transform46self.linear = nn.Linear(d_model, heads * d_k, bias=bias)47# Number of heads48self.heads = heads49# Number of dimensions in vectors in each head50self.d_k = d_k5152def forward(self, x: torch.Tensor):53# Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.54# We apply the linear transformation to the last dimension and split that into55# the heads.56head_shape = x.shape[:-1]5758# Linear transform59x = self.linear(x)6061# Split last dimension into heads62x = x.view(*head_shape, self.heads, self.d_k)6364# Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, heads, d_model]`65return x666768class MultiHeadAttention(nn.Module):69r"""70<a id="MHA"></a>7172## Multi-Head Attention Module7374This computes scaled multi-headed attention for given `query`, `key` and `value` vectors.7576$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$7778In simple terms, it finds keys that matches the query, and gets the values of79those keys.8081It uses dot-product of query and key as the indicator of how matching they are.82Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$.83This is done to avoid large dot-product values causing softmax to84give very small gradients when $d_k$ is large.8586Softmax is calculated along the axis of of the sequence (or time).87"""8889def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):90"""91* `heads` is the number of heads.92* `d_model` is the number of features in the `query`, `key` and `value` vectors.93"""9495super().__init__()9697# Number of features per head98self.d_k = d_model // heads99# Number of heads100self.heads = heads101102# These transform the `query`, `key` and `value` vectors for multi-headed attention.103self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)104self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)105self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)106107# Softmax for attention along the time dimension of `key`108self.softmax = nn.Softmax(dim=1)109110# Output layer111self.output = nn.Linear(d_model, d_model)112# Dropout113self.dropout = nn.Dropout(dropout_prob)114# Scaling factor before the softmax115self.scale = 1 / math.sqrt(self.d_k)116117# We store attentions so that it can be used for logging, or other computations if needed118self.attn = None119120def get_scores(self, query: torch.Tensor, key: torch.Tensor):121"""122### Calculate scores between queries and keys123124This method can be overridden for other variations like relative attention.125"""126127# Calculate $Q K^\top$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$128return torch.einsum('ibhd,jbhd->ijbh', query, key)129130def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):131"""132`mask` has shape `[seq_len_q, seq_len_k, batch_size]`, where first dimension is the query dimension.133If the query dimension is equal to $1$ it will be broadcasted.134"""135136assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]137assert mask.shape[1] == key_shape[0]138assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]139140# Same mask applied to all heads.141mask = mask.unsqueeze(-1)142143# resulting mask has shape `[seq_len_q, seq_len_k, batch_size, heads]`144return mask145146def forward(self, *,147query: torch.Tensor,148key: torch.Tensor,149value: torch.Tensor,150mask: Optional[torch.Tensor] = None):151"""152`query`, `key` and `value` are the tensors that store153collection of *query*, *key* and *value* vectors.154They have shape `[seq_len, batch_size, d_model]`.155156`mask` has shape `[seq_len, seq_len, batch_size]` and157`mask[i, j, b]` indicates whether for batch `b`,158query at position `i` has access to key-value at position `j`.159"""160161# `query`, `key` and `value` have shape `[seq_len, batch_size, d_model]`162seq_len, batch_size, _ = query.shape163164if mask is not None:165mask = self.prepare_mask(mask, query.shape, key.shape)166167# Prepare `query`, `key` and `value` for attention computation.168# These will then have shape `[seq_len, batch_size, heads, d_k]`.169query = self.query(query)170key = self.key(key)171value = self.value(value)172173# Compute attention scores $Q K^\top$.174# This gives a tensor of shape `[seq_len, seq_len, batch_size, heads]`.175scores = self.get_scores(query, key)176177# Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$178scores *= self.scale179180# Apply mask181if mask is not None:182scores = scores.masked_fill(mask == 0, float('-inf'))183184# $softmax$ attention along the key sequence dimension185# $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$186attn = self.softmax(scores)187188# Save attentions if debugging189tracker.debug('attn', attn)190191# Apply dropout192attn = self.dropout(attn)193194# Multiply by values195# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$196x = torch.einsum("ijbh,jbhd->ibhd", attn, value)197198# Save attentions for any other calculations199self.attn = attn.detach()200201# Concatenate multiple heads202x = x.reshape(seq_len, batch_size, -1)203204# Output layer205return self.output(x)206207208