Path: blob/master/notebooks/book1/15/multi_head_attention_torch.ipynb
1192 views
Kernel: Python 3.6.7 64-bit ('base': conda)
Please find jax implementation of this notebook here: https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/15/multi_head_attention_jax.ipynb
Multi-head attention.
We show how to multi-head attention. Based on sec 10.5 of http://d2l.ai/chapter_attention-mechanisms/multihead-attention.html.
In [1]:
Implementation
Utility functions.
In [2]:
Main function.
In [3]:
In [4]:
In [5]:
Example
The shape of the multi-head attention output is (batch_size, num_queries, num_hiddens).
In [6]:
Out[6]:
MultiHeadAttention(
(attention): DotProductAttention(
(dropout): Dropout(p=0.5, inplace=False)
)
(W_q): Linear(in_features=100, out_features=100, bias=False)
(W_k): Linear(in_features=100, out_features=100, bias=False)
(W_v): Linear(in_features=100, out_features=100, bias=False)
(W_o): Linear(in_features=100, out_features=100, bias=False)
)
In [7]:
Out[7]:
torch.Size([2, 4, 100])
In [ ]: