Path: blob/master/deprecated/notebooks/bnn_hierarchical_blackjax.ipynb
1192 views
Hierarchical Bayesian neural networks
Code is based on This blog post by Thomas Wiecki. Original PyMC3 Notebook. Converted to Blackjax by Aleyna Kara (@karalleyna) and Kevin Murphy (@murphyk). (For a Numpyro version, see here.)
We create T=18 different versions of the "two moons" dataset, each rotated by a different amount. These correspond to T different nonlinear binary classification "tasks" that we have to solve. We only get a few labeled samples from each each task, so solving them separately (with T independent MLPs, or multi layer perceptrons) will result in poor performance. If we pool all the data, and fit a single MLP, we also get poor performance, because we are mixing together different decision boundaries. But if we use a hierarchical Bayesian model, with one MLP per task, and one learned prior MLP, we will get better results, as we will see.
Below is a high level illustration of the multi-task setup. is the learned prior, and are the parameters for task . We assume training samples per task, and test samples. (We could of course consider more imbalanced scenarios.)
Setup
Data
We create T=18 different versions of the "two moons" dataset, each rotated by a different amount. These correspond to T different binary classification "tasks" that we have to solve.
Utility functions for training and testing
Hyperparameters
We use an MLP with 2 hidden layers, each with 5 hidden units.
Fit separate MLPs, one per task
Let be the weight for node to node in layer in task . We assume and compute the posterior for all the weights.
Results
Accuracy is reasonable, but the decision boundaries have not captured the underlying Z pattern in the data, due to having too little data per task. (Bayes model averaging results in a simple linear decision boundary, and prevents overfitting.)
Below we show that the decision boundaries do not look reasonable, since there is not enough data to fit each model separately.
Hierarchical Model
Now we use a hierarchical Bayesian model, which has a common Gaussian prior for all the weights, but allows each task to have its own task-specific parameters. More precisely, let be the weight for node to node in layer in task . We assume
or, in non-centered form,
In the figure below, we illustrate this prior, using an MLP with D inputs, 2 hidden layers (of size and ), and a scalar output (representing the logit).
Results
We see that the train and test accuracy are higher, and the decision boundaries all have the shared "Z" shape, as desired.