Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/18/combining_kernels_by_multiplication.ipynb
1192 views
Kernel: Python 3 (ipykernel)
import jax import jax.numpy as jnp import seaborn as sns import matplotlib.pyplot as plt try: from probml_utils import latexify, savefig, is_latexify_enabled except ModuleNotFoundError: %pip install git+https://github.com/probml/probml-utils.git from probml_utils import latexify, savefig, is_latexify_enabled try: import tinygp except ModuleNotFoundError: %pip install -qqq tinygp import tinygp kernels = tinygp.kernels from tinygp import GaussianProcess
latexify(width_scale_factor=4, fig_height=2)
/home/patel_zeel/miniconda3/lib/python3.9/site-packages/probml_utils/plotting.py:26: UserWarning: LATEXIFY environment variable not set, not latexifying warnings.warn("LATEXIFY environment variable not set, not latexifying")
def make_graph(data, save_name): if is_latexify_enabled(): fig, ax = plt.subplots(2, 1) else: fig, ax = plt.subplots(2, 1, figsize=(6.4, 6)) # Plot kernel kernel = data["kernel1"] * data["kernel2"] x2 = jnp.array([1.0]).reshape(-1, 1) kernel_values = kernel(x, x2) ax[0].plot(x.ravel(), kernel_values.ravel(), color="k") # Plot samples gp = GaussianProcess(kernel, x) samples = gp.sample(key, (2,)) for sample in samples: ax[1].plot(x, sample) ax[0].set_title(data["title"]) ax[1].set_xlabel(data["xlabel"]) for axes in ax: axes.set_xticks([]) ax[0].set_xlabel("$x$ (with $x'=1$)") plt.tight_layout() sns.despine() if len(save_name) > 0: savefig(save_name) return fig, ax x = jnp.arange(-3.0, 5.1, 0.1).reshape(-1, 1) N = len(x) key = jax.random.PRNGKey(4) fig, ax = make_graph( { "kernel1": kernels.Polynomial(order=1), "kernel2": kernels.Polynomial(order=1), "title": "Lin x Lin", "xlabel": "quadratic functions", }, save_name="kernel_mul_lin_lin.pdf", ) fig, ax = make_graph( { "kernel1": kernels.ExpSquared(scale=4.0), "kernel2": kernels.ExpSineSquared(scale=2.0, gamma=0.5), "title": "SE x Per", "xlabel": "locally periodic", }, save_name="kernel_mul_se_per.pdf", ) fig, ax = make_graph( { "kernel1": kernels.Polynomial(order=1), "kernel2": kernels.ExpSquared(scale=1.0), "title": "Lin x SE", "xlabel": "increasing variation", }, save_name="kernel_mu_lin_se.pdf", ) fig, ax = make_graph( { "kernel1": kernels.Polynomial(order=1), "kernel2": kernels.ExpSineSquared(scale=2.0, gamma=1.0), "title": "Lin x Per", "xlabel": "growing amplitude", }, save_name="kernel_mul_lin_per.pdf", )
/home/patel_zeel/miniconda3/lib/python3.9/site-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures") /home/patel_zeel/miniconda3/lib/python3.9/site-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures") /home/patel_zeel/miniconda3/lib/python3.9/site-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures") /home/patel_zeel/miniconda3/lib/python3.9/site-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebookImage in a Jupyter notebookImage in a Jupyter notebookImage in a Jupyter notebook