Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/notebooks/advi_beta_binom_jax.ipynb
1192 views
Kernel: Python [conda env:pymc_exp]

Open In Colab

ADVI from scratch in JAX

Authors: karm-patel@, murphyk@

In this notebook we apply ADVI (automatic differentiation variational inference) to the beta-binomial model, using a Normal Distribution as Variational Posterior. This involves a change of variable from the unconstrained z in R space to the constrained theta in [0,1] space.

try: import jax except ModuleNotFoundError: %pip install -qqq jax jaxlib import jax import jax.numpy as jnp from jax import lax try: from tensorflow_probability.substrates import jax as tfp except ModuleNotFoundError: %pip install -qqq tensorflow_probability from tensorflow_probability.substrates import jax as tfp try: import optax except ModuleNotFoundError: %pip install -qqq optax import optax try: from rich import print except ModuleNotFoundError: %pip install -qqq rich from rich import print try: from tqdm import trange except ModuleNotFoundError: %pip install -qqq tqdm from tqdm import trange import seaborn as sns import numpy as np import matplotlib.pyplot as plt import warnings warnings.filterwarnings("ignore") dist = tfp.distributions plt.rc("font", size=10) # controls default text sizes plt.rc("axes", labelsize=12) # fontsize of the x and y labels plt.rc("legend", fontsize=12) # legend fontsize plt.rc("figure", titlesize=15) # fontsize of the figure title

Functions

Helper functions which will be used later

def prior_dist(): return dist.Beta(concentration1=1.0, concentration0=1.0) def likelihood_dist(theta): return dist.Bernoulli(probs=theta) def transform_fn(x): return 1 / (1 + jnp.exp(-x)) # sigmoid def positivity_fn(x): return jnp.log(1 + jnp.exp(x)) # softplus def variational_distribution_q(params): loc = params["loc"] scale = positivity_fn(params["scale"]) # apply softplus return dist.Normal(loc, scale) jacobian_fn = jax.jacfwd(transform_fn) # define function to find jacobian for tranform_fun

Dataset

Now, we will create the dataset. we sample theta_true (probability of occurring head) random variable from the prior distribution which is Beta in this case. Then we sample n_samples coin tosses from likelihood distribution which is Bernouli in this case.

# preparing dataset # key = jax.random.PRNGKey(128) # n_samples = 12 # theta_true = prior_dist().sample((5,),key)[0] # dataset = likelihood_dist(theta_true).sample(n_samples,key) # print(f"Dataset: {dataset}") # n_heads = dataset.sum() # n_tails = n_samples - n_heads
# Use same data as https://github.com/probml/probml-notebooks/blob/main/notebooks/beta_binom_approx_post_pymc.ipynb key = jax.random.PRNGKey(128) dataset = np.repeat([0, 1], (10, 1)) n_samples = len(dataset) print(f"Dataset: {dataset}") n_heads = dataset.sum() n_tails = n_samples - n_heads

Prior, Likelihood, and True Posterior

For coin toss problem, since we know the closed form solution of posterior, we compare the distributions of Prior, Likelihood, and True Posterior below.

# closed form of beta posterior a = prior_dist().concentration1 b = prior_dist().concentration0 exact_posterior = dist.Beta(concentration1=a + n_heads, concentration0=b + n_tails) theta_range = jnp.linspace(0.01, 0.99, 100) ax = plt.gca() ax2 = ax.twinx() (plt2,) = ax2.plot(theta_range, exact_posterior.prob(theta_range), "g--", label="True Posterior") (plt3,) = ax2.plot(theta_range, prior_dist().prob(theta_range), label="Prior") likelihood = jax.vmap(lambda x: jnp.prod(likelihood_dist(x).prob(dataset)))(theta_range) (plt1,) = ax.plot(theta_range, likelihood, "r-.", label="Likelihood") ax.set_xlabel("theta") ax.set_ylabel("Likelihood") ax2.set_ylabel("Prior & Posterior") ax2.legend(handles=[plt1, plt2, plt3], bbox_to_anchor=(1.6, 1));
Image in a Jupyter notebook

Optimizing the ELBO

In order to minimize KL divergence between true posterior and variational distribution, we need to minimize the negative ELBO, as we describe below.

We start with the ELBO, which is given by: ELBO(ψ)=Ezq(zψ)[p(Dz)+logp(z)logq(zψ)]\begin{align} ELBO(\psi) &= E_{z \sim q(z|\psi)} \left[ p(\mathcal{D}|z) + \log p(z) - \log q(z|\psi) \right] \end{align} where ψ=(μ,σ)\psi = (\mu, \sigma) are the variational parameters, p(Dz)=p(Dθ=σ(z))p(\mathcal{D}|z) = p(\mathcal{D}|\theta=\sigma(z)) is the likelihood, and the prior is given by the change of variables formula: p(z)=p(θ)θz=p(θ)J\begin{align} p(z) &= p(\theta) | \frac{\partial \theta}{\partial z} | = p(\theta) | J | \end{align} where JJ is the Jacobian of the zθz \rightarrow \theta mapping. We will use a Monte Carlo approximation of the expectation over zz. We also apply the reparameterization trick to replace zq(zψ)z \sim q(z|\psi) with ϵN(0,1)z=μ+σϵ\begin{align} \epsilon &\sim \mathcal{N}(0,1 ) \\ z &= \mu + \sigma \epsilon \end{align} Putting it altogether our estimate for the negative ELBO (for a single sample of ϵ\epsilon) is \begin{align} -L(\psi; z) &= -( \log p(\mathcal{D}|\theta ) +\log p( \theta) + \log|J_\boldsymbol{\sigma}(z)|)

  • \log q(z|\psi) \end{align}

def log_prior_likelihood_jacobian(normal_sample, dataset): theta = transform_fn(normal_sample) # transform normal sample to beta sample likelihood_log_prob = likelihood_dist(theta).log_prob(dataset).sum() # log probability of likelihood prior_log_prob = prior_dist().log_prob(theta) # log probability of prior log_det_jacob = jnp.log( jnp.abs(jnp.linalg.det(jacobian_fn(normal_sample).reshape(1, 1))) ) # log of determinant of jacobian return likelihood_log_prob + prior_log_prob + log_det_jacob
# reference: https://code-first-ml.github.io/book2/notebooks/introduction/variational.html def negative_elbo(params, dataset, n_samples=10, key=jax.random.PRNGKey(1)): q = variational_distribution_q(params) # Normal distribution. q_loc, q_scale = q.loc, q.scale std_normal = dist.Normal(0, 1) sample_set = std_normal.sample( seed=key, sample_shape=[ n_samples, ], ) sample_set = q_loc + q_scale * sample_set # reparameterization trick # calculate log joint for each sample of z p_log_prob = jax.vmap(log_prior_likelihood_jacobian, in_axes=(0, None))(sample_set, dataset) return jnp.mean(q.log_prob(sample_set) - p_log_prob)

We now apply stochastic gradient descent to minimize negative ELBO and optimize the variational parameters (loc and scale)

loss_and_grad_fn = jax.value_and_grad(negative_elbo, argnums=(0)) loss_and_grad_fn = jax.jit(loss_and_grad_fn) # jit the loss_and_grad function params = {"loc": 0.0, "scale": 0.5} elbo, grads = loss_and_grad_fn(params, dataset) print(f"loss: {elbo}") print(f"grads:\n loc: {grads['loc']}\n scale: {grads['scale']} ") optimizer = optax.adam(learning_rate=0.01) opt_state = optimizer.init(params)
# jax scannable function for training def train_step(carry, data_output): # take carry data key = carry["key"] elbo = carry["elbo"] grads = carry["grads"] params = carry["params"] opt_state = carry["opt_state"] updates = carry["updates"] # training key, subkey = jax.random.split(key) elbo, grads = loss_and_grad_fn(params, dataset, key=subkey) updates, opt_state = optimizer.update(grads, opt_state) params = optax.apply_updates(params, updates) # forward carry to next iteration by storing it carry = {"key": subkey, "elbo": elbo, "grads": grads, "params": params, "opt_state": opt_state, "updates": updates} output = {"elbo": elbo, "params": params} return carry, output
%%time # dummy iteration to pass carry to jax scannale function train() key, subkey = jax.random.split(key) elbo, grads = loss_and_grad_fn(params, dataset, key=subkey) updates, opt_state = optimizer.update(grads, opt_state) params = optax.apply_updates(params, updates) carry = {"key": key, "elbo": elbo, "grads": grads, "params": params, "opt_state": opt_state, "updates": updates} num_iter = 1000 elbos = np.empty(num_iter) # apply scan() to optimize training loop last_carry, output = lax.scan(train_step, carry, elbos) elbo = output["elbo"] params = output["params"] optimized_params = last_carry["params"]
CPU times: user 2.09 s, sys: 20.1 ms, total: 2.11 s Wall time: 2.03 s
print(params["loc"].shape) print(params["scale"].shape)

We now plot the ELBO

plt.plot(elbo) plt.xlabel("Iterations") plt.ylabel("Negative ELBO") sns.despine() plt.savefig("advi_beta_binom_jax_loss.pdf")
Image in a Jupyter notebook

We can see that after 200 iterations ELBO is optimized and not changing too much.

Samples using Optimized parameters

Now, we take 1000 samples from variational distribution (Normal) and transform them into true posterior distribution (Beta) by applying tranform_fn (sigmoid) on samples. Then we compare density of samples with exact posterior.

q_learned = variational_distribution_q(optimized_params) key = jax.random.PRNGKey(128) q_learned_samples = q_learned.sample(1000, seed=key) # q(z|D) transformed_samples = transform_fn(q_learned_samples) # transform Normal samples into Beta samples theta_range = jnp.linspace(0.01, 0.99, 100) plt.plot(theta_range, exact_posterior.prob(theta_range), "r", label="$p(x)$: true posterior") sns.kdeplot(transformed_samples, color="blue", label="$q(x)$: learned", bw_adjust=1.5, clip=(0.0, 1.0), linestyle="--") plt.xlabel("theta") plt.legend() # bbox_to_anchor=(1.5, 1)); sns.despine() plt.savefig("advi_beta_binom_jax_posterior.pdf")
Image in a Jupyter notebook

We can see that the learned q(x) is a reasonably good approximation to the true posterior. It seems to have support over negative theta but this is an artefact of KDE.

# print(transformed_samples) print(len(transformed_samples)) print(jnp.sum(transformed_samples < 0)) # all samples of thetas should be in [0,1] print(jnp.sum(transformed_samples > 1)) # all samples of thetas should be in [0,1]
print(q_learned) print(q_learned.mean()) print(jnp.sqrt(q_learned.variance()))
locs, scales = params["loc"], params["scale"] sigmas = positivity_fn(jnp.array(scales)) plt.plot(locs, label="mu") plt.xlabel("Iterations") plt.ylabel("$E_q[z]$") plt.legend() sns.despine() plt.savefig("advi_beta_binom_jax_post_mu_vs_time.pdf") plt.show() plt.plot(sigmas, label="sigma") plt.xlabel("Iterations") # plt.ylabel(r'$\sqrt{\text{var}(z)}') plt.ylabel("$std_{q}[z]$") plt.legend() sns.despine() plt.savefig("advi_beta_binom_jax_post_sigma_vs_time.pdf") plt.show()
Image in a Jupyter notebookImage in a Jupyter notebook

Comparison with pymc.ADVI()

Now, we compare our implementation with pymc's ADVI implementation.

Note: For pymc implementation, the code is taken from this notebook: https://github.com/probml/probml-notebooks/blob/main/notebooks/beta_binom_approx_post_pymc.ipynb

try: import pymc3 as pm except ModuleNotFoundError: %pip install pymc3 import pymc3 as pm try: import scipy.stats as stats except ModuleNotFoundError: %pip install scipy import scipy.stats as stats import scipy.special as sp try: import arviz as az except ModuleNotFoundError: %pip install arviz import arviz as az import math
a = prior_dist().concentration1 b = prior_dist().concentration0 with pm.Model() as mf_model: theta = pm.Beta("theta", a, b) y = pm.Binomial("y", n=1, p=theta, observed=dataset) # Bernoulli advi = pm.ADVI() tracker = pm.callbacks.Tracker( mean=advi.approx.mean.eval, # callable that returns mean std=advi.approx.std.eval, # callable that returns std ) approx = advi.fit(callbacks=[tracker], n=20000) trace_approx = approx.sample(1000) thetas = trace_approx["theta"]
plt.plot(advi.hist, label="ELBO") plt.xlabel("Iterations") plt.ylabel("ELBO") plt.legend() sns.despine() plt.savefig("advi_beta_binom_pymc_loss.pdf") plt.show()
Image in a Jupyter notebook
print(f"ELBO comparison for last 1% iterations:\nJAX ELBO: {elbo[-10:].mean()}\nPymc ELBO: {advi.hist[-100:].mean()}")

True posterior, JAX q(x), and pymc q(x)

plt.plot(theta_range, exact_posterior.prob(theta_range), "b--", label="$p(x)$: True Posterior") sns.kdeplot(transformed_samples, color="red", label="$q(x)$: learnt - jax", clip=(0.0, 1.0), bw_adjust=1.5) sns.kdeplot(thetas, label="$q(x)$: learnt - pymc", clip=(0.0, 1.0), bw_adjust=1.5) plt.xlabel("theta") plt.legend(bbox_to_anchor=(1.3, 1)) sns.despine()
Image in a Jupyter notebook

Plot of loc and scale for variational distribution

fig1, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4), sharey=True) locs, scales = params["loc"], params["scale"] # plot loc # JAX ax1.plot(locs, label="JAX: loc") ax1.set_ylabel("loc") ax1.legend() # pymc ax2.plot(tracker["mean"], label="Pymc: loc") ax2.legend() sns.despine() # plot scale fig2, (ax3, ax4) = plt.subplots(1, 2, figsize=(10, 4), sharey=True) # JAX ax3.plot(positivity_fn(jnp.array(scales)), label="JAX: scale") # apply softplus on scale ax3.set_xlabel("Iterations") ax3.set_ylabel("scale") ax3.legend() # pymc ax4.plot(tracker["std"], label="Pymc: scale") ax4.set_xlabel("Iterations") ax4.legend() sns.despine();
Image in a Jupyter notebookImage in a Jupyter notebook