Path: blob/master/notebooks/book1/15/multi_head_attention_jax.ipynb
1192 views
Kernel: Python 3
Please find torch implementation of this notebook here: https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/15/multi_head_attention_torch.ipynb
Multi-head attention.
We show how to multi-head attention. Based on sec 10.5 of http://d2l.ai/chapter_attention-mechanisms/multihead-attention.html.
In [ ]:
In [2]:
import jax import jax.numpy as jnp # JAX NumPy from jax import lax import math from IPython import display try: from flax import linen as nn # The Linen API except ModuleNotFoundError: %pip install -qq flax from flax import linen as nn # The Linen API from flax.training import train_state # Useful dataclass to keep train state import numpy as np # Ordinary NumPy try: import optax # Optimizers except ModuleNotFoundError: %pip install -qq optax import optax # Optimizers rng = jax.random.PRNGKey(0)
Implementation
Utility functions.
In [3]:
def transpose_qkv(X, num_heads): # Shape of input `X`: # (`batch_size`, no. of queries or key-value pairs, `num_hiddens`). # Shape of output `X`: # (`batch_size`, no. of queries or key-value pairs, `num_heads`, # `num_hiddens` / `num_heads`) X = X.reshape((X.shape[0], X.shape[1], num_heads, -1)) # Shape of output `X`: # (`batch_size`, `num_heads`, no. of queries or key-value pairs, # `num_hiddens` / `num_heads`) X = jnp.transpose(X, (0, 2, 1, 3)) # Shape of `output`: # (`batch_size` * `num_heads`, no. of queries or key-value pairs, # `num_hiddens` / `num_heads`) return X.reshape((-1, X.shape[2], X.shape[3])) def transpose_output(X, num_heads): """Reverse the operation of `transpose_qkv`""" X = X.reshape((-1, num_heads, X.shape[1], X.shape[2])) X = jnp.transpose(X, (0, 2, 1, 3)) return X.reshape((X.shape[0], X.shape[1], -1))
Main function.
In [4]:
def sequence_mask(X, valid_len, value=0): """Mask irrelevant entries in sequences.""" maxlen = X.shape[1] mask = jnp.arange((maxlen), dtype=jnp.float32)[None, :] < valid_len[:, None] X = jnp.where(~mask, value, X) return X def masked_softmax(X, valid_lens): """Perform softmax operation by masking elements on the last axis.""" # `X`: 3D tensor, `valid_lens`: 1D or 2D tensor if valid_lens is None: return nn.softmax(X, axis=-1) else: shape = X.shape if valid_lens.ndim == 1: valid_lens = jnp.repeat(valid_lens, shape[1]) else: valid_lens = valid_lens.reshape(-1) # On the last axis, replace masked elements with a very large negative # value, whose exponentiation outputs 0 X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6) return nn.softmax(X.reshape(shape), axis=-1)
In [5]:
class DotProductAttention(nn.Module): """Scaled dot product attention.""" dropout: float # Shape of `queries`: (`batch_size`, no. of queries, `d`) # Shape of `keys`: (`batch_size`, no. of key-value pairs, `d`) # Shape of `values`: (`batch_size`, no. of key-value pairs, value # dimension) # Shape of `valid_lens`: (`batch_size`,) or (`batch_size`, no. of queries) @nn.compact def __call__(self, queries, keys, values, valid_lens=None, deterministic=True): d = queries.shape[-1] scores = queries @ (keys.swapaxes(1, 2)) / math.sqrt(d) attention_weights = masked_softmax(scores, valid_lens) dropout_layer = nn.Dropout(self.dropout, deterministic=deterministic) return dropout_layer(attention_weights) @ values
In [6]:
class MultiHeadAttention(nn.Module): num_hiddens: int num_heads: int dropout: float bias: bool = False @nn.compact def __call__(self, queries, keys, values, valid_lens): # Shape of `queries`, `keys`, or `values`: # (`batch_size`, no. of queries or key-value pairs, `num_hiddens`) # Shape of `valid_lens`: # (`batch_size`,) or (`batch_size`, no. of queries) # After transposing, shape of output `queries`, `keys`, or `values`: # (`batch_size` * `num_heads`, no. of queries or key-value pairs, # `num_hiddens` / `num_heads`) queries = transpose_qkv(nn.Dense(self.num_hiddens, use_bias=self.bias)(queries), self.num_heads) keys = transpose_qkv(nn.Dense(self.num_hiddens, use_bias=self.bias)(keys), self.num_heads) values = transpose_qkv(nn.Dense(self.num_hiddens, use_bias=self.bias)(values), self.num_heads) if valid_lens is not None: # On axis 0, copy the first item (scalar or vector) for # `num_heads` times, then copy the next item, and so on valid_lens = jnp.repeat(valid_lens, self.num_heads, axis=0) # Shape of `output`: (`batch_size` * `num_heads`, no. of queries, # `num_hiddens` / `num_heads`) output = DotProductAttention(self.dropout)(queries, keys, values, valid_lens) # Shape of `output_concat`: # (`batch_size`, no. of queries, `num_hiddens`) output_concat = transpose_output(output, self.num_heads) return nn.Dense(self.num_hiddens, use_bias=self.bias)(output_concat)
Example
The shape of the multi-head attention output is (batch_size, num_queries, num_hiddens).
In [7]:
num_hiddens, num_heads = 100, 5 attention = MultiHeadAttention(num_hiddens, num_heads, 0.5) batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, jnp.array([3, 2]) X = jnp.ones((batch_size, num_queries, num_hiddens)) Y = jnp.ones((batch_size, num_kvpairs, num_hiddens)) variables = attention.init(jax.random.PRNGKey(0), X, Y, Y, valid_lens) output = attention.apply(variables, X, Y, Y, valid_lens) output.shape
Out[7]:
(2, 4, 100)
In [ ]: