Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/17/randomized_priors.ipynb
1193 views
Kernel: Python 3 (ipykernel)

Open in Colab

Deep Ensembles with Randomized Prior Functions

This notebook illustrates a way to improve predictive uncertainty using deep ensembles. It is based on this paper:

I. Osband, J. Aslanides, and A. Cassirer, “Randomized prior functions for deep reinforcement learning,” in NIPS, Jun. 2018 [Online]. Available: https://proceedings.neurips.cc/paper/2018/file/5a7b238ba0f6502e5d6be14424b20ded-Paper.pdf.

The original Tensorflow demo is from https://www.kaggle.com/code/gdmarmerola/introduction-to-randomized-prior-functions/notebook

This JAX translation is by Peter G. Chang (@peterchang0414)

0. Imports

%matplotlib inline from functools import partial import matplotlib.pyplot as plt try: import optax except ModuleNotFoundError: %pip install -qq optax import optax import jax import jax.numpy as jnp from jax import jit from jax.nn.initializers import glorot_normal try: import probml_utils as pml from probml_utils import latexify, savefig, is_latexify_enabled except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git import probml_utils as pml from probml_utils import savefig, latexify, is_latexify_enabled try: import flax except ModuleNotFoundError: %pip install -qq flax import flax import flax.linen as nn from flax.training import train_state try: import seaborn as sns except: %pip install seaborn import seaborn as sns
|████████████████████████████████| 140 kB 29.0 MB/s |████████████████████████████████| 72 kB 605 kB/s Installing build dependencies ... done Getting requirements to build wheel ... done Preparing wheel metadata ... done |████████████████████████████████| 125 kB 16.6 MB/s |████████████████████████████████| 88 kB 9.0 MB/s |████████████████████████████████| 272 kB 51.4 MB/s |████████████████████████████████| 1.1 MB 56.5 MB/s Building wheel for probml-utils (PEP 517) ... done Building wheel for TexSoup (setup.py) ... done Building wheel for umap-learn (setup.py) ... done Building wheel for pynndescent (setup.py) ... done |████████████████████████████████| 197 kB 31.3 MB/s |████████████████████████████████| 596 kB 48.8 MB/s |████████████████████████████████| 217 kB 58.3 MB/s |████████████████████████████████| 51 kB 7.2 MB/s

1. Dataset

We will use a 1d synthetic regression dataset from this paper

C. Blundell, J. Cornebise, K. Kavukcuoglu, and D. Wierstra, “Weight Uncertainty in Neural Networks,” in ICML, May 2015 [Online]. Available: http://arxiv.org/abs/1505.05424

y=x+0.3sin(2π(x+ϵ))+0.3sin(4π(x+ϵ))+ϵϵN(0,0.02)xU(0.0,0.5)\begin{align} y &= x + 0.3 \sin(2 \pi(x + \epsilon)) + 0.3 \sin(4 \pi(x + \epsilon)) + \epsilon \\ \epsilon &\sim \mathcal{N}(0, 0.02) \\ x &\sim \mathcal{U}(0.0, 0.5) \end{align}
# Generate dataset and grid key, subkey = jax.random.split(jax.random.PRNGKey(0)) X = jax.random.uniform(key, shape=(100, 1), minval=0.0, maxval=0.5) x_grid = jnp.linspace(-5, 5, 1000).reshape(-1, 1) # Define function def target_toy(key, x): epsilons = jax.random.normal(key, shape=(3,)) * 0.02 return ( x + 0.3 * jnp.sin(2 * jnp.pi * (x + epsilons[0])) + 0.3 * jnp.sin(4 * jnp.pi * (x + epsilons[1])) + epsilons[2] ) # Define vectorized version of function target_vmap = jax.vmap(target_toy, in_axes=(0, 0), out_axes=0) # Generate target values keys = jax.random.split(subkey, X.shape[0]) Y = target_vmap(keys, X)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
# Plot the generated data plt.figure() # figsize=[12,6], dpi=200) plt.plot(X, Y, "kx", label="Toy data", alpha=0.8) # plt.title('Simple 1D example with toy data by Blundell et. al (2015)') plt.xlabel("$x$") plt.ylabel("$y$") plt.xlim(-0.5, 1.0) plt.ylim(-0.8, 1.6) plt.legend() plt.show()
Image in a Jupyter notebook

#2. Randomized Prior Functions

The core idea is to represent each ensemble member gig_i by gi(x;θi)=ti(x;θi)+βpi(x)g_i(x; \theta_i) = t_i(x; \theta_i) + \beta p_i(x), where tit_i is a trainable network, and pip_i is a fixed, but random, prior network.

# Prior and trainable networks have the same architecture class GenericNet(nn.Module): @nn.compact def __call__(self, x): dense = partial(nn.Dense, kernel_init=glorot_normal()) return nn.Sequential([dense(16), nn.elu, dense(16), nn.elu, dense(1)])(x) # Model that combines prior and trainable nets class Model(nn.Module): prior: GenericNet = GenericNet() trainable: GenericNet = GenericNet() beta: float = 3 @nn.compact def __call__(self, x): x1 = self.prior(x) x2 = self.trainable(x) return self.beta * x1 + x2
def create_train_state(key, X, lr): key, _ = jax.random.split(key) model = Model() params = model.init(key, X)["params"] p_params, t_params = params["prior"], params["trainable"] opt = optax.adam(learning_rate=lr) return p_params, train_state.TrainState.create(apply_fn=model.apply, params=t_params, tx=opt) @jax.jit def train_epoch(beta, prior, train_state, X, Y): def loss_fn(params): model = Model(beta=beta) Yhat = model.apply({"params": {"prior": prior, "trainable": params}}, X) loss = jnp.mean((Yhat - Y) ** 2) return loss Yhats, grads = jax.value_and_grad(loss_fn)(train_state.params) train_state = train_state.apply_gradients(grads=grads) return train_state def train(key, beta, lr, epochs, X, Y): p_params, train_state = create_train_state(key, X, lr) for epoch in range(1, epochs + 1): train_state = train_epoch(beta, p_params, train_state, X, Y) return {"prior": p_params, "trainable": train_state.params} # Prediction function to be resued in Part 3 @jax.jit def get_predictions(beta, params, X): p_param, t_param = params["prior"], params["trainable"] generic = GenericNet() model = Model(beta=beta) Y_prior = generic.apply({"params": p_param}, X) Y_trainable = generic.apply({"params": t_param}, X) Y_model = model.apply({"params": params}, X) return Y_prior, Y_trainable, Y_model
beta = 3 epochs = 2000 learning_rate = 0.03 # Train the model and get predictions for x_grid params = train(jax.random.PRNGKey(0), beta, learning_rate, epochs, X, Y) Y_prior, Y_trainable, Y_model = get_predictions(beta, params, x_grid)
# Plot the results plt.figure() # figsize=[12,6], dpi=200) plt.plot(X, Y, "kx", label="Toy data", alpha=0.8) plt.plot(x_grid, 3 * Y_prior, label="prior net (p)") plt.plot(x_grid, Y_trainable, label="trainable net (t)") plt.plot(x_grid, Y_model, label="resultant (g)") # plt.title('Predictions of the prior network: random function') plt.xlabel("$x$") plt.ylabel("$y$") plt.xlim(-0.5, 1.0) plt.ylim(-0.6, 1.4) plt.legend() # plt.savefig("randomized_priors_single_model.pdf") # plt.savefig("randomized_priors_single_model.png") plt.show()
Image in a Jupyter notebook

3. Bootstrapped Ensembles

To implement bootstrapping using JAX, we generate a random map from seed values to dataset index values: {seed,seed+1,seed+99}{0,1,,99}\{\text{seed}, \text{seed}+1 \dots, \text{seed}+99\} \to \{ 0, 1, \dots, 99 \} by utilizing jax.random.randint using a randomly-generated seed. We assume the random key space is large enough that we need not be concerned with generating overlapping seed ranges.

# Generate bootstrap of given size def generate_bootstrap(key, size): seed, _ = jax.random.split(key) return [jax.random.randint(seed + i, (), 0, size) for i in range(size)] # An ensemble model with randomized prior nets def build_ensemble(n_estimators, beta, lr, epochs, X, Y, bootstrap): train_keys = jax.random.split(jax.random.PRNGKey(0), n_estimators) # Stack the (bootstrapped) training sets for the ensemble models if bootstrap: bootstraps = jnp.stack( jnp.array([generate_bootstrap(jax.random.PRNGKey(42 * i), X.shape[0]) for i in range(1, n_estimators + 1)]) ) X_b = jnp.expand_dims(jax.vmap(jnp.take, in_axes=(None, 0))(X, bootstraps), 2) Y_b = jnp.expand_dims(jax.vmap(jnp.take, in_axes=(None, 0))(Y, bootstraps), 2) else: X_b = jnp.tile(X, (n_estimators, 1, 1)) Y_b = jnp.tile(Y, (n_estimators, 1, 1)) # Train each ensemble model on its corresponding training set ensemble_train = jax.vmap(train, in_axes=(0, None, None, None, 0, 0), out_axes={"prior": 0, "trainable": 0}) return ensemble_train(train_keys, beta, lr, epochs, X_b, Y_b) # Array of predictions for each model in trained ensemble def get_ensemble_predictions(n_estimators, beta, lr, epochs, X, Y, X_new, bootstrap=True): ensemble = build_ensemble(n_estimators, beta, lr, epochs, X, Y, bootstrap) result = jax.vmap(get_predictions, in_axes=(None, 0, None))(beta, ensemble, X_new) return result
# Compute prediction values for each net in ensemble n_estimators = 9 beta = 3 learning_rate = 0.03 epochs = 3000 p_grid, t_grid, y_grid = get_ensemble_predictions(n_estimators, beta, learning_rate, epochs, X, Y, x_grid)
latexify(width_scale_factor=1.8, fig_height=3)
/usr/local/lib/python3.7/dist-packages/probml_utils/plotting.py:26: UserWarning: LATEXIFY environment variable not set, not latexifying warnings.warn("LATEXIFY environment variable not set, not latexifying")
# Plot the results # plt.figure(figsize=[12, 12], dpi=200) fig, axs = plt.subplots(nrows=4, ncols=2, sharex=True, sharey=True) # axs = axs.flatten() for i in range(8): if i % 2 == 0: k = i // 2 j = 0 else: k = i // 2 j = 1 # pl.subplot(4, 2,i+1) axs[k, j].text(-0.25, 1, f"model " + r"$\#$" + f"{i+1}") if k == 3: # at = jnp.array([-0.5,0,0.5,1]) # labels= [f"{label:.0f}" for label in at] # axs[k,j].set_xlim(0.5,1) # axs[k,j].set_xticks(at) axs[k, j].set_xticklabels([-0.5, 0, 0.5, 1]) axs[k, j].set_xlim([-0.5, 1]) else: # axs[k,j].set_xlim(0.5,1) axs[k, j].set_xticks([-0.5, 0, 0.5, 1]) axs[k, j].set_xlim([-0.5, 1]) axs[k, j].set_ylim(-0.6, 1.4) axs[k, j].plot(X, Y, "kx", label="Toy data", alpha=0.8, markersize=1) axs[k, j].plot(x_grid, p_grid[i, :, 0], label="prior net (p)") axs[k, j].plot(x_grid, t_grid[i, :, 0], label="trainable net (t)") axs[k, j].plot(x_grid, y_grid[i, :, 0], label="resultant (g)") # ax.set_title("Ensemble: Model \#{}".format(i + 1)) # plt.xlabel('$x$'); plt.ylabel('$y$') # plt.axis('off') # plt.xticks([]) # plt.yticks([]) # ax.legend(frameon = False) sns.despine() if is_latexify_enabled(): plt.subplots_adjust(wspace=3) # plt.tight_layout() pml.savefig("randomized_priors_multi_model") # plt.savefig("randomized_priors_multi_model.png") plt.show
/usr/local/lib/python3.7/dist-packages/probml_utils/plotting.py:80: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
<function matplotlib.pyplot.show>
Image in a Jupyter notebook

4. Effects of Changing Beta

Let us go beyond the original Kaggle notebook and inspect that relationship between the weight of the prior, β\beta and the variance among the predictions of the ensembled models.

Intuitively, since the random priors are not trained, the variance should increase with β\beta. Let us verify this visually.

# Choose a diverse selection of beta values betas = jnp.array([0.001, 5, 50, 100]) n_estimates, lr, epochs = 9, 0.03, 3000 # Get ensemble predictions for the corresponding beta values pred_beta_batches = jax.vmap( get_ensemble_predictions, in_axes=( None, 0, None, None, None, None, None, ), ) preds = pred_beta_batches(n_estimates, betas, lr, epochs, X, Y, x_grid) means = jax.vmap(lambda x: x.mean(axis=0)[:, 0])(preds[2]) stds = jax.vmap(lambda x: x.std(axis=0)[:, 0])(preds[2])
# Plot mean and std for each beta # fig = plt.figure(figsize=[8, len(betas) * 3], dpi=150) fig, axs = plt.subplots(nrows=4, ncols=2, sharex=True, sharey=True) for i, beta in enumerate(betas): # Plot predictive mean and std (left graph) # plt.subplot(len(betas), 2, 2 * i + 1) axs[i, 0].plot(X, Y, "kx", label="Toy data", markersize=1) # plt.title(f'Mean and Deviation for beta={beta}', fontsize=12) if beta < 1: axs[i, 0].text(-0.2, 1.2, r"$\beta$" + f"={beta:.3f}") # axs[i,0].set_title(f"beta={beta:.3f}") else: axs[i, 0].text(-0.2, 1.2, r"$\beta$=" + f"{int(beta)}") # axs[i,0].set_title(f"beta={int(beta)}") axs[i, 0].set_xlim(-0.5, 1) axs[i, 0].set_ylim(-2, 2) # plt.legend() axs[i, 0].plot(x_grid, means[i], "r--", linewidth=1) axs[i, 0].fill_between(x_grid.reshape(1, -1)[0], means[i] - stds[i], means[i] + stds[i], alpha=0.5, color="red") axs[i, 0].fill_between( x_grid.reshape(1, -1)[0], means[i] + 2 * stds[i], means[i] - 2 * stds[i], alpha=0.2, color="red" ) # Plot means of each net in ensemble (right graph) # plt.subplot(len(betas), 2, 2 * i + 2) axs[i, 1].plot(X, Y, "kx", label="Toy data", markersize=1) # plt.title(f'Samples for beta={beta}', fontsize=12) if beta < 1: axs[i, 1].text(-0.2, 1.2, r"$\beta$" + f"={beta:.3f}") # axs[i,1].set_title(f"beta={beta:.3f}") else: axs[i, 1].text(-0.2, 1.2, r"$\beta$=" + f"{int(beta)}") # axs[i,1].set_title(f"beta={int(beta)}") if is_latexify_enabled(): axs[i, 1].set_xticklabels([-0.5, 0, 0.5, 1]) axs[i, 1].set_xlim(-0.5, 1) axs[i, 1].set_ylim(-1.5, 2) sns.despine() # plt.legend() for j in range(n_estimators): axs[i, 1].plot(x_grid, preds[2][i][j, :, 0], linestyle="--", linewidth=1) plt.tight_layout() if is_latexify_enabled(): plt.subplots_adjust(wspace=5.5) pml.savefig("randomized_priors_changing_beta") # plt.savefig("randomized_priors_changing_beta.png") plt.show()
/usr/local/lib/python3.7/dist-packages/probml_utils/plotting.py:80: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebook

5. Effects of Prior and Bootstrapping

Let us construct and compare the following four models:

  1. Ensemble of nets with prior, with bootstrap (original model)

  2. Ensemble of nets with prior, without bootstrap

  3. Ensemble of nets without prior, with bootstrap

  4. Ensemble of nets without prior, without bootstrap

Note that our previous constructions allow easy extensions into the three other model types. For nets without prior, we simply set beta=0 (or, alternatively, useGenericNet) and for nets without bootstrap, we use the method get_ensemble_predictions(..., bootstrap=False).

n, beta, lr, epochs = 9, 50, 0.03, 3000 # With prior, with bootstrap (original) *_, y_grid_1 = get_ensemble_predictions(n, beta, lr, epochs, X, Y, x_grid) means_1, stds_1 = y_grid_1.mean(axis=0)[:, 0], y_grid_1.std(axis=0)[:, 0] # With prior, without bootstrap *_, y_grid_2 = get_ensemble_predictions(n, beta, lr, epochs, X, Y, x_grid, False) means_2, stds_2 = y_grid_2.mean(axis=0)[:, 0], y_grid_2.std(axis=0)[:, 0] # Without prior, with bootstrap *_, y_grid_3 = get_ensemble_predictions(n, 0, lr, epochs, X, Y, x_grid) means_3, stds_3 = y_grid_3.mean(axis=0)[:, 0], y_grid_3.std(axis=0)[:, 0] # Without prior, without bootstrap *_, y_grid_4 = get_ensemble_predictions(n, 0, lr, epochs, X, Y, x_grid, False) means_4, stds_4 = y_grid_4.mean(axis=0)[:, 0], y_grid_4.std(axis=0)[:, 0]
# Plot the four types of models fig = plt.figure(figsize=[12, 9], dpi=150) # fig.suptitle('Bootstrapping and priors: impact of model components on result', # verticalalignment='center') # With prior, with bootstrap plt.subplot(2, 2, 1) plt.plot(X, Y, "kx", label="Toy data") plt.title("Full model with priors and bootstrap", fontsize=12) plt.xlim(-0.5, 1.0) plt.ylim(-1.5, 1.5) plt.legend() plt.plot(x_grid, means_1, "r--", linewidth=1.5) plt.fill_between(x_grid.reshape(1, -1)[0], means_1 - stds_1, means_1 + stds_1, alpha=0.5, color="red") plt.fill_between(x_grid.reshape(1, -1)[0], means_1 + 2 * stds_1, means_1 - 2 * stds_1, alpha=0.2, color="red") # With prior, without bootstrap plt.subplot(2, 2, 2) plt.plot(X, Y, "kx", label="Toy data") plt.title("No bootrapping, but use of priors", fontsize=12) plt.xlim(-0.5, 1.0) plt.ylim(-1.5, 1.5) plt.legend() plt.plot(x_grid, means_2, "r--", linewidth=1.5) plt.fill_between(x_grid.reshape(1, -1)[0], means_2 - stds_2, means_2 + stds_2, alpha=0.5, color="red") plt.fill_between(x_grid.reshape(1, -1)[0], means_2 + 2 * stds_2, means_2 - 2 * stds_2, alpha=0.2, color="red") # Without prior, with bootstrap plt.subplot(2, 2, 3) plt.plot(X, Y, "kx", label="Toy data") plt.title("No priors, but use of bootstrapping", fontsize=12) plt.xlim(-0.5, 1.0) plt.ylim(-1.5, 1.5) plt.legend() plt.plot(x_grid, means_3, "r--", linewidth=1.5) plt.fill_between(x_grid.reshape(1, -1)[0], means_3 - stds_3, means_3 + stds_3, alpha=0.5, color="red") plt.fill_between(x_grid.reshape(1, -1)[0], means_3 + 2 * stds_3, means_3 - 2 * stds_3, alpha=0.2, color="red") # Without prior, without bootstrap plt.subplot(2, 2, 4) plt.plot(X, Y, "kx", label="Toy data") plt.title("Both bootstrapping and priors turned off", fontsize=12) plt.xlim(-0.5, 1.0) plt.ylim(-1.5, 1.5) plt.legend() plt.plot(x_grid, means_4, "r--", linewidth=1.5) plt.fill_between(x_grid.reshape(1, -1)[0], means_4 - stds_4, means_4 + stds_4, alpha=0.5, color="red") plt.fill_between(x_grid.reshape(1, -1)[0], means_4 + 2 * stds_4, means_4 - 2 * stds_4, alpha=0.2, color="red") plt.tight_layout() # plt.savefig("randomized_priors_components.pdf") # plt.savefig("randomized_priors_components.png") plt.show();
Image in a Jupyter notebook