Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
rasbt
GitHub Repository: rasbt/machine-learning-book
Path: blob/main/ch16/ch16-part1-self-attention.py
1245 views
1
# coding: utf-8
2
3
4
import sys
5
from python_environment_check import check_packages
6
import torch
7
import torch.nn.functional as F
8
9
# # Machine Learning with PyTorch and Scikit-Learn
10
# # -- Code Examples
11
12
# ## Package version checks
13
14
# Add folder to path in order to load from the check_packages.py script:
15
16
17
18
sys.path.insert(0, '..')
19
20
21
# Check recommended package versions:
22
23
24
25
26
27
d = {
28
'torch': '1.9.0',
29
}
30
check_packages(d)
31
32
33
# # Chapter 16: Transformers – Improving Natural Language Processing with Attention Mechanisms (Part 1/3)
34
35
# **Outline**
36
#
37
# - [Adding an attention mechanism to RNNs](#Adding-an-attention-mechanism-to-RNNs)
38
# - [Attention helps RNNs with accessing information](#Attention-helps-RNNs-with-accessing-information)
39
# - [The original attention mechanism for RNNs](#The-original-attention-mechanism-for-RNNs)
40
# - [Processing the inputs using a bidirectional RNN](#Processing-the-inputs-using-a-bidirectional-RNN)
41
# - [Generating outputs from context vectors](#Generating-outputs-from-context-vectors)
42
# - [Computing the attention weights](#Computing-the-attention-weights)
43
# - [Introducing the self-attention mechanism](#Introducing-the-self-attention-mechanism)
44
# - [Starting with a basic form of self-attention](#Starting-with-a-basic-form-of-self-attention)
45
# - [Parameterizing the self-attention mechanism: scaled dot-product attention](#Parameterizing-the-self-attention-mechanism-scaled-dot-product-attention)
46
# - [Attention is all we need: introducing the original transformer architecture](#Attention-is-all-we-need-introducing-the-original-transformer-architecture)
47
# - [Encoding context embeddings via multi-head attention](#Encoding-context-embeddings-via-multi-head-attention)
48
# - [Learning a language model: decoder and masked multi-head attention](#Learning-a-language-model-decoder-and-masked-multi-head-attention)
49
# - [Implementation details: positional encodings and layer normalization](#Implementation-details-positional-encodings-and-layer-normalization)
50
51
52
53
54
55
# ## Adding an attention mechanism to RNNs
56
57
# ### Attention helps RNNs with accessing information
58
59
60
61
62
63
64
65
66
67
# ### The original attention mechanism for RNNs
68
69
70
71
72
73
# ### Processing the inputs using a bidirectional RNN
74
# ### Generating outputs from context vectors
75
# ### Computing the attention weights
76
77
# ## Introducing the self-attention mechanism
78
79
# ### Starting with a basic form of self-attention
80
81
# - Assume we have an input sentence that we encoded via a dictionary, which maps the words to integers as discussed in the RNN chapter:
82
83
84
85
86
87
# input sequence / sentence:
88
# "Can you help me to translate this sentence"
89
90
sentence = torch.tensor(
91
[0, # can
92
7, # you
93
1, # help
94
2, # me
95
5, # to
96
6, # translate
97
4, # this
98
3] # sentence
99
)
100
101
sentence
102
103
104
# - Next, assume we have an embedding of the words, i.e., the words are represented as real vectors.
105
# - Since we have 8 words, there will be 8 vectors. Each vector is 16-dimensional:
106
107
108
109
torch.manual_seed(123)
110
embed = torch.nn.Embedding(10, 16)
111
embedded_sentence = embed(sentence).detach()
112
embedded_sentence.shape
113
114
115
# - The goal is to compute the context vectors $\boldsymbol{z}^{(i)}=\sum_{j=1}^{T} \alpha_{i j} \boldsymbol{x}^{(j)}$, which involve attention weights $\alpha_{i j}$.
116
# - In turn, the attention weights $\alpha_{i j}$ involve the $\omega_{i j}$ values
117
# - Let's start with the $\omega_{i j}$'s first, which are computed as dot-products:
118
#
119
# $$\omega_{i j}=\boldsymbol{x}^{(i)^{\top}} \boldsymbol{x}^{(j)}$$
120
#
121
#
122
123
124
125
omega = torch.empty(8, 8)
126
127
for i, x_i in enumerate(embedded_sentence):
128
for j, x_j in enumerate(embedded_sentence):
129
omega[i, j] = torch.dot(x_i, x_j)
130
131
132
# - Actually, let's compute this more efficiently by replacing the nested for-loops with a matrix multiplication:
133
134
135
136
omega_mat = embedded_sentence.matmul(embedded_sentence.T)
137
138
139
140
141
torch.allclose(omega_mat, omega)
142
143
144
# - Next, let's compute the attention weights by normalizing the "omega" values so they sum to 1
145
#
146
# $$\alpha_{i j}=\frac{\exp \left(\omega_{i j}\right)}{\sum_{j=1}^{T} \exp \left(\omega_{i j}\right)}=\operatorname{softmax}\left(\left[\omega_{i j}\right]_{j=1 \ldots T}\right)$$
147
#
148
# $$\sum_{j=1}^{T} \alpha_{i j}=1$$
149
150
151
152
153
attention_weights = F.softmax(omega, dim=1)
154
attention_weights.shape
155
156
157
# - We can conform that the columns sum up to one:
158
159
160
161
attention_weights.sum(dim=1)
162
163
164
165
166
167
168
# - Now that we have the attention weights, we can compute the context vectors $\boldsymbol{z}^{(i)}=\sum_{j=1}^{T} \alpha_{i j} \boldsymbol{x}^{(j)}$, which involve attention weights $\alpha_{i j}$
169
# - For instance, to compute the context-vector of the 2nd input element (the element at index 1), we can perform the following computation:
170
171
172
173
x_2 = embedded_sentence[1, :]
174
context_vec_2 = torch.zeros(x_2.shape)
175
for j in range(8):
176
x_j = embedded_sentence[j, :]
177
context_vec_2 += attention_weights[1, j] * x_j
178
179
context_vec_2
180
181
182
# - Or, more effiently, using linear algebra and matrix multiplication:
183
184
185
186
context_vectors = torch.matmul(
187
attention_weights, embedded_sentence)
188
189
190
torch.allclose(context_vec_2, context_vectors[1])
191
192
193
# ### Parameterizing the self-attention mechanism: scaled dot-product attention
194
195
196
197
198
199
200
201
torch.manual_seed(123)
202
203
d = embedded_sentence.shape[1]
204
U_query = torch.rand(d, d)
205
U_key = torch.rand(d, d)
206
U_value = torch.rand(d, d)
207
208
209
210
211
x_2 = embedded_sentence[1]
212
query_2 = U_query.matmul(x_2)
213
214
215
216
217
key_2 = U_key.matmul(x_2)
218
value_2 = U_value.matmul(x_2)
219
220
221
222
223
keys = U_key.matmul(embedded_sentence.T).T
224
torch.allclose(key_2, keys[1])
225
226
227
228
229
values = U_value.matmul(embedded_sentence.T).T
230
torch.allclose(value_2, values[1])
231
232
233
234
235
omega_23 = query_2.dot(keys[2])
236
omega_23
237
238
239
240
241
omega_2 = query_2.matmul(keys.T)
242
omega_2
243
244
245
246
247
attention_weights_2 = F.softmax(omega_2 / d**0.5, dim=0)
248
attention_weights_2
249
250
251
252
253
#context_vector_2nd = torch.zeros(values[1, :].shape)
254
#for j in range(8):
255
# context_vector_2nd += attention_weights_2[j] * values[j, :]
256
257
#context_vector_2nd
258
259
260
261
262
context_vector_2 = attention_weights_2.matmul(values)
263
context_vector_2
264
265
266
# ## Attention is all we need: introducing the original transformer architecture
267
268
269
270
271
272
# ### Encoding context embeddings via multi-head attention
273
274
275
276
torch.manual_seed(123)
277
278
d = embedded_sentence.shape[1]
279
one_U_query = torch.rand(d, d)
280
281
282
283
284
h = 8
285
multihead_U_query = torch.rand(h, d, d)
286
multihead_U_key = torch.rand(h, d, d)
287
multihead_U_value = torch.rand(h, d, d)
288
289
290
291
292
multihead_query_2 = multihead_U_query.matmul(x_2)
293
multihead_query_2.shape
294
295
296
297
298
multihead_key_2 = multihead_U_key.matmul(x_2)
299
multihead_value_2 = multihead_U_value.matmul(x_2)
300
301
302
303
304
multihead_key_2[2]
305
306
307
308
309
stacked_inputs = embedded_sentence.T.repeat(8, 1, 1)
310
stacked_inputs.shape
311
312
313
314
315
multihead_keys = torch.bmm(multihead_U_key, stacked_inputs)
316
multihead_keys.shape
317
318
319
320
321
multihead_keys = multihead_keys.permute(0, 2, 1)
322
multihead_keys.shape
323
324
325
326
327
multihead_keys[2, 1] # index: [2nd attention head, 2nd key]
328
329
330
331
332
multihead_values = torch.matmul(multihead_U_value, stacked_inputs)
333
multihead_values = multihead_values.permute(0, 2, 1)
334
335
336
337
338
multihead_z_2 = torch.rand(8, 16)
339
340
341
342
343
344
345
346
347
linear = torch.nn.Linear(8*16, 16)
348
context_vector_2 = linear(multihead_z_2.flatten())
349
context_vector_2.shape
350
351
352
# ### Learning a language model: decoder and masked multi-head attention
353
354
355
356
357
358
# ### Implementation details: positional encodings and layer normalization
359
360
361
362
363
364
# ---
365
#
366
# Readers may ignore the next cell.
367
368
369
370
371
372
373
374
375
376
377