Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/transformers/mha.py
4910 views
1
"""
2
---
3
title: Multi-Headed Attention (MHA)
4
summary: >
5
This implements the Multi-Headed Attention used in transformers
6
using PyTorch with explanations.
7
---
8
9
# Multi-Headed Attention (MHA)
10
11
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/basic/autoregressive_experiment.ipynb)
12
13
This is a tutorial/implementation of multi-headed attention
14
from paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
15
in [PyTorch](https://pytorch.org/).
16
The implementation is inspired from [Annotated Transformer](https://nlp.seas.harvard.edu/2018/04/03/attention.html).
17
18
Here is the [training code](basic/autoregressive_experiment.html) that uses a basic transformer
19
with MHA for NLP auto-regression.
20
21
[Here is an experiment implementation](basic/autoregressive_experiment.html) that trains a simple transformer.
22
"""
23
24
import math
25
from typing import Optional, List
26
27
import torch
28
from torch import nn
29
30
from labml import tracker
31
32
33
class PrepareForMultiHeadAttention(nn.Module):
34
"""
35
<a id="PrepareMHA"></a>
36
37
## Prepare for multi-head attention
38
39
This module does a linear transformation and splits the vector into given
40
number of heads for multi-head attention.
41
This is used to transform **key**, **query**, and **value** vectors.
42
"""
43
44
def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
45
super().__init__()
46
# Linear layer for linear transform
47
self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
48
# Number of heads
49
self.heads = heads
50
# Number of dimensions in vectors in each head
51
self.d_k = d_k
52
53
def forward(self, x: torch.Tensor):
54
# Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
55
# We apply the linear transformation to the last dimension and split that into
56
# the heads.
57
head_shape = x.shape[:-1]
58
59
# Linear transform
60
x = self.linear(x)
61
62
# Split last dimension into heads
63
x = x.view(*head_shape, self.heads, self.d_k)
64
65
# Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, heads, d_model]`
66
return x
67
68
69
class MultiHeadAttention(nn.Module):
70
r"""
71
<a id="MHA"></a>
72
73
## Multi-Head Attention Module
74
75
This computes scaled multi-headed attention for given `query`, `key` and `value` vectors.
76
77
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
78
79
In simple terms, it finds keys that matches the query, and gets the values of
80
those keys.
81
82
It uses dot-product of query and key as the indicator of how matching they are.
83
Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$.
84
This is done to avoid large dot-product values causing softmax to
85
give very small gradients when $d_k$ is large.
86
87
Softmax is calculated along the axis of of the sequence (or time).
88
"""
89
90
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
91
"""
92
* `heads` is the number of heads.
93
* `d_model` is the number of features in the `query`, `key` and `value` vectors.
94
"""
95
96
super().__init__()
97
98
# Number of features per head
99
self.d_k = d_model // heads
100
# Number of heads
101
self.heads = heads
102
103
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
104
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
105
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
106
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
107
108
# Softmax for attention along the time dimension of `key`
109
self.softmax = nn.Softmax(dim=1)
110
111
# Output layer
112
self.output = nn.Linear(d_model, d_model)
113
# Dropout
114
self.dropout = nn.Dropout(dropout_prob)
115
# Scaling factor before the softmax
116
self.scale = 1 / math.sqrt(self.d_k)
117
118
# We store attentions so that it can be used for logging, or other computations if needed
119
self.attn = None
120
121
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
122
"""
123
### Calculate scores between queries and keys
124
125
This method can be overridden for other variations like relative attention.
126
"""
127
128
# Calculate $Q K^\top$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$
129
return torch.einsum('ibhd,jbhd->ijbh', query, key)
130
131
def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
132
"""
133
`mask` has shape `[seq_len_q, seq_len_k, batch_size]`, where first dimension is the query dimension.
134
If the query dimension is equal to $1$ it will be broadcasted.
135
"""
136
137
assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
138
assert mask.shape[1] == key_shape[0]
139
assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
140
141
# Same mask applied to all heads.
142
mask = mask.unsqueeze(-1)
143
144
# resulting mask has shape `[seq_len_q, seq_len_k, batch_size, heads]`
145
return mask
146
147
def forward(self, *,
148
query: torch.Tensor,
149
key: torch.Tensor,
150
value: torch.Tensor,
151
mask: Optional[torch.Tensor] = None):
152
"""
153
`query`, `key` and `value` are the tensors that store
154
collection of *query*, *key* and *value* vectors.
155
They have shape `[seq_len, batch_size, d_model]`.
156
157
`mask` has shape `[seq_len, seq_len, batch_size]` and
158
`mask[i, j, b]` indicates whether for batch `b`,
159
query at position `i` has access to key-value at position `j`.
160
"""
161
162
# `query`, `key` and `value` have shape `[seq_len, batch_size, d_model]`
163
seq_len, batch_size, _ = query.shape
164
165
if mask is not None:
166
mask = self.prepare_mask(mask, query.shape, key.shape)
167
168
# Prepare `query`, `key` and `value` for attention computation.
169
# These will then have shape `[seq_len, batch_size, heads, d_k]`.
170
query = self.query(query)
171
key = self.key(key)
172
value = self.value(value)
173
174
# Compute attention scores $Q K^\top$.
175
# This gives a tensor of shape `[seq_len, seq_len, batch_size, heads]`.
176
scores = self.get_scores(query, key)
177
178
# Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
179
scores *= self.scale
180
181
# Apply mask
182
if mask is not None:
183
scores = scores.masked_fill(mask == 0, float('-inf'))
184
185
# $softmax$ attention along the key sequence dimension
186
# $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
187
attn = self.softmax(scores)
188
189
# Save attentions if debugging
190
tracker.debug('attn', attn)
191
192
# Apply dropout
193
attn = self.dropout(attn)
194
195
# Multiply by values
196
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
197
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
198
199
# Save attentions for any other calculations
200
self.attn = attn.detach()
201
202
# Concatenate multiple heads
203
x = x.reshape(seq_len, batch_size, -1)
204
205
# Output layer
206
return self.output(x)
207
208