Path: blob/master/notebooks/book2/19/bnn_mnist_sgld_jaxbayes.ipynb
1192 views
Bayesian MLP for MNIST using preconditioned SGLD
We use the Jax Bayes library by James Vuckovic to fit an MLP to MNIST using SGD, and SGLD (with RMS preconditioning). Code is based on:
Setup
Data
The Bayesian NN is taken from SGMCMCJAX. However, there are couple of changes made. These can be listed as follows:
The random_layer function initialises weights from truncated_normal rather than normal distribution.
The random_layer function initialises weights with zeros rather than sampling from normal distribution.
Activation function can be determined instead of using only softmax function.
Model
SGD
SGLD
Uncertainty analysis
We select the predictions above a confidence threshold, and compute the predictive accuracy on that subset. As we increase the threshold, the accuracy should increase, but fewer examples will be selected.
The following two functions are taken from JaxBayes
SGD
For the plugin estimate, the model is very confident on nearly all of the points.
SGLD
Distribution shift
We now examine the behavior of the models on the Fashion MNIST dataset. We expect the predictions to be much less confident, since the inputs are now 'out of distribution'. We will see that this is true for the Bayesian approach, but not for the plugin approximation.