Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/tutorials/numpyro_intro.ipynb
1192 views
Kernel: Python 3

Open In Colab

NumPyro is probabilistic programming language built on top of JAX. It is very similar to Pyro, which is built on top of PyTorch. However, the HMC algorithm in NumPyro is much faster.

Both Pyro flavors are usually also faster than PyMc3, and allow for more complex models, since Pyro is integrated into Python.

Installation

import numpy as np np.set_printoptions(precision=3) import matplotlib.pyplot as plt import math
# When running in colab pro (high RAM mode), you get 4 CPUs. # But we need to force XLA to use all 4 CPUs # This is generally faster than running in GPU mode import os os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
# http://num.pyro.ai/en/stable/getting_started.html#installation # CPU mode: often faster in colab! !pip install numpyro # GPU mode: as of July 2021, this does not seem to work #!pip install numpyro[cuda111] -f https://storage.googleapis.com/jax-releases/jax_releases.html
Collecting numpyro Downloading numpyro-0.7.2-py3-none-any.whl (250 kB) |████████████████████████████████| 250 kB 5.2 MB/s eta 0:00:01 Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from numpyro) (4.41.1) Requirement already satisfied: jaxlib>=0.1.65 in /usr/local/lib/python3.7/dist-packages (from numpyro) (0.1.69+cuda110) Requirement already satisfied: jax>=0.2.13 in /usr/local/lib/python3.7/dist-packages (from numpyro) (0.2.17) Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (1.19.5) Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (0.12.0) Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (3.3.0) Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.65->numpyro) (1.4.1) Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.65->numpyro) (1.12) Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax>=0.2.13->numpyro) (1.15.0) Installing collected packages: numpyro Successfully installed numpyro-0.7.2
import jax print("jax version {}".format(jax.__version__)) print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform)) print(jax.lib.xla_bridge.device_count()) print(jax.local_device_count()) import jax.numpy as jnp from jax import random
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax version 0.2.17 jax backend cpu 4 4
import numpyro # numpyro.set_platform('gpu') import numpyro.distributions as dist from numpyro.distributions import constraints from numpyro.distributions.transforms import AffineTransform from numpyro.infer import MCMC, NUTS, Predictive from numpyro.infer import SVI, Trace_ELBO, init_to_value from numpyro.diagnostics import hpdi, print_summary from numpyro.infer.autoguide import AutoLaplaceApproximation rng_key = random.PRNGKey(0) rng_key, rng_key_ = random.split(rng_key)

Example: 1d Gaussian with unknown mean.

We use the simple example from the Pyro intro. The goal is to infer the weight θ\theta of an object, given noisy measurements yy. We assume the following model: θN(μ=8.5,τ2=1.0)yN(θ,σ2=0.752) \begin{align} \theta &\sim N(\mu=8.5, \tau^2=1.0)\\ y \sim &N(\theta, \sigma^2=0.75^2) \end{align}

Where μ=8.5\mu=8.5 is the initial guess.

Exact inference

By Bayes rule for Gaussians, we know that the exact posterior, given a single observation y=9.5y=9.5, is given by

θyN(m,ss)m=σ2μ+τ2yσ2+τ2=0.752×8.5+1×9.50.752+12=9.14s2=σ2τ2σ2+τ2=0.752×120.752+12=0.62\begin{align} \theta|y &\sim N(m, s^s) \\ m &=\frac{\sigma^2 \mu + \tau^2 y}{\sigma^2 + \tau^2} = \frac{0.75^2 \times 8.5 + 1 \times 9.5}{0.75^2 + 1^2} = 9.14 \\ s^2 &= \frac{\sigma^2 \tau^2}{\sigma^2 + \tau^2} = \frac{0.75^2 \times 1^2}{0.75^2 + 1^2}= 0.6^2 \end{align}
mu = 8.5 tau = 1.0 sigma = 0.75 hparams = (mu, tau, sigma) y = 9.5 m = (sigma**2 * mu + tau**2 * y) / (sigma**2 + tau**2) s2 = (sigma**2 * tau**2) / (sigma**2 + tau**2) s = np.sqrt(s2) print(m) print(s)
9.14 0.6
def model(hparams, y=None): prior_mean, prior_sd, obs_sd = hparams theta = numpyro.sample("theta", dist.Normal(prior_mean, prior_sd)) y = numpyro.sample("y", dist.Normal(theta, obs_sd), obs=y) return y

Ancestral sampling

def model2(hparams): prior_mean, prior_sd, obs_sd = hparams theta = numpyro.sample("theta", dist.Normal(prior_mean, prior_sd)) yy = numpyro.sample("y", dist.Normal(theta, obs_sd)) return theta, yy
with numpyro.handlers.seed(rng_seed=0): for i in range(5): theta, yy = model2(hparams) print([theta, yy])
[DeviceArray(7.248, dtype=float32), DeviceArray(6.808, dtype=float32)] [DeviceArray(8.986, dtype=float32), DeviceArray(9.149, dtype=float32)] [DeviceArray(7.851, dtype=float32), DeviceArray(8.856, dtype=float32)] [DeviceArray(9.538, dtype=float32), DeviceArray(8.973, dtype=float32)] [DeviceArray(7.895, dtype=float32), DeviceArray(6.133, dtype=float32)]
conditioned_model = numpyro.handlers.condition(model, {"y": y}) nuts_kernel = NUTS(conditioned_model) mcmc = MCMC(nuts_kernel, num_warmup=200, num_samples=200, num_chains=4) mcmc.run(rng_key_, hparams) mcmc.print_summary() samples = mcmc.get_samples()
HBox(children=(FloatProgress(value=0.0, max=400.0), HTML(value='')))
HBox(children=(FloatProgress(value=0.0, max=400.0), HTML(value='')))
HBox(children=(FloatProgress(value=0.0, max=400.0), HTML(value='')))
HBox(children=(FloatProgress(value=0.0, max=400.0), HTML(value='')))
mean std median 5.0% 95.0% n_eff r_hat theta 9.20 0.61 9.19 8.31 10.27 259.86 1.01 Number of divergences: 0
print(type(samples)) print(type(samples["theta"])) print(samples["theta"].shape)
<class 'dict'> <class 'jax.interpreters.xla._DeviceArray'> (1000,)
nuts_kernel = NUTS(model) # this is the unconditioned model mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=1000) mcmc.run(rng_key_, hparams, y) # we need to specify the observations here mcmc.print_summary() samples = mcmc.get_samples()
sample: 100%|██████████| 1100/1100 [00:02<00:00, 386.90it/s, 3 steps of size 1.58e+00. acc. prob=0.83]
mean std median 5.0% 95.0% n_eff r_hat theta 9.09 0.60 9.09 8.11 10.07 350.32 1.00 Number of divergences: 0

Stochastic variational inference

See the documentation

# the guide must have the same signature as the model def guide(hparams, y): prior_mean, prior_sd, obs_sd = hparams m = numpyro.param("m", y) # location s = numpyro.param("s", prior_sd, constraint=constraints.positive) # scale return numpyro.sample("theta", dist.Normal(m, s)) # The optimizer wrap these, so have unusual keywords # https://jax.readthedocs.io/en/latest/jax.experimental.optimizers.html # optimizer = numpyro.optim.Adam(step_size=0.001) optimizer = numpyro.optim.Momentum(step_size=0.001, mass=0.1) # svi = SVI(model, guide, optimizer, Trace_ELBO(), hparams=hparams, y=y) # specify static args to model/guide svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) nsteps = 2000 svi_result = svi.run(rng_key_, nsteps, hparams, y) # or specify arguments here print(svi_result.params) print(svi_result.losses.shape) plt.plot(svi_result.losses) plt.title("ELBO") plt.xlabel("step") plt.ylabel("loss");
100%|██████████| 2000/2000 [00:01<00:00, 1295.73it/s, init loss: 1.8902, avg. loss [1901-2000]: 1.4572]
{'m': DeviceArray(9.084, dtype=float32), 's': DeviceArray(0.586, dtype=float32)} (2000,)
Image in a Jupyter notebook
print([svi_result.params["m"], svi_result.params["s"]])
[DeviceArray(9.084, dtype=float32), DeviceArray(0.586, dtype=float32)]

Laplace (quadratic) approximation

See the documentation

guide_laplace = AutoLaplaceApproximation(model) svi = SVI(model, guide_laplace, optimizer, Trace_ELBO(), hparams=hparams, y=y) svi_run = svi.run(rng_key_, 2000) params = svi_run.params losses = svi_result.losses plt.figure() plt.plot(losses)
100%|██████████| 2000/2000 [00:01<00:00, 1607.21it/s, init loss: 155.0839, avg. loss [1901-2000]: 1.8711]
[<matplotlib.lines.Line2D at 0x7f4af42bcb10>]
Image in a Jupyter notebook
# Posterior is an MVN # https://num.pyro.ai/en/stable/distributions.html#multivariatenormal post = guide_laplace.get_posterior(params) print(post) m = post.mean s = jnp.sqrt(post.covariance_matrix) print([m, s])
<numpyro.distributions.continuous.MultivariateNormal object at 0x7f4af21b00d0> [DeviceArray([9.118], dtype=float32), DeviceArray([[0.6]], dtype=float32)]
samples = guide_laplace.sample_posterior(rng_key_, params, (1000,)) print_summary(samples, 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat theta 9.16 0.61 9.12 8.18 10.04 955.50 1.00

Example: Beta-Bernoulli model

Example is from SVI tutorial

The model is θBeta(α,β)xiBer(θ) \begin{align} \theta &\sim \text{Beta}(\alpha, \beta) \\ x_i &\sim \text{Ber}(\theta) \end{align} where α=β=10\alpha=\beta=10. In the code, θ\theta is called latent_fairness.

alpha0 = 10.0 beta0 = 10.0 def model(data): f = numpyro.sample("latent_fairness", dist.Beta(alpha0, beta0)) # loop over the observed data for i in range(len(data)): numpyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])
# create some data with 6 observed heads and 4 observed tails data = jnp.hstack((jnp.ones(6), jnp.zeros(4))) print(data) N1 = jnp.sum(data == 1) N0 = jnp.sum(data == 0) print([N1, N0])
[1. 1. 1. 1. 1. 1. 0. 0. 0. 0.] [DeviceArray(6, dtype=int32), DeviceArray(4, dtype=int32)]

Exact inference

The posterior is given by θBer(α+N1,β+N0)N1=i=1N[xi=1]N0=i=1N[xi=0] \begin{align} \theta &\sim \text{Ber}(\alpha + N_1, \beta + N_0) \\ N_1 &= \sum_{i=1}^N [x_i=1] \\ N_0 &= \sum_{i=1}^N [x_i=0] \end{align}

alpha_q = alpha0 + N1 beta_q = beta0 + N0 print("exact posterior: alpha={:0.3f}, beta={:0.3f}".format(alpha_q, beta_q)) post_mean = alpha_q / (alpha_q + beta_q) post_var = (post_mean * beta_q) / ((alpha_q + beta_q) * (alpha_q + beta_q + 1)) post_std = np.sqrt(post_var) print([post_mean, post_std])
exact posterior: alpha=16.000, beta=14.000 [DeviceArray(0.533, dtype=float32), 0.08960287]
inferred_mean = alpha_q / (alpha_q + beta_q) # compute inferred standard deviation factor = beta_q / (alpha_q * (1.0 + alpha_q + beta_q)) inferred_std = inferred_mean * math.sqrt(factor) print([inferred_mean, inferred_std])
[DeviceArray(0.533, dtype=float32), DeviceArray(0.09, dtype=float32)]

Variational inference

def guide(data): alpha_q = numpyro.param("alpha_q", alpha0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", beta0, constraint=constraints.positive) numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))
# optimizer = numpyro.optim.Adam(step_size=0.001) optimizer = numpyro.optim.Momentum(step_size=0.001, mass=0.1) svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) nsteps = 2000 svi_result = svi.run(rng_key_, nsteps, data) print(svi_result.params) print(svi_result.losses.shape) plt.plot(svi_result.losses) plt.title("ELBO") plt.xlabel("step") plt.ylabel("loss");
100%|██████████| 2000/2000 [00:03<00:00, 605.23it/s, init loss: 6.8058, avg. loss [1901-2000]: 7.0744]
{'alpha_q': DeviceArray(13.071, dtype=float32), 'beta_q': DeviceArray(11.132, dtype=float32)} (2000,)
Image in a Jupyter notebook
# grab the learned variational parameters alpha_q = svi_result.params["alpha_q"] beta_q = svi_result.params["beta_q"] print("variational posterior: alpha={:0.3f}, beta={:0.3f}".format(alpha_q, beta_q)) post_mean = alpha_q / (alpha_q + beta_q) post_var = (post_mean * beta_q) / ((alpha_q + beta_q) * (alpha_q + beta_q + 1)) post_std = np.sqrt(post_var) print([post_mean, post_std])
variational posterior: alpha=13.071, beta=11.132 [DeviceArray(0.54, dtype=float32), 0.09927754]

MCMC

nuts_kernel = NUTS(model) # this is the unconditioned model mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=1000) mcmc.run(rng_key_, data) mcmc.print_summary() samples = mcmc.get_samples()
sample: 100%|██████████| 1100/1100 [00:05<00:00, 207.57it/s, 3 steps of size 1.16e+00. acc. prob=0.88]
mean std median 5.0% 95.0% n_eff r_hat latent_fairness 0.53 0.09 0.53 0.37 0.68 380.27 1.00 Number of divergences: 0

Distributions

1d Gaussian

# 2 independent 1d gaussians (ie 1 diagonal Gaussian) mu = 1.5 sigma = 2 d = dist.Normal(mu, sigma) dir(d)
['__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_batch_shape', '_event_shape', '_validate_args', '_validate_sample', 'arg_constraints', 'batch_shape', 'cdf', 'enumerate_support', 'event_dim', 'event_shape', 'expand', 'expand_by', 'has_enumerate_support', 'has_rsample', 'icdf', 'infer_shapes', 'is_discrete', 'loc', 'log_prob', 'mask', 'mean', 'reparametrized_params', 'rsample', 'sample', 'sample_with_intermediates', 'scale', 'set_default_validate_args', 'shape', 'support', 'to_event', 'tree_flatten', 'tree_unflatten', 'variance']
rng_key, rng_key_ = random.split(rng_key) nsamples = 1000 ys = d.sample(rng_key_, (nsamples,)) print(ys.shape) mu_hat = np.mean(ys, 0) print(mu_hat) sigma_hat = np.std(ys, 0) print(sigma_hat)
(1000,) 1.5070927 2.0493808

Multivariate Gaussian

mu = np.array([-1, 1]) sigma = np.array([1, 2]) Sigma = np.diag(sigma) d2 = dist.MultivariateNormal(mu, Sigma)
# rng_key, rng_key_ = random.split(rng_key) nsamples = 1000 ys = d2.sample(rng_key_, (nsamples,)) print(ys.shape) mu_hat = np.mean(ys, 0) print(mu_hat) Sigma_hat = np.cov(ys, rowvar=False) # jax.np.cov not implemented print(Sigma_hat)
(1000, 2) [-1.013 1.009] [[ 0.977 -0.005] [-0.005 1.972]]

Shape semantics

Numpyro, Pyro and TFP and Distrax all distinguish between 'event shape' and 'batch shape'. For a D-dimensional Gaussian, the event shape is (D,), and the batch shape will be (), meaning we have a single instance of this distribution. If the covariance is diagonal, we can view this as D independent 1d Gaussians, stored along the batch dimension; this will have event shape () but batch shape (2,).

When we sample from a distribution, we also specify the sample_shape. Suppose we draw N samples from a single D-dim diagonal Gaussian, and N samples from D 1d Gaussians. These samples will have the same shape. However, the semantics of logprob differs. We illustrate this below.

mu = np.array([-1, 1]) sigma = np.array([1, 2]) Sigma = np.diag(sigma) d2 = dist.MultivariateNormal(mu, Sigma) print(f"event shape {d2.event_shape}, batch shape {d2.batch_shape}") nsamples = 3 ys2 = d2.sample(rng_key_, (nsamples,)) print("samples, shape {}".format(ys2.shape)) print(ys2) # 2 independent 1d gaussians (same as one 2d diagonal Gaussian) d3 = dist.Normal(mu, scale=np.sqrt(np.diag(Sigma))) # scalar Gaussian needs std not variance print(f"event shape {d3.event_shape}, batch shape {d3.batch_shape}") ys3 = d3.sample(rng_key_, (nsamples,)) print("samples, shape {}".format(ys3.shape)) print(ys3) print(np.allclose(ys2, ys3)) y = ys2[0, :] # 2 numbers print(d2.log_prob(y)) # log prob of a single 2d distribution on 2d input print(d3.log_prob(y)) # log prob of two 1d distributions on 2d input
event shape (2,), batch shape () samples, shape (3, 2) [[-0.068 0.994] [-1.74 -1.018] [ 0.06 2.314]] event shape (), batch shape (2,) samples, shape (3, 2) [[-0.068 0.994] [-1.74 -1.018] [ 0.06 2.314]] True -2.6185904 [-1.353 -1.266]

We can turn a set of independent distributions into a single product distribution using the Independent class

d4 = dist.Independent(d3, 1) # treat the first batch dimension as an event dimensions # now d4 is just like d2 print(f"event shape {d4.event_shape}, batch shape {d4.batch_shape}") print(d4.log_prob(y))
event shape (2,), batch shape () -2.6185904