Path: blob/master/deprecated/scripts/bayes_unigauss_2d_pyro.py
1192 views
# Approximate 2d posterior using pyro SVI1# https://www.ritchievink.com/blog/2019/06/10/bayesian-inference-how-we-are-able-to-chase-the-posterior/2# We use the same data and model as in posteriorGrid2d.py34import superimport56import numpy as np7import matplotlib.pyplot as plt8from scipy import stats91011import pyro12import pyro.distributions as dist13import pyro.optim14from pyro.infer import SVI, Trace_ELBO15import torch16import torch.distributions.constraints as constraints17import numpy as np1819figdir = "../figures"20import os21def save_fig(fname):22if figdir: plt.savefig(os.path.join(figdir, fname))2324np.random.seed(0)2526data = np.array([195, 182])2728# lets create a grid of our two parameters29mu = np.linspace(150, 250)30sigma = np.linspace(0, 15)[::-1]31mm, ss = np.meshgrid(mu, sigma) # just broadcasted parameters32likelihood = stats.norm(mm, ss).pdf(data[0]) * stats.norm(mm, ss).pdf(data[1])33aspect = mm.max() / ss.max() / 334extent = [mm.min(), mm.max(), ss.min(), ss.max()]35# extent = left right bottom top3637prior = stats.norm(200, 15).pdf(mm) * stats.cauchy(0, 10).pdf(ss)38# Posterior - grid39unnormalized_posterior = prior * likelihood40posterior = unnormalized_posterior / np.nan_to_num(unnormalized_posterior).sum()4142plt.figure()43plt.imshow(posterior, cmap='Blues', aspect=aspect, extent=extent)44plt.xlabel(r'$\mu$')45plt.ylabel(r'$\sigma$')46plt.title('Grid approximation')47plt.show()4849def model():50# priors51mu = pyro.sample('mu', dist.Normal(loc=torch.tensor(200.),52scale=torch.tensor(15.)))53sigma = pyro.sample('sigma', dist.HalfCauchy(scale=torch.tensor(10.)))5455# likelihood56with pyro.plate('plate', size=2):57pyro.sample(f'obs', dist.Normal(loc=mu, scale=sigma),58obs=torch.tensor([195., 185.]))5960def guide():61# variational parameters62var_mu = pyro.param('var_mu', torch.tensor(180.))63var_mu_sig = pyro.param('var_mu_sig', torch.tensor(5.),64constraint=constraints.positive)65var_sig = pyro.param('var_sig', torch.tensor(5.))6667# factorized distribution68pyro.sample('mu', dist.Normal(loc=var_mu, scale=var_mu_sig))69pyro.sample('sigma', dist.Chi2(var_sig))7071pyro.clear_param_store()72pyro.enable_validation(True)7374svi = SVI(model, guide,75optim=pyro.optim.ClippedAdam({"lr":0.01}),76loss=Trace_ELBO())7778# do gradient steps79c = 080for step in range(1000):81c += 182loss = svi.step()83if step % 100 == 0:84print("[iteration {:>4}] loss: {:.4f}".format(c, loss))858687sigma = dist.Chi2(pyro.param('var_sig')).sample((10000,)).numpy()88mu = dist.Normal(pyro.param('var_mu'), pyro.param('var_mu_sig')).sample((10000,)).numpy()8990plt.figure()91plt.scatter(mu, sigma, alpha=0.01)92plt.xlim([extent[0], extent[1]])93plt.ylim([extent[2], extent[3]])94plt.ylabel('$\sigma$')95plt.xlabel('$\mu$')96plt.title('VI samples')97save_fig('bayes_unigauss_2d_pyro_post.pdf')98plt.show()99100