Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/15/adf_logistic_regression_demo.ipynb
1192 views
Kernel: Python [conda env:py3713]

Open In Colab

Online Logistic Regression using ADF

Online training of a logistic regression model using Assumed Density Filtering (ADF).

We compare the ADF result with Laplace approximation of the posterior.

For further details, see the ADF paper: O. Zoeter, "Bayesian Generalized Linear Models in a Terabyte World," 2007 5th International Symposium on Image and Signal Processing and Analysis, 2007, pp. 435-440, doi: 10.1109/ISPA.2007.4383733.

Authors: Peter G. Chang (@petergchang), Gerardo Durán-Martín (@gerdm)

0. Imports

import matplotlib.pyplot as plt import seaborn as sns import jax import jax.numpy as jnp import jax.random as jr from jax.scipy.stats import norm from jax import vmap from jax import lax from jax.scipy.optimize import minimize

##1. Simulation and Plotting

We generate a reasonable 2d binary classification data.

def generate_dataset(num_points=1000, shuffle=True, key=0): if isinstance(key, int): key = jr.PRNGKey(key) key0, key1, key2 = jr.split(key, 3) # Generate standardized noisy inputs that correspond to output '0' num_zero_points = num_points // 2 zero_input = jnp.array([[-1.0, -1.0]] * num_zero_points) zero_input += jr.normal(key0, (num_zero_points, 2)) # Generate standardized noisy inputs that correspond to output '1' num_one_points = num_points - num_zero_points one_input = jnp.array([[1.0, 1.0]] * num_one_points) one_input += jr.normal(key1, (num_one_points, 2)) # Stack the inputs and add bias term input = jnp.concatenate([zero_input, one_input]) input_with_bias = jnp.concatenate([jnp.ones((num_points, 1)), input], axis=1) # Generate binary output output = jnp.concatenate([jnp.zeros((num_zero_points)), jnp.ones((num_one_points))]) # Shuffle if shuffle: idx = jr.shuffle(key2, jnp.arange(num_points)) input, input_with_bias, output = input[idx], input_with_bias[idx], output[idx] return input, input_with_bias, output
# Generate data input, input_with_bias, output = generate_dataset()

Next, we define a function that visualizes the 2d posterior predictive distribution.

def plot_posterior_predictive(ax, X, title, colors, Xspace=None, Zspace=None, cmap="viridis"): if Xspace is not None and Zspace is not None: ax.contourf(*Xspace, Zspace, cmap=cmap, levels=20) ax.axis("off") ax.scatter(*X.T, c=colors, edgecolors="gray", s=50) ax.set_title(title) plt.tight_layout()
def plot_boundary(ax, X, colors, Xspace, w): ax.scatter(*X.T, c=colors, edgecolors="gray", s=50) ax.plot(Xspace[0], -w[1] / w[2] * Xspace[0] - w[0] / w[2]) plt.tight_layout()

Let's look at our binary data in 2d.

fig, ax = plt.subplots() title = "Binary classification data" colors = ["black" if y else "red" for y in output] plot_posterior_predictive(ax, input, title, colors)
Image in a Jupyter notebook

Let us define a grid on which we compute the predictive distribution.

# Define grid limits xmin, ymin = input.min(axis=0) - 0.1 xmax, ymax = input.max(axis=0) + 0.1 # Define grid step = 0.1 input_grid = jnp.mgrid[xmin:xmax:step, ymin:ymax:step] _, nx, ny = input_grid.shape input_with_bias_grid = jnp.concatenate([jnp.ones((1, nx, ny)), input_grid])

Next, we define a function to that returns the posterior predictive probability for each point in grid.

def posterior_predictive_grid(grid, mean, cov, n_samples=5000, key=0): if isinstance(key, int): key = jr.PRNGKey(key) samples = jax.random.multivariate_normal(key, mean, cov, (n_samples,)) Z = jax.nn.sigmoid(jnp.einsum("mij,sm->sij", grid, samples)) Z = Z.mean(axis=0) return Z

Also, we define a function that prints and plots the final ADF estimate of the weights.

def show_adf_estimate(means, vars): print(f"ADF estimate weights: {means}") # ** Plotting predictive distribution ** fig, ax = plt.subplots() for i in range(len(means)): mean, std = means[i], jnp.sqrt(vars[i]) ax = fig.gca() x = jnp.linspace(mean - 4 * std, mean + 4 * std, 500) ax.plot(x, norm.pdf(x, mean, std), label="posterior (ADF)", linestyle="dashdot") ax.legend()

Finally, we define a function that plots the convergence of filtered estimates to Laplace batch estimate.

def plot_adf_post_laplace( mu_hist, tau_hist, w_fix, lcolors, label, legend_font_size=12, bb1=(1.1, 1.1), bb2=(1, 0.3), bb3=(0.6, 0.3) ): elements = (mu_hist.T, tau_hist.T, w_fix, lcolors) n_datapoints = len(mu_hist) timesteps = jnp.arange(n_datapoints) + 1 for k, (wk, Pk, wk_fix, c) in enumerate(zip(*elements)): fig_weight_k, ax = plt.subplots() ax.errorbar(timesteps, wk, jnp.sqrt(Pk), c=c, label=f"$w_{k}$ online (adf)") ax.axhline(y=wk_fix, c=c, linestyle="dotted", label=f"$w_{k}$ batch (Laplace)", linewidth=3) ax.set_xlim(1, n_datapoints) ax.set_xlabel("number samples", fontsize=15) ax.set_ylabel("weights", fontsize=15) ax.tick_params(axis="both", which="major", labelsize=15) sns.despine() if k == 0: ax.legend(frameon=False, loc="upper right", bbox_to_anchor=bb1, fontsize=legend_font_size) elif k == 1: ax.legend(frameon=False, bbox_to_anchor=bb2, fontsize=legend_font_size) elif k == 2: ax.legend(frameon=False, bbox_to_anchor=bb3, fontsize=legend_font_size) plt.tight_layout() plt.savefig(label.format(k=k) + ".pdf", bbox_inches="tight", pad_inches=0.0)

##2. Laplace Estimate

We compute the Laplace-approximation posterior for comparison.

def log_posterior(w, X, Y, prior_var): prediction = jax.nn.sigmoid(X @ w) log_prior = -(prior_var * w @ w / 2) log_likelihood = Y * jnp.log(prediction) + (1 - Y) * jnp.log(1 - prediction) return log_prior + log_likelihood.sum() def laplace_inference(X, Y, prior_var=1.0, key=0): if isinstance(key, int): key = jr.PRNGKey(key) input_dim = X.shape[-1] # Initial random guess w0 = jr.multivariate_normal(key, jnp.zeros(input_dim), jnp.eye(input_dim) * prior_var) # Energy function to minimize E = lambda w: -log_posterior(w, X, Y, prior_var) / len(Y) # Minimize energy function w_laplace = minimize(E, w0, method="BFGS").x cov_laplace = jax.hessian(E)(w_laplace) return w_laplace, cov_laplace
# Compute Laplace posterior prior_var = 1.0 w_laplace, cov_laplace = laplace_inference(input_with_bias, output, prior_var=prior_var)
fig_adf, ax = plt.subplots() plot_boundary(ax, input, colors, input_grid, w_laplace)
Image in a Jupyter notebook
fig_adf, ax = plt.subplots() # Plot Laplace posterior predictive distribution Z_laplace = posterior_predictive_grid(input_with_bias_grid, w_laplace, cov_laplace) title = "Laplace Predictive Distribution" plot_posterior_predictive(ax, input, title, colors, input_grid, Z_laplace)
Image in a Jupyter notebook

##3. ADF Estimates

First we define the sigma-point numerical integration to be used by the update step for ADF.

def compute_weights_and_sigmas_1d(m, P, alpha=jnp.sqrt(3), beta=2, kappa=1): lamb = alpha**2 * (1 + kappa) - 1 # Compute weights factor = 1 / (2 * (1 + lamb)) w_mean = jnp.concatenate((jnp.array([lamb / (1 + lamb)]), jnp.ones(2) * factor)) w_cov = jnp.concatenate((jnp.array([lamb / (1 + lamb) + (1 - alpha**2 + beta)]), jnp.ones(2) * factor)) # Compute sigmas distances = jnp.sqrt(1 + lamb) * jnp.sqrt(P) sigma_plus = jnp.array([m + distances]) sigma_minus = jnp.array([m - distances]) sigmas = jnp.concatenate((jnp.array([m]), sigma_plus, sigma_minus)) return w_mean, w_cov, sigmas def sigma_point_gaussian_expectation_1d(f, m, P): w_mean, _, sigmas = compute_weights_and_sigmas_1d(m, P) return jnp.atleast_1d(jnp.tensordot(w_mean, vmap(f)(sigmas), axes=1))

Next, we can construct a 1d logistic regression ADF.

def adf_logistic_regression(initial_means, initial_vars, inputs, emissions, drift=0): num_timesteps = len(emissions) def likelihood(y, eta): prediction = jax.nn.sigmoid(eta) return jnp.where(y, prediction, 1 - prediction) def _step(carry, t): prior_means, prior_vars = carry x, y = inputs[t], emissions[t] # Prediction step pred_means = prior_means pred_vars = prior_vars + drift # Update step cond_eta_mean = x @ pred_means cond_eta_var = x**2 @ pred_vars # Perform numerical integrations Zt = sigma_point_gaussian_expectation_1d(lambda eta: likelihood(y, eta), cond_eta_mean, cond_eta_var) post_eta_mean = sigma_point_gaussian_expectation_1d( lambda eta: eta / Zt * likelihood(y, eta), cond_eta_mean, cond_eta_var ) post_eta_var = ( sigma_point_gaussian_expectation_1d( lambda eta: eta**2 / Zt * likelihood(y, eta), cond_eta_mean, cond_eta_var ) - post_eta_mean**2 ) # Posterior estimate mean_diff = post_eta_mean - cond_eta_mean var_diff = post_eta_var - cond_eta_var weight = x * pred_vars / (x**2 + pred_vars).sum() post_means = pred_means + weight * mean_diff post_vars = pred_vars + weight**2 * var_diff return (post_means, post_vars), (post_means, post_vars) carry = (initial_means, initial_vars) (post_means, post_vars), (means_hist, vars_hist) = lax.scan(_step, carry, jnp.arange(num_timesteps)) return (post_means, post_vars), (means_hist, vars_hist)
num_dims = input_with_bias.shape[-1] # Run ADF (post_means, post_vars), (means_hist, vars_hist) = adf_logistic_regression( initial_means=jnp.zeros(num_dims), initial_vars=jnp.sqrt(prior_var) * jnp.ones(num_dims), inputs=input_with_bias, emissions=output, )
show_adf_estimate(post_means, post_vars)
ADF estimate weights: [-0.02906595 2.0735168 2.0495355 ]
Image in a Jupyter notebook
fig_adf, ax = plt.subplots() # ADF posterior predictive distribution Z_adf = posterior_predictive_grid(input_with_bias_grid, post_means, jnp.diag(post_vars)) title = "ADF Predictive distribution" plot_posterior_predictive(ax, input, title, colors, input_grid, Z_adf)
Image in a Jupyter notebook

##4. Inference over Time

Finally, we will look at the convergence of the ADF-inferred weights to the Laplace batch estimate.

lcolors = ["black", "tab:blue", "tab:red"] plot_adf_post_laplace( means_hist[:: max(1, len(output) // 100)], vars_hist[:: max(1, len(output) // 100)], w_laplace, lcolors, label, legend_font_size=14, bb2=(1.1, 0.3), bb3=(0.8, 0.3), )
Image in a Jupyter notebookImage in a Jupyter notebookImage in a Jupyter notebook