Path: blob/master/notebooks/misc/bnn_mnist_sgld_jaxbayes.ipynb
1192 views
Bayesian MLP for MNIST using preconditioned SGLD
We use the Jax Bayes library by James Vuckovic to fit an MLP to MNIST using SGD, and SGLD (with RMS preconditioning). Code is based on:
Setup
Data
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...
Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
Model
SGD
SGLD
Uncertainty analysis
We select the predictions above a confidence threshold, and compute the predictive accuracy on that subset. As we increase the threshold, the accuracy should increase, but fewer examples will be selected.
SGD
For the plugin estimate, the model is very confident on nearly all of the points.
SGLD
Distribution shift
We now examine the behavior of the models on the Fashion MNIST dataset. We expect the predictions to be much less confident, since the inputs are now 'out of distribution'. We will see that this is true for the Bayesian approach, but not for the plugin approximation.
Downloading and preparing dataset fashion_mnist/3.0.1 (download: 29.45 MiB, generated: 36.42 MiB, total: 65.87 MiB) to /root/tensorflow_datasets/fashion_mnist/3.0.1...
Dataset fashion_mnist downloaded and prepared to /root/tensorflow_datasets/fashion_mnist/3.0.1. Subsequent calls will reuse this data.
SGD
We see that the plugin estimate is confident (but wrong!) on many of the predictions, which is undesirable. If consider a confidence threshold of 0.6, the plugin approach predicts on about 80% of the examples, even though the accuracy is only about 6% on these.
SGLD
If consider a confidence threshold of 0.6, the Bayesian approach predicts on less than 20% of the examples, on which the accuracy is ~4%.