Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/tutorials/practical_jax_tips.ipynb
1192 views
Kernel: Python [conda env:ajax]
import jax import jax.numpy as jnp from IPython.display import display, Latex

If else condition with lax

f(x)=xx{x2,if x>5x3,otherwisef(\mathbf{x}) = \sum_{x \in \mathbf{x}} \begin{cases} x^2,& \text{if } x \gt 5\\ x^3, & \text{otherwise} \end{cases}
x = [jnp.array(10.0), jnp.array(2.0)] @jax.jit @jax.value_and_grad def f(x): bool_val = jax.tree_map(lambda val: val > 5.0, x) ans = jax.tree_map( lambda val, bool: jax.lax.cond(bool, lambda: val**2, lambda: val**3), x, bool_val, ) return jax.tree_util.tree_reduce(lambda a, b: a + b, ans) value, grad = f(x) display(Latex(f"$f(\mathbf{{x}}) = {value}$")) for idx in range(len(x)): display(Latex(f"$\\frac{{df}}{{dx_{idx}}} = {grad[idx]}$"))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

f(x)=108.0f(\mathbf{x}) = 108.0

dfdx0=20.0\frac{df}{dx_0} = 20.0

dfdx1=12.0\frac{df}{dx_1} = 12.0

Pair-wise distance with vmap

# create vour pairwise function def distance(a, b): return jnp.linalg.norm(a - b) # map based combinator to operate on all pairs def all_pairs(f): f = jax.vmap(f, in_axes=(None, 0)) f = jax.vmap(f, in_axes=(0, None)) return f # transform to operate over sets distances = all_pairs(distance) # Example x = jnp.array([1.0, 2.0, 3.0]) y = jnp.array([3.0, 4.0, 5.0]) distances(x, y)
DeviceArray([[2., 3., 4.], [1., 2., 3.], [0., 1., 2.]], dtype=float32)

Compute Hessian with jax

Let us consider Linear regression loss function

L(θ)=(yXθ)T(yXθ)dLdθ=2XTy+2XTXθHL(θ)=2XTX\begin{align} \mathcal{L}(\boldsymbol{\theta}) &= (\boldsymbol{y} - X\boldsymbol{\theta})^T(\boldsymbol{y} - X\boldsymbol{\theta})\\ \frac{d\mathcal{L}}{d\boldsymbol{\theta}} &= -2X^T\boldsymbol{y} + 2X^TX\boldsymbol{\theta}\\ H_{\mathcal{L}}(\boldsymbol{\theta}) &= 2X^TX \end{align}
def loss_function_per_point(theta, x, y): y_pred = x.T @ theta return jnp.square(y_pred - y) def loss_function(theta, x, y): loss_per_point = jax.vmap(loss_function_per_point, in_axes=(None, 0, 0))(theta, x, y) return jnp.sum(loss_per_point) def gt_loss(theta, x, y): return jnp.sum(jnp.square(x @ theta - y)) def gt_grad(theta, x, y): return 2 * (x.T @ x @ theta - x.T @ y) def gt_hess(theta, x, y): return 2 * x.T @ x

Simulate dataset

key = jax.random.PRNGKey(0) key, subkey1, subkey2 = jax.random.split(key, num=3) N = 100 D = 11 x = jax.random.uniform(key, shape=(N, D)) y = jax.random.uniform(subkey1, shape=(N,)) theta = jax.random.uniform(subkey2, shape=(D,))

Verify loss and gradient values

loss_and_grad_function = jax.value_and_grad(loss_function) loss_val, grad = loss_and_grad_function(theta, x, y) assert jnp.allclose(loss_val, gt_loss(theta, x, y)) assert jnp.allclose(grad, gt_grad(theta, x, y))

Verify hessian matrix

Way-1

hess = jax.hessian(loss_function)(theta, x, y) assert jnp.allclose(hess, gt_hess(theta, x, y))

Way-2

hess = jax.jacfwd(jax.jacrev(loss_function))(theta, x, y) assert jnp.allclose(hess, gt_hess(theta, x, y))

tree_map in JAX

The only requirement for tree_map to work is, output should have the same structure as the first argument (as explained here). For example:

import jax import jax.numpy as jnp try: import distrax except: %pip install -qqq distrax import distrax dists = {"Normal": distrax.Normal(3.0, 4.0), "Gamma": distrax.Gamma(3.0, 4.0)} samples = {"Normal": jnp.array(2.0), "Gamma": jnp.array(3.0)} try: log_probs = jax.tree_map(lambda dist, sample: dist.log_prob(sample), dists, samples) except Exception as e: print(e)
Custom node type mismatch: expected type: <class 'distrax._src.distributions.normal.Normal'>, value: DeviceArray(2., dtype=float32, weak_type=True).

The problem here is that dists do not have same structure as log_probs (log_probs structure matches with samples). So, we should keep samples as the first argument:

log_probs = jax.tree_map(lambda sample, dist: dist.log_prob(sample), samples, dists) log_probs
{'Gamma': DeviceArray(-6.33704, dtype=float32, weak_type=True), 'Normal': DeviceArray(-2.336483, dtype=float32, weak_type=True)}

Use of lax.scan to accelerate a training loop

Here we create a dummy training loop and check the performance of lax.scan. The example also shows how to convert a training loop to lax.scan version of it.

value_and_grad_fun = jax.jit(jax.value_and_grad(lambda x: jnp.sum(x**2))) def training_loop(n_iters, params): values = [] for i in range(n_iters): value, grad = value_and_grad_fun(params) params = params - learning_rate * grad values.append(value) return value @jax.jit def one_step(params, xs): value, grad = value_and_grad_fun(params) params = params - learning_rate * grad return params, value
key = jax.random.PRNGKey(0) N = 1000 n_iters = 10000 learning_rate = 0.01 params = jax.random.uniform(key, (N,))
training_loop(1, params) # warn up %timeit -n 1 -r 1 training_loop(n_iters, params)
155 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
%timeit -n 1 -r 1 jax.lax.scan(one_step, params, xs=None, length=n_iters)
64.9 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

Note that xs array can be passed in case we want to scan over it. An example of it can be found in this blackjax documentation.

tree_flatten v/s ravel_pytree

  • tree_flatten: This function is used to get a list of leaves from a PyTree

  • ravel_pytree: This function is used to convert all the leaves in a one dimensional JAX array

tree_flatten and tree_unflatten

pytree = {"theta": jnp.ones((2, 3)), "alpha": jnp.zeros((2,))} pytree
{'theta': DeviceArray([[1., 1., 1.], [1., 1., 1.]], dtype=float32), 'alpha': DeviceArray([0., 0.], dtype=float32)}
leaves, treedef = jax.tree_flatten(pytree) leaves
[DeviceArray([0., 0.], dtype=float32), DeviceArray([[1., 1., 1.], [1., 1., 1.]], dtype=float32)]
treedef
PyTreeDef({'alpha': *, 'theta': *})
reconstructed_pytree = jax.tree_unflatten(treedef, leaves) reconstructed_pytree
{'alpha': DeviceArray([0., 0.], dtype=float32), 'theta': DeviceArray([[1., 1., 1.], [1., 1., 1.]], dtype=float32)}

ravel_pytree

from jax.flatten_util import ravel_pytree
pytree
{'theta': DeviceArray([[1., 1., 1.], [1., 1., 1.]], dtype=float32), 'alpha': DeviceArray([0., 0.], dtype=float32)}
array, unravel_fn = ravel_pytree(pytree) array
DeviceArray([0., 0., 1., 1., 1., 1., 1., 1.], dtype=float32)
reconstructed_pytree2 = unravel_fn(array) reconstructed_pytree2
{'alpha': DeviceArray([0., 0.], dtype=float32), 'theta': DeviceArray([[1., 1., 1.], [1., 1., 1.]], dtype=float32)}

is_leaf while working with PyTrees

Sometimes you do not want to work with the leaves of your PyTree. You may want to consider a non-leaf node as a leaf node based on your requirement. Let us see such an example in distrax

distribution_pytree = { "normal": distrax.MultivariateNormalDiag(loc=jnp.zeros(2), scale_diag=jnp.ones(2)), "gamma": distrax.Gamma(concentration=0.5, rate=2), }

Suppose we want to sample from the above distribution_pytree.

seed = jax.random.PRNGKey(0) samples = jax.tree_map(lambda dist: dist.sample(seed), distribution_pytree) samples
{'gamma': <distrax._src.distributions.gamma.Gamma at 0x7f1ac86ed6a0>, 'normal': <distrax._src.distributions.mvn_diag.MultivariateNormalDiag at 0x7f1ac8747f70>}
jax.tree_leaves(distribution_pytree)
[]

The problem here is that there are no leaves returned by tree_leaves, but we want the leaves to be distrax distributions. Let us use is_leaf for this purpose.

is_leaf = lambda dist: isinstance(dist, distrax.Distribution) jax.tree_leaves(distribution_pytree, is_leaf=is_leaf)
[<distrax._src.distributions.gamma.Gamma at 0x7f1ac872f550>, <distrax._src.distributions.mvn_diag.MultivariateNormalDiag at 0x7f1ac872f700>]

And, we get what we anticipated. Let us now try to get samples passing is_leaf to tree_map.

samples = jax.tree_map(lambda dist: dist.sample(seed=seed), distribution_pytree, is_leaf=is_leaf) samples
{'gamma': DeviceArray(0.08460148, dtype=float32), 'normal': DeviceArray([-0.78476596, 0.85644484], dtype=float32)}

We can see that we are able to get the samples now.