Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/tutorials/jax_intro.ipynb
1192 views
Kernel: Python 3

Open In Colab

Yet another JAX tutorial

Kevin Murphy ([email protected]). Last update: September 2021.

JAX is a version of NumPy that runs fast on CPU, GPU and TPU, by compiling the computational graph to XLA (Accelerated Linear Algebra). It also has an excellent automatic differentiation library, extending the earlier autograd package. This library makes it easy to compute higher order derivatives, gradients of complex functions (e.g., optimize an iterative solver), etc. The JAX interface is almost identical to NumPy (by design), but with some small differences, and lots of additional features. We give a brief introduction below. For more details, see this list of JAX tutorials

Setup

# Standard Python libraries from __future__ import absolute_import, division, print_function, unicode_literals from functools import partial import os import time import numpy as np np.set_printoptions(precision=3) import glob import matplotlib.pyplot as plt import PIL import imageio from typing import Tuple, NamedTuple from IPython import display %matplotlib inline import sklearn
import jax import jax.numpy as jnp from jax import random, vmap, jit, grad, value_and_grad, hessian, jacfwd, jacrev print("jax version {}".format(jax.__version__)) # Check the jax backend print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform)) key = random.PRNGKey(0)
jax version 0.2.19 jax backend gpu

Hardware accelerators

Colab makes it easy to use GPUs and TPUs for speeding up some workflows, especially related to deep learning.

GPUs

Colab offers graphics processing units (GPUs) which can be much faster than CPUs (central processing units), as we illustrate below.

# Check if GPU is available and its model, memory ...etc. !nvidia-smi
Sat Sep 11 03:46:08 2021 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 470.63.01 Driver Version: 460.32.03 CUDA Version: 11.2 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | | N/A 66C P0 29W / 70W | 13610MiB / 15109MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | No running processes found | +-----------------------------------------------------------------------------+
# Check if JAX is using GPU print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform)) # Check the devices avaiable for JAX jax.devices()
jax backend gpu
[GpuDevice(id=0, process_index=0)]

Let's see how JAX can speed up things like matrix-matrix multiplication.

First the numpy/CPU version.

# Parameters for the experiment size = int(1e3) number_of_loops = int(1e2)
# Standard numpy CPU def f(x=None): if not isinstance(x, np.ndarray): x = np.ones((size, size), dtype=np.float32) return np.dot(x, x.T)
%timeit -o -n $number_of_loops f()
100 loops, best of 5: 16.4 ms per loop
<TimeitResult : 100 loops, best of 5: 16.4 ms per loop>
res = _ # get result of last cell time_cpu = res.best print(time_cpu)
0.016404767970000192

Now we look at the JAX version. JAX supports execution on XLA devices, which can be CPU, GPU or even TPU. We added that block_until_ready because JAX uses asynchronous execution by default.

# JAX device execution # https://github.com/google/jax/issues/1598 def jf(x=None): if not isinstance(x, jnp.ndarray): x = jnp.ones((size, size), dtype=jnp.float32) return jnp.dot(x, x.T) f_gpu = jit(jf, backend="gpu") f_cpu = jit(jf, backend="cpu")
# Time the CPU version %timeit -o -n $number_of_loops f_cpu()
100 loops, best of 5: 25.7 ms per loop
<TimeitResult : 100 loops, best of 5: 25.7 ms per loop>
res = _ time_jcpu = res.best print(time_jcpu)
0.025695679049999854
# Time the GPU version %timeit -o -n $number_of_loops f_gpu().block_until_ready()
The slowest run took 40.61 times longer than the fastest. This could mean that an intermediate result is being cached. 100 loops, best of 5: 485 µs per loop
<TimeitResult : 100 loops, best of 5: 485 µs per loop>
res = _ time_jgpu = res.best print(time_jgpu)
0.00048540602000002764
print("JAX CPU time {:0.6f}, Numpy CPU time {:0.6f}, speedup {:0.6f}".format(time_jcpu, time_cpu, time_cpu / time_jcpu)) print("JAX GPU time {:0.6f}, Numpy CPU time {:0.6f}, speedup {:0.6f}".format(time_jgpu, time_cpu, time_cpu / time_jgpu))
JAX CPU time 0.025696, Numpy CPU time 0.016405, speedup 0.638425 JAX GPU time 0.000485, Numpy CPU time 0.016405, speedup 33.795971

In the above example we see that JAX GPU is much faster than Numpy CPU. However we also see that JAX CPU is slower than Numpy CPU - this can happen with simple functions, but usually JAX provides a speedup, even on CPU, if you JIT compile a complex function (see below).

We can move numpy arrays to the GPU for speed. The result will be transferred back to CPU for printing, saving, etc.

from jax import device_put x = np.ones((size, size)).astype(np.float32) print(type(x)) %timeit -o -n $number_of_loops f(x) x = device_put(x) print(type(x)) %timeit -o -n $number_of_loops jf(x)
<class 'numpy.ndarray'> 100 loops, best of 5: 14.8 ms per loop <class 'jaxlib.xla_extension.DeviceArray'> 100 loops, best of 5: 452 µs per loop
<TimeitResult : 100 loops, best of 5: 452 µs per loop>

TPUs

We can turn on the tensor processing unit by selecting from the Colab runtime. Everything else "just works" as before.

import jax.tools.colab_tpu jax.tools.colab_tpu.setup_tpu()

If everything is set up correctly, the following command should return a list of 8 TPU devices.

jax.local_devices()
[GpuDevice(id=0, process_index=0)]

Vmap

We often write a function to process a single vector or matrix, and then want to apply it to a batch of data. Using for loops is slow, and manually batchifying code is complex. Fortunately we can use the vmap function, which will map our function across a set of inputs, automatically batchifying it.

Example: 1d convolution

(This example is from the Deepmind tutorial.)

Consider standard 1d convolution of two vectors.

x = jnp.arange(5) w = jnp.array([2.0, 3.0, 4.0]) def convolve(x, w): output = [] for i in range(1, len(x) - 1): output.append(jnp.dot(x[i - 1 : i + 2], w)) return jnp.array(output) convolve(x, w)
DeviceArray([11., 20., 29.], dtype=float32)

Now suppose we want to convolve multiple vectors with multiple kernels. The simplest way is to use a for loop, but this is slow.

xs = jnp.stack([x, x]) ws = jnp.stack([w, w]) def manually_batched_convolve(xs, ws): output = [] for i in range(xs.shape[0]): output.append(convolve(xs[i], ws[i])) return jnp.stack(output) manually_batched_convolve(xs, ws)
DeviceArray([[11., 20., 29.], [11., 20., 29.]], dtype=float32)

We can manually vectorize the code, but it is complex.

def manually_vectorised_convolve(xs, ws): output = [] for i in range(1, xs.shape[-1] - 1): output.append(jnp.sum(xs[:, i - 1 : i + 2] * ws, axis=1)) return jnp.stack(output, axis=1) manually_vectorised_convolve(xs, ws)
DeviceArray([[11., 20., 29.], [11., 20., 29.]], dtype=float32)

Fortunately vmap can do this for us!

auto_batch_convolve = jax.vmap(convolve) auto_batch_convolve(xs, ws)
DeviceArray([[11., 20., 29.], [11., 20., 29.]], dtype=float32)

Axes

By default, vmap vectorizes over the first axis of each of its inputs. If the first argument has a batch and the second does not, ,specify in_axes=[0,None], so the second argument is not vectorized over.

jax.vmap(convolve, in_axes=[0, None])(xs, w)
DeviceArray([[11., 20., 29.], [11., 20., 29.]], dtype=float32)

We can also vectorize over other dimensions.

print(xs.shape) xst = jnp.transpose(xs) print(xst.shape) wst = jnp.transpose(ws) auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1) auto_batch_convolve_v2(xst, wst)
(2, 5) (5, 2)
DeviceArray([[11., 11.], [20., 20.], [29., 29.]], dtype=float32)

Example: logistic regression

We now give another example, using binary logistic regression. Let us start with a predictor for a single example .

D = 2 N = 3 w = np.random.normal(size=(D,)) X = np.random.normal(size=(N, D)) def sigmoid(x): return 0.5 * (jnp.tanh(x / 2.0) + 1) def predict_single(x): return sigmoid(jnp.dot(w, x)) # <(D) , (D)> = (1) # inner product
print(predict_single(X[0, :])) # works
0.23207036
print(predict_single(X)) # fails
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-28-853aeac052fc> in <module>() 1 ----> 2 print(predict_single(X)) # fails <ipython-input-26-e29aadf4f504> in predict_single(x) 9 10 def predict_single(x): ---> 11 return sigmoid(jnp.dot(w, x)) # <(D) , (D)> = (1) # inner product 12 /usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in dot(a, b, precision) 4194 return lax.mul(a, b) 4195 if _max(a_ndim, b_ndim) <= 2: -> 4196 return lax.dot(a, b, precision=precision) 4197 4198 if b_ndim == 1: /usr/local/lib/python3.7/dist-packages/jax/_src/lax/lax.py in dot(lhs, rhs, precision, preferred_element_type) 666 else: 667 raise TypeError("Incompatible shapes for dot: got {} and {}.".format( --> 668 lhs.shape, rhs.shape)) 669 670 TypeError: Incompatible shapes for dot: got (2,) and (3, 2).

We can manually vectorize the code by remembering the shapes, so XwX w multiplies each row of XX with ww.

def predict_batch(X): return sigmoid(jnp.dot(X, w)) # (N,D) * (D,1) = (N,1) # matrix-vector multiply print(predict_batch(X))
[0.232 0.121 0.878]

But it easier to use vmap.

print(vmap(predict_single)(X))
[0.232 0.121 0.878]

Failure cases

Vmap requires that the shapes of all the variables that are created by the function that is being mapped are the same for all values of the input arguments, as explained here. So vmap cannot be used to do any kind of embarassingly parallel task. Below we give a simple example of where this fails, since internally we create a vector whose length depends on the input 'length'.

def example_fun(length, val=4): return jnp.sum(jnp.ones((length,)) * val) xs = jnp.arange(1, 10) # Python map works fine v = list(map(example_fun, xs)) print(v)
[DeviceArray(4., dtype=float32), DeviceArray(8., dtype=float32), DeviceArray(12., dtype=float32), DeviceArray(16., dtype=float32), DeviceArray(20., dtype=float32), DeviceArray(24., dtype=float32), DeviceArray(28., dtype=float32), DeviceArray(32., dtype=float32), DeviceArray(36., dtype=float32)]

The following fails.

v = vmap(example_fun)(xs) print(v)
--------------------------------------------------------------------------- UnfilteredStackTrace Traceback (most recent call last) <ipython-input-32-7fe93a3635da> in <module>() ----> 1 v = vmap(example_fun)(xs) 2 print(v) /usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs) 161 try: --> 162 return fun(*args, **kwargs) 163 except Exception as e: /usr/local/lib/python3.7/dist-packages/jax/_src/api.py in batched_fun(*args, **kwargs) 1286 lambda: flatten_axes("vmap out_axes", out_tree(), out_axes) -> 1287 ).call_wrapped(*args_flat) 1288 return tree_unflatten(out_tree(), out_flat) /usr/local/lib/python3.7/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs) 165 try: --> 166 ans = self.f(*args, **dict(self.params, **kwargs)) 167 except: <ipython-input-31-4138480dd486> in example_fun(length, val) 1 def example_fun(length, val=4): ----> 2 return jnp.sum(jnp.ones((length,)) * val) 3 /usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in ones(shape, dtype) 3187 shape = (shape,) if ndim(shape) == 0 else shape -> 3188 return lax.full(shape, 1, dtype) 3189 /usr/local/lib/python3.7/dist-packages/jax/_src/lax/lax.py in full(shape, fill_value, dtype) 1594 """ -> 1595 shape = canonicalize_shape(shape) 1596 if np.shape(fill_value): /usr/local/lib/python3.7/dist-packages/jax/core.py in canonicalize_shape(shape, context) 1440 pass -> 1441 raise _invalid_shape_error(shape, context) 1442 UnfilteredStackTrace: TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with val = DeviceArray([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32) batch_dim = 0,). If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions. The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: TypeError Traceback (most recent call last) <ipython-input-32-7fe93a3635da> in <module>() ----> 1 v = vmap(example_fun)(xs) 2 print(v) <ipython-input-31-4138480dd486> in example_fun(length, val) 1 def example_fun(length, val=4): ----> 2 return jnp.sum(jnp.ones((length,)) * val) 3 4 xs = jnp.arange(1,10) 5 /usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in ones(shape, dtype) 3186 dtype = float_ if dtype is None else dtype 3187 shape = (shape,) if ndim(shape) == 0 else shape -> 3188 return lax.full(shape, 1, dtype) 3189 3190 TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with val = DeviceArray([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32) batch_dim = 0,). If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.

Stochastics

JAX is designed to be deterministic, but in some cases, we want to introduce randomness in a controlled way, and to reason about it. We discuss this below

Random number generation

One of the biggest differences from NumPy is the way Jax treates pseudo random number generation (PRNG). This is because Jax does not maintain any global state, i.e., it is purely functional. This design "provides reproducible results invariant to compilation boundaries and backends, while also maximizing performance by enabling vectorized generation and parallelization across random calls" (to quote the official page).

For example, consider this Numpy snippet. Each call to np.random.uniform updates the global state. The value of foo() is therefore only guaranteed to give the same result every time if we evaluate bar() and baz() in the same order (eg left to right). This is why foo1 and foo2 give different answers even though mathematically they shouldn't (we cannot just substitute in the value of a variable and derive the result, so we are violating "referential transparency").

import numpy as np def bar(): return np.random.uniform(size=(3)) def baz(): return np.random.uniform(size=(3)) def foo(seed): np.random.seed(seed) return bar() + 2 * baz() def foo1(seed): np.random.seed(seed) a = bar() b = 2 * baz() return a + b def foo2(seed): np.random.seed(seed) a = 2 * baz() b = bar() return a + b seed = 0 print(foo(seed)) print(foo1(seed)) print(foo2(seed))
[1.639 1.562 1.895] [1.639 1.562 1.895] [1.643 1.854 1.851]

Jax may evaluate parts of expressions such as bar() + baz() in parallel, which would violate reproducibility. To prevent this, the user must pass in an explicit PRNG key to every function that requires a source of randomness. Using the same key will give the same results. See the example below.

key = random.PRNGKey(0) print(random.normal(key, shape=(3,))) # [ 1.81608593 -0.48262325 0.33988902] print(random.normal(key, shape=(3,))) # [ 1.81608593 -0.48262325 0.33988902] ## identical results
[ 1.816 -0.483 0.34 ] [ 1.816 -0.483 0.34 ]

When generating independent samples, it is important to use different keys, to ensure results are not correlated. We can do this by splitting the key into the the 'master' key (which will be used in later parts of the code via splitting), and the 'subkey', which is used temporarily to generate randomness and then thrown away, as we illustrate below.

# To make a new key, we split the current key into two pieces. key, subkey = random.split(key) print(random.normal(subkey, shape=(3,))) # [ 1.1378783 -1.22095478 -0.59153646] # We can continue to split off new pieces from the global key. key, subkey = random.split(key) print(random.normal(subkey, shape=(3,))) # [-0.06607265 0.16676566 1.17800343]
[ 1.138 -1.221 -0.592] [-0.066 0.167 1.178]

We now reimplement the numpy example in Jax and show that we get the result no matter the order of evaluation of bar and baz.

def bar(key): return jax.random.uniform(key, shape=(3,)) def baz(key): return jax.random.uniform(key, shape=(3,)) def foo(key): subkey1, subkey2 = random.split(key, num=2) return bar(subkey1) + 2 * baz(subkey2) def foo1(key): subkey1, subkey2 = random.split(key, num=2) a = bar(subkey1) b = 2 * baz(subkey2) return a + b def foo2(key): subkey1, subkey2 = random.split(key, num=2) a = 2 * baz(subkey2) b = bar(subkey1) return a + b key = random.PRNGKey(0) key, subkey = random.split(key) print(foo(subkey)) print(foo1(subkey)) print(foo2(subkey))
[2.079 2.002 1.089] [2.079 2.002 1.089] [2.079 2.002 1.089]

In Jax (but not in python), a random draw of N samples in parallel will not give the same results as N draws of individual samples, as we show below.

key = random.PRNGKey(42) subkeys = random.split(key, 3) sequence = np.stack([jax.random.normal(subkey) for subkey in subkeys]) print("individually:", sequence) key = random.PRNGKey(42) print("all at once: ", jax.random.normal(key, shape=(3,)))
individually: [-0.048 0.108 -1.223] all at once: [ 0.187 -1.281 -1.559]
np.random.seed(0) sequence = np.stack([np.random.normal() for i in range(3)]) print("individually:", sequence) np.random.seed(0) print("all at once: ", np.random.normal(size=(3,)))
individually: [1.764 0.4 0.979] all at once: [1.764 0.4 0.979]

Probability distributions

The distrax library is a JAX-native implementation of some parts of the distrbitions library from Tensorflow Probabilty (TFP). The main advantage is that the distrax source code is much easier to read and understand. For distributions not in distrax, it is possible to use TFP instead.

Here is a brief example.

%%capture !pip install git+git://github.com/deepmind/distrax.git
import distrax import jax import jax.numpy as jnp from tensorflow_probability.substrates import jax as tfp tfd = tfp.distributions key = jax.random.PRNGKey(1234) mu = jnp.array([-1.0, 0.0, 1.0]) sigma = jnp.array([0.1, 0.2, 0.3]) dist_distrax = distrax.MultivariateNormalDiag(mu, sigma) dist_tfp = tfd.MultivariateNormalDiag(mu, sigma) samples = dist_distrax.sample(seed=key) # Both print 1.775 print(dist_distrax.log_prob(samples)) print(dist_tfp.log_prob(samples))
1.7750063 1.7750063

Autograd

In this section, we illustrate automatic differentation using JAX. For details, see see this video or The Autodiff Cookbook.

Derivatives

We can compute (f)(x)(\nabla f)(x) using grad(f)(x). For example, consider

f(x)=x3+2x23x+1f(x) = x^3 + 2x^2 - 3x + 1

f(x)=3x2+4x3f'(x) = 3x^2 + 4x -3

f(x)=6x+4f''(x) = 6x + 4

f(x)=6f'''(x) = 6

fiv(x)=0f^{iv}(x) = 0

f = lambda x: x**3 + 2 * x**2 - 3 * x + 1 dfdx = jax.grad(f) d2fdx = jax.grad(dfdx) d3fdx = jax.grad(d2fdx) d4fdx = jax.grad(d3fdx) print(dfdx(1.0)) print(d2fdx(1.0)) print(d3fdx(1.0)) print(d4fdx(1.0))
4.0 10.0 6.0 0.0

Partial derivatives

f(x,y)=x2+yfx=2xfy=1\begin{align} f(x,y) &= x^2 + y \\ \frac{\partial f}{\partial x} &= 2x \\ \frac{\partial f}{\partial y} &= 1 \end{align}
def f(x, y): return x**2 + y # Partial derviatives x = 2.0 y = 3.0 v, gx = value_and_grad(f, argnums=0)(x, y) print(v) print(gx) gy = grad(f, argnums=1)(x, y) print(gy)
7.0 4.0 1.0

Gradients

Linear function: multi-input, scalar output.

f(x;a)=aTxxf(x;a)=a\begin{align} f(x; a) &= a^T x\\ \nabla_x f(x;a) &= a \end{align}
def fun1d(x): return jnp.dot(a, x)[0] Din = 3 Dout = 1 a = np.random.normal(size=(Dout, Din)) x = np.random.normal(size=(Din,)) g = grad(fun1d)(x) assert np.allclose(g, a) # It is often useful to get the function value and gradient at the same time val_grad_fn = jax.value_and_grad(fun1d) v, g = val_grad_fn(x) print(v) print(g) assert np.allclose(v, fun1d(x)) assert np.allclose(a, g)
1.9472518 [ 2.241 1.868 -0.977]

Linear function: multi-input, multi-output.

f(x;A)=Axf(x;A)x=A\begin{align} f(x;A) &= A x \\ \frac{\partial f(x;A)}{\partial x} &= A \end{align}
# We construct a multi-output linear function. # We check forward and reverse mode give same Jacobians. def fun(x): return jnp.dot(A, x) Din = 3 Dout = 4 A = np.random.normal(size=(Dout, Din)) x = np.random.normal(size=(Din,)) Jf = jacfwd(fun)(x) Jr = jacrev(fun)(x) assert np.allclose(Jf, Jr) assert np.allclose(Jf, A)

Quadratic form.

f(x;A)=xTAxxf(x;A)=(A+AT)x\begin{align} f(x;A) &= x^T A x \\ \nabla_x f(x;A) &= (A+A^T) x \end{align}
D = 4 A = np.random.normal(size=(D, D)) x = np.random.normal(size=(D,)) quadfun = lambda x: jnp.dot(x, jnp.dot(A, x)) g = grad(quadfun)(x) assert np.allclose(g, jnp.dot(A + A.T, x))

Chain rule applied to sigmoid function.

μ(x;w)=σ(wTx)wμ(x;w)=σ(wTx)xσ(a)=σ(a)(1σ(a))\begin{align} \mu(x;w) &=\sigma(w^T x) \\ \nabla_w \mu(x;w) &= \sigma'(w^T x) x \\ \sigma'(a) &= \sigma(a) * (1-\sigma(a)) \end{align}
D = 4 w = np.random.normal(size=(D,)) x = np.random.normal(size=(D,)) y = 0 def sigmoid(x): return 0.5 * (jnp.tanh(x / 2.0) + 1) def mu(w): return sigmoid(jnp.dot(w, x)) def deriv_mu(w): return mu(w) * (1 - mu(w)) * x deriv_mu_jax = grad(mu) print(deriv_mu(w)) print(deriv_mu_jax(w)) assert np.allclose(deriv_mu(w), deriv_mu_jax(w), atol=1e-3)
[-0.13 -0.017 -0.072 0.031] [-0.13 -0.017 -0.072 0.031]

Auxiliary return values

A function can return its value and other auxiliary results; the latter are not differentiated.

def f(x, y): return x**2 + y, 42 x = 2.0 y = 3.0 (v, aux), g = value_and_grad(f, has_aux=True)(x, y) print(v) print(aux) print(g)
7.0 42 4.0

Jacobians

Example: Linear function: multi-input, multi-output.

f(x;A)=Axf(x;A)x=A\begin{align} f(x;A) &= A x \\ \frac{\partial f(x;A)}{\partial x} &= A \end{align}
# We construct a multi-output linear function. # We check forward and reverse mode give same Jacobians. def fun(x): return jnp.dot(A, x) Din = 3 Dout = 4 A = np.random.normal(size=(Dout, Din)) x = np.random.normal(size=(Din,)) Jf = jacfwd(fun)(x) Jr = jacrev(fun)(x) assert np.allclose(Jf, Jr)

Hessians

Quadratic form.

f(x;A)=xTAxx2f(x;A)=A+AT\begin{align} f(x;A) &= x^T A x \\ \nabla_x^2 f(x;A) &= A + A^T \end{align}
D = 4 A = np.random.normal(size=(D, D)) x = np.random.normal(size=(D,)) quadfun = lambda x: jnp.dot(x, jnp.dot(A, x)) H1 = hessian(quadfun)(x) assert np.allclose(H1, A + A.T) def my_hessian(fun): return jacfwd(jacrev(fun)) H2 = my_hessian(quadfun)(x) assert np.allclose(H1, H2)

Example: Binary logistic regression

def sigmoid(x): return 0.5 * (jnp.tanh(x / 2.0) + 1) def predict_single(w, x): return sigmoid(jnp.dot(w, x)) # <(D) , (D)> = (1) # inner product def predict_batch(w, X): return sigmoid(jnp.dot(X, w)) # (N,D) * (D,1) = (N,1) # matrix-vector multiply # negative log likelihood def loss(weights, inputs, targets): preds = predict_batch(weights, inputs) logprobs = jnp.log(preds) * targets + jnp.log(1 - preds) * (1 - targets) return -jnp.sum(logprobs) D = 2 N = 3 w = jax.random.normal(key, shape=(D,)) X = jax.random.normal(key, shape=(N, D)) y = jax.random.choice(key, 2, shape=(N,)) # uniform binary labels # logits = jnp.dot(X, w) # y = jax.random.categorical(key, logits) print(loss(w, X, y)) # Gradient function grad_fun = grad(loss) # Gradient of each example in the batch - 2 different ways grad_fun_w = partial(grad_fun, w) grads = vmap(grad_fun_w)(X, y) print(grads) assert grads.shape == (N, D) grads2 = vmap(grad_fun, in_axes=(None, 0, 0))(w, X, y) assert np.allclose(grads, grads2) # Gradient for entire batch grad_sum = jnp.sum(grads, axis=0) assert grad_sum.shape == (D,) print(grad_sum)
2.384371 [[ 0.596 -0.621] [ 0.315 0.837] [-0.233 -0.448]] [ 0.677 -0.232]
# Textbook implementation of gradient def NLL_grad(weights, batch): X, y = batch N = X.shape[0] mu = predict_batch(weights, X) g = jnp.sum(jnp.dot(jnp.diag(mu - y), X), axis=0) return g grad_sum_batch = NLL_grad(w, (X, y)) print(grad_sum_batch) assert np.allclose(grad_sum, grad_sum_batch)
[ 0.677 -0.232]
# We can also compute Hessians, as we illustrate below. hessian_fun = hessian(loss) # Hessian on one example H0 = hessian_fun(w, X[0, :], y[0]) print("Hessian(example 0)\n{}".format(H0)) # Hessian for batch Hbatch = vmap(hessian_fun, in_axes=(None, 0, 0))(w, X, y) print("Hbatch shape {}".format(Hbatch.shape)) Hbatch_sum = jnp.sum(Hbatch, axis=0) print("Hbatch sum\n {}".format(Hbatch_sum))
Hessian(example 0) [[ 0.185 -0.193] [-0.193 0.201]] Hbatch shape (3, 2, 2) Hbatch sum [[0.357 0.225] [0.225 1.239]]
# Textbook implementation of Hessian def NLL_hessian(weights, batch): X, y = batch mu = predict_batch(weights, X) S = jnp.diag(mu * (1 - mu)) H = jnp.dot(jnp.dot(X.T, S), X) return H H2 = NLL_hessian(w, (X, y)) assert np.allclose(Hbatch_sum, H2, atol=1e-2)

Vector Jacobian Products (VJP) and Jacobian Vector Products (JVP)

Suppose f1(x)=Wxf1(x) = W x for fixed W, so J(x)=WJ(x) = W, and uTJ(x)=WTuu^T J(x) = W^T u. Instead of computing JJ explicitly and then multiplying by uu, wecan do this in one operation.

n = 3 m = 2 W = jax.random.normal(key, shape=(m, n)) x = jax.random.normal(key, shape=(n,)) u = jax.random.normal(key, shape=(m,)) def f1(x): return jnp.dot(W, x) J1 = jacfwd(f1)(x) print(J1.shape) assert np.allclose(J1, W) tmp1 = jnp.dot(u.T, J1) print(tmp1) (val, jvp_fun) = jax.vjp(f1, x) tmp2 = jvp_fun(u) assert np.allclose(tmp1, tmp2) tmp3 = np.dot(W.T, u) assert np.allclose(tmp1, tmp3)
(2, 3) [ 0.888 -0.538 -0.539]

Suppose f2(W)=Wxf2(W) = W x for fixed xx. Now J(W)=something complexJ(W) = \text{something complex}, but uTJ(W)=J(W)Tu=uxTu^T J(W) = J(W)^T u = u x^T.

def f2(W): return jnp.dot(W, x) J2 = jacfwd(f2)(W) print(J2.shape) tmp1 = jnp.dot(u.T, J2) print(tmp1) print(tmp1.shape) (val, jvp_fun) = jax.vjp(f2, W) tmp2 = jvp_fun(u) assert np.allclose(tmp1, tmp2) tmp3 = np.outer(u, x) assert np.allclose(tmp1, tmp3)
(2, 2, 3) [[-0.009 -0.032 -0.474] [ 0.005 0.019 0.286]] (2, 3)

Stop-gradient

Sometimes we want to take the gradient of a complex expression wrt some parameters θ\theta, but treating θ\theta as a constant for some parts of the expression. For example, consider the TD(0) update in reinforcement learning, which has the following form:

Δθ=(rt+vθ(st)vθ(st1))vθ(st1)\Delta \theta = (r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})) \nabla v_{\theta}(s_{t-1})

where ss is the state, rr is the reward, and vv is the value function. This update is not the gradient of any loss function. However, if the dependency of the target rt+vθ(st)r_t + v_{\theta}(s_t) on the parameter θ\theta is ignored, it can be written as the gradient of the pseudo loss function

L(θ)=[rt+vθ(st)vθ(st1)]2L(\theta) = [r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})]^2

since

θL(θ)=2[rt+vθ(st)vθ(st1)]vθ(st1)\nabla_{\theta} L(\theta) = 2 [r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})] \nabla v_{\theta}(s_{t-1})

. We can implement this in JAX using stop_gradient, as we show below.

def td_loss(theta, s_prev, r_t, s_t): v_prev = value_fn(theta, s_prev) target = r_t + value_fn(theta, s_t) return 0.5 * (jax.lax.stop_gradient(target) - v_prev) ** 2 td_update = jax.grad(td_loss) # An example transition. s_prev = jnp.array([1.0, 2.0, -1.0]) r_t = jnp.array(1.0) s_t = jnp.array([2.0, 1.0, 0.0]) # Value function and initial parameters value_fn = lambda theta, state: jnp.dot(theta, state) theta = jnp.array([0.1, -0.1, 0.0]) print(td_update(theta, s_prev, r_t, s_t))
[-1.2 -2.4 1.2]

Straight through estimator

The straight-through estimator is a trick for defining a 'gradient' of a function that is otherwise non-differentiable. Given a non-differentiable function f:RnRnf : \mathbb{R}^n \to \mathbb{R}^n that is used as part of a larger function that we wish to find a gradient of, we simply pretend during the backward pass that ff is the identity function, so gradients pass through ff ignoring the ff' term. This can be implemented neatly using jax.lax.stop_gradient.

Here is an example of a non-differentiable function that converts a soft probability distribution to a one-hot vector (discretization).

def onehot(labels, num_classes): y = labels[..., None] == jnp.arange(num_classes)[None] return y.astype(jnp.float32) def quantize(y_soft): y_hard = onehot(jnp.argmax(y_soft), 3)[0] return y_hard y_soft = np.array([0.1, 0.2, 0.7]) print(quantize(y_soft))
[0. 0. 1.]

Now suppose we define some linear function of the quantized variable of the form f(y)=wTq(y)f(y) = w^T q(y). If w=[1,2,3]w=[1,2,3] and q(y)=[0,0,1]q(y)=[0,0,1], we get f(y)=3f(y) = 3. But the gradient is 0 because qq is not differentiable.

def f(y): w = jnp.array([1, 2, 3]) yq = quantize(y) return jnp.dot(w, yq) print(f(y_soft)) print(grad(f)(y_soft))
3.0 [0. 0. 0.]

To use the straight-through estimator, we replace q(y)q(y) with y+SG(q(y)y)y + SG(q(y)-y), where SG is stop gradient. In the forwards pass, we have y+q(y)y=q(y)y+q(y)-y=q(y). In the backwards pass, the gradient of SG is 0, so we effectively replace q(y)q(y) with yy. So in the backwarsd pass we have f(y)=wTq(y)wTyyf(y)w \begin{align} f(y) &= w^T q(y) \approx w^T y \\ \nabla_y f(y) &\approx w \end{align}

def f_ste(y): w = jnp.array([1, 2, 3]) yq = quantize(y) yy = y + jax.lax.stop_gradient(yq - y) # gives yq on fwd, and y on backward return jnp.dot(w, yy) print(f_ste(y_soft)) print(grad(f_ste)(y_soft))
3.0 [1. 2. 3.]

Per-example gradients

In some applications, we want to compute the gradient for every example in a batch, not just the sum of gradients over the batch. This is hard in other frameworks like TF and PyTorch but easy in JAX, as we show below.

def loss(w, x): return jnp.dot(w, x) w = jnp.ones((3,)) x0 = jnp.array([1.0, 2.0, 3.0]) x1 = 2 * x0 X = jnp.stack([x0, x1]) print(X.shape) perex_grads = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0))) print(perex_grads(w, X))
(2, 3) [[1. 2. 3.] [2. 4. 6.]]

To explain the above code in more depth, note that the vmap converts the function loss to take a batch of inputs for each of its arguments, and returns a batch of outputs. To make it work with a single weight vector, we specify in_axes=(None,0), meaning the first argument (w) is not replicated, and the second argument (x) is replicated along dimension 0.

gradfn = jax.grad(loss) W = jnp.stack([w, w]) print(jax.vmap(gradfn)(W, X)) print(jax.vmap(gradfn, in_axes=(None, 0))(w, X))
[[1. 2. 3.] [2. 4. 6.]] [[1. 2. 3.] [2. 4. 6.]]

Optimization

The Optax library implements many common optimizers. Below is a simple example.

%%capture !pip install git+git://github.com/deepmind/optax.git import optax
num_weights = 2 params = {"w": jnp.ones((num_weights,))} num_ex = 3 xs = 2 * jnp.ones((num_ex, num_weights)) ys = jnp.ones(num_ex) compute_loss_single = lambda params, x, y: optax.l2_loss(params["w"].dot(x), y) compute_loss = lambda params, xs, ys: jnp.sum(jax.vmap(compute_loss_single, in_axes=[None, 0, 0])(params, xs, ys)) print("original params ", params) print("loss ", compute_loss(params, xs, ys)) # create a stateful optimizer learning_rate = 0.1 optimizer = optax.adam(learning_rate) opt_state = optimizer.init(params) print("original state ", opt_state) # compute gradients grads = jax.grad(compute_loss)(params, xs, ys) print("grads ", grads) # updage params (and optstate) given gradients updates, opt_state = optimizer.update(grads, opt_state) params = optax.apply_updates(params, updates) print("updated params ", params) print("updated state ", opt_state)
original params {'w': DeviceArray([1., 1.], dtype=float32)} loss 13.5 original state [ScaleByAdamState(count=DeviceArray(0, dtype=int32), mu={'w': DeviceArray([0., 0.], dtype=float32)}, nu={'w': DeviceArray([0., 0.], dtype=float32)}), EmptyState()] grads {'w': DeviceArray([18., 18.], dtype=float32)} updated params {'w': DeviceArray([0.9, 0.9], dtype=float32)} updated state [ScaleByAdamState(count=DeviceArray(1, dtype=int32), mu={'w': DeviceArray([1.8, 1.8], dtype=float32)}, nu={'w': DeviceArray([0.324, 0.324], dtype=float32)}), EmptyState()]

JIT (just in time compilation)

In this section, we illustrate how to use the Jax JIT compiler to make code go much faster (even on a CPU). It does this by compiling the computational graph into low-level XLA primitives, potentially fusing multiple sequential operations into a single op. However, it does not work on arbitrary Python code, as we explain below.

def slow_f(x): # Element-wise ops see a large benefit from fusion return x * x + x * 2.0 x = jnp.ones((5000, 5000)) %timeit slow_f(x) fast_f = jit(slow_f) %timeit fast_f(x) assert np.allclose(slow_f(x), fast_f(x))
100 loops, best of 5: 3.11 ms per loop The slowest run took 15.74 times longer than the fastest. This could mean that an intermediate result is being cached. 1000 loops, best of 5: 918 µs per loop

We can also add the @jit decorator in front of a function.

@jit def faster_f(x): return x * x + x * 2.0 %timeit faster_f(x) assert np.allclose(faster_f(x), fast_f(x))
The slowest run took 20.30 times longer than the fastest. This could mean that an intermediate result is being cached. 1000 loops, best of 5: 918 µs per loop

How it works: Jaxprs and tracing

In this section, we briefly explain the mechanics behind JIT, which will help you understand when it does not work.

First, consider this function.

def f(x): y = jnp.ones((1, 5)) * x return y

When a function is first executed (applied to an argument), it is converted to an intermediate representatio called a JAX expression or jaxpr, by a process called tracing, as we show below.

print(f(3.0)) print(jax.make_jaxpr(f)(3.0))
[[3. 3. 3. 3. 3.]] { lambda ; a. let b = broadcast_in_dim[ broadcast_dimensions=( ) shape=(1, 5) ] 1.0 c = convert_element_type[ new_dtype=float32 weak_type=False ] a d = mul b c in (d,) }

The XLA JIT compiler can then convert the jaxpr to code that runs fast on a CPU, GPU or TPU; the original python code is no longer needed.

f_jit = jit(f) print(f_jit(3.0))
[[3. 3. 3. 3. 3.]]

However, the jaxpr is created by tracing the function for a specific value. If different code is executed depending on the value of the input arguments, the resulting jaxpr will be different, so the function cannot be JITed, as we illustrate below.

def f(x): if x > 0: return x else: return 2 * x print(f(3.0)) f_jit = jit(f) print(f_jit(3.0))
3.0
--------------------------------------------------------------------------- UnfilteredStackTrace Traceback (most recent call last) <ipython-input-91-545a2f514f83> in <module>() 10 f_jit = jit(f) ---> 11 print(f_jit(3.0)) /usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs) 161 try: --> 162 return fun(*args, **kwargs) 163 except Exception as e: /usr/local/lib/python3.7/dist-packages/jax/_src/api.py in cache_miss(*args, **kwargs) 407 device=device, backend=backend, name=flat_fun.__name__, --> 408 donated_invars=donated_invars, inline=inline) 409 out_pytree_def = out_tree() /usr/local/lib/python3.7/dist-packages/jax/core.py in bind(self, fun, *args, **params) 1613 def bind(self, fun, *args, **params): -> 1614 return call_bind(self, fun, *args, **params) 1615 /usr/local/lib/python3.7/dist-packages/jax/core.py in call_bind(primitive, fun, *args, **params) 1604 tracers = map(top_trace.full_raise, args) -> 1605 outs = primitive.process(top_trace, fun, tracers, params) 1606 return map(full_lower, apply_todos(env_trace_todo(), outs)) /usr/local/lib/python3.7/dist-packages/jax/core.py in process(self, trace, fun, tracers, params) 1616 def process(self, trace, fun, tracers, params): -> 1617 return trace.process_call(self, fun, tracers, params) 1618 /usr/local/lib/python3.7/dist-packages/jax/core.py in process_call(self, primitive, f, tracers, params) 612 def process_call(self, primitive, f, tracers, params): --> 613 return primitive.impl(f, *tracers, **params) 614 process_map = process_call /usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py in _xla_call_impl(***failed resolving arguments***) 619 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars, --> 620 *unsafe_map(arg_spec, args)) 621 try: /usr/local/lib/python3.7/dist-packages/jax/linear_util.py in memoized_fun(fun, *args) 261 else: --> 262 ans = call(fun, *args) 263 cache[key] = (ans, fun.stores) /usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs) 696 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final( --> 697 fun, abstract_args, pe.debug_info_final(fun, "jit")) 698 if any(isinstance(c, core.Tracer) for c in consts): /usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals, debug_info) 1284 with core.new_sublevel(): -> 1285 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals) 1286 del fun, main /usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals) 1262 in_tracers = map(trace.new_arg, in_avals) -> 1263 ans = fun.call_wrapped(*in_tracers) 1264 out_tracers = map(trace.full_raise, ans) /usr/local/lib/python3.7/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs) 165 try: --> 166 ans = self.f(*args, **dict(self.params, **kwargs)) 167 except: <ipython-input-91-545a2f514f83> in f(x) 2 def f(x): ----> 3 if x > 0: 4 return x /usr/local/lib/python3.7/dist-packages/jax/core.py in __bool__(self) 541 def __nonzero__(self): return self.aval._nonzero(self) --> 542 def __bool__(self): return self.aval._bool(self) 543 def __int__(self): return self.aval._int(self) /usr/local/lib/python3.7/dist-packages/jax/core.py in error(self, arg) 986 def error(self, arg): --> 987 raise ConcretizationTypeError(arg, fname_context) 988 return error UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> The problem arose with the `bool` function. While tracing the function f at <ipython-input-91-545a2f514f83>:2 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: ConcretizationTypeError Traceback (most recent call last) <ipython-input-91-545a2f514f83> in <module>() 9 10 f_jit = jit(f) ---> 11 print(f_jit(3.0)) <ipython-input-91-545a2f514f83> in f(x) 1 2 def f(x): ----> 3 if x > 0: 4 return x 5 else: ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> The problem arose with the `bool` function. While tracing the function f at <ipython-input-91-545a2f514f83>:2 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

Jit will create a new compiled version for each different ShapedArray, but will reuse the code for different values of the same shape. If the code path depends on the concrete value, we can either just jit a subfunction (whose code path is constant), or we can create a different jaxpr for each concrete value of the input arguments as we explain below.

Static argnum

Note that JIT compilation requires that the control flow through the function can be determined by the shape (but not concrete value) of its inputs. The function below violates this, since when x<0, it takes one branch, whereas when x>0, it takes the other.

@jit def f(x): if x > 0: return x else: return 2 * x # This will fail! try: print(f(3)) except Exception as e: print("ERROR:", e)
ERROR: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> The problem arose with the `bool` function. While tracing the function f at <ipython-input-92-94e6eda28128>:1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

We can fix this by telling JAX to trace the control flow through the function using concrete values of some of its arguments. JAX will then compile different versions, depending on the input values. See below for an example.

def f(x): if x > 0: return x else: return 2 * x f = jit(f, static_argnums=(0,)) print(f(3))
3
@partial(jit, static_argnums=(0,)) def f(x): if x > 0: return x else: return 2 * x print(f(3))
3

Jit and vmap

Unfortunately, the static argnum method fails when the function is passed to vmap, because the latter can take arguments of different shape.

xs = jnp.arange(5) @partial(jit, static_argnums=(0,)) def f(x): if x > 0: return x else: return 2 * x ys = vmap(f)(xs)
--------------------------------------------------------------------------- UnfilteredStackTrace Traceback (most recent call last) <ipython-input-95-00960fe99363> in <module>() 9 ---> 10 ys = vmap(f)(xs) /usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs) 161 try: --> 162 return fun(*args, **kwargs) 163 except Exception as e: /usr/local/lib/python3.7/dist-packages/jax/_src/api.py in batched_fun(*args, **kwargs) 1286 lambda: flatten_axes("vmap out_axes", out_tree(), out_axes) -> 1287 ).call_wrapped(*args_flat) 1288 return tree_unflatten(out_tree(), out_flat) /usr/local/lib/python3.7/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs) 165 try: --> 166 ans = self.f(*args, **dict(self.params, **kwargs)) 167 except: UnfilteredStackTrace: ValueError: Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'jax.interpreters.batching.BatchTracer'>, Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with val = DeviceArray([0, 1, 2, 3, 4], dtype=int32) batch_dim = 0. The error was: TypeError: unhashable type: 'BatchTracer' The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: ValueError Traceback (most recent call last) <ipython-input-95-00960fe99363> in <module>() 8 return 2 * x 9 ---> 10 ys = vmap(f)(xs) /usr/local/lib/python3.7/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs) 164 165 try: --> 166 ans = self.f(*args, **dict(self.params, **kwargs)) 167 except: 168 # Some transformations yield from inside context managers, so we have to ValueError: Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'jax.interpreters.batching.BatchTracer'>, Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with val = DeviceArray([0, 1, 2, 3, 4], dtype=int32) batch_dim = 0. The error was: TypeError: unhashable type: 'BatchTracer'

Side effects

Since the jaxpr is created only once, if your function has global side-effects, such as using print, they will only happen once, even if the function is called multiple times. See example below.

def f(x): print("x", x) y = 2 * x print("y", y) return y y1 = f(2) print("f", y1) print("\ncall function a second time") y1 = f(2) print("f", y1) print("\njit version follows") g = jax.jit(f) y2 = g(2) print("f", y2) print("\ncall jitted function a second time") y2 = g(2) print("f", y2)
x 2 y 4 f 4 call function a second time x 2 y 4 f 4 jit version follows x Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> y Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> f 4 call jitted function a second time f 4

Caching

If you write g=jax.jit(f), then f will get compiled and the XLA code will be cahced. Subsequent calls to g reuse the cached code for speed. But if the jit is called inside a loop, it is effectively making a new f each time, which is slow. So typically jit occurs in the outermost scope (modulo being constant shape).

Also, if you specify static_argnums, then the cached code will be used only for the same values of arguments labelled as static. If any of them change, recompilation occurs.

Strings

Jit does not work with functions that consume or return strings.

def f(x: int, y: str): if y == "add": return x + 1 else: return x - 1 print(f(42, "add")) print(f(42, "sub")) fj = jax.jit(f) print(fj(42, "add"))
43 41
--------------------------------------------------------------------------- UnfilteredStackTrace Traceback (most recent call last) <ipython-input-98-613837d92373> in <module>() 10 fj = jax.jit(f) ---> 11 print(fj(42, 'add')) /usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs) 161 try: --> 162 return fun(*args, **kwargs) 163 except Exception as e: /usr/local/lib/python3.7/dist-packages/jax/_src/api.py in cache_miss(*args, **kwargs) 402 for arg in args_flat: --> 403 _check_arg(arg) 404 flat_fun, out_tree = flatten_fun(f, in_tree) /usr/local/lib/python3.7/dist-packages/jax/_src/api.py in _check_arg(arg) 2484 if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)): -> 2485 raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.") 2486 UnfilteredStackTrace: TypeError: Argument 'add' of type <class 'str'> is not a valid JAX type. The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: TypeError Traceback (most recent call last) <ipython-input-98-613837d92373> in <module>() 9 10 fj = jax.jit(f) ---> 11 print(fj(42, 'add')) /usr/local/lib/python3.7/dist-packages/jax/_src/api.py in _check_arg(arg) 2483 def _check_arg(arg): 2484 if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)): -> 2485 raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.") 2486 2487 # TODO(necula): this duplicates code in core.valid_jaxtype TypeError: Argument 'add' of type <class 'str'> is not a valid JAX type.

Pytrees

A Pytree is a container of leaf elements and/or more pytrees. Containers include lists, tuples, and dicts. A leaf element is anything that’s not a pytree, e.g. an array. Pytrees are useful for representing hierarchical sets of parameters for DNNs (and other structured dsta).

Simple example

from jax import tree_util # a simple pytree t1 = [1, {"k1": 2, "k2": (3, 4)}, 5] print("tree", t1) leaves = jax.tree_leaves(t1) print("num leaves", len(leaves)) print(leaves) t4 = [jnp.array([1, 2, 3]), "foo"] print("tree", t4) leaves = jax.tree_leaves(t4) print("num leaves", len(leaves)) print(leaves)
tree [1, {'k1': 2, 'k2': (3, 4)}, 5] num leaves 5 [1, 2, 3, 4, 5] tree [DeviceArray([1, 2, 3], dtype=int32), 'foo'] num leaves 2 [DeviceArray([1, 2, 3], dtype=int32), 'foo']

Treemap

We can map functions down a pytree in the same way that we can map a function down a list. We can also combine elements in two pytrees that have the same shape to make a third pytree.

t1 = [1, {"k1": 2, "k2": (3, 4)}, 5] print(t1) t2 = tree_util.tree_map(lambda x: x * x, t1) print("square each element", t2) t3 = tree_util.tree_map(lambda x, y: x + y, t1, t2) print("t1+t2", t3)
[1, {'k1': 2, 'k2': (3, 4)}, 5] square each element [1, {'k1': 4, 'k2': (9, 16)}, 25] t1+t2 [2, {'k1': 6, 'k2': (12, 20)}, 30]

If we have a list of dicts, we can convert to a dict of lists, as shown below.

data = [dict(t=1, obs="a", val=-1), dict(t=2, obs="b", val=-2), dict(t=3, obs="c", val=-3)] data2 = jax.tree_map(lambda d0, d1, d2: list((d0, d1, d2)), data[0], data[1], data[2]) print(data2) def join_trees(list_of_trees): d = jax.tree_map(lambda *xs: list(xs), *list_of_trees) return d print(join_trees(data))
{'obs': ['a', 'b', 'c'], 't': [1, 2, 3], 'val': [-1, -2, -3]} {'obs': ['a', 'b', 'c'], 't': [1, 2, 3], 'val': [-1, -2, -3]}

Flattening / Unflattening

t1 = [1, {"k1": 2, "k2": (3, 4)}, 5] print(t1) leaves, treedef = jax.tree_util.tree_flatten(t1) print(leaves) print(treedef) t2 = jax.tree_util.tree_unflatten(treedef, leaves) print(t2)
[1, {'k1': 2, 'k2': (3, 4)}, 5] [1, 2, 3, 4, 5] PyTreeDef([*, {'k1': *, 'k2': (*, *)}, *]) [1, {'k1': 2, 'k2': (3, 4)}, 5]

Example: Linear regression

In this section we show how to use pytrees as a container for parameters of a linear reregression model. The code is based on the flax JAX tutorial. When we compute the gradient, it will also be a pytree, and will have the same shape as the parameters, so we can add the params to the gradient without having to flatten and unflatten the parameters.

# Create the predict function from a set of parameters def make_predict_pytree(params): def predict(x): return jnp.dot(params["W"], x) + params["b"] return predict # Create the loss from the data points set def make_mse_pytree(x_batched, y_batched): # returns fn(params)->real def mse(params): # Define the squared loss for a single pair (x,y) def squared_error(x, y): y_pred = make_predict_pytree(params)(x) return jnp.inner(y - y_pred, y - y_pred) / 2.0 # We vectorize the previous to compute the average of the loss on all samples. return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0) return jax.jit(mse) # And finally we jit the result.
# Set problem dimensions N = 20 xdim = 10 ydim = 5 # Generate random ground truth W and b key = random.PRNGKey(0) Wtrue = random.normal(key, (ydim, xdim)) btrue = random.normal(key, (ydim,)) params_true = {"W": Wtrue, "b": btrue} true_predict_fun = make_predict_pytree(params_true) # Generate data with additional observation noise X = random.normal(key, (N, xdim)) Ytrue = jax.vmap(true_predict_fun)(X) Y = Ytrue + 0.1 * random.normal(key, (N, ydim)) # Generate MSE for our samples mse_fun = make_mse_pytree(X, Y)
# Initialize estimated W and b with zeros. params = {"W": jnp.zeros_like(Wtrue), "b": jnp.zeros_like(btrue)} mse_pytree = make_mse_pytree(X, Y) print(mse_pytree(params_true)) print(mse_pytree(params)) print(jax.grad(mse_pytree)(params))
0.022292053 24.97824 {'W': DeviceArray([[-0.039, 0.755, 0.542, 0.36 , 0.224, 1.651, 1.534, -1.342, -0.15 , -1.638], [-0.324, 0.141, -0.402, 0.498, 1.829, 4.308, 2.138, -2.43 , -0.381, -2.178], [ 1.7 , -0.707, -0.656, -0.568, 1.824, -2.194, -0.477, 0.96 , 1.622, 1.408], [-0.862, 0.321, -0.388, -0.74 , -0.82 , 0.441, 0.772, -1.713, -1.592, -0.557], [ 1.338, -0.632, -0.968, -1.127, 1.775, 0.323, 1.405, -0.638, 1.077, -0.739]], dtype=float32), 'b': DeviceArray([ 0.036, 1.092, -0.413, -1.389, -0.862], dtype=float32)}
alpha = 0.3 # Gradient step size print('Loss for "true" W,b: ', mse_pytree(params_true)) for i in range(101): gradients = jax.grad(mse_pytree)(params) params = jax.tree_map(lambda old, grad: old - alpha * grad, params, gradients) if i % 10 == 0: print("Loss step {}: ".format(i), mse_pytree(params))
Loss for "true" W,b: 0.022292053 Loss step 0: 6.5597453 Loss step 10: 0.17232798 Loss step 20: 0.04339735 Loss step 30: 0.024473602 Loss step 40: 0.017078908 Loss step 50: 0.013489487 Loss step 60: 0.011695375 Loss step 70: 0.010795272 Loss step 80: 0.010343453 Loss step 90: 0.01011666 Loss step 100: 0.010002809
print(jax.tree_map(lambda x, y: np.allclose(x, y, atol=1e-1), params, params_true))
{'W': True, 'b': True}

Compare the above to what the training code would look like if W and b were passed in as separate arguments:

for i in range(101): grad_W = jax.grad(mse_fun,0)(What,bhat) grad_b = jax.grad(mse_fun,1)(What,bhat) What = What - alpha*grad_W bhat = bhat - alpha*grad_b if (i%10==0): print("Loss step {}: ".format(i), mse_fun(What,bhat)

Example: MLPs

We now show a more interesting example, from the Deepmind tutorial, where we fit an MLP using SGD. The basic structure is similar to the linear regression case.

# define the model def init_mlp_params(layer_widths): params = [] for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]): params.append( dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2 / n_in), biases=np.ones(shape=(n_out,))) ) return params def forward(params, x): *hidden, last = params for layer in hidden: x = jax.nn.relu(x @ layer["weights"] + layer["biases"]) return x @ last["weights"] + last["biases"] def loss_fn(params, x, y): return jnp.mean((forward(params, x) - y) ** 2)
# MLP with 2 hidden layers and linear output np.random.seed(0) params = init_mlp_params([1, 128, 128, 1]) jax.tree_map(lambda x: x.shape, params)
[{'biases': (128,), 'weights': (1, 128)}, {'biases': (128,), 'weights': (128, 128)}, {'biases': (1,), 'weights': (128, 1)}]
LEARNING_RATE = 0.0001 @jax.jit def update(params, x, y): grads = jax.grad(loss_fn)(params, x, y) return jax.tree_map(lambda p, g: p - LEARNING_RATE * g, params, grads) np.random.seed(0) xs = np.random.normal(size=(200, 1)) ys = xs**2 for _ in range(1000): params = update(params, xs, ys) plt.scatter(xs, ys, label="truth") plt.scatter(xs, forward(params, xs), label="Prediction") plt.legend()
<matplotlib.legend.Legend at 0x7f91e6397190>
Image in a Jupyter notebook

Looping constructs

For loops in Python are slow, even when JIT-compiled. However, there are built-in primitives for loops that are fast, as we illustrate below.

For loops.

The semantics of the for loop function in JAX is as follows:

def fori_loop(lower, upper, body_fun, init_val): val = init_val for i in range(lower, upper): val = body_fun(i, val) return val

We see that val is used to accumulate the results across iterations.

Below is an example.

# sum from 1 to N = N*(N+1)/2 def sum_exact(N): return int(N * (N + 1) / 2) def sum_slow(N): s = 0 for i in range(1, N + 1): s += i return s N = 10 assert sum_slow(N) == sum_exact(N) def sum_fast(N): s = jax.lax.fori_loop(1, N + 1, lambda i, partial_sum: i + partial_sum, 0) return s assert sum_fast(N) == sum_exact(N)
N = 1000 %timeit sum_slow(N) %timeit sum_fast(N)
10000 loops, best of 5: 54.8 µs per loop 10 loops, best of 5: 21.9 ms per loop
# Let's do more compute per step of the for loop D = 10 X = jax.random.normal(key, shape=(D, D)) def sum_slow(N): s = jnp.zeros_like(X) for i in range(1, N + 1): s += jnp.dot(X, X) return s def sum_fast(N): s = jnp.zeros_like(X) s = jax.lax.fori_loop(1, N + 1, lambda i, s: s + jnp.dot(X, X), s) return s N = 10 assert np.allclose(sum_fast(N), sum_slow(N))
N = 1000 %timeit sum_slow(N) %timeit sum_fast(N)
1 loop, best of 5: 298 ms per loop 10 loops, best of 5: 27.8 ms per loop

While loops

Here is the semantics of the JAX while loop

def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val

Below is an example.

def sum_slow_while(N): s = 0 i = 0 while i <= N: s += i i += 1 return s def sum_fast_while(N): init_val = (0, 0) def cond_fun(val): s, i = val return i <= N def body_fun(val): s, i = val s += i i += 1 return (s, i) val = jax.lax.while_loop(cond_fun, body_fun, init_val) s2 = val[0] return s2 N = 10 assert sum_slow_while(N) == sum_exact(N) assert sum_slow_while(N) == sum_fast_while(N)
N = 1000 %timeit sum_slow(N) %timeit sum_fast(N)
1 loop, best of 5: 312 ms per loop 10 loops, best of 5: 28.3 ms per loop

Scan

Here is the semantics of scan:

def scan(f, init, xs, length=None): if xs is None: xs = [None] * length carry = init ys = [] for x in xs: carry, y = f(carry, x) ys.append(y) return carry, np.stack(ys)

Here is an example where we use scan to sample from a discrete-time, discrete-state Markov chain.

init_dist = jnp.array([0.8, 0.2]) trans_mat = jnp.array([[0.9, 0.1], [0.5, 0.5]]) rng_key = jax.random.PRNGKey(0) from jax.scipy.special import logit seq_len = 15 initial_state = jax.random.categorical(rng_key, logits=logit(init_dist), shape=(1,)) def draw_state(prev_state, key): logits = logit(trans_mat[:, prev_state]) state = jax.random.categorical(key, logits=logits.flatten(), shape=(1,)) return state, state rng_key, rng_state, rng_obs = jax.random.split(rng_key, 3) keys = jax.random.split(rng_state, seq_len - 1) final_state, states = jax.lax.scan(draw_state, initial_state, keys) print(states)
[[0] [0] [0] [0] [0] [0] [0] [1] [1] [1] [1] [1] [1] [1]]

Common gotchas

Handling state

In this section, we discuss how to transform code that uses object-oriented programming (which can be stateful) to pure functional programming, which is stateless, as required by JAX. Our presentation is based on the Deepmind tutorial.

To start, consider a simple class that maintains an internal counter, and when called, increments the counter and returns the next number from some sequence.

# import string # DICTIONARY = list(string.ascii_lowercase) SEQUENCE = jnp.arange(0, 100, 2) class Counter: def __init__(self): self.n = 0 def count(self) -> int: # res = DICTIONARY[self.n] res = SEQUENCE[self.n] self.n += 1 return res def reset(self): self.n = 0 counter = Counter() for _ in range(3): print(counter.count())
0 2 4

The trouble with the above code is that the call to count depends on the internal state of the object (the value n), even though this is not an argument to the function. (The code is therefoe said to violate 'referential transparency'.) When we Jit compile it, Jax will only call the code once (to convert to a jaxpr), so the side effect of updating n will not happen, resulting in incorrect behavior, as we show below,

counter.reset() fast_count = jax.jit(counter.count) for _ in range(3): print(fast_count())
0 0 0

We can solve this problem by passing the state as an argument into the function.

CounterState = int Result = int class CounterV2: def count(self, n: CounterState) -> Tuple[Result, CounterState]: return SEQUENCE[n], n + 1 def reset(self) -> CounterState: return 0 counter = CounterV2() state = counter.reset() for _ in range(3): value, state = counter.count(state) print(value)
0 2 4

This version is functionally pure, so jit-compiles nicely.

state = counter.reset() fast_count = jax.jit(counter.count) for _ in range(3): value, state = fast_count(state) print(value)
0 2 4

We can apply the same process to any stateful method to convert it into a stateless one. We took a class of the form

class StatefulClass state: State def stateful_method(*args, **kwargs) -> Output:

and turned it into a class of the form

class StatelessClass def stateless_method(state: State, *args, **kwargs) -> (Output, State):

This is a common functional programming pattern, and, essentially, is the way that state is handled in all JAX programs (as we saw with the way Jax handles random number state, or parameters of a model that get updated). Note that the stateless version of the code no longer needs to use a class, but can instead group the functions into a common namespace using modules.

In some cases (eg when working with DNNs), it is more convenient to write code in an OO way. There are several libraries (notably Flax and Haiku) that let you define a model in an OO way, and then generate functionally pure code.

Mutation of arrays

Since JAX is functional, you cannot mutate arrays in place, since this makes program analysis and transformation very difficult. JAX requires a pure functional expression of a numerical program. Instead, JAX offers the functional update functions: index_update, index_add, index_min, index_max, and the index helper. These are illustrated below. However it is best to avoid these if possible, since they are slow.

Note: If the input values of index_update aren't reused, jit-compiled code will perform these operations in-place, rather than making a copy.

# You cannot assign directly to elements of an array. A = jnp.zeros((3, 3), dtype=np.float32) # In place update of JAX's array will yield an error! try: A[1, :] = 1.0 except: print("must use index_update")
must use index_update
from jax.ops import index, index_add, index_update D = 3 A = 2 * jnp.ones((D, D)) print("original array:") print(A) A2 = index_update(A, index[1, :], 42.0) # A[1,:] = 42 print("original array:") print(A) # unchanged print("new array:") print(A2) A3 = A.at[1, :].set(42.0) # A3=np.copy(A), A3[1,:] = 42 print("original array:") print(A) # unchanged print("new array:") print(A3) A4 = A.at[1, :].mul(42.0) # A4=np.copy(A), A4[1,:] *= 42 print("original array:") print(A) # unchanged print("new array:") print(A4)
original array: [[2. 2. 2.] [2. 2. 2.] [2. 2. 2.]] original array: [[2. 2. 2.] [2. 2. 2.] [2. 2. 2.]] new array: [[ 2. 2. 2.] [42. 42. 42.] [ 2. 2. 2.]] original array: [[2. 2. 2.] [2. 2. 2.] [2. 2. 2.]] new array: [[ 2. 2. 2.] [42. 42. 42.] [ 2. 2. 2.]] original array: [[2. 2. 2.] [2. 2. 2.] [2. 2. 2.]] new array: [[ 2. 2. 2.] [84. 84. 84.] [ 2. 2. 2.]]

Implicitly casting lists to vectors

You cannot treat a list of numbers as a vector. Instead you must explicitly create the vector using the np.array() constructor.

# You cannot treat a list of numbers as a vector. try: S = jnp.diag([1.0, 2.0, 3.0]) except: print("must convert indices to np.array")
must convert indices to np.array
# Instead you should explicitly construct the vector. S = jnp.diag(jnp.array([1.0, 2.0, 3.0]))