Path: blob/master/deprecated/scripts/beta_binom_approx_post_pymc3.py
1192 views
# 1d approixmation to beta binomial model1# https://github.com/aloctavodia/BAP234import superimport56import pymc3 as pm7import numpy as np8import seaborn as sns9import scipy.stats as stats10import matplotlib.pyplot as plt11import arviz as az12import math13import pyprobml_utils as pml1415#data = np.repeat([0, 1], (10, 3))16data = np.repeat([0, 1], (10, 1))17h = data.sum()18t = len(data) - h1920# Exact2122plt.figure()23x = np.linspace(0, 1, 100)24xs = x #grid25dx_exact = xs[1]-xs[0]26post_exact = stats.beta.pdf(xs, h+1, t+1)27post_exact = post_exact / np.sum(post_exact)28plt.plot(xs, post_exact)29plt.yticks([])30plt.title('exact posterior')31pml.savefig('bb_exact.pdf')323334# Grid35def posterior_grid(heads, tails, grid_points=100):36grid = np.linspace(0, 1, grid_points)37prior = np.repeat(1/grid_points, grid_points) # uniform prior38likelihood = stats.binom.pmf(heads, heads+tails, grid)39posterior = likelihood * prior40posterior /= posterior.sum()41#posterior = posterior * grid_points42return grid, posterior434445n = 2046grid, posterior = posterior_grid(h, t, n)47dx_grid = grid[1] - grid[0]48sf = dx_grid / dx_exact # Jacobian scale factor49plt.figure()50#plt.stem(grid, posterior, use_line_collection=True)51plt.bar(grid, posterior, width=1/n, alpha=0.2)52plt.plot(xs, post_exact*sf)53plt.title('grid approximation')54plt.yticks([])55plt.xlabel('θ');56pml.savefig('bb_grid.pdf')575859# Laplace60with pm.Model() as normal_aproximation:61theta = pm.Beta('theta', 1., 1.)62y = pm.Binomial('y', n=1, p=theta, observed=data) # Bernoulli63mean_q = pm.find_MAP()64std_q = ((1/pm.find_hessian(mean_q, vars=[theta]))**0.5)[0]65mu = mean_q['theta']6667print([mu, std_q])6869plt.figure()70plt.plot(xs, stats.norm.pdf(xs, mu, std_q), '--', label='Laplace')71post_exact = stats.beta.pdf(xs, h+1, t+1)72plt.plot(xs, post_exact, label='exact')73plt.title('Quadratic approximation')74plt.xlabel('θ', fontsize=14)75plt.yticks([])76plt.legend()77pml.savefig('bb_laplace.pdf');78798081# HMC82with pm.Model() as hmc_model:83theta = pm.Beta('theta', 1., 1.)84y = pm.Binomial('y', n=1, p=theta, observed=data) # Bernoulli85trace = pm.sample(1000, random_seed=42, cores=1, chains=2)86thetas = trace['theta']87axes = az.plot_posterior(thetas, hdi_prob=0.95)88pml.savefig('bb_hmc.pdf');8990az.plot_trace(trace)91pml.savefig('bb_hmc_trace.pdf', dpi=300)9293# ADVI94with pm.Model() as mf_model:95theta = pm.Beta('theta', 1., 1.)96y = pm.Binomial('y', n=1, p=theta, observed=data) # Bernoulli97mean_field = pm.fit(method='advi')98trace_mf = mean_field.sample(1000)99thetas = trace_mf['theta']100axes = az.plot_posterior(thetas, hdi_prob=0.95)101pml.savefig('bb_mf.pdf');102103plt.show()104105106# track mean and std107with pm.Model() as mf_model:108theta = pm.Beta('theta', 1., 1.)109y = pm.Binomial('y', n=1, p=theta, observed=data) # Bernoulli110advi = pm.ADVI()111tracker = pm.callbacks.Tracker(112mean=advi.approx.mean.eval, # callable that returns mean113std=advi.approx.std.eval # callable that returns std114)115approx = advi.fit(callbacks=[tracker])116117trace_approx = approx.sample(1000)118thetas = trace_approx['theta']119120plt.figure()121plt.plot(tracker['mean'])122plt.title('Mean')123pml.savefig('bb_mf_mean.pdf');124125plt.figure()126plt.plot(tracker['std'])127plt.title('Std ')128pml.savefig('bb_mf_std.pdf');129130plt.figure()131plt.plot(advi.hist)132plt.title('Negative ELBO');133pml.savefig('bb_mf_elbo.pdf');134135plt.figure()136sns.kdeplot(thetas);137plt.title('KDE of posterior samples')138pml.savefig('bb_mf_kde.pdf');139140141fig,axs = plt.subplots(1,4, figsize=(30,10))142mu_ax = axs[0]143std_ax = axs[1]144elbo_ax = axs[2]145kde_ax = axs[3]146mu_ax.plot(tracker['mean'])147mu_ax.set_title('Mean')148std_ax.plot(tracker['std'])149std_ax.set_title('Std ')150elbo_ax.plot(advi.hist)151elbo_ax.set_title('Negative ELBO');152kde_ax = sns.kdeplot(thetas);153kde_ax.set_title('KDE of posterior samples')154pml.savefig('bb_mf_panel.pdf');155156157fig = plt.figure(figsize=(16, 9))158mu_ax = fig.add_subplot(221)159std_ax = fig.add_subplot(222)160hist_ax = fig.add_subplot(212)161mu_ax.plot(tracker['mean'])162mu_ax.set_title('Mean track')163std_ax.plot(tracker['std'])164std_ax.set_title('Std track')165hist_ax.plot(advi.hist)166hist_ax.set_title('Negative ELBO track');167pml.savefig('bb_mf_tracker.pdf');168169trace_approx = approx.sample(1000)170thetas = trace_approx['theta']171axes = az.plot_posterior(thetas, hdi_prob=0.95)172173plt.show()174175