Path: blob/master/notebooks/book1/15/attention_torch.ipynb
1192 views
Kernel: Python 3
Please find jax implementation of this notebook here: https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/15/attention_jax.ipynb
Basics of differentiable (soft) attention
We show how to implement soft attention. Based on sec 10.3 of http://d2l.ai/chapter_attention-mechanisms/attention-scoring-functions.html.
In [ ]:
Masked soft attention
In [ ]:
Example. Batch size 2, feature size 2, sequence length 4. The valid lengths are 2,3. So the output has size (2,2,4), but the length dimension is full of 0s in the invalid locations.
In [ ]:
tensor([[[0.6174, 0.3826, 0.0000, 0.0000],
[0.3164, 0.6836, 0.0000, 0.0000]],
[[0.3391, 0.2975, 0.3634, 0.0000],
[0.4018, 0.2755, 0.3227, 0.0000]]])
Example. Batch size 2, feature size 2, sequence length 4. The valid lengths are (1,3) for batch 1, and (2,4) for batch 2.
In [ ]:
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
[0.3335, 0.2970, 0.3695, 0.0000]],
[[0.4541, 0.5459, 0.0000, 0.0000],
[0.1296, 0.2880, 0.2429, 0.3395]]])
Additive attention
In [ ]:
In [ ]:
torch.Size([2, 10, 4])
torch.Size([2, 1, 4])
tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]],
[[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=<BmmBackward0>)
The heatmap is uniform across the keys, since the keys are all 1s. However, the support is truncated to the valid length.
In [ ]:
In [ ]:
Dot-product attention
In [ ]:
In [ ]:
tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]],
[[10.0000, 11.0000, 12.0000, 13.0000]]])
In [ ]: