Path: blob/master/notebooks/book1/15/attention_jax.ipynb
1192 views
Kernel: Python 3
Please find torch implementation of this notebook here: https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/15/attention_torch.ipynb
Basics of differentiable (soft) attention
We show how to implement soft attention. Based on sec 10.3 of http://d2l.ai/chapter_attention-mechanisms/attention-scoring-functions.html.
In [ ]:
In [1]:
Out[1]:
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Masked soft attention
In [2]:
Example. Batch size 2, feature size 2, sequence length 4. The valid lengths are 2,3. So the output has size (2,2,4), but the length dimension is full of 0s in the invalid locations.
In [3]:
Out[3]:
[[[0.60708654 0.39291343 0. 0. ]
[0.46680522 0.5331948 0. 0. ]]
[[0.3154995 0.3071782 0.37732235 0. ]
[0.4908877 0.2676804 0.24143189 0. ]]]
Example. Batch size 2, feature size 2, sequence length 4. The valid lengths are (1,3) for batch 1, and (2,4) for batch 2.
In [4]:
Out[4]:
[[[1. 0. 0. 0. ]
[0.3412919 0.38983083 0.2688773 0. ]]
[[0.50668186 0.4933181 0. 0. ]
[0.30692047 0.16736333 0.1509518 0.3747644 ]]]
Additive attention
In [5]:
In [6]:
Out[6]:
(2, 10, 4)
(2, 1, 4)
[[[ 2. 3. 4. 5. ]]
[[10. 11. 12.000001 13. ]]]
The heatmap is uniform across the keys, since the keys are all 1s. However, the support is truncated to the valid length.
In [7]:
In [8]:
Out[8]:
Dot-product attention
In [9]:
In [10]:
Out[10]:
[[[ 2. 3. 4. 5. ]]
[[10. 11. 12.000001 13. ]]]
In [11]:
Out[11]: