Path: blob/master/notebooks/book2/05/minimize_kl_divergence.ipynb
1193 views
Kernel: spyder-dev
Minimize KL-divergence between two distributions
In [1]:
import jax.numpy as jnp import jax import matplotlib.pyplot as plt try: import distrax except: %pip install -qq distrax import distrax try: import optax except: %pip install -qq optax import optax try: import seaborn as sns except: %pip install -qq seaborn import seaborn as sns try: from probml_utils import savefig, latexify, is_latexify_enabled except ModuleNotFoundError: print("installing probml_utils") %pip install -qq git+https://github.com/probml/probml-utils.git from probml_utils import savefig, latexify, is_latexify_enabled from jax.config import config config.update("jax_enable_x64", True)
In [2]:
# for making the book figures if False: import os os.environ["LATEXIFY"] = "1" os.environ["FIG_DIR"] = "/Users/kpmurphy/github/bookv2/figures" latexify(width_scale_factor=1.6, fig_height=1.0)
In [3]:
# generating a bimodal distribution ptrue mix = 0.5 mean_one, mean_two = 1, 10 scale_one, scale_two = 1, 1.5 ptrue = distrax.MixtureSameFamily( mixture_distribution=distrax.Categorical(probs=[mix, 1 - mix]), components_distribution=distrax.Normal(loc=[mean_one, mean_two], scale=[scale_one, scale_two]), )
In [4]:
key = jax.random.PRNGKey(1234) def kl_sampling(params, q, samples=100000): # loss function when KL(p:q) p = distrax.Normal(loc=params[0], scale=params[1]) sample_set = p.sample(seed=key, sample_shape=samples) return jnp.mean(p.log_prob(sample_set) - q.log_prob(sample_set)) def kl_sampling_inverse(params, q, samples=100000): # loss function when KL(q:p) p = distrax.Normal(loc=params[0], scale=params[1]) sample_set = q.sample(seed=key, sample_shape=samples) return jnp.mean(q.log_prob(sample_set) - p.log_prob(sample_set))
In [7]:
def fit(params, optimizer, loss_fun, n_itr): opt_state = optimizer.init(params) loss = [] fn = jax.jit(jax.value_and_grad(loss_fun)) for i in range(n_itr): samples = 100000 loss_value, grads = jax.value_and_grad(loss_fun)(params, ptrue, samples) # loss_value, grads = fn(params, ptrue, samples) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) loss.append(loss_value) return params, loss
In [8]:
optimizer = optax.adam(learning_rate=0.05) n_itr = 650 params_one = jnp.array([5.0, 8.0]) optimized_params_one, loss_one = fit(params=params_one, optimizer=optimizer, loss_fun=kl_sampling, n_itr=n_itr) params_two = jnp.array([5.0, 10.0]) optimized_params_two, loss_two = fit(params=params_two, optimizer=optimizer, loss_fun=kl_sampling_inverse, n_itr=n_itr)
In [9]:
x_loss = jnp.linspace(1, n_itr, n_itr) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) ax1.plot(x_loss, loss_one, color="green") ax2.plot(x_loss, loss_two, color="orange") ax1.set_xlabel("Iterations") ax1.set_ylabel("loss") ax1.set_title("Iteration vs loss\n" + r"$\min_{q}\ KL[q ; p]$") ax2.set_xlabel("Iterations") ax2.set_ylabel("loss") ax2.set_title("Iteration vs loss\n" + r"$\min_{q}\ KL[p ; q]$")
Out[9]:
Text(0.5, 1.0, 'Iteration vs loss\n$\\min_{q}\\ KL[p ; q]$')
In [11]:
fig, ax = plt.subplots(1, 1) x = jnp.linspace(-8, 20, int(1e6)) label_one = str() label_two = str() if is_latexify_enabled(): label_one = r"$$ \min_q\ KL[q ; p]$$" label_two = r"$$ \min_q\ KL[p ; q]$$" else: label_one = r"$\min_q\ KL[q ; p]$" label_two = r"$\min_q\ KL[p ; q]$" ax.plot(x, ptrue.prob(x), label=r"$p$") ax.plot( x, distrax.Normal(loc=optimized_params_two[0], scale=optimized_params_two[1]).prob(x), color="orange", label=label_two, linestyle="--", ) ax.plot( x, distrax.Normal(loc=optimized_params_one[0], scale=optimized_params_one[1]).prob(x), color="green", label=label_one, linestyle="-.", ) ax.set_xlabel("x") ax.set_ylabel("P(x)") ax.legend(fontsize=5.2, loc="upper right") sns.despine() savefig("minimize_kl_divergence_latexified")
Out[11]:
saving image to /Users/kpmurphy/github/bookv2/figures/minimize_kl_divergence_latexified.pdf
Figure size: [3.75 1. ]