Path: blob/master/notebooks/book1/13/bnn_hierarchical_numpyro.ipynb
1192 views
Hierarchical Bayesian Neural Networks
Illustration of hierarchial Bayesian neural network classifiers. Code and text is based on This blog post by Thomas Wiecki. Original PyMC3 Notebook. Converted to Numpyro by Aleyna Kara (@karalleyna).
Setup
Please change your colab runtime to CPU.
Data
The data set we are using are our battle tested half-moons as it is simple, non-linear and leads to pretty visualizations. This is what it looks like:
This is just to illustrate what the data generating distribution looks like, we will use way fewer data points, and create different subsets with different rotations.
As you can see, we have 4 categories in default that share a higher-order structure (the half-moons). However, in the pure data space, no single classifier will be able to do a good job here. Also, because we only have 50 data points in each class, a NN will likely have a hard time producing robust results. But let's actually test this.
Fit separate MLPs, one per task
First we fit one MLP per task/dataset. For details, see Thomas's blob post on Bayesian Deep Learning.
OK, that doesn't seem so bad. Now let's look at the decision surfaces -- i.e. what the classifier thinks about each point in the data space.
That doens't look all that convincing. We know from the data generation process we should get a "Z"-shaped decision surface, but we don't have enough data to properly estimate the non-linearity in every category.
Hierarchical Bayesian Neural Network
It's actually quite straight-forward to turn this into one big hierarchical model for all categories, rather than many individual ones. Let's call the weight connecting neuron in layer 1 to neuron in layer 2 in category (I just omit the layer index for simplicity in notation). Rather than placing a fixed prior as we did above (i.e. ), we will assume that each weight comes from an overarching group distribution: . The key is that we will estimate and simultaneously from data.
Why not allow for different per connection you might ask? Mainly just to make our life simpler and because it works well enough.
Note that we create a very rich model here. Every individual weight has its own hierarchical structure with a single group mean parameter and 16 per-category weights distributed around the group mean. While this creates a big amount of group distributions (as many as the flat NN had weights) there is no problem with this per-se, although it might be a bit unusual. One might argue that this model is quite complex and while that's true, in terms of degrees-of-freedom, this model is simpler than the unpooled one above (more on this below).
As for the code, we stack weights along a 3rd dimenson to get separate weights for each group. That way, through the power of broadcasting, the linear algebra works out almost the same as before.
Great -- we get higher train and test accuracy. Let's look at what the classifier has learned for each category.