Path: blob/master/notebooks/book1/08/autodiff_jax.ipynb
1193 views
Automatic differentiation using JAX
In this section, we illustrate automatic differentation using JAX. For details, see see this video or The Autodiff Cookbook.
Derivatives
We can compute using grad(f)(x)
. For example, consider
Partial derivatives
Gradients
Linear function: multi-input, scalar output.
Linear function: multi-input, multi-output.
Quadratic form.
Chain rule applied to sigmoid function.
Auxiliary return values
A function can return its value and other auxiliary results; the latter are not differentiated.
Jacobians
Example: Linear function: multi-input, multi-output.
Hessians
Quadratic form.
Example: Binary logistic regression
Vector Jacobian Products (VJP) and Jacobian Vector Products (JVP)
Consider a bilinear mapping . For fixed parameters, we have , so , and .
For fixed inputs, we have , so , but .
Stop-gradient
Sometimes we want to take the gradient of a complex expression wrt some parameters , but treating as a constant for some parts of the expression. For example, consider the TD(0) update in reinforcement learning, which as the following form:
where is the state, is the reward, and is the value function. This update is not the gradient of any loss function. However it can be written as the gradient of the pseudo loss function
since
if the dependency of the target on the parameter is ignored. We can implement this in JAX using stop_gradient
, as we show below.
Straight through estimator
The straight-through estimator is a trick for defining a 'gradient' of a function that is otherwise non-differentiable. Given a non-differentiable function that is used as part of a larger function that we wish to find a gradient of, we simply pretend during the backward pass that is the identity function, so gradients pass through ignoring the term. This can be implemented neatly using jax.lax.stop_gradient
.
Here is an example of a non-differentiable function that converts a soft probability distribution to a one-hot vector (discretization).
Now suppose we define some linear function of the quantized variable of the form . If and , we get . But the gradient is 0 because is not differentiable.
To use the straight-through estimator, we replace with , where SG is stop gradient. In the forwards pass, we have . In the backwards pass, the gradient of SG is 0, so we effectively replace with . So in the backwarsd pass we have
Per-example gradients
In some applications, we want to compute the gradient for every example in a batch, not just the sum of gradients over the batch. This is hard in other frameworks like TF and PyTorch but easy in JAX, as we show below.
To explain the above code in more depth, note that the vmap converts the function loss to take a batch of inputs for each of its arguments, and returns a batch of outputs. To make it work with a single weight vector, we specify in_axes=(None,0), meaning the first argument (w) is not replicated, and the second argument (x) is replicated along dimension 0.