Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/18/gp_deep_kernel_learning.ipynb
1193 views
Kernel: Python 3 (ipykernel)
%%capture import os try: from tinygp import kernels, transforms, GaussianProcess except ModuleNotFoundError: %pip install -qq tinygp from tinygp import kernels, transforms, GaussianProcess try: import flax.linen as nn except ModuleNotFoundError: %pip install -qq flax import flax.linen as nn from flax.linen.initializers import zeros try: import optax except ModuleNotFoundError: %pip install -qq optax import optax try: import probml_utils as pml except ModuleNotFoundError: %pip install git+https://github.com/probml/probml-utils.git import probml_utils as pml import numpy as np import matplotlib.pyplot as plt import seaborn as sns import jax import jax.numpy as jnp from jax.config import config config.update("jax_enable_x64", True)

Data

pml.latexify(width_scale_factor=2) markersize = 3 if pml.is_latexify_enabled() else 6
/home/patel_zeel/miniconda3/envs/pyprobml/lib/python3.7/site-packages/probml_utils/plotting.py:26: UserWarning: LATEXIFY environment variable not set, not latexifying warnings.warn("LATEXIFY environment variable not set, not latexifying")
random = np.random.default_rng(567) noise = 0.1 x = np.sort(random.uniform(-1, 1, 100)) def true_fn(x): return 2 * (x > 0) - 1 y = true_fn(x) + random.normal(0.0, noise, len(x)) t = np.linspace(-1.5, 1.5, 500) plt.plot(t, true_fn(t), "k", lw=1, label="truth") plt.plot(x, y, ".k", markersize=markersize, label="data") plt.xlim(-1.5, 1.5) plt.ylim(-1.3, 1.3) plt.xlabel("$x$") plt.ylabel("$y$") _ = plt.legend() sns.despine() pml.savefig("gp-dkl-data")
/home/patel_zeel/miniconda3/envs/pyprobml/lib/python3.7/site-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebook

Deep kernel

We transform the (1d) input using an MLP and then pass it to a Matern kernel.

# Define a small neural network used to non-linearly transform the input data in our model class Transformer(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(features=15)(x) x = nn.relu(x) x = nn.Dense(features=10)(x) x = nn.relu(x) x = nn.Dense(features=1)(x) return x class GPdeep(nn.Module): @nn.compact def __call__(self, x, y, t): # Set up a typical Matern-3/2 kernel log_sigma = self.param("log_sigma", zeros, ()) log_rho = self.param("log_rho", zeros, ()) log_jitter = self.param("log_jitter", zeros, ()) base_kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(jnp.exp(log_rho)) # Define a custom transform to pass the input coordinates through our `Transformer` # network from above transform = Transformer() kernel = transforms.Transform(transform, base_kernel) # Evaluate and return the GP negative log likelihood as usual gp = GaussianProcess(kernel, x[:, None], diag=jnp.exp(2 * log_jitter)) pred_gp = gp.condition(y, t[:, None]).gp return -gp.log_probability(y), (pred_gp.loc, pred_gp.variance)
# Define and train the model def loss(params): return model.apply(params, x, y, t)[0] model = GPdeep() params = model.init(jax.random.PRNGKey(1234), x, y, t) tx = optax.sgd(learning_rate=1e-4) opt_state = tx.init(params) loss_grad_fn = jax.jit(jax.value_and_grad(loss)) for i in range(1000): loss_val, grads = loss_grad_fn(params) updates, opt_state = tx.update(grads, opt_state) params = optax.apply_updates(params, updates)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
# Plot the results and compare to the true model plt.figure() mu, var = model.apply(params, x, y, t)[1] plt.plot(t, true_fn(t), "k", lw=1, label="truth") plt.plot(x, y, ".k", markersize=markersize, label="data") plt.plot(t, mu) plt.fill_between(t, mu + np.sqrt(var), mu - np.sqrt(var), alpha=0.5, label="model") plt.xlim(-1.5, 1.5) plt.ylim(-1.3, 1.3) plt.xlabel("$x$") plt.ylabel("$y$") _ = plt.legend() sns.despine() pml.savefig("gp-dkl-deep")
/home/patel_zeel/miniconda3/envs/pyprobml/lib/python3.7/site-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebook

Shallow kernel

class GPshallow(nn.Module): @nn.compact def __call__(self, x, y, t): # Set up a typical Matern-3/2 kernel log_sigma = self.param("log_sigma", zeros, ()) log_rho = self.param("log_rho", zeros, ()) log_jitter = self.param("log_jitter", zeros, ()) base_kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(jnp.exp(log_rho)) # Evaluate and return the GP negative log likelihood as usual gp = GaussianProcess(base_kernel, x[:, None], diag=jnp.exp(2 * log_jitter)) pred_gp = gp.condition(y, t[:, None]).gp return -gp.log_probability(y), (pred_gp.loc, pred_gp.variance)
model = GPshallow() params = model.init(jax.random.PRNGKey(1234), x, y, t) tx = optax.sgd(learning_rate=1e-4) opt_state = tx.init(params) loss_grad_fn = jax.jit(jax.value_and_grad(loss)) for i in range(1000): loss_val, grads = loss_grad_fn(params) updates, opt_state = tx.update(grads, opt_state) params = optax.apply_updates(params, updates)
# Plot the results and compare to the true model plt.figure() mu, var = model.apply(params, x, y, t)[1] plt.plot(t, true_fn(t), "k", lw=1, label="truth") plt.plot(x, y, ".k", markersize=markersize, label="data") plt.plot(t, mu) plt.fill_between(t, mu + np.sqrt(var), mu - np.sqrt(var), alpha=0.5, label="model") plt.xlim(-1.5, 1.5) plt.ylim(-1.3, 1.3) plt.xlabel("$x$") plt.ylabel("$y$") _ = plt.legend() sns.despine() pml.savefig("gp-dkl-shallow")
/home/patel_zeel/miniconda3/envs/pyprobml/lib/python3.7/site-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebook