Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/29/gp_mauna_loa.ipynb
1193 views
Kernel: Python 3.10.4 ('PyroNB')

Gaussian process time series forecasting for Mauna Loa CO2

In the following, we'll reproduce the analysis for Figure 5.6 in Chapter 5 of Rasmussen & Williams (R&W).

Code is from https://tinygp.readthedocs.io/en/latest/tutorials/quickstart.html

try: import tinygp except ImportError: %pip install -q tinygp import tinygp try: from statsmodels.datasets import co2 except ModuleNotFoundError: %pip install -qq statsmodels from statsmodels.datasets import co2 import jax import jax.numpy as jnp from tinygp import kernels, transforms, GaussianProcess from jax.config import config from scipy.optimize import minimize import matplotlib.pyplot as plt import seaborn as sns try: from probml_utils import latexify, savefig, is_latexify_enabled except ModuleNotFoundError: %pip install git+https://github.com/probml/probml-utils.git from probml_utils import latexify, savefig, is_latexify_enabled latexify(width_scale_factor=2, fig_height=2) config.update("jax_enable_x64", True) marksize = 6 if is_latexify_enabled() else 50

Data

The data are measurements of the atmospheric CO2 concentration made at Mauna Loa, Hawaii (Keeling & Whorf 2004). Data can be found at http://scrippsco2.ucsd.edu/data/atmospheric_co2/primary_mlo_co2_record. We use the [statsmodels version](http://statsmodels.sourceforge.net/devel/datasets/generated/co2.html].

data = co2.load_pandas().data t = 2000 + (jnp.array(data.index.to_julian_date()) - 2451545.0) / 365.25 y = jnp.array(data.co2) m = jnp.isfinite(t) & jnp.isfinite(y) & (t < 1996) t, y = t[m][::4], y[m][::4] plt.figure() plt.scatter(t, y, s=marksize, c="k", marker=".", label="Data") sns.despine() plt.xlabel("year") plt.ylabel("CO$_2$ in ppm") plt.legend(frameon=False) savefig("gp-mauna-loa-data")

Kernel

In this figure, you can see that there is periodic (or quasi-periodic) signal with a year-long period superimposed on a long term trend. We will follow R&W and model these effects non-parametrically using a complicated covariance function. The covariance function that we’ll use is:

k(r)=k1(r)+k2(r)+k3(r)+k4(r)k(r) = k_1(r) + k_2(r) + k_3(r) + k_4(r)

where

k1(r)=θ02exp(r22θ12)k2(r)=θ22exp(r22θ32θ5sin2(πrθ4))k3(r)=θ62[1+r22θ72θ8]θ8k4(r)=θ92exp(r22θ102)+θ112δij\begin{darray}{rcl} k_1(r) &=& \theta_0^2 \, \exp \left(-\frac{r^2}{2\,\theta_1^2} \right) \\ k_2(r) &=& \theta_2^2 \, \exp \left(-\frac{r^2}{2\,\theta_3^2} -\theta_5\,\sin^2\left( \frac{\pi\,r}{\theta_4}\right) \right) \\ k_3(r) &=& \theta_6^2 \, \left [ 1 + \frac{r^2}{2\,\theta_7^2\,\theta_8} \right ]^{-\theta_8} \\ k_4(r) &=& \theta_{9}^2 \, \exp \left(-\frac{r^2}{2\,\theta_{10}^2} \right) + \theta_{11}^2\,\delta_{ij} \end{darray}

We can implement this kernel in tinygp as follows (we'll use the R&W results as the hyperparameters for now):

def build_gp(theta, X): mean = theta[-1] # We want most of out parameters to be positive so we take the `exp` here theta = jnp.exp(theta[:-1]) # Construct the kernel by multiplying and adding `Kernel` objects k1 = theta[0] ** 2 * kernels.ExpSquared(theta[1]) k2 = theta[2] ** 2 * kernels.ExpSquared(theta[3]) * kernels.ExpSineSquared(scale=theta[4], gamma=theta[5]) k3 = theta[6] ** 2 * kernels.RationalQuadratic(alpha=theta[7], scale=theta[8]) k4 = theta[9] ** 2 * kernels.ExpSquared(theta[10]) kernel = k1 + k2 + k3 + k4 return GaussianProcess(kernel, X, diag=theta[11] ** 2, mean=mean) def neg_log_likelihood(theta, X, y): gp = build_gp(theta, X) return -gp.log_probability(y)

Model fitting

# Objective obj = jax.jit(jax.value_and_grad(neg_log_likelihood)) # These are the parameters from R&W mean_output = 340.0 theta_init = jnp.append( jnp.log(jnp.array([66.0, 67.0, 2.4, 90.0, 1.0, 4.3, 0.66, 1.2, 0.78, 0.18, 1.6, 0.19])), mean_output, )
obj(theta_init, t, y)

Using our loss function defined above, we'll run a gradient based optimization routine from scipy (you could also use a jax-specific optimizer, but that's not necessary) to fit this model as follows:

soln = minimize(obj, theta_init, jac=True, args=(t, y)) print(f"Final negative log likelihood: {soln.fun}")

Warning: An optimization code something like this should work on most problems but the results can be very sensitive to your choice of initialization and algorithm. If the results are nonsense, try choosing a better initial guess or try a different value of the method parameter in op.minimize.

Plot results

x = jnp.linspace(max(t), 2025, 2000) gp = build_gp(soln.x, t) gp_condition = gp.condition(y, x).gp mu, var = gp_condition.loc, gp_condition.variance plt.figure() plt.scatter(t, y, s=marksize, c="k", marker=".", label="Data") plt.plot(x, mu, color="C0", label="Mean") plt.fill_between(x, mu + jnp.sqrt(var), mu - jnp.sqrt(var), color="C0", alpha=0.5, label="Confidence") sns.despine() plt.xlabel("year") plt.ylabel("CO$_2$ in ppm") plt.legend(prop={"size": 5}, frameon=False) savefig("gp-mauna-loa-pred")