Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/notebooks/bernoulli_hmm_example.ipynb
1192 views
Kernel: Python 3 (ipykernel)

Open In Colab

!pip install git+git://github.com/lindermanlab/ssm-jax-refactor.git
Collecting git+git://github.com/lindermanlab/ssm-jax-refactor.git Cloning git://github.com/lindermanlab/ssm-jax-refactor.git to /tmp/pip-req-build-j0n1k4xi Running command git clone -q git://github.com/lindermanlab/ssm-jax-refactor.git /tmp/pip-req-build-j0n1k4xi Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (1.19.5) Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (1.4.1) Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (3.2.2) Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (1.0.2) Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (4.62.3) Requirement already satisfied: seaborn in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (0.11.2) Collecting jax==0.2.21 Downloading jax-0.2.21.tar.gz (756 kB) |████████████████████████████████| 756 kB 13.7 MB/s Requirement already satisfied: jaxlib in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (0.1.71+cuda111) Requirement already satisfied: h5py in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (3.1.0) Requirement already satisfied: jupyter in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (1.0.0) Requirement already satisfied: ipywidgets in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (7.6.5) Requirement already satisfied: tensorflow-probability in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (0.15.0) Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax==0.2.21->ssm==0.1) (1.0.0) Requirement already satisfied: opt_einsum in /usr/local/lib/python3.7/dist-packages (from jax==0.2.21->ssm==0.1) (3.3.0) Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax==0.2.21->ssm==0.1) (1.15.0) Requirement already satisfied: cached-property in /usr/local/lib/python3.7/dist-packages (from h5py->ssm==0.1) (1.5.2) Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (1.0.2) Requirement already satisfied: ipython-genutils~=0.2.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (0.2.0) Requirement already satisfied: ipython>=4.0.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (5.5.0) Requirement already satisfied: nbformat>=4.2.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (5.1.3) Requirement already satisfied: ipykernel>=4.5.1 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (4.10.1) Requirement already satisfied: traitlets>=4.3.1 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (5.1.1) Requirement already satisfied: widgetsnbextension~=3.5.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (3.5.2) Requirement already satisfied: jupyter-client in /usr/local/lib/python3.7/dist-packages (from ipykernel>=4.5.1->ipywidgets->ssm==0.1) (5.3.5) Requirement already satisfied: tornado>=4.0 in /usr/local/lib/python3.7/dist-packages (from ipykernel>=4.5.1->ipywidgets->ssm==0.1) (5.1.1) Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (57.4.0) Requirement already satisfied: simplegeneric>0.8 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (0.8.1) Requirement already satisfied: pickleshare in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (0.7.5) Requirement already satisfied: pygments in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (2.6.1) Requirement already satisfied: prompt-toolkit<2.0.0,>=1.0.4 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (1.0.18) Requirement already satisfied: decorator in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (4.4.2) Requirement already satisfied: pexpect in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (4.8.0) Requirement already satisfied: jupyter-core in /usr/local/lib/python3.7/dist-packages (from nbformat>=4.2.0->ipywidgets->ssm==0.1) (4.9.1) Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /usr/local/lib/python3.7/dist-packages (from nbformat>=4.2.0->ipywidgets->ssm==0.1) (4.3.3) Requirement already satisfied: importlib-resources>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (5.4.0) Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (21.4.0) Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (4.10.1) Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (0.18.1) Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (3.10.0.2) Requirement already satisfied: zipp>=3.1.0 in /usr/local/lib/python3.7/dist-packages (from importlib-resources>=1.4.0->jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (3.7.0) Requirement already satisfied: wcwidth in /usr/local/lib/python3.7/dist-packages (from prompt-toolkit<2.0.0,>=1.0.4->ipython>=4.0.0->ipywidgets->ssm==0.1) (0.2.5) Requirement already satisfied: notebook>=4.4.1 in /usr/local/lib/python3.7/dist-packages (from widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (5.3.1) Requirement already satisfied: nbconvert in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (5.6.1) Requirement already satisfied: jinja2 in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (2.11.3) Requirement already satisfied: terminado>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.13.1) Requirement already satisfied: Send2Trash in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (1.8.0) Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from jupyter-client->ipykernel>=4.5.1->ipywidgets->ssm==0.1) (2.8.2) Requirement already satisfied: pyzmq>=13 in /usr/local/lib/python3.7/dist-packages (from jupyter-client->ipykernel>=4.5.1->ipywidgets->ssm==0.1) (22.3.0) Requirement already satisfied: ptyprocess in /usr/local/lib/python3.7/dist-packages (from terminado>=0.8.1->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.7.0) Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib->ssm==0.1) (2.0) Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (2.0.1) Requirement already satisfied: jupyter-console in /usr/local/lib/python3.7/dist-packages (from jupyter->ssm==0.1) (5.2.0) Requirement already satisfied: qtconsole in /usr/local/lib/python3.7/dist-packages (from jupyter->ssm==0.1) (5.2.2) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->ssm==0.1) (0.11.0) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->ssm==0.1) (1.3.2) Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->ssm==0.1) (3.0.7) Requirement already satisfied: bleach in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (4.1.0) Requirement already satisfied: testpath in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.5.0) Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (1.5.0) Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.3) Requirement already satisfied: defusedxml in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.7.1) Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.8.4) Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (21.3) Requirement already satisfied: webencodings in /usr/local/lib/python3.7/dist-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.5.1) Requirement already satisfied: qtpy in /usr/local/lib/python3.7/dist-packages (from qtconsole->jupyter->ssm==0.1) (2.0.0) Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->ssm==0.1) (1.1.0) Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->ssm==0.1) (3.0.0) Requirement already satisfied: pandas>=0.23 in /usr/local/lib/python3.7/dist-packages (from seaborn->ssm==0.1) (1.3.5) Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.23->seaborn->ssm==0.1) (2018.9) Requirement already satisfied: dm-tree in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability->ssm==0.1) (0.1.6) Requirement already satisfied: cloudpickle>=1.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability->ssm==0.1) (1.3.0) Requirement already satisfied: gast>=0.3.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability->ssm==0.1) (0.4.0) Building wheels for collected packages: ssm, jax Building wheel for ssm (setup.py) ... done Created wheel for ssm: filename=ssm-0.1-py3-none-any.whl size=75282 sha256=07a1ee07356d240c4bf185042fad0cc0216b5f6a00fb7ba950c498cc897b251e Stored in directory: /tmp/pip-ephem-wheel-cache-1b2n7kks/wheels/78/93/24/866323c03bb6444c9ad2485bc0abe61ad5e6828d66c2c2fda3 Building wheel for jax (setup.py) ... done Created wheel for jax: filename=jax-0.2.21-py3-none-any.whl size=869303 sha256=ea52c5240af54eab126396cb09cc3730651e7b07fd3cef4ded61073ccd7e50fd Stored in directory: /root/.cache/pip/wheels/5c/69/0d/3784dd6d281be0837d8cef1db0c8b37d108c8bff727b961178 Successfully built ssm jax Installing collected packages: jax, ssm Attempting uninstall: jax Found existing installation: jax 0.2.25 Uninstalling jax-0.2.25: Successfully uninstalled jax-0.2.25 Successfully installed jax-0.2.21 ssm-0.1
import ssm

Imports and Plotting Functions

import jax.random as jr import jax.numpy as np import matplotlib.pyplot as plt from tensorflow_probability.substrates import jax as tfp from ssm.hmm import BernoulliHMM from ssm.plots import gradient_cmap from ssm.utils import find_permutation import warnings import seaborn as sns sns.set_style("white") sns.set_context("talk") color_names = ["windows blue", "red", "amber", "faded green", "dusty purple", "orange"] colors = sns.xkcd_palette(color_names) cmap = gradient_cmap(colors) def plot_transition_matrix(transition_matrix): plt.imshow(transition_matrix, vmin=0, vmax=1, cmap="Greys") plt.xlabel("next state") plt.ylabel("current state") plt.colorbar() plt.show() def compare_transition_matrix(true_matrix, test_matrix): fig, axs = plt.subplots(1, 2, figsize=(10, 5)) out = axs[0].imshow(true_matrix, vmin=0, vmax=1, cmap="Greys") axs[1].imshow(test_matrix, vmin=0, vmax=1, cmap="Greys") axs[0].set_title("True Transition Matrix") axs[1].set_title("Test Transition Matrix") cax = fig.add_axes( [ axs[1].get_position().x1 + 0.07, axs[1].get_position().y0, 0.02, axs[1].get_position().y1 - axs[1].get_position().y0, ] ) plt.colorbar(out, cax=cax) plt.show() def plot_hmm_data(obs, states): lim = 1.01 * abs(obs).max() time_bins, obs_dim = obs.shape plt.figure(figsize=(8, 3)) plt.imshow( states[None, :], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors) - 1, extent=(0, time_bins, -lim, (obs_dim) * lim), ) for d in range(obs_dim): plt.plot(obs[:, d] + lim * d, "-k") plt.xlim(0, time_bins) plt.xlabel("time") plt.yticks(lim * np.arange(obs_dim), ["$x_{}$".format(d + 1) for d in range(obs_dim)]) plt.title("Simulated data from an HMM") plt.tight_layout() def plot_posterior_states(Ez, states, perm): plt.figure(figsize=(25, 5)) plt.imshow(Ez.T[perm], aspect="auto", interpolation="none", cmap="Greys") plt.plot(states, label="True State") plt.plot(Ez.T[perm].argmax(axis=0), "--", label="Predicted State") plt.xlabel("time") plt.ylabel("latent state") # plt.legend(bbox_to_anchor=(1,1)) plt.title("Predicted vs. Ground Truth Latent State") # plt.show()

Bernoulli HMM

Let's create a true model

num_states = 5 num_channels = 10 transition_matrix = 0.90 * np.eye(num_states) + 0.10 * np.ones((num_states, num_states)) / num_states true_hmm = BernoulliHMM( num_states, num_emission_dims=num_channels, transition_matrix=transition_matrix, seed=jr.PRNGKey(0) )
plot_transition_matrix(true_hmm.transition_matrix)
Image in a Jupyter notebook

From the true model, we can sample synthetic data

rng = jr.PRNGKey(0) num_timesteps = 500 states, data = true_hmm.sample(rng, num_timesteps)

Let's view the synthetic data

fig, axs = plt.subplots(2, 1, sharex=True, figsize=(20, 8)) axs[0].imshow(data.T, aspect="auto", interpolation="none") # axs[0].set_ylabel("neuron") axs[0].set_title("Observations") axs[1].plot(states) axs[1].set_title("Latent State") axs[1].set_xlabel("time") axs[1].set_ylabel("state") plt.savefig("bernoulli-hmm-data.pdf") plt.savefig("bernoulli-hmm-data.png") plt.show()
Image in a Jupyter notebook

Fit HMM using exact EM update

test_hmm = BernoulliHMM(num_states, num_channels, seed=jr.PRNGKey(32)) lps, test_hmm, posterior = test_hmm.fit(data, method="em", tol=-1)
Initializing... Done.
0%| | 0/100 [00:00<?, ?it/s]
# Plot the log probabilities plt.plot(lps) plt.xlabel("iteration") plt.ylabel("log likelihood")
Text(0, 0.5, 'log likelihood')
Image in a Jupyter notebook
test_hmm.transition_matrix
DeviceArray([[0.84915197, 0.04904636, 0.00188023, 0.07304415, 0.02687721], [0.00975706, 0.9331732 , 0.02004754, 0.02332079, 0.01370143], [0.00132545, 0.02912631, 0.9396162 , 0.0288132 , 0.00111887], [0.06562801, 0.03183502, 0.00170273, 0.850693 , 0.0501412 ], [0.01220625, 0.03129839, 0.00972043, 0.01064176, 0.93613315]], dtype=float32)
# Compare the transition matrices compare_transition_matrix(true_hmm.transition_matrix, test_hmm.transition_matrix) plt.savefig("bernoulli-hmm-transmat-comparison.pdf")
Image in a Jupyter notebook
<Figure size 432x288 with 0 Axes>
# Posterior distribution Ez = posterior.expected_states.reshape(-1, num_states) perm = find_permutation(states, np.argmax(Ez, axis=-1)) plot_posterior_states(Ez, states, perm) plt.savefig("bernoulli-hmm-state-est-comparison.pdf") plt.savefig("bernoulli-hmm-state-est-comparison.png") plt.show()
Image in a Jupyter notebook

Fit Bernoulli Over Multiple Trials

rng = jr.PRNGKey(0) num_timesteps = 500 num_trials = 5 all_states, all_data = true_hmm.sample(rng, num_timesteps, num_samples=num_trials)
# Now we have a batch dimension of size `num_trials` print(all_states.shape) print(all_data.shape)
(5, 500) (5, 500, 10)
lps, test_hmm, posterior = test_hmm.fit(all_data, method="em", tol=-1)
Initializing... Done.
LP: -13234.113: 100%|██████████| 100/100 [00:01<00:00, 53.80it/s]
# plot marginal log probabilities plt.title("Marginal Log Probability") plt.ylabel("lp") plt.xlabel("idx") plt.plot(lps / data.size) compare_transition_matrix(true_hmm.transition_matrix, test_hmm.transition_matrix)
Image in a Jupyter notebookImage in a Jupyter notebook
# For the first few trials, let's see how good our predicted states are for trial_idx in range(3): print("=" * 5, f"Trial: {trial_idx}", "=" * 5) Ez = posterior.expected_states[trial_idx] states = all_states[trial_idx] perm = find_permutation(states, np.argmax(Ez, axis=-1)) plot_posterior_states(Ez, states, perm)
===== Trial: 0 =====
Image in a Jupyter notebook
===== Trial: 1 =====
Image in a Jupyter notebook
===== Trial: 2 =====
Image in a Jupyter notebook