Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/tutorials/jax_tutorials.md
1192 views

JAX

JAX is a version of NumPy that runs fast on CPU, GPU and TPU, by compiling down to XLA. It also has an excellent automatic differentiation library, extending the earlier autograd package, which makes it easy to compute higher order derivatives, per-example gradients (instead of aggregated gradients), and gradients of complex code (e.g., optimize an optimizer). The JAX interface is almost identical to NumPy (by design), but with some small differences, and lots of additional features. More details can be found in the other tutorials listed below.

Here is a one-slide summary of JAX from a recent 10 minute overview talk by Jake Vanderplas. Jax summary slide

Tutorials (blog posts / notebooks)

Videos / talks

in this tutorial, we focus on core JAX. However, since JAX is quite low level (like numpy), many libraries are being developed that build on top of it, to provide more specialized functionality. We summarize a few of the ML-related libraries below. See also https://github.com/n2cholas/awesome-jax which has a more extensive list.

DNN libraries

JAX is a purely functional library, which differs from Tensorflow and Pytorch, which are stateful. The main advantages of functional programming are that we can safely transform the code, and/or run it in parallel, without worrying about global state changing behind the scenes. The main disadvantage is that code (especially DNNs) can be harder to write. To simplify the task, various DNN libraries have been designed, as we list below. In this book, we use Flax.

NameDescription
StaxBarebones library for specifying DNNs
FlaxLibrary for specifying and training DNNs
HaikuLibrary for specifying DNNs, similar to Sonnet
JraphLibrary for graph neural networks
TraxLibrary for specifying and training DNNs, with a focus on sequence models
T5XT5 (a large seq2seq model) in JAX/Flax
HuggingFace TransformersTransformers
ObjaxPyTorch-like library for JAX (stateful/ object-oriented, not compatible with other JAX libraries)
ElegyKeras-like library for Jax
FlaxVisionFlax version of torchvision
Neural tangentsLibrary to compute a kernel from a DNN
Efficient netsEfficient CNN classifiers in Flax
Progressive GANSProgressive GANs in Flax

RL libraries

NameDescription
RLaxLibrary from Deepmind
CoaxLightweight library from Microsoft for solving Open-AI gym environments

Probabilistic programming languages

NameDescription
NumPyroLibrary for PPL
OryxLightweight library for PPL
MCXLibrary for PPL
JXFExponential families in JAX
DistraxLibrary for probability distributions and bijectors
TFP/JAX distributionsJax port of tfp.distributions
BlackJaxLibrary for HMC
NewtVariational inference for (Markov) GPs

Other libraries

There are also many other JAX libraries for tasks that are not about defining DNN models. We list some of them below.

NameDescription
OptaxLibrary for defining gradient-based optimizers
ChexLibrary for debugging and developing reliable JAX code
Common loop utilitiesLibrary for writing "beautiful training loops in JAX"
KF(Extended, Parallelized) Kalman filtering/ smoothing
Easy Neural ODEsNeural ODEs for classification, Latent ODEs for time series and FFJORD for density estimation models with a bunch of higher order adaptive-stepping numercial solvers(e.g. Heun-Euler, Fehlberg,Cash-Karp,Tanyam and Adams adaptive order)
GTP-JOpen source version of GPT-3 using JAX and TPU v3-256.
CLIP-JaxJax wrapper for inference using pre-trained OpenAI CLIP model
ScenicComputer vision library
PixImage pre-processing library