Path: blob/master/notebooks/book2/03/change_of_variable_hmc.ipynb
1193 views
Change of variables in HMC
This notebook reproduces the hierarchical binomial model from https://github.com/probml/pyprobml/blob/master/notebooks/book2/03/hierarchical_binom_rats.ipynb, but performs HMC inference both with and without a change of variables. The purpose is to illustrate the deleterious effects of forgetting to use bijectors and Jacobians. Note that the blackjax library will not detect such mistakes for you, but we show how to use arviz to diagnose the problem.
Data
Posterior sampling
Now we use Blackjax's NUTS algorithm to get posterior samples of , , and
We take initial parameters from uniform distribution
Now we use blackjax's window adaption algorithm to get NUTS kernel and initial states. Window adaption algorithm will automatically configure inverse_mass_matrix
and step size
Now we write inference loop for multiple chains
Arviz plots
We have all our posterior samples stored in states.position
dictionary and infos
store additional information like acceptance probability, divergence, etc. Now, we can use certain diagnostics to judge if our MCMC samples are converged on stationary distribution. Some of widely diagnostics are trace plots, potential scale reduction factor (R hat), divergences, etc. Arviz
library provides quicker ways to anaylze these diagnostics. We can use arviz.summary()
and arviz_plot_trace()
, but these functions take specific format (arviz's trace) as a input. So now first we will convert states
and infos
into trace
.
r_hat is showing measure of each chain is converged to stationary distribution. r_hat should be less than or equal to 1.01, here we get r_hat far from 1.01 for each latent sample.
Trace plots also looks terrible and does not seems to be converged! Also, black band shows that every sample is diverged from original distribution. So what's wrong happeing here?
Well, it's related to support of latent variable. In HMC, the latent variable must be in an unconstrained space, but in above model theta
is constrained in between 0 to 1. We can use change of variable trick to solve above problem
Change of variable
We can sample from logits which is in unconstrained space and in joint_logprob()
we can convert logits to theta by suitable bijector (sigmoid). We calculate jacobian (first order derivaive) of bijector to tranform one probability distribution to another
except change of variable in joint_logprob()
function, everthing will remain same
We can see that r_hat is less than or equal to 1.01 for each latent variable, trace plots looks converged to stationary distribution, and only few samples are diverged.