Path: blob/master/notebooks/book2/17/randomized_priors.ipynb
1193 views
Deep Ensembles with Randomized Prior Functions
This notebook illustrates a way to improve predictive uncertainty using deep ensembles. It is based on this paper:
I. Osband, J. Aslanides, and A. Cassirer, “Randomized prior functions for deep reinforcement learning,” in NIPS, Jun. 2018 [Online]. Available: https://proceedings.neurips.cc/paper/2018/file/5a7b238ba0f6502e5d6be14424b20ded-Paper.pdf.
The original Tensorflow demo is from https://www.kaggle.com/code/gdmarmerola/introduction-to-randomized-prior-functions/notebook
This JAX translation is by Peter G. Chang (@peterchang0414)
0. Imports
|████████████████████████████████| 140 kB 29.0 MB/s
|████████████████████████████████| 72 kB 605 kB/s
Installing build dependencies ... done
Getting requirements to build wheel ... done
Preparing wheel metadata ... done
|████████████████████████████████| 125 kB 16.6 MB/s
|████████████████████████████████| 88 kB 9.0 MB/s
|████████████████████████████████| 272 kB 51.4 MB/s
|████████████████████████████████| 1.1 MB 56.5 MB/s
Building wheel for probml-utils (PEP 517) ... done
Building wheel for TexSoup (setup.py) ... done
Building wheel for umap-learn (setup.py) ... done
Building wheel for pynndescent (setup.py) ... done
|████████████████████████████████| 197 kB 31.3 MB/s
|████████████████████████████████| 596 kB 48.8 MB/s
|████████████████████████████████| 217 kB 58.3 MB/s
|████████████████████████████████| 51 kB 7.2 MB/s
1. Dataset
We will use a 1d synthetic regression dataset from this paper
C. Blundell, J. Cornebise, K. Kavukcuoglu, and D. Wierstra, “Weight Uncertainty in Neural Networks,” in ICML, May 2015 [Online]. Available: http://arxiv.org/abs/1505.05424
#2. Randomized Prior Functions
The core idea is to represent each ensemble member by , where is a trainable network, and is a fixed, but random, prior network.
3. Bootstrapped Ensembles
To implement bootstrapping using JAX, we generate a random map from seed values to dataset index values: by utilizing jax.random.randint
using a randomly-generated seed. We assume the random key space is large enough that we need not be concerned with generating overlapping seed ranges.
4. Effects of Changing Beta
Let us go beyond the original Kaggle notebook and inspect that relationship between the weight of the prior, and the variance among the predictions of the ensembled models.
Intuitively, since the random priors are not trained, the variance should increase with . Let us verify this visually.
5. Effects of Prior and Bootstrapping
Let us construct and compare the following four models:
Ensemble of nets with prior, with bootstrap (original model)
Ensemble of nets with prior, without bootstrap
Ensemble of nets without prior, with bootstrap
Ensemble of nets without prior, without bootstrap
Note that our previous constructions allow easy extensions into the three other model types. For nets without prior, we simply set beta=0
(or, alternatively, useGenericNet
) and for nets without bootstrap, we use the method get_ensemble_predictions(..., bootstrap=False)
.