Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/05/vib_demo.ipynb
1193 views
Kernel: Python [conda env:jax_gpu]

Variational Information Bottleneck Demo

This notebook aims to serve as a modern tutorial introduction to the variational information bottleneck method of Alemi et al 2016.

Source: https://github.com/alexalemi/vib_demo/blob/master/vib_demo_2021.ipynb

Imports

%%capture # Silence WARNING:root:The use of `check_types` is deprecated and does not have any effect. # https://github.com/tensorflow/probability/issues/1523 import logging logger = logging.getLogger() class CheckTypesFilter(logging.Filter): def filter(self, record): return "check_types" not in record.getMessage() logger.addFilter(CheckTypesFilter()) import jax import jax.numpy as np from jax import grad, vmap, jit, random try: import flax except ModuleNotFoundError: %pip install -qq flax import flax import flax.linen as nn import matplotlib.pyplot as plt import seaborn as sns try: import einops except ModuleNotFoundError: %pip install -qq einops import einops try: import tensorflow_probability.substrates.jax as tfp except ModuleNotFoundError: %pip install -qq tensorflow-probability import tensorflow_probability.substrates.jax as tfp tfd = tfp.distributions from functools import partial try: from probml_utils import savefig, latexify, is_latexify_enabled except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git from probml_utils import savefig, latexify, is_latexify_enabled try: import tensorflow_datasets as tfds except ModuleNotFoundError: %pip install -qq tensorflow tensorflow_datasets import tensorflow_datasets as tfds import typing from typing import Any try: import chex except ModuleNotFoundError: %pip install -qq chex import chex try: import optax except ModuleNotFoundError: %pip install -qq optax import optax
latexify(width_scale_factor=3, fig_height=2)
/home/patel_zeel/miniconda3/envs/jax_gpu/lib/python3.9/site-packages/probml_utils/plotting.py:26: UserWarning: LATEXIFY environment variable not set, not latexifying warnings.warn("LATEXIFY environment variable not set, not latexifying")

We'll experiment on the MNIST dataset which we can load in memory.

Data

# From https://twitter.com/alemi/status/1042834067173957632 "Here is a one line MNIST Dataset Loader in Python in a tweet." import numpy as onp try: import imageio except ModuleNotFoundError: %pip install -qq imageio import imageio ims, labels = onp.split( imageio.imread( "https://gist.github.com/alexalemi/4b240729f6ce8aa62b24b4eb1cc34167/raw/2b7211ce842a262684b453746de4e0946c29c2cc/mnist.png" ).ravel(), [-70000], ) ims, labels = [onp.split(y, [60000]) for y in (ims.reshape((-1, 28, 28)), labels.ravel())]
/tmp/ipykernel_638232/2516621731.py:11: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning dissapear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly. imageio.imread(
train_ims, test_ims = ims train_lbs, test_lbs = labels cut = 100 x = train_ims[:cut] y = train_lbs[:cut] test_batch = (test_ims[:100], test_lbs[:100])

Utils

Some helpful utility functions and things we'll use below.

from IPython.display import display_png import matplotlib.cm as cm import matplotlib as mpl import six def imify(arr, vmin=None, vmax=None, cmap=None, origin=None): """Convert an array to an image.""" sm = cm.ScalarMappable(cmap=cmap) sm.set_clim(vmin, vmax) if origin is None: origin = mpl.rcParams["image.origin"] if origin == "lower": arr = arr[::-1] rgba = sm.to_rgba(arr, bytes=True) return rgba def rawarrview(array, **kwargs): """Visualize an array as if it was an image in colab notebooks. Arguments: array: an array which will be turned into an image. **kwargs: Additional keyword arguments passed to imify. """ f = six.BytesIO() imarray = imify(array, **kwargs) plt.imsave(f, imarray, format="png") f.seek(0) dat = f.read() f.close() display_png(dat, raw=True)
def reshape_image_batch(ims, cut=None, rows=None): if cut is not None: ims = ims[:cut] n = ims.shape[0] if rows is None: rows = int(np.sqrt(n)) cols = n // rows return einops.rearrange(ims, "(rows cols) x y -> (rows x) (cols y)", rows=rows, cols=cols)
def batcher(rng, batch_size, ims, lbs): """Creates a python iterator that batches data.""" n = ims.shape[0] while True: rng, spl = random.split(rng) pks = random.permutation(spl, n) ims = ims[pks] lbs = lbs[pks] for i in range(n // batch_size): yield ( ims[i * batch_size : (i + 1) * batch_size], lbs[i * batch_size : (i + 1) * batch_size], )
def generate_tril_dist(params, dim=2): """Given flat parameters, assemble a MultivariateTril Distribution.""" loc = params[..., :dim] tril_params = tfp.bijectors.FillScaleTriL(diag_bijector=tfp.bijectors.Exp(), diag_shift=None).forward( params[..., dim:] ) return tfd.MultivariateNormalTriL(loc=loc, scale_tril=tril_params)
@flax.struct.dataclass class Store: """A simple dataclass to hold our optimization data.""" params: chex.Array state: Any rng: Any step: int = 0
seethrough = plt.cm.colors.ListedColormap([(0, 0, 0, 0), (0, 0, 0, 1)]) colors = plt.cm.tab10(np.linspace(0, 1, 10)) def add_im(axs, im, xy, h=0.3, **kwargs): x, y = xy axs.imshow(im, extent=(x - h / 2, x + h / 2, y - h / 2, y + h / 2), cmap=seethrough, zorder=2, **kwargs)
def ellipse_coords(cov): """Given a Covariance matrix, return the Ellipse coordinates.""" u, v = jax.scipy.linalg.eigh(cov) width = np.sqrt(5.991 * u[1]) height = np.sqrt(5.991 * u[0]) angle = np.arctan2(v[1][1], v[1][0]) return width, height, angle
from matplotlib.patches import Ellipse def add_ellipse(axs, mean, cov, color, alpha=0.4): width, height, angle = ellipse_coords(cov) ep = Ellipse( xy=(mean[0], mean[1]), width=width, height=height, angle=np.degrees(angle), color=color, alpha=alpha, fill=False, linewidth=2, ) axs.add_artist(ep)

Model

To start we'll create the model components, here the encoder which will take our images and turn them into our vector representation.

Next we'll define our classifier, the network that will take the vector representation and predict which class each of the images are in.

class Encoder(nn.Module): """The encoder takes in images and spits out a vector representation.""" output_dim: int = 2 @nn.compact def __call__(self, x): # We'll use a simple feed forward neural network. x = np.expand_dims(x, -1) / 128.0 - 1.0 x = einops.rearrange(x, "b x y d -> b (x y d)") x = nn.Dense(512)(x) x = nn.LayerNorm()(nn.gelu(x)) y = nn.Dense(512)(x) y = nn.LayerNorm()(nn.gelu(y)) x = x + nn.gelu(y) z = nn.Dense(self.output_dim)(x) return z class Classifier(nn.Module): """The classifier takes the representation and attempts to predict a label.""" @nn.compact def __call__(self, x): # A single hidden layer neural network. y = nn.Dense(32)(x) y = nn.LayerNorm()(nn.gelu(y)) return nn.Dense(10)(y)

Deterministic Classifier

Model definition

class Deterministic(nn.Module): """Represents a traditional, deterministic classifier.""" z_dim: int = 2 # size of the intermediate representation. def setup(self): self.encoder = Encoder(self.z_dim) self.classifier = Classifier() def encode(self, x): return self.encoder(x) def classify(self, z): return self.classifier(z) def __call__(self, x): z = self.encode(x) return self.classify(z)

Training

# in JAX we need to be deliberate about our randomness, we'll choose a seed here rng = random.PRNGKey(0) # We'll generate our first model, the deterministic network here. model = Deterministic() # We'll initialize the model and get its parameters. logits, params = model.init_with_output(rng, x) nparams = jax.tree_util.tree_reduce(lambda acc, val: acc + np.prod(np.array(val.shape)), params, 0) print(f"Number of parameters: {nparams:,}")
Number of parameters: 668,140

We'll write a simple loss function where the first argument is the parameters, thsi will make it easy for us to use JAX's automatic differentiation capabilities to generate the gradients for optimization.

@jit def loss(params, batch): """The loss function for our deterministic network.""" ims, lbs = batch # Compute the logits for each image by going through the model. logits = model.apply(params, ims) # Use those to create a categorical distribution. pred_dist = tfd.Categorical(logits=logits) # Evaluate the negative log likelihood class_err = -pred_dist.log_prob(lbs).mean() # Output several statistics that we can look at. aux = {"correct": logits.argmax(axis=-1) == lbs, "logits": logits, "class_err": class_err, "preddist": pred_dist} # JAX expects the first argument to be a scalar # that we can differentiate. return class_err, aux

To optimize we'll use the adabelief as implemented in optax. Optax separates gradient computation from the application of updates, the basic usage looks something like this:

opt = optax.sgd(learning_rate) state = opt.init(params) def train_step(state, args): grads = jax.grad(loss)(params, *args) updates, state = opt.update(grads, state, params) params = optax.apply_updates(params, updates) return state
opt = optax.adabelief(1e-4) store = Store(params, opt.init(params), rng, 0) batches = batcher(rng, 500, train_ims, train_lbs) testset = (test_ims, test_lbs)
@jit def train_step(store: Store, batch): """Implements a training set, computing and applying gradients for a batch.""" (val, aux), grads = jax.value_and_grad(loss, has_aux=True)(store.params, batch) updates, state = opt.update(grads, store.state, store.params) params = optax.apply_updates(store.params, updates) return store.replace(params=params, state=state, step=store.step + 1), (val, aux)

Now we can train for a while and observe as the network makes better predictions.

for step, batch in zip(range(6_000), batches): store, (val, aux) = train_step(store, batch) if store.step % 300 == 0: print(f"step = {store.step}, train acc = {aux['correct'].mean():0.2%} train loss = {val:.3}", flush=True) # Evaluate on the test set tval, taux = loss(store.params, testset) print(f"test acc = {taux['correct'].mean():0.2%} test loss = {taux['class_err'].mean():.3}", flush=True)
step = 300, train acc = 61.20% train loss = 1.3 test acc = 62.28% test loss = 1.29 step = 600, train acc = 79.40% train loss = 0.87 test acc = 77.22% test loss = 0.917 step = 900, train acc = 86.40% train loss = 0.642 test acc = 85.74% test loss = 0.671 step = 1200, train acc = 90.40% train loss = 0.506 test acc = 90.58% test loss = 0.503 step = 1500, train acc = 93.20% train loss = 0.356 test acc = 92.78% test loss = 0.377 step = 1800, train acc = 96.20% train loss = 0.235 test acc = 93.65% test loss = 0.318 step = 2100, train acc = 97.40% train loss = 0.179 test acc = 94.54% test loss = 0.277 step = 2400, train acc = 98.60% train loss = 0.125 test acc = 94.70% test loss = 0.256 step = 2700, train acc = 97.60% train loss = 0.117 test acc = 94.88% test loss = 0.246 step = 3000, train acc = 98.80% train loss = 0.0788 test acc = 95.15% test loss = 0.23 step = 3300, train acc = 99.40% train loss = 0.0559 test acc = 95.14% test loss = 0.231 step = 3600, train acc = 98.80% train loss = 0.0748 test acc = 94.99% test loss = 0.233 step = 3900, train acc = 99.60% train loss = 0.0568 test acc = 95.32% test loss = 0.227 step = 4200, train acc = 99.20% train loss = 0.0764 test acc = 95.04% test loss = 0.231 step = 4500, train acc = 100.00% train loss = 0.0169 test acc = 95.53% test loss = 0.228 step = 4800, train acc = 99.60% train loss = 0.0208 test acc = 95.26% test loss = 0.235 step = 5100, train acc = 98.40% train loss = 0.088 test acc = 95.23% test loss = 0.249 step = 5400, train acc = 99.80% train loss = 0.0282 test acc = 95.56% test loss = 0.233 step = 5700, train acc = 100.00% train loss = 0.00962 test acc = 95.70% test loss = 0.221 step = 6000, train acc = 100.00% train loss = 0.00661 test acc = 95.86% test loss = 0.225

After training, our final parameters are stored in store.params which we could use to evaluate the full training set accuracy

val, aux = loss(store.params, (train_ims, train_lbs)) print(f"Final Train accuracy: {aux['correct'].mean():0.2%}")
Final Train accuracy: 99.87%

We can also visualize our representation, since we chose a two dimensional representation its easy to render a simple scatterplot.

fig, axs = plt.subplots() xx = test_ims[::10] yy = test_lbs[::10] zs = model.apply(store.params, xx, method=model.encode) axs.scatter(zs[..., 0], zs[..., 1], c=yy, cmap="tab10", alpha=0.3, s=3) sns.despine() savefig("vib-deterministic-2d")
/home/patel_zeel/miniconda3/envs/jax_gpu/lib/python3.9/site-packages/probml_utils/plotting.py:80: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebook

Notice that the network has learned to separate the different classes (here each point is colored according to which class it is.)

To better see what is going on we can embed some example images onto the scene.

fig, axs = plt.subplots() axs.scatter(zs[..., 0], zs[..., 1], c=yy, cmap="tab10", alpha=0.3, s=3) xlim = plt.xlim() ylim = plt.ylim() for i in range(1, xx.shape[0], 10): add_im(axs, xx[i], zs[i], h=1, alpha=0.8) plt.xlim(xlim) plt.ylim(ylim) sns.despine() savefig("vib-deterministic-2d-digits")
Image in a Jupyter notebook

VIB

Now we'll try to train a VIB version of this network. Whereas before we learned a determinsitic representation of each image, now our representation will be stochastic, each image will be mapped to a distribution.

We'll keep things two dimensional, so we'll use a two dimensional Normal distribution, parameterized by a two dimensional mean and three parameters we'll use to parameterize the covariance matrix.

class VIB(nn.Module): """An implementation of the Variational Information Bottleneck Method.""" z_dim: int = 2 mix_components: int = 32 def setup(self): # We'll use the same encoder as before but predict additional parameters # for our distribution. num_params = self.z_dim + (self.z_dim * (self.z_dim + 1)) // 2 self.encoder = Encoder(output_dim=num_params, name="encoder") # We'll use the same classifier as before self.decoder = Classifier(name="classifier") # We also need a marginal distribution, here we'll use a fixed # isotropic Gaussian, though this could also be learned. self.marginal = tfd.MultivariateNormalDiag(loc=np.zeros(self.z_dim)) def encode(self, x): """Take an image and return a two dimensional distribution.""" params = self.encoder(x) return generate_tril_dist(params, self.z_dim) def decode(self, z_samps): """Given a sampled representation, predict the class.""" logits = self.decoder(z_samps) return tfd.Categorical(logits=logits) def __call__(self, rng, batch, num_samps=16): """Compute relevant VIB quantities.""" ims, lbs = batch z_dist = self.encode(ims) z_samps = z_dist.sample((num_samps,) if num_samps > 1 else (), seed=rng) y_dist = self.decode(z_samps) lps = y_dist.log_prob(lbs) class_err = -lps.mean(0) rate = (z_dist.log_prob(z_samps) - self.marginal.log_prob(z_samps)).mean(0) logits = y_dist.logits correct = y_dist.logits.argmax(-1) == lbs # Now that we have a stochastic representation, we can marginalize out # that representation to get increased predictive performance. lse_class_err = -jax.nn.logsumexp(lps - np.log(num_samps), axis=0) lse_correct = jax.nn.logsumexp(y_dist.logits - np.log(num_samps), axis=0).argmax(-1) == lbs return { "class_err": class_err, "rate": rate, "correct": correct, "lse_correct": lse_correct, "lse_class_err": lse_class_err, "logits": logits, }

Create the model and initialize the parameters.

vib = VIB() aux, vib_params = vib.init_with_output(rng, rng, batch) vib_store = Store(vib_params, opt.init(vib_params), rng)

Our new loss is:

logq(yz)+βlogp(zx)q(z)\left\langle -\log q(y|z) + \beta \log \frac{p(z|x)}{q(z)} \right\rangle

or

C+βRC + \beta R

the combinatino of our classification error (the term on the left), and β\beta times the rate RR.

@jit def vib_loss(params, rng, batch, beta=0.5): aux = vib.apply(params, rng, batch) loss = aux["class_err"].mean() + beta * aux["rate"].mean() return loss, aux

The VIB train step is same as before, though now we also update the random seed each step as our representation is stochastic.

@jit def vib_train_step(store: Store, batch): rng, spl = random.split(store.rng) (val, aux), grads = jax.value_and_grad(vib_loss, has_aux=True)(store.params, spl, batch) updates, state = opt.update(grads, store.state, store.params) params = optax.apply_updates(store.params, updates) return store.replace(params=params, state=state, step=store.step + 1, rng=rng), (val, aux)
for step, batch in zip(range(3_000), batches): vib_store, (val, aux) = vib_train_step(vib_store, batch) if vib_store.step % 300 == 0: print( f"step = {store.step}, train acc = {aux['correct'].mean():0.2%} lse acc = {aux['lse_correct'].mean():0.2%} train loss = {val:.3}", flush=True, ) # Evaluate on the test set tval, taux = vib_loss(vib_store.params, rng, testset, 128) print( f"test acc = {taux['correct'].mean():0.2%} rate: {taux['rate'].mean():.3} lse_acc = {taux['lse_correct'].mean():0.2%}", flush=True, )
step = 6000, train acc = 26.29% lse acc = 35.80% train loss = 2.31 test acc = 26.26% rate: 0.493 lse_acc = 37.92% step = 6000, train acc = 35.31% lse acc = 53.40% train loss = 2.21 test acc = 34.20% rate: 0.854 lse_acc = 48.25% step = 6000, train acc = 41.13% lse acc = 57.00% train loss = 2.09 test acc = 40.81% rate: 1.13 lse_acc = 55.72% step = 6000, train acc = 46.74% lse acc = 64.20% train loss = 2.04 test acc = 44.66% rate: 1.26 lse_acc = 59.99% step = 6000, train acc = 47.86% lse acc = 65.00% train loss = 2.01 test acc = 48.84% rate: 1.44 lse_acc = 64.87% step = 6000, train acc = 55.18% lse acc = 72.80% train loss = 1.94 test acc = 54.34% rate: 1.62 lse_acc = 71.74% step = 6000, train acc = 55.35% lse acc = 75.20% train loss = 1.94 test acc = 57.67% rate: 1.64 lse_acc = 77.64% step = 6000, train acc = 61.31% lse acc = 79.40% train loss = 1.89 test acc = 61.77% rate: 1.74 lse_acc = 82.59% step = 6000, train acc = 67.28% lse acc = 88.60% train loss = 1.81 test acc = 65.08% rate: 1.81 lse_acc = 86.45% step = 6000, train acc = 69.46% lse acc = 92.20% train loss = 1.81 test acc = 68.63% rate: 1.87 lse_acc = 89.02%

Now each image, instead of returning a point actually returns an entire distribution.

rawarrview(reshape_image_batch(test_batch[0]), cmap="bone_r")
Image in a Jupyter notebook
z_dist = vib.apply(vib_store.params, test_batch[0], method=vib.encode) print(z_dist)
tfp.distributions.MultivariateNormalTriL("MultivariateNormalTriL", batch_shape=[100], event_shape=[2], dtype=float32)

We can visualize the probability distribution associated with any one image.

xs = np.linspace(-2, 2, 300) zs = np.array(np.meshgrid(xs, xs)).T print(zs.shape) zz = zs[..., None, :] print(zz.shape) probs = z_dist.prob(zz) fig, axs = plt.subplots(2, 10, figsize=(20, 4)) for i in range(10): axs[0, i].imshow(test_batch[0][i], cmap="bone_r") axs[0, i].axis("off") axs[1, i].pcolor(xs, xs, probs[..., i])
(300, 300, 2) (300, 300, 1, 2)
Image in a Jupyter notebook

For example, it we look at the two 1s in the batch above, looking at samples from the encoder its difficult to tell them apart, the VIB encoder has learned to map the two images to very similar distributions.

z_samps = z_dist.sample((128,), seed=rng) print(z_samps.shape) fig, axs = plt.subplots() axs.scatter(*z_samps[:, 2].T) axs.scatter(*z_samps[:, 5].T) add_ellipse(axs, z_dist.mean()[2], z_dist.covariance()[2], "C0", alpha=0.7) add_ellipse(axs, z_dist.mean()[5], z_dist.covariance()[5], "C1", alpha=0.7)
(128, 100, 2)
Image in a Jupyter notebook

We can visualize the mean of the embeddings of each image.

fig, axs = plt.subplots() z_dist = vib.apply(vib_store.params, xx, method=vib.encode) means = z_dist.mean() covs = z_dist.covariance() axs.scatter(means[..., 0], means[..., 1], c=yy, cmap="tab10", s=3) sns.despine() savefig("vib-deterministic-2d-embeddings")
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Input In [1], in <cell line: 1>() ----> 1 fig, axs = plt.subplots() 2 z_dist = vib.apply(vib_store.params, xx, method=vib.encode) 3 means = z_dist.mean() NameError: name 'plt' is not defined

To better visualize what is happening, we can show each distribution with an ellipse denoting its one sigma contour.

from matplotlib.patches import Ellipse fig, axs = plt.subplots() for i in range(200): add_ellipse(axs, means[i], covs[i], colors[yy[i]]) # axs.grid('off'); # axs.patch.set_facecolor('white') edge = np.max(np.abs(means)) axs.set_xlim((-edge, edge)) axs.set_ylim((-edge, edge)) for i in range(1, xx.shape[0], 10): add_im(axs, xx[i], means[i], h=0.4, alpha=0.8) sns.despine() savefig("vib-stochastic-2d-digits")
Image in a Jupyter notebook

Notice that the ellipses are frequently on top of one another, the network has thrown out a lot of the information contained in the original image and now largely focusses on its class.

aux = vib.apply(vib_store.params, rng, (xx, yy))
(~aux["lse_correct"]).sum()
DeviceArray(111, dtype=int32)
(aux["lse_correct"] == False).sum()
DeviceArray(111, dtype=int32)
args = aux["lse_class_err"].argsort() # -ve log prob, least probable first print(args[:10]) print(aux["lse_class_err"][args[:10]])
[760 479 824 995 860 215 889 335 606 140] [0.18474936 0.19205356 0.19407225 0.19793534 0.20415187 0.20613337 0.2114532 0.21730804 0.21964717 0.22603917]
# Show unusual images args = aux["lse_class_err"].argsort() plt.imshow(einops.rearrange(np.take(xx, args[-40:], axis=0), "(h1 h2) x y -> (h1 x) (h2 y)", h1=5, h2=8), cmap="bone_r")
<matplotlib.image.AxesImage at 0x7f6b0ca512b0>
Image in a Jupyter notebook
xs = np.linspace(-2, 2, 300) all_zs = np.array(np.meshgrid(xs, xs)).T y_dist = vib.apply(vib_store.params, all_zs, method=vib.decode)