Path: blob/master/deprecated/notebooks/gp_kernel_opt.ipynb
1192 views
GP kernel parameter optimization/ inference
Slightly modified from
https://tinygp.readthedocs.io/en/latest/tutorials/modeling.html
(We simplify the code by not passing the yerr
variable to the model, and rename the input t
(for time) to x
, to make it more generic.)
Found existing installation: jax 0.2.25
Uninstalling jax-0.2.25:
Successfully uninstalled jax-0.2.25
Found existing installation: jaxlib 0.1.71+cuda111
Uninstalling jaxlib-0.1.71+cuda111:
Successfully uninstalled jaxlib-0.1.71+cuda111
|████████████████████████████████| 264 kB 932 kB/s
|████████████████████████████████| 873 kB 11.6 MB/s
|████████████████████████████████| 62.2 MB 1.2 MB/s
Building wheel for jax (setup.py) ... done
|████████████████████████████████| 207 kB 13.4 MB/s
|████████████████████████████████| 126 kB 66.8 MB/s
|████████████████████████████████| 65 kB 4.2 MB/s
Data
Optimizing hyper-parameters using flax & optax
We find the maximum (marginal) likelihood hyperparameters for the GP model.
To set up our model, we define a custom linen.Module
, and optimize it's parameters as follows:
Our Module
defined above also returns the conditional predictions, that we can compare to the true model:
Inferring hyper-parameters using HMC in numpyro
We can compute a posterior over the kernel parameters, and hence the posterior predictive over the mean function, using HMC.
Let's examine the posterior. For that task, let's use ArviZ
:
And, finally we can plot our posterior inferences of the comditional process, compared to the true model:
Inferring hyper-parameters using HMC in BlackJax
TBD