Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/08/autodiff_jax.ipynb
1193 views
Kernel: Python 3

Open In Colab

Automatic differentiation using JAX

In this section, we illustrate automatic differentation using JAX. For details, see see this video or The Autodiff Cookbook.

# Standard Python libraries from __future__ import absolute_import, division, print_function, unicode_literals from functools import partial import os import time import numpy as np np.set_printoptions(precision=3) import glob import matplotlib.pyplot as plt import PIL import imageio from typing import Tuple, NamedTuple from IPython import display %matplotlib inline import sklearn
# Load JAX import jax import jax.numpy as jnp from jax import random, vmap, jit, grad, value_and_grad, hessian, jacfwd, jacrev print("jax version {}".format(jax.__version__)) # Check the jax backend print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform)) key = random.PRNGKey(0)
jax version 0.2.9 jax backend gpu

Derivatives

We can compute (f)(x)(\nabla f)(x) using grad(f)(x). For example, consider

f(x)=x3+2x23x+1f(x) = x^3 + 2x^2 - 3x + 1

f(x)=3x2+4x3f'(x) = 3x^2 + 4x -3

f(x)=6x+4f''(x) = 6x + 4

f(x)=6f'''(x) = 6

fiv(x)=0f^{iv}(x) = 0

f = lambda x: x**3 + 2 * x**2 - 3 * x + 1 dfdx = jax.grad(f) d2fdx = jax.grad(dfdx) d3fdx = jax.grad(d2fdx) d4fdx = jax.grad(d3fdx) print(dfdx(1.0)) print(d2fdx(1.0)) print(d3fdx(1.0)) print(d4fdx(1.0))
4.0 10.0 6.0 0.0

Partial derivatives

f(x,y)=x2+yfx=2xfy=1\begin{align} f(x,y) &= x^2 + y \\ \frac{\partial f}{\partial x} &= 2x \\ \frac{\partial f}{\partial y} &= 1 \end{align}
def f(x, y): return x**2 + y # Partial derviatives x = 2.0 y = 3.0 v, gx = value_and_grad(f, argnums=0)(x, y) print(v) print(gx) gy = grad(f, argnums=1)(x, y) print(gy)
7.0 4.0 1.0

Gradients

Linear function: multi-input, scalar output.

f(x;a)=aTxxf(x;a)=a\begin{align} f(x; a) &= a^T x\\ \nabla_x f(x;a) &= a \end{align}
def fun1d(x): return jnp.dot(a, x)[0] Din = 3 Dout = 1 a = np.random.normal(size=(Dout, Din)) x = np.random.normal(size=(Din,)) g = grad(fun1d)(x) assert np.allclose(g, a) # It is often useful to get the function value and gradient at the same time val_grad_fn = jax.value_and_grad(fun1d) v, g = val_grad_fn(x) print(v) print(g) assert np.allclose(v, fun1d(x)) assert np.allclose(a, g)
-1.0599848 [-1.311 0.546 0.915]

Linear function: multi-input, multi-output.

f(x;A)=Axf(x;A)x=A\begin{align} f(x;A) &= A x \\ \frac{\partial f(x;A)}{\partial x} &= A \end{align}
# We construct a multi-output linear function. # We check forward and reverse mode give same Jacobians. def fun(x): return jnp.dot(A, x) Din = 3 Dout = 4 A = np.random.normal(size=(Dout, Din)) x = np.random.normal(size=(Din,)) Jf = jacfwd(fun)(x) Jr = jacrev(fun)(x) assert np.allclose(Jf, Jr) assert np.allclose(Jf, A)

Quadratic form.

f(x;A)=xTAxxf(x;A)=(A+AT)x\begin{align} f(x;A) &= x^T A x \\ \nabla_x f(x;A) &= (A+A^T) x \end{align}
D = 4 A = np.random.normal(size=(D, D)) x = np.random.normal(size=(D,)) quadfun = lambda x: jnp.dot(x, jnp.dot(A, x)) g = grad(quadfun)(x) assert np.allclose(g, jnp.dot(A + A.T, x))

Chain rule applied to sigmoid function.

μ(x;w)=σ(wTx)wμ(x;w)=σ(wTx)xσ(a)=σ(a)(1σ(a))\begin{align} \mu(x;w) &=\sigma(w^T x) \\ \nabla_w \mu(x;w) &= \sigma'(w^T x) x \\ \sigma'(a) &= \sigma(a) * (1-\sigma(a)) \end{align}
D = 4 w = np.random.normal(size=(D,)) x = np.random.normal(size=(D,)) y = 0 def sigmoid(x): return 0.5 * (jnp.tanh(x / 2.0) + 1) def mu(w): return sigmoid(jnp.dot(w, x)) def deriv_mu(w): return mu(w) * (1 - mu(w)) * x deriv_mu_jax = grad(mu) print(deriv_mu(w)) print(deriv_mu_jax(w)) assert np.allclose(deriv_mu(w), deriv_mu_jax(w), atol=1e-3)
[-0.458 0.022 -0.266 -0.005] [-0.458 0.022 -0.266 -0.005]

Auxiliary return values

A function can return its value and other auxiliary results; the latter are not differentiated.

def f(x, y): return x**2 + y, 42 (v, aux), g = value_and_grad(f, has_aux=True)(x, y) print(v) print(aux) print(g)
7.0 42 4.0

Jacobians

Example: Linear function: multi-input, multi-output.

f(x;A)=Axf(x;A)x=A\begin{align} f(x;A) &= A x \\ \frac{\partial f(x;A)}{\partial x} &= A \end{align}
# We construct a multi-output linear function. # We check forward and reverse mode give same Jacobians. def fun(x): return jnp.dot(A, x) Din = 3 Dout = 4 A = np.random.normal(size=(Dout, Din)) x = np.random.normal(size=(Din,)) Jf = jacfwd(fun)(x) Jr = jacrev(fun)(x) assert np.allclose(Jf, Jr)

Hessians

Quadratic form.

f(x;A)=xTAxx2f(x;A)=A+AT\begin{align} f(x;A) &= x^T A x \\ \nabla_x^2 f(x;A) &= A + A^T \end{align}
D = 4 A = np.random.normal(size=(D, D)) x = np.random.normal(size=(D,)) quadfun = lambda x: jnp.dot(x, jnp.dot(A, x)) H1 = hessian(quadfun)(x) assert np.allclose(H1, A + A.T) def my_hessian(fun): return jacfwd(jacrev(fun)) H2 = my_hessian(quadfun)(x) assert np.allclose(H1, H2)

Example: Binary logistic regression

def sigmoid(x): return 0.5 * (jnp.tanh(x / 2.0) + 1) def predict_single(w, x): return sigmoid(jnp.dot(w, x)) # <(D) , (D)> = (1) # inner product def predict_batch(w, X): return sigmoid(jnp.dot(X, w)) # (N,D) * (D,1) = (N,1) # matrix-vector multiply # negative log likelihood def loss(weights, inputs, targets): preds = predict_batch(weights, inputs) logprobs = jnp.log(preds) * targets + jnp.log(1 - preds) * (1 - targets) return -jnp.sum(logprobs) D = 2 N = 3 w = jax.random.normal(key, shape=(D,)) X = jax.random.normal(key, shape=(N, D)) y = jax.random.choice(key, 2, shape=(N,)) # uniform binary labels # logits = jnp.dot(X, w) # y = jax.random.categorical(key, logits) print(loss(w, X, y)) # Gradient function grad_fun = grad(loss) # Gradient of each example in the batch - 2 different ways grad_fun_w = partial(grad_fun, w) grads = vmap(grad_fun_w)(X, y) print(grads) assert grads.shape == (N, D) grads2 = vmap(grad_fun, in_axes=(None, 0, 0))(w, X, y) assert np.allclose(grads, grads2) # Gradient for entire batch grad_sum = jnp.sum(grads, axis=0) assert grad_sum.shape == (D,) print(grad_sum)
1.5545294 [[ 0.042 -0.287] [-0.236 -0.454] [-0.14 0.067]] [-0.334 -0.673]
# Textbook implementation of gradient def NLL_grad(weights, batch): X, y = batch N = X.shape[0] mu = predict_batch(weights, X) g = jnp.sum(jnp.dot(jnp.diag(mu - y), X), axis=0) return g grad_sum_batch = NLL_grad(w, (X, y)) print(grad_sum_batch) assert np.allclose(grad_sum, grad_sum_batch)
[-0.334 -0.673]
# We can also compute Hessians, as we illustrate below. hessian_fun = hessian(loss) # Hessian on one example H0 = hessian_fun(w, X[0, :], y[0]) print("Hessian(example 0)\n{}".format(H0)) # Hessian for batch Hbatch = vmap(hessian_fun, in_axes=(None, 0, 0))(w, X, y) print("Hbatch shape {}".format(Hbatch.shape)) Hbatch_sum = jnp.sum(Hbatch, axis=0) print("Hbatch sum\n {}".format(Hbatch_sum))
Hessian(example 0) [[ 0.006 -0.042] [-0.042 0.286]] Hbatch shape (3, 2, 2) Hbatch sum [[0.118 0.139] [0.139 0.65 ]]
# Textbook implementation of Hessian def NLL_hessian(weights, batch): X, y = batch mu = predict_batch(weights, X) S = jnp.diag(mu * (1 - mu)) H = jnp.dot(jnp.dot(X.T, S), X) return H H2 = NLL_hessian(w, (X, y)) assert np.allclose(Hbatch_sum, H2, atol=1e-2)

Vector Jacobian Products (VJP) and Jacobian Vector Products (JVP)

Consider a bilinear mapping f(x,W)=xWf(x,W) = x W. For fixed parameters, we have f1(x)=Wxf1(x) = W x, so J(x)=WJ(x) = W, and uTJ(x)=J(x)Tu=WTuu^T J(x) = J(x)^T u = W^T u.

n = 3 m = 2 W = jax.random.normal(key, shape=(m, n)) x = jax.random.normal(key, shape=(n,)) u = jax.random.normal(key, shape=(m,)) def f1(x): return jnp.dot(W, x) J1 = jacfwd(f1)(x) print(J1.shape) assert np.allclose(J1, W) tmp1 = jnp.dot(u.T, J1) print(tmp1) (val, jvp_fun) = jax.vjp(f1, x) tmp2 = jvp_fun(u) assert np.allclose(tmp1, tmp2) tmp3 = np.dot(W.T, u) assert np.allclose(tmp1, tmp3)
(2, 3) [ 0.922 1.216 -0.61 ]

For fixed inputs, we have f2(W)=Wxf2(W) = W x, so J(W)=something complexJ(W) = \text{something complex}, but uTJ(W)=J(W)Tu=uxTu^T J(W) = J(W)^T u = u x^T.

def f2(W): return jnp.dot(W, x) J2 = jacfwd(f2)(W) print(J2.shape) tmp1 = jnp.dot(u.T, J2) print(tmp1) print(tmp1.shape) (val, jvp_fun) = jax.vjp(f2, W) tmp2 = jvp_fun(u) assert np.allclose(tmp1, tmp2) tmp3 = np.outer(u, x) assert np.allclose(tmp1, tmp3)
(2, 2, 3) [[-1.425 0.379 -0.267] [ 1.555 -0.413 0.291]] (2, 3)

Stop-gradient

Sometimes we want to take the gradient of a complex expression wrt some parameters θ\theta, but treating θ\theta as a constant for some parts of the expression. For example, consider the TD(0) update in reinforcement learning, which as the following form:

Δθ=(rt+vθ(st)vθ(st1))vθ(st1)\Delta \theta = (r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})) \nabla v_{\theta}(s_{t-1})

where ss is the state, rr is the reward, and vv is the value function. This update is not the gradient of any loss function. However it can be written as the gradient of the pseudo loss function

L(θ)=[rt+vθ(st)vθ(st1)]2L(\theta) = [r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})]^2

since

θL(θ)=2[rt+vθ(st)vθ(st1)]vθ(st1)\nabla_{\theta} L(\theta) = 2 [r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})] \nabla v_{\theta}(s_{t-1})

if the dependency of the target rt+vθ(st)r_t + v_{\theta}(s_t) on the parameter θ\theta is ignored. We can implement this in JAX using stop_gradient, as we show below.

def td_loss(theta, s_prev, r_t, s_t): v_prev = value_fn(theta, s_prev) target = r_t + value_fn(theta, s_t) return 0.5 * (jax.lax.stop_gradient(target) - v_prev) ** 2 td_update = jax.grad(td_loss) # An example transition. s_prev = jnp.array([1.0, 2.0, -1.0]) r_t = jnp.array(1.0) s_t = jnp.array([2.0, 1.0, 0.0]) # Value function and initial parameters value_fn = lambda theta, state: jnp.dot(theta, state) theta = jnp.array([0.1, -0.1, 0.0]) print(td_update(theta, s_prev, r_t, s_t))
[-1.2 -2.4 1.2]

Straight through estimator

The straight-through estimator is a trick for defining a 'gradient' of a function that is otherwise non-differentiable. Given a non-differentiable function f:RnRnf : \mathbb{R}^n \to \mathbb{R}^n that is used as part of a larger function that we wish to find a gradient of, we simply pretend during the backward pass that ff is the identity function, so gradients pass through ff ignoring the ff' term. This can be implemented neatly using jax.lax.stop_gradient.

Here is an example of a non-differentiable function that converts a soft probability distribution to a one-hot vector (discretization).

def onehot(labels, num_classes): y = labels[..., None] == jnp.arange(num_classes)[None] return y.astype(jnp.float32) def quantize(y_soft): y_hard = onehot(jnp.argmax(y_soft), 3)[0] return y_hard y_soft = np.array([0.1, 0.2, 0.7]) print(quantize(y_soft))
[0. 0. 1.]

Now suppose we define some linear function of the quantized variable of the form f(y)=wTq(y)f(y) = w^T q(y). If w=[1,2,3]w=[1,2,3] and q(y)=[0,0,1]q(y)=[0,0,1], we get f(y)=3f(y) = 3. But the gradient is 0 because qq is not differentiable.

def f(y): w = jnp.array([1, 2, 3]) yq = quantize(y) return jnp.dot(w, yq) print(f(y_soft)) print(grad(f)(y_soft))
3.0 [0. 0. 0.]

To use the straight-through estimator, we replace q(y)q(y) with y+SG(q(y)y)y + SG(q(y)-y), where SG is stop gradient. In the forwards pass, we have y+q(y)y=q(y)y+q(y)-y=q(y). In the backwards pass, the gradient of SG is 0, so we effectively replace q(y)q(y) with yy. So in the backwarsd pass we have f(y)=wTq(y)wTyyf(y)w \begin{align} f(y) &= w^T q(y) \approx w^T y \\ \nabla_y f(y) &\approx w \end{align}

def f_ste(y): w = jnp.array([1, 2, 3]) yq = quantize(y) yy = y + jax.lax.stop_gradient(yq - y) # gives yq on fwd, and y on backward return jnp.dot(w, yy) print(f_ste(y_soft)) print(grad(f_ste)(y_soft))
3.0 [1. 2. 3.]

Per-example gradients

In some applications, we want to compute the gradient for every example in a batch, not just the sum of gradients over the batch. This is hard in other frameworks like TF and PyTorch but easy in JAX, as we show below.

def loss(w, x): return jnp.dot(w, x) w = jnp.ones((3,)) x0 = jnp.array([1.0, 2.0, 3.0]) x1 = 2 * x0 X = jnp.stack([x0, x1]) print(X.shape) perex_grads = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0))) print(perex_grads(w, X))
(2, 3) [[1. 2. 3.] [2. 4. 6.]]

To explain the above code in more depth, note that the vmap converts the function loss to take a batch of inputs for each of its arguments, and returns a batch of outputs. To make it work with a single weight vector, we specify in_axes=(None,0), meaning the first argument (w) is not replicated, and the second argument (x) is replicated along dimension 0.

gradfn = jax.grad(loss) W = jnp.stack([w, w]) print(jax.vmap(gradfn)(W, X)) print(jax.vmap(gradfn, in_axes=(None, 0))(w, X))
[[1. 2. 3.] [2. 4. 6.]] [[1. 2. 3.] [2. 4. 6.]]