Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/18/krr_vs_gpr.ipynb
1192 views
Kernel: Python 3.7.13 ('py3713')

Kernel Ridge Regression Vs Gaussian Process

import jax import jax.numpy as jnp import seaborn as sns import matplotlib.pyplot as plt try: from probml_utils import latexify, savefig, is_latexify_enabled except ModuleNotFoundError: %pip install git+https://github.com/probml/probml-utils.git from probml_utils import latexify, savefig, is_latexify_enabled jax.config.update("jax_enable_x64", True) try: import tinygp except ModuleNotFoundError: %pip install -qqq tinygp import tinygp try: import jaxopt except ModuleNotFoundError: %pip install jaxopt import jaxopt from tinygp import GaussianProcess, kernels import time from sklearn.kernel_ridge import KernelRidge from sklearn.model_selection import GridSearchCV from sklearn.gaussian_process.kernels import ExpSineSquared latexify(width_scale_factor=1, fig_height=2) marksize = 10 if is_latexify_enabled() else 30

Data

key = jax.random.PRNGKey(0) # Generate sample data X = 15 * jax.random.uniform(key, (100, 1)) key_split = jax.random.split(key, 2) y = jnp.sin(X).ravel() y += 3 * (0.5 - jax.random.uniform(key_split[0], (X.shape[0],))) # add noise

Fitting the Models

# Fit KernelRidge with parameter selection based on 5-fold cross validation param_grid = { "alpha": [1e0, 1e-1, 1e-2, 1e-3], "kernel": [ExpSineSquared(l, p) for l in jnp.logspace(-2, 2, 10) for p in jnp.logspace(0, 2, 10)], } kr = GridSearchCV(KernelRidge(), param_grid=param_grid) stime = time.time() kr.fit(X, y) print("Time for KRR fitting: %.3f" % (time.time() - stime))
Time for KRR fitting: 21.660
# Fit GP using scipy.minimize theta_init = {"log_diag": jnp.log(1e-1), "log_scale": jnp.log(5.0), "log_gamma": jnp.log(2.0)} stime = time.time() def neg_log_likelihood(theta, X, y): kernel = kernels.ExpSineSquared(scale=jnp.exp(theta["log_scale"]), gamma=jnp.exp(theta["log_gamma"])) gp = GaussianProcess(kernel, X, diag=jnp.exp(theta["log_diag"])) return -gp.log_probability(y) obj = jax.jit(jax.value_and_grad(neg_log_likelihood)) solver = jaxopt.ScipyMinimize(fun=neg_log_likelihood) soln = solver.run(theta_init, X=X, y=y) print("Time for GPR fitting: %.3f" % (time.time() - stime))
Time for GPR fitting: 0.485

Predicting the models

# Predict using kernel ridge X_plot = jnp.linspace(0, 20, 10000)[:, None] stime = time.time() y_kr = kr.predict(X_plot) print("Time for KRR prediction: %.3f" % (time.time() - stime)) # Predict using gp.predict X_plot = X_plot.reshape( -1, ) y = y.reshape( -1, ) def build_gp(theta_, X): kernel = kernels.ExpSineSquared(scale=jnp.exp(theta_["log_scale"]), gamma=jnp.exp(theta_["log_gamma"])) gp = GaussianProcess(kernel, X, diag=jnp.exp(theta_["log_diag"])) return gp # predict without variance stime = time.time() gp = build_gp(soln.params, X) y_mu = gp.predict(y, X_plot, return_var=False) print("Time for GPR prediction: %.3f" % (time.time() - stime)) # predict with variance stime = time.time() gp = build_gp(soln.params, X) y_mu, y_var = gp.predict(y, X_plot, return_var=True) print("Time for GPR prediction with standard-deviation: %.3f" % (time.time() - stime))
Time for KRR prediction: 0.042
/home/patel_karm/anaconda3/envs/py3713/lib/python3.7/site-packages/ipykernel_launcher.py:25: DeprecationWarning: The 'predict' method is deprecated and 'condition' should be preferred
Time for GPR prediction: 5.817
/home/patel_karm/anaconda3/envs/py3713/lib/python3.7/site-packages/ipykernel_launcher.py:31: DeprecationWarning: The 'predict' method is deprecated and 'condition' should be preferred
Time for GPR prediction with standard-deviation: 5.894

Plots

# Plot results plt.figure() X_plot = X_plot.reshape(-1, 1) y_var = y_var + jnp.exp(soln.params["log_diag"]) plt.scatter(X, y, c="k", label="$data$", s=marksize) plt.plot(X_plot, jnp.sin(X_plot), color="navy", label="True") plt.plot(X_plot, y_kr, color="turquoise", label="KRR") plt.plot(X_plot, y_mu, color="darkorange", label="GPR") plt.fill_between( X_plot.flatten(), y_mu.flatten() - jnp.sqrt(y_var), y_mu.flatten() + jnp.sqrt(y_var), color="darkorange", alpha=0.2 ) plt.xlabel("data") plt.ylabel("target") plt.ylim(-4, 4) sns.despine() plt.legend(bbox_to_anchor=(0.8, 0.6), frameon=False, fontsize=8) savefig("krr_vs_gpr_latexified") plt.show()
saving image to figures/krr_vs_gpr_latexified.pdf Figure size: [6. 2.]