Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/03/schools8_blackjax.ipynb
1192 views
Kernel: Python [conda env:py3713]

8-School Problem: Hierarchical Bayesian model

author: @anandShegde, @karm-patel

In this notebook, we fit a hierarchical Bayesian model to the "8 schools" dataset. based on: https://github.com/probml/pyprobml/blob/master/notebooks/book2/03/schools8_pymc3.ipynb

%matplotlib inline import jax import jax.numpy as jnp import scipy.stats as stats import matplotlib.pyplot as plt import seaborn as sns import pandas as pd import logging logger = logging.getLogger() class CheckTypesFilter(logging.Filter): def filter(self, record): return "check_types" not in record.getMessage() logger.addFilter(CheckTypesFilter())
# !pip install -qq -U pymc3>=3.8 try: import blackjax except ModuleNotFoundError: %pip install -qq blackjax import blackjax try: import probml_utils as pml except: %pip install -qq git+https://github.com/probml/probml-utils.git import probml_utils as pml from probml_utils.blackjax_utils import inference_loop_multiple_chains, arviz_trace_from_states, inference_loop try: import arviz as az except ModuleNotFoundError: %pip install -qq arviz import arviz as az try: import tensorflow_probability.substrates.jax as tfp except ModuleNotFoundError: %pip install -qq tensorflow_probability import tensorflow_probability.substrates.jax as tfp tfd = tfp.distributions
# import os # os.environ["LATEXIFY"] = "" # os.environ["FIG_DIR"] = "figures"

Data

pml.latexify(fig_width=2, fig_height=1.5)
/home/patel_karm/sendbox/probml-utils/probml_utils/plotting.py:25: UserWarning: LATEXIFY environment variable not set, not latexifying warnings.warn("LATEXIFY environment variable not set, not latexifying")
# https://github.com/probml/pyprobml/blob/master/scripts/schools8_pymc3.py # Data of the Eight Schools Model J = 8 treatment_effects = jnp.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) treatment_stddevs = jnp.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) names = [] for t in range(8): names.append("{}".format(t)) # Plot raw data fig, ax = plt.subplots() y_pos = jnp.arange(8) ax.errorbar(treatment_effects, y_pos, xerr=treatment_stddevs, fmt="o", elinewidth=1, markersize=3) ax.set_yticks(y_pos) ax.set_yticklabels(names) ax.invert_yaxis() # labels read top-to-bottom ax.axvline(jnp.mean(treatment_effects), color="r", linestyle="--", label="pooled MLE") ax.set_ylabel("$\\theta$") sns.despine() pml.savefig("schools8_data") plt.show()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) /home/patel_karm/sendbox/probml-utils/probml_utils/plotting.py:84: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebook

Centered model

mu_prior = tfd.Normal(loc=0, scale=5) sigma_prior = tfd.Cauchy(loc=0, scale=5) positive_fn = jax.nn.softplus def log_post_fn_centered(params): mu, log_tou, theta = params["mu"], params["log_tou"], params["theta"].reshape((-1,)) log_prior_mu = mu_prior.log_prob(mu) log_prior_tou = sigma_prior.log_prob(log_tou) # change of variable tou = positive_fn(log_tou) log_tou_jacob = jnp.log(jnp.abs(jax.jacfwd(positive_fn)(log_tou))) log_prior_theta = jnp.sum(tfd.Normal(mu, tou).log_prob(theta)) log_like = 0 log_like += jnp.sum(tfd.Normal(theta, treatment_stddevs).log_prob(treatment_effects)) return log_like + log_prior_mu + log_prior_theta + log_prior_tou + log_tou_jacob
num_samples = 10_000 num_chains = 4 num_warmup = 25_000
key = jax.random.PRNGKey(311) key_samples, key_init, key_warmup, key = jax.random.split(key, 4)
num_chains = 4 keys_warmup = jax.random.split(key_warmup, num_chains) mu_inital = mu_prior.sample(seed=key_init, sample_shape=num_chains) sigma_initial = sigma_prior.sample(seed=key_init, sample_shape=num_chains) theta_initial = tfd.Normal(loc=mu_inital, scale=sigma_initial).sample(seed=key_init, sample_shape=8).T params_centerd = {"mu": mu_inital, "log_tou": sigma_initial, "theta": theta_initial}
params_centerd_one = {"mu": mu_inital[0], "log_tou": sigma_initial[0], "theta": theta_initial[0]}
%%time adapt = blackjax.window_adaptation(blackjax.nuts, log_post_fn_centered, num_warmup) final_states_cent = jax.vmap(lambda key, param: adapt.run(key, param)[0])(keys_warmup, params_centerd) final_state, kernel, _ = adapt.run(keys_warmup[1], params_centerd_one) states_cent, infos_cent = inference_loop_multiple_chains( key_samples, kernel, final_states_cent, num_samples, num_chains )
CPU times: user 19.5 s, sys: 256 ms, total: 19.8 s Wall time: 19.6 s
infos_cent.is_divergent.sum(axis=0)
DeviceArray([ 1314, 2165, 10000, 10000], dtype=int32)
states_cent.position["tou"] = positive_fn(states_cent.position["log_tou"]) del states_cent.position["log_tou"] trace_centered = arviz_trace_from_states(states_cent, infos_cent) az.plot_trace(trace_centered);
/home/patel_karm/anaconda3/envs/py3713/lib/python3.7/site-packages/arviz/stats/density_utils.py:491: UserWarning: Your data appears to have a single value or no finite values warnings.warn("Your data appears to have a single value or no finite values") /home/patel_karm/anaconda3/envs/py3713/lib/python3.7/site-packages/arviz/stats/density_utils.py:491: UserWarning: Your data appears to have a single value or no finite values warnings.warn("Your data appears to have a single value or no finite values") /home/patel_karm/anaconda3/envs/py3713/lib/python3.7/site-packages/arviz/stats/density_utils.py:491: UserWarning: Your data appears to have a single value or no finite values warnings.warn("Your data appears to have a single value or no finite values")
Image in a Jupyter notebook
az.summary(trace_centered)
# Display the total number and percentage of divergent chains print("Number of Divergent Chains: {}".format(infos_cent.is_divergent.sum())) diverging_pct = infos_cent.is_divergent.sum() / (num_samples * num_chains) * 100 print("Percentage of Divergent Chains: {:.1f}".format(diverging_pct))
Number of Divergent Chains: 23479 Percentage of Divergent Chains: 58.7

Non-centered

mu_prior = tfd.Normal(loc=0, scale=5) sigma_prior = tfd.Cauchy(loc=0, scale=5) positive_fn = jax.nn.softplus def log_post_fn_non_centered(params): mu, log_tou, eta = params["mu"], params["log_tou"], params["eta"] log_prior_mu = mu_prior.log_prob(mu) # change of variable tou = positive_fn(log_tou) log_prior_tou = sigma_prior.log_prob(log_tou) log_tou_jacob = jnp.log(jnp.abs(jax.jacfwd(positive_fn)(log_tou))) log_prior_eta = jnp.sum(tfd.Normal(0, 1).log_prob(eta)) log_like = 0 log_like += jnp.sum(tfd.Normal(mu + tou * eta, treatment_stddevs).log_prob(treatment_effects)) return log_like + log_prior_mu + log_prior_eta + log_prior_tou + log_tou_jacob
potential = log_post_fn_non_centered key = jax.random.PRNGKey(315) key_samples, key_init, key_warmup, key = jax.random.split(key, 4) mu_inital = mu_prior.sample(seed=key_init, sample_shape=num_chains) sigma_initial = sigma_prior.sample(seed=key_init, sample_shape=num_chains) eta_initial = tfd.Normal(loc=[0] * 4, scale=1).sample(seed=key_init, sample_shape=8).T params_noncenterd = { "mu": mu_inital, "log_tou": sigma_initial, "eta": eta_initial, } params_noncenterd_one = { "mu": mu_inital[0], "log_tou": sigma_initial[0], "eta": eta_initial[0], }
%%time adapt = blackjax.window_adaptation(blackjax.nuts, log_post_fn_non_centered, num_warmup) final_states_non_cent = jax.vmap(lambda key, param: adapt.run(key, param)[0])(keys_warmup, params_noncenterd) final_state, kernel, _ = adapt.run(keys_warmup[1], params_noncenterd_one) states_non_cent, infos_non_cent = inference_loop_multiple_chains( key_samples, kernel, final_states_non_cent, num_samples, num_chains )
CPU times: user 16.1 s, sys: 193 ms, total: 16.3 s Wall time: 16.1 s
states_non_cent.position["tou"] = positive_fn(states_non_cent.position["log_tou"]) del states_non_cent.position["log_tou"]
infos_non_cent.momentum["theta"] = infos_non_cent.momentum["eta"] tou = jax.vmap(lambda x: jnp.array([x] * 8), in_axes=1, out_axes=1)(states_non_cent.position["tou"]).T mu = jax.vmap(lambda x: jnp.array([x] * 8), in_axes=1, out_axes=1)(states_non_cent.position["mu"]).T theta = states_non_cent.position["eta"] * tou + mu states_non_cent.position["theta"] = theta
infos_non_cent.is_divergent.sum(axis=0)
DeviceArray([0, 0, 0, 0], dtype=int32)
trace_noncentered = arviz_trace_from_states(states_non_cent, infos_non_cent)
az.summary(trace_noncentered)
az.plot_trace(trace_noncentered) plt.tight_layout()
Image in a Jupyter notebook
num_schools = 8 burnin = 300 # thetas = states_non_cent.position['eta']*jnp.mean(states_non_cent.position['tou'][burnin:])+jnp.mean(states_non_cent.position['mu'][burnin:]) thetas = states_non_cent.position["theta"] fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4)) ax1.bar(range(num_schools), treatment_effects, yerr=treatment_stddevs) ax2.bar( range(num_schools), jnp.mean(jnp.mean(thetas[burnin:, :], axis=0), axis=0), yerr=jnp.mean(jnp.std(thetas[burnin:, :], axis=0), axis=0), ) ax1.plot([0, 8], [treatment_effects.mean(), treatment_effects.mean()], color="r", label="global mean") ax2.plot([0, 8], [treatment_effects.mean(), treatment_effects.mean()], color="r", label="global mean") ax1.set_xlabel("School") ax1.set_ylabel("Treatment effect") ax1.set_title("Without Pooling") ax2.set_title("Non Centerd partial Pooling ") sns.despine() plt.suptitle("8 Schools treatment effects") plt.xlabel("School") plt.ylabel("Treatment effect") plt.show()
Image in a Jupyter notebook
pml.latexify(fig_width=2)
/home/patel_karm/sendbox/probml-utils/probml_utils/plotting.py:25: UserWarning: LATEXIFY environment variable not set, not latexifying warnings.warn("LATEXIFY environment variable not set, not latexifying")
# with NonCentered_eight: fig, ax = plt.subplots(figsize=(4, 3)) forest = az.plot_forest(trace_noncentered, var_names="theta", combined=True, hdi_prob=0.95, ax=ax, textsize=16) forest[0].set_title("95\% Credible Interval", fontsize=16) plt.axvline(jnp.mean(states_non_cent.position["theta"].mean()), color="k", linestyle="--") forest[0].set_xticks([-10, -5, 0, 5, 10, 15, 20]) pml.savefig("hbayes_schools8_forest")
/home/patel_karm/sendbox/probml-utils/probml_utils/plotting.py:84: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebook
pml.latexify(width_scale_factor=2) post_plot = az.plot_posterior(trace_noncentered, var_names="tou", hdi_prob=0.95) post_plot.set_title("$\\tau$") text = post_plot.get_children()[5] text.set_text("95\% HPD") pml.savefig("hbayes_schools8_tau")
/home/patel_karm/sendbox/probml-utils/probml_utils/plotting.py:25: UserWarning: LATEXIFY environment variable not set, not latexifying warnings.warn("LATEXIFY environment variable not set, not latexifying")
Image in a Jupyter notebook
from cProfile import label az.plot_forest( [trace_centered, trace_noncentered], model_names=["centered", "noncentered"], var_names="theta", combined=True, hdi_prob=0.95, ) plt.axvline(jnp.mean(treatment_effects), color="k", linestyle="--")
<matplotlib.lines.Line2D at 0x7fe74837e150>
Image in a Jupyter notebook
az.plot_forest( [trace_centered, trace_noncentered], model_names=["centered", "noncentered"], var_names=["theta"], kind="ridgeplot", combined=True, hdi_prob=0.95, );
Image in a Jupyter notebook

Funnel of hell

pml.latexify(width_scale_factor=1.5, fig_height=2) FIG_SIZE = None if pml.is_latexify_enabled() else (10, 5)
/home/patel_karm/sendbox/probml-utils/probml_utils/plotting.py:25: UserWarning: LATEXIFY environment variable not set, not latexifying warnings.warn("LATEXIFY environment variable not set, not latexifying")
# Plot the "funnel of hell" # Based on # https://github.com/twiecki/WhileMyMCMCGentlySamples/blob/master/content/downloads/notebooks/GLM_hierarchical_non_centered.ipynb burnin = num_samples // 4 group = 0 fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True, figsize=FIG_SIZE) x = pd.Series(states_cent.position["theta"][burnin:, 0, group], name=f"alpha {group}") y = pd.Series(jnp.log(states_cent.position["tou"][burnin:, 0]), name="log_sigma_alpha") axs[0].plot(x, y, ".", markersize=1) axs[0].set(title="Centered", xlabel=r"$\theta_0$", ylabel=r"$\log(\tau)$") axs[0].set_xticks([-10, 10, 30, 50]) x = pd.Series(states_non_cent.position["theta"][burnin:, 0, group], name=f"alpha {group}") y = pd.Series(jnp.log(states_non_cent.position["tou"][burnin:, 0]), name="log_sigma_alpha") axs[1].plot(x, y, ".", markersize=1) axs[1].set(title="NonCentered", xlabel=r"$\theta_0$", ylabel=r"$\log(\tau)$") sns.despine() pml.savefig("schools8_funnel_group0")
/home/patel_karm/sendbox/probml-utils/probml_utils/plotting.py:84: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebook
xlim = axs[0].get_xlim() ylim = axs[0].get_ylim()
x = pd.Series(states_cent.position["mu"][:, 0], name="mu") y = pd.Series(jnp.log(states_cent.position["tou"][:, 0]), name="log sigma_alpha") sns.jointplot(x, y, xlim=xlim, ylim=ylim) plt.suptitle("centered") plt.savefig("schools8_centered_joint.png", dpi=300) x = pd.Series(states_non_cent.position["mu"][:, 0], name="mu") y = pd.Series(jnp.log(states_non_cent.position["tou"][:, 0]), name="log sigma_alpha") sns.jointplot(x, y, xlim=xlim, ylim=ylim) plt.suptitle("noncentered") plt.savefig("schools8_noncentered_joint.png", dpi=300)
/home/patel_karm/anaconda3/envs/py3713/lib/python3.7/site-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. FutureWarning /home/patel_karm/anaconda3/envs/py3713/lib/python3.7/site-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. FutureWarning
Image in a Jupyter notebookImage in a Jupyter notebook
for group in range(8): burnin = num_samples // 4 fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True, figsize=(10, 5)) x = pd.Series(states_cent.position["theta"][burnin:, 0, group], name=f"alpha {group}") y = pd.Series(jnp.log(states_cent.position["tou"][burnin:, 0]), name="log_sigma_alpha") axs[0].plot(x, y, ".") axs[0].set(title="Centered", xlabel=r"$\alpha_0$", ylabel=r"$\log(\sigma_\alpha)$") x = pd.Series(states_non_cent.position["theta"][burnin:, 0, group], name=f"alpha {group}") y = pd.Series(jnp.log(states_non_cent.position["tou"][burnin:, 0]), name="log_sigma_alpha") axs[1].plot(x, y, ".") axs[1].set(title="NonCentered", xlabel=r"$\alpha_0$", ylabel=r"$\log(\sigma_\alpha)$") xlim = axs[0].get_xlim() ylim = axs[0].get_ylim() sns.despine() plt.savefig(f"schools8_funnel_group{group}.png", dpi=300) plt.show()
Image in a Jupyter notebookImage in a Jupyter notebookImage in a Jupyter notebookImage in a Jupyter notebookImage in a Jupyter notebookImage in a Jupyter notebookImage in a Jupyter notebookImage in a Jupyter notebook