Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/misc/gaussian_param_inf_1d_numpyro.ipynb
1192 views
Kernel: Python 3

Open In Colab

Inference for the parameters of a 1d Gaussian using a non-conjugate prior

We illustrate various inference methods using the example in sec 4.3 ("Gaussian model of height") of Statistical Rethinking ed 2. This requires computing p(μ,σ∣D)p(\mu,\sigma|D) using a Gaussian likelihood but a non-conjugate prior. The numpyro code is from Du Phan's site.

import numpy as np np.set_printoptions(precision=3) import matplotlib.pyplot as plt import math import os import warnings import pandas as pd # from scipy.interpolate import BSpline # from scipy.stats import gaussian_kde
!mkdir figures
mkdir: cannot create directory ‘figures’: File exists
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro
import jax print("jax version {}".format(jax.__version__)) print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform)) import jax.numpy as jnp from jax import random, vmap rng_key = random.PRNGKey(0) rng_key, rng_key_ = random.split(rng_key)
jax version 0.2.12 jax backend gpu
import numpyro import numpyro.distributions as dist from numpyro.distributions import constraints from numpyro.distributions.transforms import AffineTransform from numpyro.diagnostics import hpdi, print_summary from numpyro.infer import Predictive from numpyro.infer import MCMC, NUTS from numpyro.infer import SVI, Trace_ELBO, init_to_value from numpyro.infer.autoguide import AutoLaplaceApproximation import numpyro.optim as optim
!pip install arviz import arviz as az
Requirement already satisfied: arviz in /usr/local/lib/python3.7/dist-packages (0.11.2) Requirement already satisfied: setuptools>=38.4 in /usr/local/lib/python3.7/dist-packages (from arviz) (56.0.0) Requirement already satisfied: typing-extensions<4,>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from arviz) (3.7.4.3) Requirement already satisfied: pandas>=0.23 in /usr/local/lib/python3.7/dist-packages (from arviz) (1.1.5) Requirement already satisfied: netcdf4 in /usr/local/lib/python3.7/dist-packages (from arviz) (1.5.6) Requirement already satisfied: scipy>=0.19 in /usr/local/lib/python3.7/dist-packages (from arviz) (1.4.1) Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from arviz) (20.9) Requirement already satisfied: matplotlib>=3.0 in /usr/local/lib/python3.7/dist-packages (from arviz) (3.2.2) Requirement already satisfied: xarray>=0.16.1 in /usr/local/lib/python3.7/dist-packages (from arviz) (0.17.0) Requirement already satisfied: numpy>=1.12 in /usr/local/lib/python3.7/dist-packages (from arviz) (1.19.5) Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.23->arviz) (2018.9) Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.23->arviz) (2.8.1) Requirement already satisfied: cftime in /usr/local/lib/python3.7/dist-packages (from netcdf4->arviz) (1.4.1) Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->arviz) (2.4.7) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.0->arviz) (0.10.0) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.0->arviz) (1.3.1) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas>=0.23->arviz) (1.15.0)

Data

We use the "Howell" dataset, which consists of measurements of height, weight, age and sex, of a certain foraging tribe, collected by Nancy Howell.

# url = 'https://github.com/fehiepsi/rethinking-numpyro/tree/master/data/Howell1.csv?raw=True' url = "https://raw.githubusercontent.com/fehiepsi/rethinking-numpyro/master/data/Howell1.csv" Howell1 = pd.read_csv(url, sep=";") d = Howell1 d.info() d.head()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 544 entries, 0 to 543 Data columns (total 4 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 height 544 non-null float64 1 weight 544 non-null float64 2 age 544 non-null float64 3 male 544 non-null int64 dtypes: float64(3), int64(1) memory usage: 17.1 KB
# get data for adults d2 = d[d.age >= 18] N = len(d2) ndx = jax.random.permutation(rng_key, N) data = d2.height.values[ndx] N = 20 # take a subset of the 354 samples data = data[:N]

Empirical mean and std.

print(len(data)) print(np.mean(data)) print(np.std(data))
20 154.16326000000004 7.459859122289108

Model

We use the following model for the heights (in cm): hi∼N(μ,σ)μ∼N(178,20)σ∼U(0,50) \begin{align} h_i &\sim N(\mu,\sigma) \\ \mu &\sim N(178, 20) \\ \sigma &\sim U(0,50) \end{align}

The prior for μ\mu has a mean 178cm, since that is the height of Richard McElreath, the author of the "Statisical Rethinking" book. The standard deviation is 20, so that 90% of people lie in the range 138--218.

The prior for σ\sigma has a lower bound of 0 (since it must be positive), and an upper bound of 50, so that the interval [μ−σ,μ+σ][\mu-\sigma, \mu+\sigma] has width 100cm, which seems sufficiently large to capture human heights.

Note that this is not a conjugate prior, so we will just approximate the posterior. But since there are just 2 unknowns, this will be easy.

Grid posterior

mu_prior = dist.Normal(178, 20) sigma_prior = dist.Uniform(0, 50) mu_range = [150, 160] sigma_range = [4, 14] ngrid = 100 plot_square = False mu_list = jnp.linspace(start=mu_range[0], stop=mu_range[1], num=ngrid) sigma_list = jnp.linspace(start=sigma_range[0], stop=sigma_range[1], num=ngrid) mesh = jnp.meshgrid(mu_list, sigma_list) print([mesh[0].shape, mesh[1].shape]) print(mesh[0].reshape(-1).shape) post = {"mu": mesh[0].reshape(-1), "sigma": mesh[1].reshape(-1)} post["LL"] = vmap(lambda mu, sigma: jnp.sum(dist.Normal(mu, sigma).log_prob(data)))(post["mu"], post["sigma"]) logprob_mu = mu_prior.log_prob(post["mu"]) logprob_sigma = sigma_prior.log_prob(post["sigma"]) post["prob"] = post["LL"] + logprob_mu + logprob_sigma post["prob"] = jnp.exp(post["prob"] - jnp.max(post["prob"])) prob = post["prob"] / jnp.sum(post["prob"]) # normalize over the grid
[(100, 100), (100, 100)] (10000,)
prob2d = prob.reshape(ngrid, ngrid) prob_mu = jnp.sum(prob2d, axis=0) prob_sigma = jnp.sum(prob2d, axis=1) plt.figure() plt.plot(mu_list, prob_mu, label="mu") plt.legend() plt.savefig("figures/gauss_params_1d_post_grid_marginal_mu.pdf", dpi=300) plt.show() plt.figure() plt.plot(sigma_list, prob_sigma, label="sigma") plt.legend() plt.savefig("figures/gauss_params_1d_post_grid_marginal_sigma.pdf", dpi=300) plt.show()
Image in a Jupyter notebookImage in a Jupyter notebook
plt.contour( post["mu"].reshape(ngrid, ngrid), post["sigma"].reshape(ngrid, ngrid), post["prob"].reshape(ngrid, ngrid), ) plt.xlabel(r"$\mu$") plt.ylabel(r"$\sigma$") if plot_square: plt.axis("square") plt.savefig("figures/gauss_params_1d_post_grid_contours.pdf", dpi=300) plt.show()
Image in a Jupyter notebook
plt.imshow( post["prob"].reshape(ngrid, ngrid), origin="lower", extent=(mu_range[0], mu_range[1], sigma_range[0], sigma_range[1]), aspect="auto", ) plt.xlabel(r"$\mu$") plt.ylabel(r"$\sigma$") if plot_square: plt.axis("square") plt.savefig("figures/gauss_params_1d_post_grid_heatmap.pdf", dpi=300) plt.show()
Image in a Jupyter notebook

Posterior samples.

nsamples = 5000 # int(1e4) sample_rows = dist.Categorical(probs=prob).sample(random.PRNGKey(0), (nsamples,)) sample_mu = post["mu"][sample_rows] sample_sigma = post["sigma"][sample_rows] samples = {"mu": sample_mu, "sigma": sample_sigma}
print_summary(samples, 0.95, False) plt.scatter(samples["mu"], samples["sigma"], s=64, alpha=0.1, edgecolor="none") plt.xlim(mu_range[0], mu_range[1]) plt.ylim(sigma_range[0], sigma_range[1]) plt.xlabel(r"$\mu$") plt.ylabel(r"$\sigma$") plt.axis("square") plt.show() az.plot_kde(samples["mu"], samples["sigma"]) plt.xlim(mu_range[0], mu_range[1]) plt.ylim(sigma_range[0], sigma_range[1]) plt.xlabel(r"$\mu$") plt.ylabel(r"$\sigma$") if plot_square: plt.axis("square") plt.savefig("figures/gauss_params_1d_post_grid.pdf", dpi=300) plt.show()
mean std median 2.5% 97.5% n_eff r_hat mu 154.39 1.75 154.34 150.91 157.78 4505.97 1.00 sigma 8.18 1.41 8.04 5.72 10.97 5065.49 1.00
Image in a Jupyter notebookImage in a Jupyter notebook

posterior marginals.

print(hpdi(samples["mu"], 0.95)) print(hpdi(samples["sigma"], 0.95)) fig, ax = plt.subplots() az.plot_kde(samples["mu"], ax=ax, label=r"$\mu$") fig, ax = plt.subplots() az.plot_kde(samples["sigma"], ax=ax, label=r"$\sigma$")
[150.909 157.778] [ 5.717 10.97 ]
<matplotlib.axes._subplots.AxesSubplot at 0x7fb3d076f550>
Image in a Jupyter notebookImage in a Jupyter notebook

Laplace approximation

See the documentation

Optimization

def model(data): mu = numpyro.sample("mu", mu_prior) sigma = numpyro.sample("sigma", sigma_prior) numpyro.sample("height", dist.Normal(mu, sigma), obs=data) guide = AutoLaplaceApproximation(model) svi = SVI(model, guide, optim.Adam(1), Trace_ELBO(), data=data) svi_result = svi.run(random.PRNGKey(0), 2000) plt.figure() plt.plot(svi_result.losses)
100%|██████████| 2000/2000 [00:01<00:00, 1144.90it/s, init loss: 269.4353, avg. loss [1901-2000]: 75.2568]
[<matplotlib.lines.Line2D at 0x7fb3ae842310>]
Image in a Jupyter notebook
start = {"mu": data.mean(), "sigma": data.std()} guide = AutoLaplaceApproximation(model, init_loc_fn=init_to_value(values=start)) svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), data=data) svi_result = svi.run(random.PRNGKey(0), 2000) plt.figure() plt.plot(svi_result.losses)
100%|██████████| 2000/2000 [00:01<00:00, 1128.19it/s, init loss: 75.2585, avg. loss [1901-2000]: 75.2447]
[<matplotlib.lines.Line2D at 0x7fb3aea58b50>]
Image in a Jupyter notebook

Posterior samples.

samples = guide.sample_posterior(random.PRNGKey(1), svi_result.params, (nsamples,))
print_summary(samples, 0.95, False) plt.scatter(samples["mu"], samples["sigma"], s=64, alpha=0.1, edgecolor="none") plt.xlim(mu_range[0], mu_range[1]) plt.ylim(sigma_range[0], sigma_range[1]) plt.xlabel(r"$\mu$") plt.ylabel(r"$\sigma$") plt.show() az.plot_kde(samples["mu"], samples["sigma"]) plt.xlim(mu_range[0], mu_range[1]) plt.ylim(sigma_range[0], sigma_range[1]) plt.xlabel(r"$\mu$") plt.ylabel(r"$\sigma$") if plot_square: plt.axis("square") plt.savefig("figures/gauss_params_1d_post_laplace.pdf", dpi=300) plt.show()
mean std median 2.5% 97.5% n_eff r_hat mu 154.32 1.68 154.31 151.06 157.57 4327.88 1.00 sigma 7.70 1.25 7.59 5.45 10.22 4173.88 1.00
Image in a Jupyter notebookImage in a Jupyter notebook
print(hpdi(samples["mu"], 0.95)) print(hpdi(samples["sigma"], 0.95)) fig, ax = plt.subplots() az.plot_kde(samples["mu"], ax=ax, label=r"$\mu$") fig, ax = plt.subplots() az.plot_kde(samples["sigma"], ax=ax, label=r"$\sigma$")
[151.06 157.569] [ 5.446 10.217]
<matplotlib.axes._subplots.AxesSubplot at 0x7fb3d0468990>
Image in a Jupyter notebookImage in a Jupyter notebook

Extract 2d joint posterior

The Gaussian approximation is over transformed parameters.

post = guide.get_posterior(svi_result.params) print(post.mean) print(post.covariance_matrix)
[154.335 -1.716] [[2.885 0.01 ] [0.01 0.036]]
def logit(p): return jnp.log(p / (1 - p)) def sigmoid(a): return 1 / (1 + jnp.exp(-a)) scale = 50 print(logit(7.7 / scale)) print(sigmoid(-1.7) * scale)
-1.7035668 7.723263
unconstrained_samples = post.sample(rng_key, sample_shape=(nsamples,)) constrained_samples = guide._unpack_and_constrain(unconstrained_samples, svi_result.params) print(unconstrained_samples.shape) print(jnp.mean(unconstrained_samples, axis=0)) print(jnp.mean(constrained_samples["mu"], axis=0)) print(jnp.mean(constrained_samples["sigma"], axis=0))
(5000, 2) [154.326 -1.724] 154.32643 7.6484103

We can sample from the posterior, which return results in the original parameterization.

samples = guide.sample_posterior(random.PRNGKey(1), params, (nsamples,)) x = jnp.stack(list(samples.values()), axis=0) print(x.shape) print("mean of ssamples\n", jnp.mean(x, axis=1)) vcov = jnp.cov(x) print("cov of samples\n", vcov) # variance-covariance matrix # correlation matrix R = vcov / jnp.sqrt(jnp.outer(jnp.diagonal(vcov), jnp.diagonal(vcov))) print("corr of samples\n", R)
(2, 5000) mean of ssamples [154.324 7.702] cov of samples [[2.839 0.051] [0.051 1.56 ]] corr of samples [[1. 0.024] [0.024 1. ]]

Variational inference

We use q(μ,σ)=N(μ∣m,s)Ga(σ∣a,b)q(\mu,\sigma) = N(\mu|m,s) Ga(\sigma|a,b)

def guide(data): data_mean = jnp.mean(data) data_std = jnp.std(data) m = numpyro.param("m", data_mean) s = numpyro.param("s", 10, constraint=constraints.positive) a = numpyro.param("a", data_std, constraint=constraints.positive) b = numpyro.param("b", 1, constraint=constraints.positive) mu = numpyro.sample("mu", dist.Normal(m, s)) sigma = numpyro.sample("sigma", dist.Gamma(a, b)) optimizer = numpyro.optim.Momentum(step_size=0.001, mass=0.1) svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) nsteps = 2000 svi_result = svi.run(rng_key_, nsteps, data=data) print(svi_result.params) print(svi_result.losses.shape) plt.plot(svi_result.losses) plt.title("ELBO") plt.xlabel("step") plt.ylabel("loss");
100%|██████████| 2000/2000 [00:03<00:00, 518.24it/s, init loss: 75.1782, avg. loss [1901-2000]: 74.7501]
{'a': DeviceArray(22.394, dtype=float32), 'b': DeviceArray(2.443, dtype=float32), 'm': DeviceArray(154.246, dtype=float32), 's': DeviceArray(1.785, dtype=float32)} (2000,)
Image in a Jupyter notebook

Extract Variational parameters.

print(svi_result.params) a = np.array(svi_result.params["a"]) b = np.array(svi_result.params["b"]) m = np.array(svi_result.params["m"]) s = np.array(svi_result.params["s"])
{'a': DeviceArray(22.394, dtype=float32), 'b': DeviceArray(2.443, dtype=float32), 'm': DeviceArray(154.246, dtype=float32), 's': DeviceArray(1.785, dtype=float32)}
print("empirical mean", jnp.mean(data)) print("empirical std", jnp.std(data)) print(r"posterior mean and std of $\mu$") post_mean = dist.Normal(m, s) print([post_mean.mean, jnp.sqrt(post_mean.variance)]) print(r"posterior mean and std of unconstrained $\sigma$") post_sigma = dist.Gamma(a, b) print([post_sigma.mean, jnp.sqrt(post_sigma.variance)])
empirical mean 154.16325 empirical std 7.459859 posterior mean and std of $\mu$ [array(154.246, dtype=float32), DeviceArray(1.785, dtype=float32)] posterior mean and std of unconstrained $\sigma$ [9.165675, DeviceArray(1.937, dtype=float32)]

Posterior samples

predictive = Predictive(guide, params=svi_result.params, num_samples=nsamples) samples = predictive(rng_key, data)
print_summary(samples, 0.95, False) plt.scatter(samples["mu"], samples["sigma"], s=64, alpha=0.1, edgecolor="none") plt.xlim(mu_range[0], mu_range[1]) plt.ylim(sigma_range[0], sigma_range[1]) plt.xlabel(r"$\mu$") plt.ylabel(r"$\sigma$") plt.show() az.plot_kde(samples["mu"], samples["sigma"]) plt.xlim(mu_range[0], mu_range[1]) plt.ylim(sigma_range[0], sigma_range[1]) plt.xlabel(r"$\mu$") plt.ylabel(r"$\sigma$") if plot_square: plt.axis("square") plt.savefig("figures/gauss_params_1d_post_vi.pdf", dpi=300) plt.show()
mean std median 2.5% 97.5% n_eff r_hat mu 154.32 1.79 154.29 150.85 157.88 4950.79 1.00 sigma 9.17 1.91 9.01 5.56 12.88 5093.46 1.00
Image in a Jupyter notebookImage in a Jupyter notebook
print(hpdi(samples["mu"], 0.95)) print(hpdi(samples["sigma"], 0.95)) fig, ax = plt.subplots() az.plot_kde(samples["mu"], ax=ax, label=r"$\mu$") fig, ax = plt.subplots() az.plot_kde(samples["sigma"], ax=ax, label=r"$\sigma$")
[150.846 157.881] [ 5.559 12.877]
<matplotlib.axes._subplots.AxesSubplot at 0x7fb3d0e4cd90>
Image in a Jupyter notebookImage in a Jupyter notebook

MCMC

conditioned_model = numpyro.handlers.condition(model, {"data": data}) nuts_kernel = NUTS(conditioned_model) mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=nsamples) mcmc.run(rng_key_, data) mcmc.print_summary() samples = mcmc.get_samples()
sample: 100%|██████████| 5100/5100 [00:19<00:00, 266.79it/s, 3 steps of size 1.66e-01. acc. prob=0.83]
mean std median 5.0% 95.0% n_eff r_hat mu 154.34 1.86 154.33 151.36 157.39 3652.68 1.00 sigma 8.24 1.51 8.05 5.88 10.36 2806.00 1.00 Number of divergences: 0
print_summary(samples, 0.95, False) plt.scatter(samples["mu"], samples["sigma"], s=64, alpha=0.1, edgecolor="none") plt.xlim(mu_range[0], mu_range[1]) plt.ylim(sigma_range[0], sigma_range[1]) plt.xlabel(r"$\mu$") plt.ylabel(r"$\sigma$") plt.show() az.plot_kde(samples["mu"], samples["sigma"]) plt.xlim(mu_range[0], mu_range[1]) plt.ylim(sigma_range[0], sigma_range[1]) plt.xlabel(r"$\mu$") plt.ylabel(r"$\sigma$") if plot_square: plt.axis("square") plt.savefig("figures/gauss_params_1d_post_mcmc.pdf", dpi=300) plt.show()
mean std median 2.5% 97.5% n_eff r_hat mu 154.34 1.86 154.33 150.70 158.08 3652.68 1.00 sigma 8.24 1.51 8.05 5.67 11.04 2806.00 1.00
Image in a Jupyter notebookImage in a Jupyter notebook