Path: blob/master/deprecated/notebooks/randomized_priors.ipynb
1192 views
Deep Ensembles with Randomized Prior Functions
This notebook illustrates a way to improve predictive uncertainty using deep ensembles. It is based on this paper:
I. Osband, J. Aslanides, and A. Cassirer, “Randomized prior functions for deep reinforcement learning,” in NIPS, Jun. 2018 [Online]. Available: https://proceedings.neurips.cc/paper/2018/file/5a7b238ba0f6502e5d6be14424b20ded-Paper.pdf.
The original Tensorflow demo is from https://www.kaggle.com/code/gdmarmerola/introduction-to-randomized-prior-functions/notebook
This JAX translation is by Peter G. Chang (@peterchang0414)
0. Imports
|████████████████████████████████| 184 kB 11.7 MB/s
|████████████████████████████████| 136 kB 46.6 MB/s
|████████████████████████████████| 72 kB 485 kB/s
1. Dataset
We will use a 1d synthetic regression dataset from this paper
C. Blundell, J. Cornebise, K. Kavukcuoglu, and D. Wierstra, “Weight Uncertainty in Neural Networks,” in ICML, May 2015 [Online]. Available: http://arxiv.org/abs/1505.05424
#2. Randomized Prior Functions
The core idea is to represent each ensemble member by , where is a trainable network, and is a fixed, but random, prior network.
3. Bootstrapped Ensembles
To implement bootstrapping using JAX, we generate a random map from seed values to dataset index values: by utilizing jax.random.randint
using a randomly-generated seed. We assume the random key space is large enough that we need not be concerned with generating overlapping seed ranges.
4. Effects of Changing Beta
Let us go beyond the original Kaggle notebook and inspect that relationship between the weight of the prior, and the variance among the predictions of the ensembled models.
Intuitively, since the random priors are not trained, the variance should increase with . Let us verify this visually.
5. Effects of Prior and Bootstrapping
Let us construct and compare the following four models:
Ensemble of nets with prior, with bootstrap (original model)
Ensemble of nets with prior, without bootstrap
Ensemble of nets without prior, with bootstrap
Ensemble of nets without prior, without bootstrap
Note that our previous constructions allow easy extensions into the three other model types. For nets without prior, we simply set beta=0
(or, alternatively, useGenericNet
) and for nets without bootstrap, we use the method get_ensemble_predictions(..., bootstrap=False)
.