Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/18/gp_poisson_1d.ipynb
1192 views
Kernel: Python 3.10.4 ('PyroNB')

GP with a Poisson Likelihood

https://tinygp.readthedocs.io/en/latest/tutorials/likelihoods.html

We use the tinygp library to define the model, and the numpyro library to do inference, using either MCMC or SVI.

try: import tinygp except ModuleNotFoundError: %pip install -q tinygp import tinygp try: import numpyro except ModuleNotFoundError: %pip install -qq numpyro %pip install -q numpyro jax jaxlib import numpyro try: import arviz except ModuleNotFoundError: %pip install arviz import arviz try: from probml_utils import latexify, savefig, is_latexify_enabled except ModuleNotFoundError: %pip install git+https://github.com/probml/probml-utils.git from probml_utils import latexify, savefig, is_latexify_enabled import seaborn as sns import numpyro.distributions as dist import jax import jax.numpy as jnp import matplotlib.pyplot as plt from tinygp import kernels, GaussianProcess from jax.config import config config.update("jax_enable_x64", True) latexify(width_scale_factor=3, fig_height=1.5)

Data

key = jax.random.PRNGKey(203618) x = jnp.linspace(-3, 3, 20) true_log_rate = 2 * jnp.cos(2 * x) y = jax.random.poisson(key, jnp.exp(true_log_rate)) plt.figure() plt.plot(x, y, ".k", label="data") plt.plot(x, jnp.exp(true_log_rate), "C1", label="true rate") plt.xlabel("$x$") sns.despine() plt.ylabel("counts") plt.legend(loc=1, prop={"size": 5}, frameon=False) savefig("gp-poisson-data")

Markov chain Monte Carlo (MCMC)

We set up the model in numpyro and run MCMC. Note that the log_rate parameter doesn't have the obs=... argument set, since it is latent.

%%capture def model(x, y=None): # The parameters of the GP model mean = numpyro.sample("mean", dist.Normal(0.0, 2.0)) sigma = numpyro.sample("sigma", dist.HalfNormal(3.0)) rho = numpyro.sample("rho", dist.HalfNormal(10.0)) # Set up the kernel and GP objects kernel = sigma**2 * kernels.Matern52(rho) gp = GaussianProcess(kernel, x, diag=1e-5, mean=mean) log_rate = numpyro.sample("log_rate", gp.numpyro_dist()) # Finally, our observation model is Poisson numpyro.sample("obs", dist.Poisson(jnp.exp(log_rate)), obs=y) # Run the MCMC nuts_kernel = numpyro.infer.NUTS(model, target_accept_prob=0.9) mcmc = numpyro.infer.MCMC( nuts_kernel, num_warmup=500, num_samples=500, num_chains=2, progress_bar=False, ) key = jax.random.PRNGKey(55873) mcmc.run(key, x, y=y) samples = mcmc.get_samples()

We can summarize the MCMC results by plotting our inferred model (here we're showing the 1- and 2-sigma credible regions), and compare it to the known ground truth:

percentile = jnp.percentile(samples["log_rate"], jnp.array([5, 50, 95]), axis=0) plt.figure() plt.plot(x, y, ".k", label="data") plt.plot(x, jnp.exp(true_log_rate), "--", color="C1", label="true rate") plt.plot(x, jnp.exp(percentile[1]), color="C0", label="MCMC inferred rate") plt.fill_between( x, jnp.exp(percentile[0]), jnp.exp(percentile[-1]), alpha=0.3, lw=0, color="C0", label="$95\%$ Confidence" ) plt.legend(loc=1, prop={"size": 5}, frameon=False) sns.despine() plt.xlabel("$x$") plt.ylabel("counts") savefig("gp-poisson-mcmc")

Stochastic variational inference (SVI)

For larger datasets, it is faster to use stochastic variational inference (SVI) instead of MCMC.

def model(x, y=None): # The parameters of the GP model mean = numpyro.param("mean", jnp.zeros(())) sigma = numpyro.param("sigma", jnp.ones(()), constraint=dist.constraints.positive) rho = numpyro.param("rho", 2 * jnp.ones(()), constraint=dist.constraints.positive) # Set up the kernel and GP objects kernel = sigma**2 * kernels.Matern52(rho) gp = GaussianProcess(kernel, x, diag=1e-5, mean=mean) log_rate = numpyro.sample("log_rate", gp.numpyro_dist()) # Finally, our observation model is Poisson numpyro.sample("obs", dist.Poisson(jnp.exp(log_rate)), obs=y) def guide(x, y=None): mu = numpyro.param("log_rate_mu", jnp.zeros_like(x) if y is None else jnp.log(y + 1)) sigma = numpyro.param( "log_rate_sigma", jnp.ones_like(x), constraint=dist.constraints.positive, ) numpyro.sample("log_rate", dist.Independent(dist.Normal(mu, sigma), 1)) optim = numpyro.optim.Adam(0.01) svi = numpyro.infer.SVI(model, guide, optim, numpyro.infer.Trace_ELBO(10)) results = svi.run(jax.random.PRNGKey(5583), 3000, x, y=y, progress_bar=False)

As above, we can plot our inferred conditional model and compare it to the ground truth:

mu = results.params["log_rate_mu"] sigma = results.params["log_rate_sigma"] plt.figure() plt.plot(x, y, ".k", label="data") plt.plot(x, jnp.exp(true_log_rate), "--", color="C1", label="true rate") plt.plot(x, jnp.exp(mu), color="C0", label="VI inferred rate") plt.fill_between( x, jnp.exp(mu - 2 * sigma), jnp.exp(mu + 2 * sigma), alpha=0.3, lw=0, color="C0", label="$95\%$ Confidence", ) plt.legend(loc=1, prop={"size": 5}, frameon=False) plt.xlabel("$x$") plt.ylabel("counts") sns.despine() savefig("gp-poisson-svi")