Path: blob/main/ch16/ch16-part1-self-attention.py
1245 views
# coding: utf-8123import sys4from python_environment_check import check_packages5import torch6import torch.nn.functional as F78# # Machine Learning with PyTorch and Scikit-Learn9# # -- Code Examples1011# ## Package version checks1213# Add folder to path in order to load from the check_packages.py script:14151617sys.path.insert(0, '..')181920# Check recommended package versions:212223242526d = {27'torch': '1.9.0',28}29check_packages(d)303132# # Chapter 16: Transformers – Improving Natural Language Processing with Attention Mechanisms (Part 1/3)3334# **Outline**35#36# - [Adding an attention mechanism to RNNs](#Adding-an-attention-mechanism-to-RNNs)37# - [Attention helps RNNs with accessing information](#Attention-helps-RNNs-with-accessing-information)38# - [The original attention mechanism for RNNs](#The-original-attention-mechanism-for-RNNs)39# - [Processing the inputs using a bidirectional RNN](#Processing-the-inputs-using-a-bidirectional-RNN)40# - [Generating outputs from context vectors](#Generating-outputs-from-context-vectors)41# - [Computing the attention weights](#Computing-the-attention-weights)42# - [Introducing the self-attention mechanism](#Introducing-the-self-attention-mechanism)43# - [Starting with a basic form of self-attention](#Starting-with-a-basic-form-of-self-attention)44# - [Parameterizing the self-attention mechanism: scaled dot-product attention](#Parameterizing-the-self-attention-mechanism-scaled-dot-product-attention)45# - [Attention is all we need: introducing the original transformer architecture](#Attention-is-all-we-need-introducing-the-original-transformer-architecture)46# - [Encoding context embeddings via multi-head attention](#Encoding-context-embeddings-via-multi-head-attention)47# - [Learning a language model: decoder and masked multi-head attention](#Learning-a-language-model-decoder-and-masked-multi-head-attention)48# - [Implementation details: positional encodings and layer normalization](#Implementation-details-positional-encodings-and-layer-normalization)495051525354# ## Adding an attention mechanism to RNNs5556# ### Attention helps RNNs with accessing information57585960616263646566# ### The original attention mechanism for RNNs676869707172# ### Processing the inputs using a bidirectional RNN73# ### Generating outputs from context vectors74# ### Computing the attention weights7576# ## Introducing the self-attention mechanism7778# ### Starting with a basic form of self-attention7980# - Assume we have an input sentence that we encoded via a dictionary, which maps the words to integers as discussed in the RNN chapter:818283848586# input sequence / sentence:87# "Can you help me to translate this sentence"8889sentence = torch.tensor(90[0, # can917, # you921, # help932, # me945, # to956, # translate964, # this973] # sentence98)99100sentence101102103# - Next, assume we have an embedding of the words, i.e., the words are represented as real vectors.104# - Since we have 8 words, there will be 8 vectors. Each vector is 16-dimensional:105106107108torch.manual_seed(123)109embed = torch.nn.Embedding(10, 16)110embedded_sentence = embed(sentence).detach()111embedded_sentence.shape112113114# - The goal is to compute the context vectors $\boldsymbol{z}^{(i)}=\sum_{j=1}^{T} \alpha_{i j} \boldsymbol{x}^{(j)}$, which involve attention weights $\alpha_{i j}$.115# - In turn, the attention weights $\alpha_{i j}$ involve the $\omega_{i j}$ values116# - Let's start with the $\omega_{i j}$'s first, which are computed as dot-products:117#118# $$\omega_{i j}=\boldsymbol{x}^{(i)^{\top}} \boldsymbol{x}^{(j)}$$119#120#121122123124omega = torch.empty(8, 8)125126for i, x_i in enumerate(embedded_sentence):127for j, x_j in enumerate(embedded_sentence):128omega[i, j] = torch.dot(x_i, x_j)129130131# - Actually, let's compute this more efficiently by replacing the nested for-loops with a matrix multiplication:132133134135omega_mat = embedded_sentence.matmul(embedded_sentence.T)136137138139140torch.allclose(omega_mat, omega)141142143# - Next, let's compute the attention weights by normalizing the "omega" values so they sum to 1144#145# $$\alpha_{i j}=\frac{\exp \left(\omega_{i j}\right)}{\sum_{j=1}^{T} \exp \left(\omega_{i j}\right)}=\operatorname{softmax}\left(\left[\omega_{i j}\right]_{j=1 \ldots T}\right)$$146#147# $$\sum_{j=1}^{T} \alpha_{i j}=1$$148149150151152attention_weights = F.softmax(omega, dim=1)153attention_weights.shape154155156# - We can conform that the columns sum up to one:157158159160attention_weights.sum(dim=1)161162163164165166167# - Now that we have the attention weights, we can compute the context vectors $\boldsymbol{z}^{(i)}=\sum_{j=1}^{T} \alpha_{i j} \boldsymbol{x}^{(j)}$, which involve attention weights $\alpha_{i j}$168# - For instance, to compute the context-vector of the 2nd input element (the element at index 1), we can perform the following computation:169170171172x_2 = embedded_sentence[1, :]173context_vec_2 = torch.zeros(x_2.shape)174for j in range(8):175x_j = embedded_sentence[j, :]176context_vec_2 += attention_weights[1, j] * x_j177178context_vec_2179180181# - Or, more effiently, using linear algebra and matrix multiplication:182183184185context_vectors = torch.matmul(186attention_weights, embedded_sentence)187188189torch.allclose(context_vec_2, context_vectors[1])190191192# ### Parameterizing the self-attention mechanism: scaled dot-product attention193194195196197198199200torch.manual_seed(123)201202d = embedded_sentence.shape[1]203U_query = torch.rand(d, d)204U_key = torch.rand(d, d)205U_value = torch.rand(d, d)206207208209210x_2 = embedded_sentence[1]211query_2 = U_query.matmul(x_2)212213214215216key_2 = U_key.matmul(x_2)217value_2 = U_value.matmul(x_2)218219220221222keys = U_key.matmul(embedded_sentence.T).T223torch.allclose(key_2, keys[1])224225226227228values = U_value.matmul(embedded_sentence.T).T229torch.allclose(value_2, values[1])230231232233234omega_23 = query_2.dot(keys[2])235omega_23236237238239240omega_2 = query_2.matmul(keys.T)241omega_2242243244245246attention_weights_2 = F.softmax(omega_2 / d**0.5, dim=0)247attention_weights_2248249250251252#context_vector_2nd = torch.zeros(values[1, :].shape)253#for j in range(8):254# context_vector_2nd += attention_weights_2[j] * values[j, :]255256#context_vector_2nd257258259260261context_vector_2 = attention_weights_2.matmul(values)262context_vector_2263264265# ## Attention is all we need: introducing the original transformer architecture266267268269270271# ### Encoding context embeddings via multi-head attention272273274275torch.manual_seed(123)276277d = embedded_sentence.shape[1]278one_U_query = torch.rand(d, d)279280281282283h = 8284multihead_U_query = torch.rand(h, d, d)285multihead_U_key = torch.rand(h, d, d)286multihead_U_value = torch.rand(h, d, d)287288289290291multihead_query_2 = multihead_U_query.matmul(x_2)292multihead_query_2.shape293294295296297multihead_key_2 = multihead_U_key.matmul(x_2)298multihead_value_2 = multihead_U_value.matmul(x_2)299300301302303multihead_key_2[2]304305306307308stacked_inputs = embedded_sentence.T.repeat(8, 1, 1)309stacked_inputs.shape310311312313314multihead_keys = torch.bmm(multihead_U_key, stacked_inputs)315multihead_keys.shape316317318319320multihead_keys = multihead_keys.permute(0, 2, 1)321multihead_keys.shape322323324325326multihead_keys[2, 1] # index: [2nd attention head, 2nd key]327328329330331multihead_values = torch.matmul(multihead_U_value, stacked_inputs)332multihead_values = multihead_values.permute(0, 2, 1)333334335336337multihead_z_2 = torch.rand(8, 16)338339340341342343344345346linear = torch.nn.Linear(8*16, 16)347context_vector_2 = linear(multihead_z_2.flatten())348context_vector_2.shape349350351# ### Learning a language model: decoder and masked multi-head attention352353354355356357# ### Implementation details: positional encodings and layer normalization358359360361362363# ---364#365# Readers may ignore the next cell.366367368369370371372373374375376377