Path: blob/master/deprecated/scripts/adf_logistic_regression_demo.py
1192 views
# Online training of a logistic regression model1# using Assumed Density Filtering (ADF).2# We compare the ADF result with MCMC sampling3# For further details, see the ADF paper:4# * O. Zoeter, "Bayesian Generalized Linear Models in a Terabyte World,"5# 2007 5th International Symposium on Image and Signal Processing and Analysis, 2007,6# pp. 435-440, doi: 10.1109/ISPA.2007.4383733.7# of the posterior distribution8# Dependencies:9# !pip install jax_cosmo1011# Author: Gerardo Durán-Martín (@gerdm)1213import superimport1415import jax16import jax.numpy as jnp17import matplotlib.pyplot as plt18import pyprobml_utils as pml19from jax import random20from jax.scipy.stats import norm21from jax_cosmo.scipy import integrate22from functools import partial23from jsl.demos import logreg_biclusters_demo as demo2425import pyprobml_utils as pml2627# cosmo seems to only support numerical integration in CPU mode28jax.config.update("jax_platform_name", "cpu")29jax.config.update("jax_enable_x64", True)3031figures, data = demo.main()3233X = data["X"]34y = data["y"]35Phi = data["Phi"]36Xspace = data["Xspace"]37Phispace = data["Phispace"]38w_laplace = data["w_laplace"]3940def sigmoid(z): return jnp.exp(z) / (1 + jnp.exp(z))41def log_sigmoid(z): return z - jnp.log1p(jnp.exp(z))4243def Zt_func(eta, y, mu, v):44log_term = y * log_sigmoid(eta) + (1 - y) * jnp.log1p(-sigmoid(eta))45log_term = log_term + norm.logpdf(eta, mu, v)4647return jnp.exp(log_term)484950def mt_func(eta, y, mu, v, Zt):51log_term = y * log_sigmoid(eta) + (1 - y) * jnp.log1p(-sigmoid(eta))52log_term = log_term + norm.logpdf(eta, mu, v)5354return eta * jnp.exp(log_term) / Zt555657def vt_func(eta, y, mu, v, Zt):58log_term = y * log_sigmoid(eta) + (1 - y) * jnp.log1p(-sigmoid(eta))59log_term = log_term + norm.logpdf(eta, mu, v)6061return eta ** 2 * jnp.exp(log_term) / Zt626364def adf_step(state, xs, prior_variance, lbound, ubound):65mu_t, tau_t = state66Phi_t, y_t = xs6768mu_t_cond = mu_t69tau_t_cond = tau_t + prior_variance7071# prior predictive distribution72m_t_cond = (Phi_t * mu_t_cond).sum()73v_t_cond = (Phi_t ** 2 * tau_t_cond).sum()7475v_t_cond_sqrt = jnp.sqrt(v_t_cond)7677# Moment-matched Gaussian approximation elements78Zt = integrate.romb(lambda eta: Zt_func(eta, y_t, m_t_cond, v_t_cond_sqrt), lbound, ubound)7980mt = integrate.romb(lambda eta: mt_func(eta, y_t, m_t_cond, v_t_cond_sqrt, Zt), lbound, ubound)8182vt = integrate.romb(lambda eta: vt_func(eta, y_t, m_t_cond, v_t_cond_sqrt, Zt), lbound, ubound)83vt = vt - mt ** 28485# Posterior estimation86delta_m = mt - m_t_cond87delta_v = vt - v_t_cond88a = Phi_t * tau_t_cond / (Phi_t ** 2 * tau_t_cond).sum()89mu_t = mu_t_cond + a * delta_m90tau_t = tau_t_cond + a ** 2 * delta_v9192return (mu_t, tau_t), (mu_t, tau_t)9394# ** ADF inference **95prior_variance = 0.096# Lower and upper bounds of integration. Ideally, we would like to97# integrate from -inf to inf, but we run into numerical issues.98n_datapoints, ndims = Phi.shape99lbound, ubound = -20, 20100mu_t = jnp.zeros(ndims)101tau_t = jnp.ones(ndims) * 1.0102103init_state = (mu_t, tau_t)104xs = (Phi, y)105106adf_loop = partial(adf_step, prior_variance=prior_variance, lbound=lbound, ubound=ubound)107(mu_t, tau_t), (mu_t_hist, tau_t_hist) = jax.lax.scan(adf_loop, init_state, xs)108print("ADF weights")109print(mu_t)110111# ADF posterior predictive distribution112n_samples = 5000113key = random.PRNGKey(3141)114adf_samples = random.multivariate_normal(key, mu_t, jnp.diag(tau_t), (n_samples,))115Z_adf = sigmoid(jnp.einsum("mij,sm->sij", Phispace, adf_samples))116Z_adf = Z_adf.mean(axis=0)117118# ** Plotting predictive distribution **119colors = ["black" if el else "white" for el in y]120121## Add posterior marginal for ADF-estimated weights122for i in range(ndims):123mean, std = mu_t[i], jnp.sqrt(tau_t[i])124#fig = figures[f"weights_marginals_w{i}"]125fig = figures[f"logistic_regression_weights_marginals_w{i}"]126ax = fig.gca()127x = jnp.linspace(mean - 4 * std, mean + 4 * std, 500)128ax.plot(x, norm.pdf(x, mean, std), label="posterior (ADF)", linestyle="dashdot")129ax.legend()130131fig_adf, ax = plt.subplots()132title = "ADF Predictive distribution"133demo.plot_posterior_predictive(ax, X, Xspace, Z_adf, title, colors)134#figures["predictive_distribution_adf"] = fig_adf135#figures["logistic_regression_surface_adf"] = fig_adf136pml.savefig("logistic_regression_surface_adf.pdf")137138# Posterior vs time139140lcolors = ["black", "tab:blue", "tab:red"]141elements = mu_t_hist.T, tau_t_hist.T, w_laplace, lcolors142timesteps = jnp.arange(n_datapoints) + 1143144for k, (wk, Pk, wk_laplace, c) in enumerate(zip(*elements)):145fig_weight_k, ax = plt.subplots()146ax.errorbar(timesteps, wk, jnp.sqrt(Pk), c=c, label=f"$w_{k}$ online (adf)")147ax.axhline(y=wk_laplace, c=c, linestyle="dotted", label=f"$w_{k}$ batch (Laplace)", linewidth=3)148149ax.set_xlim(1, n_datapoints)150ax.legend(framealpha=0.7, loc="upper right")151ax.set_xlabel("number samples")152ax.set_ylabel("weights")153plt.tight_layout()154#figures[f"adf_logistic_regression_hist_w{k}"] = fig_weight_k155#figures[f"logistic_regression_hist_adf_w{k}"] = fig_weight_k156pml.savefig(f"logistic_regression_hist_adf_w{k}")157158#for name, figure in figures.items():159# filename = f"./../figures/{name}.pdf"160# figure.savefig(filename)161162plt.show()163164165