Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/29/supplementary/hmm_poisson_changepoint_jax.ipynb
1193 views
Kernel: Python 3

HMM with Poisson observations for detecting changepoints in the rate of a signal

from IPython.utils import io with io.capture_output() as captured: !pip install -qq distrax !pip install -qq flax
import logging logging.getLogger("absl").setLevel(logging.CRITICAL)
import numpy as np import jax from jax.random import split, PRNGKey import jax.numpy as jnp from jax import jit, lax, vmap from jax.experimental import optimizers try: import tensorflow_probability as tfp except ModuleNotFoundError: %pip install -qq tensorflow-probability import tensorflow_probability as tfp from matplotlib import pylab as plt %matplotlib inline import scipy.stats try: import distrax except ModuleNotFoundError: %pip install -qq distrax import distrax from distrax import HMM

Data

The synthetic data corresponds to a single time series of counts, where the rate of the underlying generative process changes at certain points in time.

true_rates = [40, 3, 20, 50] true_durations = [10, 20, 5, 35] random_state = 0 observed_counts = jnp.concatenate( [ scipy.stats.poisson(rate).rvs(num_steps, random_state=random_state) for (rate, num_steps) in zip(true_rates, true_durations) ] ).astype(jnp.float32) plt.plot(observed_counts);
Image in a Jupyter notebook

Model with fixed KK

To model the changing Poisson rate, we use an HMM. We initially assume the number of states is known to be K=4K=4. Later we will try comparing HMMs with different KK.

We fix the initial state distribution to be uniform, and fix the transition matrix to be the following, where we set p=0.05p=0.05:

z1Categorical({14,14,14,14})ztzt1Categorical({pif zt=zt11p41otherwise})\begin{align*} z_1 &\sim \text{Categorical}\left(\left\{\frac{1}{4}, \frac{1}{4}, \frac{1}{4}, \frac{1}{4}\right\}\right)\\ z_t | z_{t-1} &\sim \text{Categorical}\left(\left\{\begin{array}{cc}p & \text{if } z_t = z_{t-1} \\ \frac{1-p}{4-1} & \text{otherwise}\end{array}\right\}\right) \end{align*}
def build_latent_state(num_states, max_num_states, daily_change_prob): # Give probability 0 to states outside of the current model. def prob(s): return jnp.where(s < num_states + 1, 1 / num_states, 0.0) states = jnp.arange(1, max_num_states + 1) initial_state_probs = vmap(prob)(states) # Build a transition matrix that transitions only within the current # `num_states` states. def transition_prob(i, s): return jnp.where( (s <= num_states) & (i <= num_states) & (1 < num_states), jnp.where(s == i, 1 - daily_change_prob, daily_change_prob / (num_states - 1)), jnp.where(s == i, 1, 0), ) transition_probs = vmap(transition_prob, in_axes=(None, 0))(states, states) return initial_state_probs, transition_probs
num_states = 4 daily_change_prob = 0.05 initial_state_probs, transition_probs = build_latent_state(num_states, num_states, daily_change_prob) print("Initial state probs:\n{}".format(initial_state_probs)) print("Transition matrix:\n{}".format(transition_probs))
Initial state probs: [0.25 0.25 0.25 0.25] Transition matrix: [[0.95 0.01666667 0.01666667 0.01666667] [0.01666667 0.95 0.01666667 0.01666667] [0.01666667 0.01666667 0.95 0.01666667] [0.01666667 0.01666667 0.01666667 0.95 ]]

Now we create an HMM where the observation distribution is a Poisson with learnable parameters. We specify the parameters in log space and initialize them to random values around the log of the overall mean count (to set the scal

def make_hmm(log_rates, transition_probs, initial_state_probs): """Make a Hidden Markov Model with Poisson observation distribution.""" return HMM( obs_dist=tfp.substrates.jax.distributions.Poisson(log_rate=log_rates), trans_dist=distrax.Categorical(probs=transition_probs), init_dist=distrax.Categorical(probs=initial_state_probs), )
rng_key = PRNGKey(0) rng_key, rng_normal, rng_poisson = split(rng_key, 3) # Define variable to represent the unknown log rates. trainable_log_rates = jnp.log(jnp.mean(observed_counts)) + jax.random.normal(rng_normal, (num_states,)) hmm = make_hmm(trainable_log_rates, transition_probs, initial_state_probs)
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:5792: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. lax._check_user_dtype_supported(dtype, "astype")

Model fitting using Gradient Descent

We compute a MAP estimate of the Poisson rates λ\lambda using batch gradient descent, using the Adam optimizer applied to the log likelihood (from the HMM) plus the log prior for p(λ)p(\lambda).

def loss_fn(trainable_log_rates, transition_probs, initial_state_probs): cur_hmm = make_hmm(trainable_log_rates, transition_probs, initial_state_probs) return -(jnp.sum(rate_prior.log_prob(jnp.exp(trainable_log_rates))) + cur_hmm.forward(observed_counts)[0]) def update(i, opt_state, transition_probs, initial_state_probs): params = get_params(opt_state) loss, grads = jax.value_and_grad(loss_fn)(params, transition_probs, initial_state_probs) return opt_update(i, grads, opt_state), loss def fit(trainable_log_rates, transition_probs, initial_state_probs, n_steps): opt_state = opt_init(trainable_log_rates) def train_step(opt_state, step): opt_state, loss = update(step, opt_state, transition_probs, initial_state_probs) return opt_state, loss steps = jnp.arange(n_steps) opt_state, losses = lax.scan(train_step, opt_state, steps) return get_params(opt_state), losses
rate_prior = distrax.LogStddevNormal(5, 5) opt_init, opt_update, get_params = optimizers.adam(1e-1) n_steps = 201 params, losses = fit(trainable_log_rates, transition_probs, initial_state_probs, n_steps) rates = jnp.exp(params) hmm = make_hmm(params, transition_probs, initial_state_probs)
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:5792: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.float64'> requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. lax._check_user_dtype_supported(dtype, "astype") /usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:5792: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. lax._check_user_dtype_supported(dtype, "astype")
print("Inferred rates: {}".format(rates)) print("True rates: {}".format(true_rates))
Inferred rates: [ 2.9000478 49.509304 40.368298 21.979261 ] True rates: [40, 3, 20, 50]
plt.plot(losses) plt.ylabel("Negative log marginal likelihood");
Image in a Jupyter notebook

We see that the method learned a good approximation to the true (generating) parameters, up to a permutation of the states (since the labels are unidentifiable). However, results can vary with different random seeds. We may find that the rates are the same for some states, which means those states are being treated as identical, and are therefore redundant.

Plotting the posterior over states

_, _, posterior_probs, _ = hmm.forward_backward(observed_counts)
def plot_state_posterior(ax, state_posterior_probs, title): ln1 = ax.plot(state_posterior_probs, c="tab:blue", lw=3, label="p(state | counts)") ax.set_ylim(0.0, 1.1) ax.set_ylabel("posterior probability") ax2 = ax.twinx() ln2 = ax2.plot(observed_counts, c="black", alpha=0.3, label="observed counts") ax2.set_title(title) ax2.set_xlabel("time") lns = ln1 + ln2 labs = [l.get_label() for l in lns] ax.legend(lns, labs, loc=4) ax.grid(True, color="white") ax2.grid(False) fig = plt.figure(figsize=(10, 10)) plot_state_posterior(fig.add_subplot(2, 2, 1), posterior_probs[:, 0], title="state 0 (rate {:.2f})".format(rates[0])) plot_state_posterior(fig.add_subplot(2, 2, 2), posterior_probs[:, 1], title="state 1 (rate {:.2f})".format(rates[1])) plot_state_posterior(fig.add_subplot(2, 2, 3), posterior_probs[:, 2], title="state 2 (rate {:.2f})".format(rates[2])) plot_state_posterior(fig.add_subplot(2, 2, 4), posterior_probs[:, 3], title="state 3 (rate {:.2f})".format(rates[3])) plt.tight_layout()
Image in a Jupyter notebook
print(rates)
[ 2.9000478 49.509304 40.368298 21.979261 ]
# max marginals most_probable_states = jnp.argmax(posterior_probs, axis=-1) most_probable_rates = rates[most_probable_states]
fig = plt.figure(figsize=(10, 4)) ax = fig.add_subplot(1, 1, 1) ax.plot(most_probable_rates, c="tab:green", lw=3, label="inferred rate") ax.plot(observed_counts, c="black", alpha=0.3, label="observed counts") ax.set_ylabel("latent rate") ax.set_xlabel("time") ax.set_title("Inferred latent rate over time") ax.legend(loc=4);
Image in a Jupyter notebook
# max probaility trajectory (Viterbi) most_probable_states = hmm.viterbi(observed_counts) most_probable_rates = rates[most_probable_states]
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:5792: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. lax._check_user_dtype_supported(dtype, "astype")
fig = plt.figure(figsize=(10, 4)) ax = fig.add_subplot(1, 1, 1) color_list = np.array(["tab:red", "tab:green", "tab:blue", "k"]) colors = color_list[most_probable_states] for i in range(len(colors)): ax.plot(i, most_probable_rates[i], "-o", c=colors[i], lw=3, alpha=0.75) ax.plot(observed_counts, c="black", alpha=0.3, label="observed counts") ax.set_ylabel("latent rate") ax.set_xlabel("time") ax.set_title("Inferred latent rate over time");
Image in a Jupyter notebook

Model with unknown KK

In general we don't know the true number of states. One way to select the 'best' model is to compute the one with the maximum marginal likelihood. Rather than summing over both discrete latent states and integrating over the unknown parameters λ\lambda, we just maximize over the parameters (empirical Bayes approximation).

p(x1:TK)maxλp(x1:T,z1:Tλ,K)dzp(x_{1:T}|K) \approx \max_\lambda \int p(x_{1:T}, z_{1:T} | \lambda, K) dz

We can do this by fitting a bank of separate HMMs in parallel, one for each value of KK. We need to make them all the same size so we can batch them efficiently. To do this, we pad the transition matrices (and other paraemeter vectors) so they all have the same shape, and then use masking.

max_num_states = 6 states = jnp.arange(1, max_num_states + 1) # For each candidate model, build the initial state prior and transition matrix. batch_initial_state_probs, batch_transition_probs = vmap(build_latent_state, in_axes=(0, None, None))( states, max_num_states, daily_change_prob ) print("Shape of initial_state_probs: {}".format(batch_initial_state_probs.shape)) print("Shape of transition probs: {}".format(batch_transition_probs.shape)) print("Example initial state probs for num_states==3:\n{}".format(batch_initial_state_probs[2, :])) print("Example transition_probs for num_states==3:\n{}".format(batch_transition_probs[2, :, :]))
Shape of initial_state_probs: (6, 6) Shape of transition probs: (6, 6, 6) Example initial state probs for num_states==3: [0.33333334 0.33333334 0.33333334 0. 0. 0. ] Example transition_probs for num_states==3: [[0.95 0.025 0.025 0. 0. 0. ] [0.025 0.95 0.025 0. 0. 0. ] [0.025 0.025 0.95 0. 0. 0. ] [0. 0. 0. 1. 0. 0. ] [0. 0. 0. 0. 1. 0. ] [0. 0. 0. 0. 0. 1. ]]
rng_key, rng_normal = split(rng_key) # Define variable to represent the unknown log rates. trainable_log_rates = jnp.log(jnp.mean(observed_counts)) + jax.random.normal(rng_normal, (max_num_states,))

Model fitting with gradient descent

n_steps = 201 params, losses = vmap(fit, in_axes=(None, 0, 0, None))( trainable_log_rates, batch_transition_probs, batch_initial_state_probs, n_steps ) rates = jnp.exp(params)
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:5792: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. lax._check_user_dtype_supported(dtype, "astype")
plt.plot(losses.T) plt.ylabel("Negative log marginal likelihood");
Image in a Jupyter notebook

Plot marginal likelihood of each model

plt.plot(-losses[:, -1]) plt.ylim([-400, -200]) plt.ylabel("marginal likelihood $\\tilde{p}(x)$") plt.xlabel("number of latent states") plt.title("Model selection on latent states");
Image in a Jupyter notebook

Plot posteriors

for i, learned_model_rates in enumerate(rates): print("rates for {}-state model: {}".format(i + 1, learned_model_rates[: i + 1]))
rates for 1-state model: [32.970913] rates for 2-state model: [44.998596 2.9001565] rates for 3-state model: [26.061035 2.900157 48.456497] rates for 4-state model: [21.980347 2.9001298 49.50941 40.368587 ] rates for 5-state model: [40.430183 2.899876 49.539047 40.4293 21.923853] rates for 6-state model: [34.36414 2.899878 40.887142 40.887688 21.84138 49.52287 ]
def posterior_marginals(trainable_log_rates, initial_state_probs, transition_probs): hmm = make_hmm(trainable_log_rates, transition_probs, initial_state_probs) _, _, marginals, _ = hmm.forward_backward(observed_counts) return marginals posterior_probs = vmap(posterior_marginals, in_axes=(0, 0, 0))( params, batch_initial_state_probs, batch_transition_probs ) most_probable_states = jnp.argmax(posterior_probs, axis=-1)
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:5792: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. lax._check_user_dtype_supported(dtype, "astype")
fig = plt.figure(figsize=(14, 12)) for i, learned_model_rates in enumerate(rates): ax = fig.add_subplot(4, 3, i + 1) ax.plot(learned_model_rates[most_probable_states[i]], c="green", lw=3, label="inferred rate") ax.plot(observed_counts, c="black", alpha=0.3, label="observed counts") ax.set_ylabel("latent rate") ax.set_xlabel("time") ax.set_title("{}-state model".format(i + 1)) ax.legend(loc=4) plt.tight_layout()
Image in a Jupyter notebook