JAX
JAX is a version of NumPy that runs fast on CPU, GPU and TPU, by compiling down to XLA. It also has an excellent automatic differentiation library, extending the earlier autograd package, which makes it easy to compute higher order derivatives, per-example gradients (instead of aggregated gradients), and gradients of complex code (e.g., optimize an optimizer). The JAX interface is almost identical to NumPy (by design), but with some small differences, and lots of additional features. More details can be found in the other tutorials listed below.
Here is a one-slide summary of JAX from a recent 10 minute overview talk by Jake Vanderplas.
Tutorials (blog posts / notebooks)
From PyTorch to JAX: towards neural net frameworks that purify stateful code. Explains concepts behind flax and haiku.
CMA-ES in JAX blog post for fitting DNNs using blackbox optimization.
JAX on TPU pods, solving y=mx+b with jax on a tpu pod slice by Mat Kelcey
Videos / talks
JAX libraries related to ML
in this tutorial, we focus on core JAX. However, since JAX is quite low level (like numpy), many libraries are being developed that build on top of it, to provide more specialized functionality. We summarize a few of the ML-related libraries below. See also https://github.com/n2cholas/awesome-jax which has a more extensive list.
DNN libraries
JAX is a purely functional library, which differs from Tensorflow and Pytorch, which are stateful. The main advantages of functional programming are that we can safely transform the code, and/or run it in parallel, without worrying about global state changing behind the scenes. The main disadvantage is that code (especially DNNs) can be harder to write. To simplify the task, various DNN libraries have been designed, as we list below. In this book, we use Flax.
Name | Description |
---|---|
Stax | Barebones library for specifying DNNs |
Flax | Library for specifying and training DNNs |
Haiku | Library for specifying DNNs, similar to Sonnet |
Jraph | Library for graph neural networks |
Trax | Library for specifying and training DNNs, with a focus on sequence models |
T5X | T5 (a large seq2seq model) in JAX/Flax |
HuggingFace Transformers | Transformers |
Objax | PyTorch-like library for JAX (stateful/ object-oriented, not compatible with other JAX libraries) |
Elegy | Keras-like library for Jax |
FlaxVision | Flax version of torchvision |
Neural tangents | Library to compute a kernel from a DNN |
Efficient nets | Efficient CNN classifiers in Flax |
Progressive GANS | Progressive GANs in Flax |
RL libraries
Name | Description |
---|---|
RLax | Library from Deepmind |
Coax | Lightweight library from Microsoft for solving Open-AI gym environments |
Probabilistic programming languages
Name | Description |
---|---|
NumPyro | Library for PPL |
Oryx | Lightweight library for PPL |
MCX | Library for PPL |
JXF | Exponential families in JAX |
Distrax | Library for probability distributions and bijectors |
TFP/JAX distributions | Jax port of tfp.distributions |
BlackJax | Library for HMC |
Newt | Variational inference for (Markov) GPs |
Other libraries
There are also many other JAX libraries for tasks that are not about defining DNN models. We list some of them below.
Name | Description |
---|---|
Optax | Library for defining gradient-based optimizers |
Chex | Library for debugging and developing reliable JAX code |
Common loop utilities | Library for writing "beautiful training loops in JAX" |
KF | (Extended, Parallelized) Kalman filtering/ smoothing |
Easy Neural ODEs | Neural ODEs for classification, Latent ODEs for time series and FFJORD for density estimation models with a bunch of higher order adaptive-stepping numercial solvers(e.g. Heun-Euler, Fehlberg,Cash-Karp,Tanyam and Adams adaptive order) |
GTP-J | Open source version of GPT-3 using JAX and TPU v3-256. |
CLIP-Jax | Jax wrapper for inference using pre-trained OpenAI CLIP model |
Scenic | Computer vision library |
Pix | Image pre-processing library |