Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/15/positional_encoding_jax.ipynb
1192 views
Kernel: Python 3

AUTHOR : Susnato Dhar(GitHub : https://github.com/susnato)

Posititional encoding for transformers.

  • Some code of this Notebook greatly resembles from this original Notebook which has support for PyTorch, I converted it to JAX compatible.
  • try: import flax except ModuleNotFoundError: %pip install -qq flax import flax import jax from flax import linen as nn import matplotlib.pyplot as plt !wget https://raw.githubusercontent.com/d2l-ai/d2l-en/master/d2l/torch.py -q -O d2l.py import d2l
    class PositionalEncoding(nn.Module): num_hiddens: int dropout: float max_len = 1000 @nn.compact def __call__(self, X): dropout = nn.Dropout(self.dropout) P = jax.numpy.zeros((1, self.max_len, self.num_hiddens)) x = jax.numpy.arange(self.max_len, dtype=jax.numpy.float32).reshape(-1, 1) / jax.numpy.power( 10000, jax.numpy.arange(0, self.num_hiddens, 2, dtype=jax.numpy.float32) / self.num_hiddens ) P = P.at[:, :, 0::2].set(jax.numpy.sin(x)) P = P.at[:, :, 1::2].set(jax.numpy.cos(x)) X = X + P[:, : X.shape[1], :] return dropout(X, deterministic=True), P
    encoding_dim, num_steps = 32, 60 pos_encoding = PositionalEncoding(encoding_dim, 0) x = jax.numpy.zeros((1, num_steps, encoding_dim)) init_variables = pos_encoding.init(jax.random.PRNGKey(0), x) X, P = pos_encoding.apply(init_variables, x) P = P[:, : X.shape[1], :] d2l.plot( jax.numpy.arange(num_steps), P[0, :, 6:10].T, xlabel="Row (position)", figsize=(6, 2.5), legend=["Col %d" % d for d in jax.numpy.arange(6, 10)], ) plt.savefig("positionalEncodingSinusoids.pdf", dpi=300)
    WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
    Image in a Jupyter notebook
    plt.figure(figsize=(3.5, 4)) plt.imshow(P[0, :, :], cmap="Blues") plt.xlabel("Column (encoding dimension)") plt.ylabel("Row (position)") plt.colorbar() plt.savefig("positionalEncodingHeatmap.pdf", dpi=300)
    Image in a Jupyter notebook