Path: blob/master/notebooks/book2/04/rbm_contrastive_divergence.ipynb
1192 views
A demonstration of using contrastive divergence to train the parameters of a restricted Boltzmann machine.
References and Materials
This notebook has made use of various textbooks, articles, and other resources with some particularly relevant examples given below.
RBM and CD Background:
[1] K. Murphy. Probabilistic Machine Learning: Advanced Topics. MIT Press, 2023.
D. MacKay. Information theory, inference and learning algorithms. Cambridge University Press, 2003.
Hastie, Trevor, et al. The elements of statistical learning: data mining, inference, and prediction. Vol. 2. New York: springer, 2009.
Practical advice for training RBMs with the CD algorithm:
[2] G. Hinton. A Practical Guide to Training Restricted Boltzmann Machines. Tech. rep. U. Toronto, 2010.
Code:
Plotting functions
Restricted Boltzmann Machines
Restricted Boltzmann Machines (RBMs) are a type of energy based model in which the connectivity of nodes is carefully designed to facilitate efficient sampling methods. For details of RBMs see the sections on undirected graphical models (Section 4.3) and energy-based models (Chapter 23) in [1]. We reproduce here some of the relevant sampling equations which we will instrumenting below.
We will be considering RBMs with binary units in both the hidden, , and visible, , layers.
In general for Boltzmann machines with hidden units the probability of a particular state for the visible nodes is given by: where is the collection of parameters :
and the energy of state is given by:
In restricted Boltzmann machines the hidden units are independent from one another conditional on the visible units, and vic versa. This means that it is straightforward to do conditional block-sampling of the state of the network.
This independence structure has the property that when conditionally sampling, the probability that the th hidden unit is active is, and probability that the th visible unit is active is given by,
The function is the sigmoid function:
Contrastive Divergence
Contrastive divergence (CD) is the name for a family of algorithms used to perform approximate maximum likelihood training for RBMs.
Contrastive divergence approximates the gradient of the log probability of the data (our desired objective function) by intialising an MCMC chain on the data vector and sampling for a small number of steps. The insight behind CD is that even with a very small number of steps the process still provides gradient information which can be used to fit the model parameters.
Here we implement the CD1 algorithm which uses just a single round of Gibbs sampling.
For more details on the CD algorithm see [1] (Section 23.2.2).
Load MNIST
Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to ~/tensorflow_datasets/mnist/3.0.1...
Dataset mnist downloaded and prepared to ~/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
Training with optax
Evaluating Training
The reconstruction loss is a heuristic measure of training performance. It measures a combination of two effects:
The difference between the equilibrium distribution of the RBM and the empirical distribution of the data.
The mixing rate of the Gibbs sampling.
The first of these effects tends to be what we care about however it is impossible to distinguish it from the second [2].
The objective function which contrastive divergence optimizes is the probability that the RBM assigns to the dataset. For the reasons outlined above we cannot calculate this directly because it requires knowledge of the partition function.
We can however compare the average free energy between two different sets of data. In the comparison the partition function cancel out. Hinton [2] suggests using this comparison as a measure of overfitting. If the model is not overfitting the values should be approximately the same. As the model starts to overfit the free energy of the validation data will increase with respect to the training data so the difference between the two values will become increasingly negative.
Classification
While Boltzmann Machines are generative models they can be adapted to be used for classification and other discriminative tasks.
Here we use RBM to transform a sample image into the hidden representation and then use this as input to a logistic regression classifier.
This classification is more accurate than when using the raw image data as input. Furthermore, the hidden the accuracy of classification increases as the training time increases.
Alternatively, a RBM can made to include a set of visible units which encode the class label. Classification is then performed by clamping each of the class units in turn along with the test sample. The unit that gives the lowest free energy is the chosen class [2].
The increase in accuracy here is modest because of the small number of hidden units. When 1000 hidden units are used the Epoch-5 accuracy approaches 97.5%.
We can explore the quality of the learned hidden tranformation by inspecting reconstructions of these test images.
You can explore this by choosing different subsets of images in the cell below: