Variational Information Bottleneck Demo
This notebook aims to serve as a modern tutorial introduction to the variational information bottleneck method of Alemi et al 2016.
Source: https://github.com/alexalemi/vib_demo/blob/master/vib_demo_2021.ipynb
Imports
We'll experiment on the MNIST dataset which we can load in memory.
Data
Utils
Some helpful utility functions and things we'll use below.
Model
To start we'll create the model components, here the encoder which will take our images and turn them into our vector representation.
Next we'll define our classifier, the network that will take the vector representation and predict which class each of the images are in.
Deterministic Classifier
Model definition
Training
We'll write a simple loss function where the first argument is the parameters, thsi will make it easy for us to use JAX's automatic differentiation capabilities to generate the gradients for optimization.
Now we can train for a while and observe as the network makes better predictions.
After training, our final parameters are stored in store.params
which we could use to evaluate the full training set accuracy
We can also visualize our representation, since we chose a two dimensional representation its easy to render a simple scatterplot.
Notice that the network has learned to separate the different classes (here each point is colored according to which class it is.)
To better see what is going on we can embed some example images onto the scene.
VIB
Now we'll try to train a VIB version of this network. Whereas before we learned a determinsitic representation of each image, now our representation will be stochastic, each image will be mapped to a distribution.
We'll keep things two dimensional, so we'll use a two dimensional Normal distribution, parameterized by a two dimensional mean and three parameters we'll use to parameterize the covariance matrix.
Create the model and initialize the parameters.
Our new loss is:
or
the combinatino of our classification error (the term on the left), and times the rate .
The VIB train step is same as before, though now we also update the random seed each step as our representation is stochastic.
Now each image, instead of returning a point actually returns an entire distribution.
We can visualize the probability distribution associated with any one image.
For example, it we look at the two 1s in the batch above, looking at samples from the encoder its difficult to tell them apart, the VIB encoder has learned to map the two images to very similar distributions.
We can visualize the mean of the embeddings of each image.
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Input In [1], in <cell line: 1>()
----> 1 fig, axs = plt.subplots()
2 z_dist = vib.apply(vib_store.params, xx, method=vib.encode)
3 means = z_dist.mean()
NameError: name 'plt' is not defined
To better visualize what is happening, we can show each distribution with an ellipse denoting its one sigma contour.
Notice that the ellipses are frequently on top of one another, the network has thrown out a lot of the information contained in the original image and now largely focusses on its class.