Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/18/gpr_demo_marglik.ipynb
1193 views
Kernel: Python 3.10.4 ('PyroNB')

Illustration of local minima in the marginal likelihood surface

try: import tinygp except ImportError: %pip install -q tinygp import tinygp import jax.numpy as jnp from tinygp import kernels, GaussianProcess from jax.config import config from scipy.optimize import minimize import matplotlib.pyplot as plt import seaborn as sns try: from probml_utils import latexify, savefig, is_latexify_enabled except ModuleNotFoundError: %pip install git+https://github.com/probml/probml-utils.git from probml_utils import latexify, savefig, is_latexify_enabled latexify(width_scale_factor=3, fig_height=1.5) config.update("jax_enable_x64", True) marksize = 20 if is_latexify_enabled() else 100

GP negative log likelihood

def build_gp(x, length_scale, sigma_f, sigma_y): kernel = (sigma_f**2) * kernels.ExpSquared(scale=length_scale) return GaussianProcess(kernel, x, diag=sigma_y**2) def neg_log_likelihood(X, y, sigma_f, length_scale, sigma_y): gp = build_gp(X, jnp.exp(length_scale), sigma_f, jnp.exp(sigma_y)) return -gp.log_probability(y)

Plot predict

def plot_gp_pred(x, y, xtest, sigma_f, length_scale, sigma_y): gp = build_gp(x, length_scale, sigma_f, sigma_y) cond_gp = gp.condition(y, xtest).gp mu, var = cond_gp.loc, cond_gp.variance plt.scatter(x, y, s=marksize, c="k", marker="+", label="Data") plt.plot(xtest, mu, color="black", label="Mean") plt.fill_between( xtest, mu + 2 * jnp.sqrt(var), mu - 2 * jnp.sqrt(var), color="tab:gray", alpha=0.3, edgecolor="none", label="Confidence", ) sns.despine() plt.xlabel("$x$") plt.ylabel("$y$")

Plot marginal likelihood surface

def plot_marginal_likelihood_surface(x, y, sigma_f, l_space, sigma_y_space, levels=None): P = jnp.stack(jnp.meshgrid(l_space, sigma_y_space), axis=0) Z = jnp.apply_along_axis(lambda p: neg_log_likelihood(x, y, sigma_f, *p), 0, P) plt.contour(*jnp.exp(P), Z, levels) plt.xlabel("length scale (log scale)") plt.ylabel("noise std_dev \n (log scale)") plt.xscale("log") plt.yscale("log") sns.despine()

Plots

sigma_f = 1.0 x = jnp.array([-1.3089, 6.7612, 1.0553, -1.1734, -2.9339, 7.2530, -6.5843]).reshape( -1, ) y = jnp.array([1.6218, 1.8558, 0.4102, 1.2526, -0.0133, 1.6380, 0.2189]).reshape( -1, ) x_test = jnp.linspace(-7.5, 7.5, 201).reshape( -1, ) params = [(1.0, 0.2), (10, 0.8)] fig = plt.figure() plot_gp_pred(x, y, x_test, sigma_f, *params[0]) plt.ylim(-3, 6) plt.legend(prop={"size": 6}, frameon=False, loc=2) savefig("gpr_config0") fig = plt.figure() plot_gp_pred(x, y, x_test, sigma_f, *params[1]) plt.ylim(-1, 3) plt.legend(prop={"size": 6}, frameon=False, loc=2) savefig("gpr_config1")
ngrid = 41 params1 = jnp.array([jnp.log(1), jnp.log(0.1)]) params2 = jnp.array([jnp.log(10), jnp.log(0.8)]) ## Minimizing for two set of params local_minima1 = minimize((lambda p: neg_log_likelihood(x, y, sigma_f, *p)), params1) local_minima2 = minimize((lambda p: neg_log_likelihood(x, y, sigma_f, *p)), params2) levels = jnp.array([8.3, 8.5, 8.9, 9.3, 9.8, 11.5, 15]) length_space = jnp.linspace(jnp.log(0.1), jnp.log(80), ngrid) sigma_y_space = jnp.linspace(jnp.log(0.03), jnp.log(3), ngrid) fig = plt.figure() plot_marginal_likelihood_surface(x, y, sigma_f, length_space, sigma_y_space, levels=levels) plt.scatter(*jnp.exp(local_minima1.x), marker="+", s=marksize, c="red") plt.scatter(*jnp.exp(local_minima2.x), marker="+", s=marksize, c="red") savefig("gpr_marginal_likelihood") plt.show()