Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/misc/bnn_hmc_gaussian.ipynb
1192 views
Kernel: Python 3

Open In Colab

import jax print(jax.devices())
[GpuDevice(id=0, process_index=0)]
!git clone https://github.com/google-research/google-research.git
fatal: destination path 'google-research' already exists and is not an empty directory.
%cd /content/google-research
/content/google-research
!ls bnn_hmc
core README.md run_sgd.py utils make_posterior_surface_plot.py requirements.txt run_sgmcmc.py notebooks run_hmc.py run_vi.py
!pip install optax
Collecting optax Downloading optax-0.0.9-py3-none-any.whl (118 kB) |████████████████████████████████| 118 kB 12.1 MB/s eta 0:00:01 Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from optax) (0.12.0) Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from optax) (1.19.5) Collecting chex>=0.0.4 Downloading chex-0.0.8-py3-none-any.whl (57 kB) |████████████████████████████████| 57 kB 5.7 MB/s eta 0:00:01 Requirement already satisfied: jax>=0.1.55 in /usr/local/lib/python3.7/dist-packages (from optax) (0.2.19) Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax) (0.1.70+cuda110) Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.7.1->optax) (1.15.0) Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.1.6) Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.11.1) Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (3.3.0) Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax) (1.12) Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax) (1.4.1) Installing collected packages: chex, optax Successfully installed chex-0.0.8 optax-0.0.9

Setup

from jax.config import config import jax from jax import numpy as jnp import numpy as onp import numpy as np
import os import sys import time import tqdm import optax import functools from matplotlib import pyplot as plt from bnn_hmc.utils import losses from bnn_hmc.utils import train_utils from bnn_hmc.utils import tree_utils %matplotlib inline %load_ext autoreload %autoreload 2

Data and model

mu = jnp.zeros( [ 2, ] ) # sigma = jnp.array([[1., .5], [.5, 1.]]) sigma = jnp.array([[1.0e-4, 0], [0.0, 1.0]]) sigma_l = jnp.linalg.cholesky(sigma) sigma_inv = jnp.linalg.inv(sigma) sigma_det = jnp.linalg.det(sigma)
onp.random.seed(0) samples = onp.random.multivariate_normal(onp.asarray(mu), onp.asarray(sigma), size=1000) plt.scatter(samples[:, 0], samples[:, 1], alpha=0.3) plt.grid()
Image in a Jupyter notebook
def log_density_fn(params): assert params.shape == mu.shape, "Shape error" diff = params - mu k = mu.size log_density = -jnp.log(2 * jnp.pi) * k / 2 log_density -= jnp.log(sigma_det) / 2 log_density -= diff.T @ sigma_inv @ diff / 2 return log_density
def log_likelihood_fn(_, params, *args, **kwargs): return log_density_fn(params), jnp.array(jnp.nan) def log_prior_fn(_): return 0.0 def log_prior_diff_fn(*args): return 0.0
fake_net_apply = None fake_data = jnp.array([[jnp.nan,],]), jnp.array( [ [ jnp.nan, ], ] ) fake_net_state = jnp.array( [ jnp.nan, ] )

HMC

step_size = 1e-1 trajectory_len = jnp.pi / 2 max_num_leapfrog_steps = int(trajectory_len // step_size + 1) print("Leapfrog steps per iteration:", max_num_leapfrog_steps)
Leapfrog steps per iteration: 16
update, get_log_prob_and_grad = train_utils.make_hmc_update( fake_net_apply, log_likelihood_fn, log_prior_fn, log_prior_diff_fn, max_num_leapfrog_steps, 1.0, 0.0 )
# Initial log-prob and grad values # params = jnp.ones_like(mu)[None, :] params = jnp.ones_like(mu) log_prob, state_grad, log_likelihood, net_state = get_log_prob_and_grad(fake_data, params, fake_net_state)
%%time num_iterations = 500 all_samples = [] key = jax.random.PRNGKey(0) for iteration in tqdm.tqdm(range(num_iterations)): (params, net_state, log_likelihood, state_grad, step_size, key, accept_prob, accepted) = update( fake_data, params, net_state, log_likelihood, state_grad, key, step_size, trajectory_len, True ) if accepted: all_samples.append(onp.asarray(params).copy()) # print("It: {} \t Accept P: {} \t Accepted {} \t Log-likelihood: {}".format( # iteration, accept_prob, accepted, log_likelihood))
100%|██████████| 500/500 [03:26<00:00, 2.42it/s]
len(all_samples)
0
log_prob, state_grad, log_likelihood, net_state
(DeviceArray(-4997.733, dtype=float32), DeviceArray([-1.e+04, -1.e+00], dtype=float32), DeviceArray(-4997.733, dtype=float32), ShardedDeviceArray([nan], dtype=float32))
all_samples_cat = onp.stack(all_samples)
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-38-dde50ab74494> in <module>() 1 ----> 2 all_samples_cat = onp.stack(all_samples) <__array_function__ internals> in stack(*args, **kwargs) /usr/local/lib/python3.7/dist-packages/numpy/core/shape_base.py in stack(arrays, axis, out) 421 arrays = [asanyarray(arr) for arr in arrays] 422 if not arrays: --> 423 raise ValueError('need at least one array to stack') 424 425 shapes = {arr.shape for arr in arrays} ValueError: need at least one array to stack
plt.scatter(all_samples_cat[:, 0], all_samples_cat[:, 1], alpha=0.3) plt.grid()
--------------------------------------------------------------------------- NameError Traceback (most recent call last) <ipython-input-19-ccd922184a23> in <module>() ----> 1 plt.scatter(all_samples_cat[:, 0], all_samples_cat[:, 1], alpha=0.3) 2 plt.grid() NameError: name 'all_samples_cat' is not defined

Blackjax

!pip install blackjax
Requirement already satisfied: blackjax in /usr/local/lib/python3.7/dist-packages (0.2.1)
import jax import jax.numpy as jnp import jax.scipy.stats as stats import matplotlib.pyplot as plt import numpy as np import blackjax.hmc as hmc import blackjax.nuts as nuts import blackjax.stan_warmup as stan_warmup
print(jax.devices())
[GpuDevice(id=0, process_index=0)]
potential = lambda x: -log_density_fn(**x)
num_integration_steps = 30 kernel_generator = lambda step_size, inverse_mass_matrix: hmc.kernel( potential, step_size, inverse_mass_matrix, num_integration_steps ) rng_key = jax.random.PRNGKey(0) initial_position = {"params": np.zeros(2)} initial_state = hmc.new_state(initial_position, potential) print(initial_state)
HMCState(position={'params': array([0., 0.])}, potential_energy=DeviceArray(-2.7672932, dtype=float32), potential_energy_grad={'params': DeviceArray([0., 0.], dtype=float32)})
%%time nsteps = 500 final_state, (step_size, inverse_mass_matrix), info = stan_warmup.run( rng_key, kernel_generator, initial_state, nsteps, )
CPU times: user 3.59 s, sys: 132 ms, total: 3.72 s Wall time: 2.07 s
%%time kernel = nuts.kernel(potential, step_size, inverse_mass_matrix) kernel = jax.jit(kernel)
CPU times: user 1.81 ms, sys: 0 ns, total: 1.81 ms Wall time: 925 µs
def inference_loop(rng_key, kernel, initial_state, num_samples): def one_step(state, rng_key): state, _ = kernel(rng_key, state) return state, state keys = jax.random.split(rng_key, num_samples) _, states = jax.lax.scan(one_step, initial_state, keys) return states
%%time nsamples = 500 states = inference_loop(rng_key, kernel, initial_state, nsamples) samples = states.position["params"].block_until_ready() print(samples.shape)
(500, 2) CPU times: user 2.54 s, sys: 69.8 ms, total: 2.61 s Wall time: 1.72 s
plt.scatter(samples[:, 0], samples[:, 1], alpha=0.3) plt.grid()
Image in a Jupyter notebook