Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/misc/funnel_pymc3.ipynb
1192 views
Kernel: Python 3

Open In Colab

In this notebook, we explore the "funnel of hell". This refers to a posterior in which the mean and variance of a variable are highly correlated, and have a funnel shape. (The term "funnel of hell" is from this blog post by Thomas Wiecki.)

We illustrate this using a hierarchical Bayesian model for inferring Gaussian means, fit to synthetic data, similar to 8 schools (except we vary the same size and fix the variance). This code is based on this notebook from Justin Bois.

%matplotlib inline import sklearn import scipy.stats as stats import scipy.optimize import matplotlib.pyplot as plt import seaborn as sns import time import numpy as np import os import pandas as pd
!pip install -U pymc3>=3.8 import pymc3 as pm print(pm.__version__) import arviz as az print(az.__version__)
3.11.2 0.11.2
import math import pickle import numpy as np import pandas as pd import scipy.stats as st import theano.tensor as tt import theano
np.random.seed(0) # Specify parameters for random data mu_val = 8 tau_val = 3 sigma_val = 10 n_groups = 10 # Generate number of replicates for each repeat n = np.random.randint(low=3, high=10, size=n_groups, dtype=int) print(n) print(sum(n))
[7 8 3 6 6 6 4 6 8 5] 59
# Generate data set mus = np.zeros(n_groups) x = np.array([]) for i in range(n_groups): mus[i] = np.random.normal(mu_val, tau_val) samples = np.random.normal(mus[i], sigma_val, size=n[i]) x = np.append(x, samples) print(x.shape) group_ind = np.concatenate([[i] * n_val for i, n_val in enumerate(n)])
(59,)
with pm.Model() as centered_model: # Hyperpriors mu = pm.Normal("mu", mu=0, sd=5) tau = pm.HalfCauchy("tau", beta=2.5) log_tau = pm.Deterministic("log_tau", tt.log(tau)) # Prior on theta theta = pm.Normal("theta", mu=mu, sd=tau, shape=n_groups) # Likelihood x_obs = pm.Normal("x_obs", mu=theta[group_ind], sd=sigma_val, observed=x) np.random.seed(0) with centered_model: centered_trace = pm.sample(10000, chains=2) pm.summary(centered_trace).round(2)
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:20: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning. Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Sequential sampling (2 chains in 1 job) NUTS: [theta, tau, mu]
Sampling 2 chains for 1_000 tune and 10_000 draw iterations (2_000 + 20_000 draws total) took 39 seconds. There were 329 divergences after tuning. Increase `target_accept` or reparameterize. The acceptance probability does not match the target. It is 0.6086240934009718, but should be close to 0.8. Try to increase the number of tuning steps. There were 449 divergences after tuning. Increase `target_accept` or reparameterize. The acceptance probability does not match the target. It is 0.703261908881651, but should be close to 0.8. Try to increase the number of tuning steps. The estimated number of effective samples is smaller than 200 for some parameters. /usr/local/lib/python3.7/dist-packages/arviz/data/io_pymc3.py:100: FutureWarning: Using `from_pymc3` without the model will be deprecated in a future release. Not using the model will return less accurate and less useful results. Make sure you use the model argument or call from_pymc3 within a model context. FutureWarning,
with pm.Model() as noncentered_model: # Hyperpriors mu = pm.Normal("mu", mu=0, sd=5) tau = pm.HalfCauchy("tau", beta=2.5) log_tau = pm.Deterministic("log_tau", tt.log(tau)) # Prior on theta # theta = pm.Normal('theta', mu=mu, sd=tau, shape=n_trials) var_theta = pm.Normal("var_theta", mu=0, sd=1, shape=n_groups) theta = pm.Deterministic("theta", mu + var_theta * tau) # Likelihood x_obs = pm.Normal("x_obs", mu=theta[group_ind], sd=sigma_val, observed=x) np.random.seed(0) with noncentered_model: noncentered_trace = pm.sample(1000, chains=2) pm.summary(noncentered_trace).round(2)
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:20: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning. Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Sequential sampling (2 chains in 1 job) NUTS: [var_theta, tau, mu]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 6 seconds. There were 4 divergences after tuning. Increase `target_accept` or reparameterize. There were 59 divergences after tuning. Increase `target_accept` or reparameterize. The acceptance probability does not match the target. It is 0.6508565774982737, but should be close to 0.8. Try to increase the number of tuning steps. The estimated number of effective samples is smaller than 200 for some parameters. /usr/local/lib/python3.7/dist-packages/arviz/data/io_pymc3.py:100: FutureWarning: Using `from_pymc3` without the model will be deprecated in a future release. Not using the model will return less accurate and less useful results. Make sure you use the model argument or call from_pymc3 within a model context. FutureWarning,
fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True) x = pd.Series(centered_trace["mu"], name="mu") y = pd.Series(centered_trace["tau"], name="tau") axs[0].plot(x, y, ".") axs[0].set(title="Centered", xlabel="µ", ylabel="τ") axs[0].axhline(0.01) x = pd.Series(noncentered_trace["mu"], name="mu") y = pd.Series(noncentered_trace["tau"], name="tau") axs[1].plot(x, y, ".") axs[1].set(title="NonCentered", xlabel="µ", ylabel="τ") axs[1].axhline(0.01) xlim = axs[0].get_xlim() ylim = axs[0].get_ylim()
Image in a Jupyter notebook
x = pd.Series(centered_trace["mu"], name="mu") y = pd.Series(centered_trace["tau"], name="tau") g = sns.jointplot(x, y, xlim=xlim, ylim=ylim) plt.suptitle("centered") plt.show()
/usr/local/lib/python3.7/dist-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. FutureWarning
Image in a Jupyter notebook
x = pd.Series(noncentered_trace["mu"], name="mu") y = pd.Series(noncentered_trace["tau"], name="tau") g = sns.jointplot(x, y, xlim=xlim, ylim=ylim) plt.suptitle("noncentered") plt.show()
/usr/local/lib/python3.7/dist-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. FutureWarning
Image in a Jupyter notebook
fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True) x = pd.Series(centered_trace["mu"], name="mu") y = pd.Series(centered_trace["log_tau"], name="log_tau") axs[0].plot(x, y, ".") axs[0].set(title="Centered", xlabel="µ", ylabel="log(τ)") x = pd.Series(noncentered_trace["mu"], name="mu") y = pd.Series(noncentered_trace["log_tau"], name="log_tau") axs[1].plot(x, y, ".") axs[1].set(title="NonCentered", xlabel="µ", ylabel="log(τ)") xlim = axs[0].get_xlim() ylim = axs[0].get_ylim()
Image in a Jupyter notebook
# https://seaborn.pydata.org/generated/seaborn.jointplot.html x = pd.Series(centered_trace["mu"], name="mu") y = pd.Series(centered_trace["log_tau"], name="log_tau") g = sns.jointplot(x, y, xlim=xlim, ylim=ylim) plt.suptitle("centered") plt.show()
/usr/local/lib/python3.7/dist-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. FutureWarning
Image in a Jupyter notebook
x = pd.Series(noncentered_trace["mu"], name="mu") y = pd.Series(noncentered_trace["log_tau"], name="log_tau") g = sns.jointplot(x, y, xlim=xlim, ylim=ylim) plt.suptitle("noncentered") plt.show()
/usr/local/lib/python3.7/dist-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. FutureWarning
Image in a Jupyter notebook
az.plot_forest( [centered_trace, noncentered_trace], model_names=["centered", "noncentered"], var_names="theta", combined=True, hdi_prob=0.95, );
/usr/local/lib/python3.7/dist-packages/arviz/data/io_pymc3.py:100: FutureWarning: Using `from_pymc3` without the model will be deprecated in a future release. Not using the model will return less accurate and less useful results. Make sure you use the model argument or call from_pymc3 within a model context. FutureWarning,
Image in a Jupyter notebook