Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/12/mcmc_gmm_demo.ipynb
1192 views
Kernel: Python [conda env:py3713]

Metropolis Hastings, Gibbs sampling and Autocorrelation plots

## Metropolis Hastings (MH) Sampling from a mixture of two 1d Gaussians # using a 1d Gaussian proposal with different sigma. # Author: Gerardo Duran-Martin (@gerdm) # !pip install -qq matplotlib==3.4.2 import logging logger = logging.getLogger() class CheckTypesFilter(logging.Filter): def filter(self, record): return "check_types" not in record.getMessage() and "jax_enable_x64" not in record.getMessage() logger.addFilter(CheckTypesFilter()) import warnings warnings.filterwarnings("ignore") import numpy as np import matplotlib.pyplot as plt # from numpy.random import rand try: import probml_utils as pml except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git import probml_utils as pml try: import statsmodels.api as sm except ModuleNotFoundError: %pip install -qq statsmodels import statsmodels.api as sm import os import matplotlib try: import tensorflow_probability.substrates.jax as tfp except ModuleNotFoundError: %pip install -qq tensorflow-probability import tensorflow_probability.substrates.jax as tfp import jax import seaborn as sns import jax.numpy as jnp
[<matplotlib.lines.Line2D at 0x7f3602ab7450>]
Image in a Jupyter notebook
def plot_gmm_3d_trace(trace_hist, proobs, mu, scale, title, xmin, xmax, ax, h=1, n_eval=500): norm_mixture = tfp.distributions.MixtureSameFamily( mixture_distribution=tfp.distributions.Categorical(probs=probs), components_distribution=tfp.distributions.Normal(loc=mu, scale=scale), ) x_eval = jnp.linspace(xmin, xmax, n_eval) kde_eval = pml.kdeg(x_eval[:, None], trace_hist[:, None], h) px = norm_mixture.prob(x_eval[:, None]).reshape( -1, ) ax.plot(jnp.arange(n_iterations), trace_hist) ax.plot(jnp.zeros(n_eval), x_eval, px, c="tab:red", linewidth=2) ax.plot(jnp.zeros(n_eval), x_eval, kde_eval, c="tab:blue") ax.set_zlim(0, kde_eval.max() * 1.1) ax.set_xlabel("Iterations", fontsize=18) ax.set_ylabel("Samples", fontsize=18) ax.tick_params(labelsize=12) ax.view_init(25, -30) if title: ax.set_title(title, fontsize=18)

Metropolis hastings

def metropolis_sample_jax(x0, tau, probs, mu, scale, n_iterations, rng_key): x_curr = x0 norm_mixture = tfp.distributions.MixtureSameFamily( mixture_distribution=tfp.distributions.Categorical(probs=probs), components_distribution=tfp.distributions.Normal(loc=mu, scale=scale), ) x_samples = np.zeros(n_iterations) x_samples[0] = x_curr rng_keys = jax.random.split(rng_key, num=n_iterations) for n in range(1, n_iterations): x_candidate = x_curr + tau * tfp.distributions.Normal(loc=0, scale=1).sample(seed=rng_keys[n]) p_candidate = norm_mixture.prob(x_candidate) p_curr = norm_mixture.prob(x_curr) alpha = p_candidate / p_curr p_accept = min(1, alpha) key = jax.random.split(rng_keys[n], 1)[0] u = jax.random.uniform(key) # u = rand() x_curr = x_curr if u >= p_accept else x_candidate x_samples[n] = x_curr x_samples = jnp.array(x_samples) return x_samples
x0 = 20.0 # initial value of sample probs = np.array([0.3, 0.7]) # weights for each normal distribution in GMM mu = np.array([-20, 20.0]) # mean of two normal distribution component scale = np.array([10.0, 10.0]) # scale of two normal distribution component n_iterations = 1000 xmin, xmax = -100, 100
# os.environ["FIG_DIR"] = "figures/" # os.environ["LATEXIFY"]= ""
# modification of latexify plt.rcParams["pdf.fonttype"] = 42 plt.rcParams["backend"] = "ps" plt.rc("text", usetex=True)
FIG_SIZE_TRACE = (3.6, 3) if pml.is_latexify_enabled() else None def plot_trace_outer_fun(tau, x_samples): fig = plt.figure(figsize=FIG_SIZE_TRACE) axs = plt.axes(projection="3d") plot_gmm_3d_trace( x_samples, probs, mu, scale, "MH with $\mathcal{N}" + f"(0,{int(tau)}^2)$ proposal", xmin, xmax, axs ) pml.style3d(axs, 1.5, 1, 0.8) plt.subplots_adjust(left=0.001, bottom=0.208) pml.savefig(f"mh_trace_{tau}tau.pdf", pad_inches=0, tight_bbox=True)

Ï„=1\tau=1

Note: Here τ2\tau^2 is the variance of the proposal distribution - q(x′∣x)=N(x′∣x,τ2I)q(x'|x) = \mathcal{N}(x'|x,\tau^2\mathcal{I})

tau = 1 rng_key = jax.random.PRNGKey(10) x_samples_1 = metropolis_sample_jax(x0, tau, probs, mu, scale, n_iterations, rng_key) plot_trace_outer_fun(tau, x_samples_1)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Image in a Jupyter notebook

Ï„=8\tau=8

tau = 8 rng_key = jax.random.PRNGKey(31) x_samples_8 = metropolis_sample_jax(x0, tau, probs, mu, scale, n_iterations, rng_key) plot_trace_outer_fun(tau, x_samples_8)
Image in a Jupyter notebook

Ï„=500\tau=500

tau = 500 rng_key = jax.random.PRNGKey(314) x_samples_500 = metropolis_sample_jax(x0, tau, probs, mu, scale, n_iterations, rng_key) plot_trace_outer_fun(tau, x_samples_500)
Image in a Jupyter notebook

Gibbs sampling

xmin, xmax = -100, 100 x0, z0 = 20, 0 tau = 8.0 probs = np.array([0.3, 0.7]) mu = np.array([-20, 20.0]) scale = np.array([10.0, 10.0]) kv = np.arange(2) n_iterations = 1000
def gibbs_sample(x0, z0, kv, probs, mu, scale, n_iterations, rng_key): x_curr = x0 z_curr = z0 z_samples = np.zeros(n_iterations) x_samples = np.zeros(n_iterations) keys = jax.random.split(rng_key, n_iterations) for n in range(n_iterations): pz = tfp.distributions.Normal(loc=mu, scale=scale).prob(x_curr) * probs pz = pz / pz.sum() z_curr = jax.random.choice(keys[n], kv, p=pz) x_curr = tfp.distributions.Normal(loc=mu[z_curr], scale=scale[z_curr]).sample(seed=keys[n]) x_samples[n] = x_curr z_samples[n] = z_curr return x_samples, z_samples
rng_key = jax.random.PRNGKey(4) x_samples_gibbbs, z_samples_gibbs = gibbs_sample(x0, z0, kv, probs, mu, scale, n_iterations, rng_key)
# os.environ["FIG_DIR"] = "figures/" # os.environ["LATEXIFY"]= "" pml.latexify(width_scale_factor=2, fig_height=2)
colors = ["tab:blue" if z else "tab:red" for z in z_samples_gibbs] fig, axs = plt.subplots() axs.scatter(np.arange(n_iterations), x_samples_gibbbs, s=3, facecolors="none", edgecolors=colors) axs.set_xlabel("Sample number") axs.set_ylabel("Value of $x$") sns.despine() pml.savefig("gibbs_scatter.pdf")
Image in a Jupyter notebook
def plot_gmm_3d_trace_gibbs(trace_hist, π, μ, σ, title, xmin, xmax, ax, h=1, n_eval=500): norm_mixture = tfp.distributions.MixtureSameFamily( mixture_distribution=tfp.distributions.Categorical(probs=probs), components_distribution=tfp.distributions.Normal(loc=mu, scale=scale), ) x_eval = jnp.linspace(xmin, xmax, n_eval) kde_eval = pml.kdeg(x_eval[:, None], trace_hist[:, None], h) px = norm_mixture.prob(x_eval[:, None]).reshape( -1, ) ax.plot(np.arange(n_iterations), trace_hist) ax.plot(np.zeros(n_eval), x_eval, px, c="tab:red", linewidth=2) ax.plot(np.zeros(n_eval), x_eval, kde_eval, c="tab:blue") ax.set_zlim(0, kde_eval.max() * 1.2) ax.set_xlabel("Iterations", fontsize=None) ax.set_ylabel("Samples", fontsize=None) ax.tick_params(labelsize=None) ax.view_init(25, -30) if title: ax.set_title(title, fontsize=None)
fig = plt.figure() axs = plt.axes(projection="3d") plot_gmm_3d_trace_gibbs(x_samples_gibbbs, probs, mu, scale, None, xmin, xmax, axs) pml.style3d(axs, 1.5, 1, 0.8) plt.subplots_adjust(left=0.001, bottom=0.208, right=0.7) pml.savefig("gibbs_trace.pdf", tight_bbox=True, pad_inches=0)
Image in a Jupyter notebook

Autocorrelation plots

pml.latexify(width_scale_factor=2, fig_height=1.5) def plot_autocorrelation(x_samples, fig_name=None, title=None): fig, ax = plt.subplots() if not title: title = "MH with $\mathcal{N}" + f"(0,{int(tau)}^2)$ proposal" op = sm.graphics.tsa.plot_acf( x_samples, lags=45, alpha=None, title=title, ax=ax, markersize=3, ) ax.set_ylim(-0.1, 1.05) sns.despine() if fig_name: pml.savefig(fig_name)
tau = 1 plot_autocorrelation(x_samples_1, f"mh_autocorrelation_{tau}tau.pdf")
Image in a Jupyter notebook
tau = 8 plot_autocorrelation(x_samples_8, f"mh_autocorrelation_{tau}tau.pdf")
Image in a Jupyter notebook
tau = 500 plot_autocorrelation(x_samples_500, f"mh_autocorrelation_{tau}tau.pdf")
Image in a Jupyter notebook
plot_autocorrelation(x_samples_gibbbs, "gibbs_autocorrelation.pdf", "gibbs sampling")
Image in a Jupyter notebook