Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Avatar for Software 20.04.
Download
175 views
ubuntu2004-dev
Kernel: Python 3 (system-wide)
import numpyro numpyro.__version__
'0.8.0'
import os from IPython.display import set_matplotlib_formats import jax.numpy as jnp from jax import random, vmap from jax.scipy.special import logsumexp import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns import numpyro from numpyro.diagnostics import hpdi import numpyro.distributions as dist from numpyro import handlers from numpyro.infer import MCMC, NUTS plt.style.use("bmh") if "NUMPYRO_SPHINXBUILD" in os.environ: set_matplotlib_formats("svg") assert numpyro.__version__.startswith("0.8.0")
/usr/local/lib/python3.8/dist-packages/jax/experimental/optimizers.py:28: FutureWarning: jax.experimental.optimizers is deprecated, import jax.example_libraries.optimizers instead warnings.warn('jax.experimental.optimizers is deprecated, ' /usr/local/lib/python3.8/dist-packages/jax/experimental/stax.py:28: FutureWarning: jax.experimental.stax is deprecated, import jax.example_libraries.stax instead warnings.warn('jax.experimental.stax is deprecated, '
DATASET_URL = "https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/WaffleDivorce.csv" dset = pd.read_csv(DATASET_URL, sep=";") dset.head()
vars = [ "Population", "MedianAgeMarriage", "Marriage", "WaffleHouses", "South", "Divorce", ] sns.pairplot(dset, x_vars=vars, y_vars=vars, palette="husl");
Image in a Jupyter notebook
sns.regplot(x="WaffleHouses", y="Divorce", data=dset);
Image in a Jupyter notebook
standardize = lambda x: (x - x.mean()) / x.std() dset["AgeScaled"] = dset.MedianAgeMarriage.pipe(standardize) dset["MarriageScaled"] = dset.Marriage.pipe(standardize) dset["DivorceScaled"] = dset.Divorce.pipe(standardize)
def model(marriage=None, age=None, divorce=None): a = numpyro.sample("a", dist.Normal(0.0, 0.2)) M, A = 0.0, 0.0 if marriage is not None: bM = numpyro.sample("bM", dist.Normal(0.0, 0.5)) M = bM * marriage if age is not None: bA = numpyro.sample("bA", dist.Normal(0.0, 0.5)) A = bA * age sigma = numpyro.sample("sigma", dist.Exponential(1.0)) mu = a + M + A numpyro.sample("obs", dist.Normal(mu, sigma), obs=divorce)
# Start from this source of randomness. We will split keys for subsequent operations. rng_key = random.PRNGKey(0) rng_key, rng_key_ = random.split(rng_key) # Run NUTS. kernel = NUTS(model) num_samples = 2000 mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples) mcmc.run( rng_key_, marriage=dset.MarriageScaled.values, divorce=dset.DivorceScaled.values ) mcmc.print_summary() samples_1 = mcmc.get_samples()
sample: 100%|██████████| 3000/3000 [00:04<00:00, 608.52it/s, 3 steps of size 7.86e-01. acc. prob=0.92]
mean std median 5.0% 95.0% n_eff r_hat a 0.01 0.11 0.01 -0.17 0.19 1511.88 1.00 bM 0.35 0.13 0.35 0.14 0.57 1694.14 1.00 sigma 0.95 0.10 0.94 0.78 1.10 1696.30 1.00 Number of divergences: 0
def plot_regression(x, y_mean, y_hpdi): # Sort values for plotting by x axis idx = jnp.argsort(x) marriage = x[idx] mean = y_mean[idx] hpdi = y_hpdi[:, idx] divorce = dset.DivorceScaled.values[idx] # Plot fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 6)) ax.plot(marriage, mean) ax.plot(marriage, divorce, "o") ax.fill_between(marriage, hpdi[0], hpdi[1], alpha=0.3, interpolate=True) return ax # Compute empirical posterior distribution over mu posterior_mu = ( jnp.expand_dims(samples_1["a"], -1) + jnp.expand_dims(samples_1["bM"], -1) * dset.MarriageScaled.values ) mean_mu = jnp.mean(posterior_mu, axis=0) hpdi_mu = hpdi(posterior_mu, 0.9) ax = plot_regression(dset.MarriageScaled.values, mean_mu, hpdi_mu) ax.set( xlabel="Marriage rate", ylabel="Divorce rate", title="Regression line with 90% CI" );
Image in a Jupyter notebook
from numpyro.infer import Predictive rng_key, rng_key_ = random.split(rng_key) prior_predictive = Predictive(model, num_samples=100) prior_predictions = prior_predictive(rng_key_, marriage=dset.MarriageScaled.values)[ "obs" ] mean_prior_pred = jnp.mean(prior_predictions, axis=0) hpdi_prior_pred = hpdi(prior_predictions, 0.9) ax = plot_regression(dset.MarriageScaled.values, mean_prior_pred, hpdi_prior_pred) ax.set(xlabel="Marriage rate", ylabel="Divorce rate", title="Predictions with 90% CI");
Image in a Jupyter notebook
rng_key, rng_key_ = random.split(rng_key) predictive = Predictive(model, samples_1) predictions = predictive(rng_key_, marriage=dset.MarriageScaled.values)["obs"] df = dset.filter(["Location"]) df["Mean Predictions"] = jnp.mean(predictions, axis=0) df.head()
def predict(rng_key, post_samples, model, *args, **kwargs): model = handlers.seed(handlers.condition(model, post_samples), rng_key) model_trace = handlers.trace(model).get_trace(*args, **kwargs) return model_trace["obs"]["value"] # vectorize predictions via vmap predict_fn = vmap( lambda rng_key, samples: predict( rng_key, samples, model, marriage=dset.MarriageScaled.values ) )
# Using the same key as we used for Predictive - note that the results are identical. predictions_1 = predict_fn(random.split(rng_key_, num_samples), samples_1) mean_pred = jnp.mean(predictions_1, axis=0) df = dset.filter(["Location"]) df["Mean Predictions"] = mean_pred df.head()
hpdi_pred = hpdi(predictions_1, 0.9) ax = plot_regression(dset.MarriageScaled.values, mean_pred, hpdi_pred) ax.set(xlabel="Marriage rate", ylabel="Divorce rate", title="Predictions with 90% CI");
Image in a Jupyter notebook
def log_likelihood(rng_key, params, model, *args, **kwargs): model = handlers.condition(model, params) model_trace = handlers.trace(model).get_trace(*args, **kwargs) obs_node = model_trace["obs"] return obs_node["fn"].log_prob(obs_node["value"]) def log_pred_density(rng_key, params, model, *args, **kwargs): n = list(params.values())[0].shape[0] log_lk_fn = vmap( lambda rng_key, params: log_likelihood(rng_key, params, model, *args, **kwargs) ) log_lk_vals = log_lk_fn(random.split(rng_key, n), params) return (logsumexp(log_lk_vals, 0) - jnp.log(n)).sum()
rng_key, rng_key_ = random.split(rng_key) print( "Log posterior predictive density: {}".format( log_pred_density( rng_key_, samples_1, model, marriage=dset.MarriageScaled.values, divorce=dset.DivorceScaled.values, ) ) )
Log posterior predictive density: -66.6878890991211
rng_key, rng_key_ = random.split(rng_key) mcmc.run(rng_key_, age=dset.AgeScaled.values, divorce=dset.DivorceScaled.values) mcmc.print_summary() samples_2 = mcmc.get_samples()
sample: 100%|██████████| 3000/3000 [00:05<00:00, 599.87it/s, 3 steps of size 7.32e-01. acc. prob=0.92]
mean std median 5.0% 95.0% n_eff r_hat a -0.00 0.10 -0.01 -0.17 0.17 2028.08 1.00 bA -0.57 0.11 -0.57 -0.74 -0.38 1974.47 1.00 sigma 0.82 0.08 0.82 0.68 0.95 1860.86 1.00 Number of divergences: 0
posterior_mu = ( jnp.expand_dims(samples_2["a"], -1) + jnp.expand_dims(samples_2["bA"], -1) * dset.AgeScaled.values ) mean_mu = jnp.mean(posterior_mu, axis=0) hpdi_mu = hpdi(posterior_mu, 0.9) ax = plot_regression(dset.AgeScaled.values, mean_mu, hpdi_mu) ax.set( xlabel="Median marriage age", ylabel="Divorce rate", title="Regression line with 90% CI", );
Image in a Jupyter notebook
rng_key, rng_key_ = random.split(rng_key) predictions_2 = Predictive(model, samples_2)(rng_key_, age=dset.AgeScaled.values)["obs"] mean_pred = jnp.mean(predictions_2, axis=0) hpdi_pred = hpdi(predictions_2, 0.9) ax = plot_regression(dset.AgeScaled.values, mean_pred, hpdi_pred) ax.set(xlabel="Median Age", ylabel="Divorce rate", title="Predictions with 90% CI");
Image in a Jupyter notebook
rng_key, rng_key_ = random.split(rng_key) print( "Log posterior predictive density: {}".format( log_pred_density( rng_key_, samples_2, model, age=dset.AgeScaled.values, divorce=dset.DivorceScaled.values, ) ) )
Log posterior predictive density: -59.274169921875
rng_key, rng_key_ = random.split(rng_key) mcmc.run( rng_key_, marriage=dset.MarriageScaled.values, age=dset.AgeScaled.values, divorce=dset.DivorceScaled.values, ) mcmc.print_summary() samples_3 = mcmc.get_samples()
sample: 100%|██████████| 3000/3000 [00:05<00:00, 589.76it/s, 7 steps of size 5.67e-01. acc. prob=0.92]
mean std median 5.0% 95.0% n_eff r_hat a 0.00 0.10 0.00 -0.17 0.17 2005.88 1.00 bA -0.61 0.15 -0.61 -0.87 -0.37 1496.72 1.00 bM -0.07 0.15 -0.07 -0.34 0.17 1526.17 1.00 sigma 0.83 0.09 0.82 0.69 0.96 1752.97 1.00 Number of divergences: 0
rng_key, rng_key_ = random.split(rng_key) print( "Log posterior predictive density: {}".format( log_pred_density( rng_key_, samples_3, model, marriage=dset.MarriageScaled.values, age=dset.AgeScaled.values, divorce=dset.DivorceScaled.values, ) ) )
Log posterior predictive density: -59.02093505859375
# Predictions for Model 3. rng_key, rng_key_ = random.split(rng_key) predictions_3 = Predictive(model, samples_3)( rng_key_, marriage=dset.MarriageScaled.values, age=dset.AgeScaled.values )["obs"] y = jnp.arange(50) fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 16)) pred_mean = jnp.mean(predictions_3, axis=0) pred_hpdi = hpdi(predictions_3, 0.9) residuals_3 = dset.DivorceScaled.values - predictions_3 residuals_mean = jnp.mean(residuals_3, axis=0) residuals_hpdi = hpdi(residuals_3, 0.9) idx = jnp.argsort(residuals_mean) # Plot posterior predictive ax[0].plot(jnp.zeros(50), y, "--") ax[0].errorbar( pred_mean[idx], y, xerr=pred_hpdi[1, idx] - pred_mean[idx], marker="o", ms=5, mew=4, ls="none", alpha=0.8, ) ax[0].plot(dset.DivorceScaled.values[idx], y, marker="o", ls="none", color="gray") ax[0].set( xlabel="Posterior Predictive (red) vs. Actuals (gray)", ylabel="State", title="Posterior Predictive with 90% CI", ) ax[0].set_yticks(y) ax[0].set_yticklabels(dset.Loc.values[idx], fontsize=10) # Plot residuals residuals_3 = dset.DivorceScaled.values - predictions_3 residuals_mean = jnp.mean(residuals_3, axis=0) residuals_hpdi = hpdi(residuals_3, 0.9) err = residuals_hpdi[1] - residuals_mean ax[1].plot(jnp.zeros(50), y, "--") ax[1].errorbar( residuals_mean[idx], y, xerr=err[idx], marker="o", ms=5, mew=4, ls="none", alpha=0.8 ) ax[1].set(xlabel="Residuals", ylabel="State", title="Residuals with 90% CI") ax[1].set_yticks(y) ax[1].set_yticklabels(dset.Loc.values[idx], fontsize=10);
Image in a Jupyter notebook