Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/03/linreg_divorce_ppc.ipynb
1192 views
Kernel: Unknown Kernel
#We illustrate linear using the "waffle divorce" example in sec 5.1 of [Statistical Rethinking ed 2](https://xcelab.net/rm/statistical-rethinking/). #The numpyro code is from [Du Phan's site](https://fehiepsi.github.io/rethinking-numpyro/05-the-many-variables-and-the-spurious-waffles.html) try: import probml_utils as pml except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git import probml_utils as pml import numpy as np np.set_printoptions(precision=3) import matplotlib.pyplot as plt import math import os import warnings import pandas as pd import jax print("jax version {}".format(jax.__version__)) print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform)) import jax.numpy as jnp from jax import random, vmap rng_key = random.PRNGKey(0) rng_key, rng_key_ = random.split(rng_key) try: import numpyro except ModuleNotFoundError: %pip install -qq numpyro import numpyro import numpyro.distributions as dist from numpyro.distributions import constraints from numpyro.distributions.transforms import AffineTransform from numpyro.diagnostics import hpdi, print_summary from numpyro.infer import Predictive, log_likelihood from numpyro.infer import MCMC, NUTS from numpyro.infer import SVI, Trace_ELBO, init_to_value from numpyro.infer.autoguide import AutoLaplaceApproximation import numpyro.optim as optim try: import arviz as az except ModuleNotFoundError: %pip install -qq arviz import arviz as az # Data # load data and copy url = 'https://raw.githubusercontent.com/fehiepsi/rethinking-numpyro/master/data/WaffleDivorce.csv' WaffleDivorce = pd.read_csv(url, sep=";") d = WaffleDivorce # standardize variables d["A"] = d.MedianAgeMarriage.pipe(lambda x: (x - x.mean()) / x.std()) d["D"] = d.Divorce.pipe(lambda x: (x - x.mean()) / x.std()) d["M"] = d.Marriage.pipe(lambda x: (x - x.mean()) / x.std()) # Model def model(M, A, D=None): a = numpyro.sample("a", dist.Normal(0, 0.2)) bM = numpyro.sample("bM", dist.Normal(0, 0.5)) bA = numpyro.sample("bA", dist.Normal(0, 0.5)) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = numpyro.deterministic("mu", a + bM * M + bA * A) numpyro.sample("D", dist.Normal(mu, sigma), obs=D) m5_3 = AutoLaplaceApproximation(model) svi = SVI(model, m5_3, optim.Adam(1), Trace_ELBO(), M=d.M.values, A=d.A.values, D=d.D.values) p5_3, losses = svi.run(random.PRNGKey(0), 1000) post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, (1000,)) # Posterior param_names = {'a', 'bA', 'bM', 'sigma'} for p in param_names: print(f'posterior for {p}') print_summary(post[p], 0.95, False) # PPC # call predictive without specifying new data # so it uses original data post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, (int(1e4),)) post_pred = Predictive(m5_3.model, post)(random.PRNGKey(2), M=d.M.values, A=d.A.values) mu = post_pred["mu"] # summarize samples across cases mu_mean = jnp.mean(mu, 0) mu_PI = jnp.percentile(mu, q=(5.5, 94.5), axis=0) ax = plt.subplot( ylim=(float(mu_PI.min()), float(mu_PI.max())), xlabel="Observed divorce", ylabel="Predicted divorce" ) plt.plot(d.D, mu_mean, "o") x = jnp.linspace(mu_PI.min(), mu_PI.max(), 101) plt.plot(x, x, "--") for i in range(d.shape[0]): plt.plot([d.D[i]] * 2, mu_PI[:, i], "b") fig = plt.gcf() for i in range(d.shape[0]): if d.Loc[i] in ["ID", "UT", "RI", "ME"]: ax.annotate( d.Loc[i], (d.D[i], mu_mean[i]), xytext=(1,0), textcoords="offset pixels" ) plt.tight_layout() pml.savefig('linreg_divorce_postpred.pdf', dpi=300) plt.show()