Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/notebooks/poisson_lds_example.ipynb
1192 views
Kernel: Python 3 (ipykernel)

Open In Colab

Linear Dynamical System with Poisson likelihood

Code modified from

https://github.com/lindermanlab/ssm-jax-refactor/blob/main/notebooks/poisson-lds-example.ipynb

!pip install git+git://github.com/lindermanlab/ssm-jax-refactor.git import ssm
Collecting git+git://github.com/lindermanlab/ssm-jax-refactor.git Cloning git://github.com/lindermanlab/ssm-jax-refactor.git to /tmp/pip-req-build-b7yfm2xt Running command git clone -q git://github.com/lindermanlab/ssm-jax-refactor.git /tmp/pip-req-build-b7yfm2xt Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (1.19.5) Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (1.4.1) Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (3.2.2) Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (1.0.2) Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (4.62.3) Requirement already satisfied: seaborn in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (0.11.2) Collecting jax==0.2.21 Downloading jax-0.2.21.tar.gz (756 kB) |████████████████████████████████| 756 kB 5.5 MB/s Requirement already satisfied: jaxlib in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (0.1.71+cuda111) Requirement already satisfied: h5py in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (3.1.0) Requirement already satisfied: jupyter in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (1.0.0) Requirement already satisfied: ipywidgets in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (7.6.5) Requirement already satisfied: tensorflow-probability in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (0.15.0) Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax==0.2.21->ssm==0.1) (1.0.0) Requirement already satisfied: opt_einsum in /usr/local/lib/python3.7/dist-packages (from jax==0.2.21->ssm==0.1) (3.3.0) Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax==0.2.21->ssm==0.1) (1.15.0) Requirement already satisfied: cached-property in /usr/local/lib/python3.7/dist-packages (from h5py->ssm==0.1) (1.5.2) Requirement already satisfied: nbformat>=4.2.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (5.1.3) Requirement already satisfied: ipython-genutils~=0.2.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (0.2.0) Requirement already satisfied: ipykernel>=4.5.1 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (4.10.1) Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (1.0.2) Requirement already satisfied: traitlets>=4.3.1 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (5.1.1) Requirement already satisfied: ipython>=4.0.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (5.5.0) Requirement already satisfied: widgetsnbextension~=3.5.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (3.5.2) Requirement already satisfied: jupyter-client in /usr/local/lib/python3.7/dist-packages (from ipykernel>=4.5.1->ipywidgets->ssm==0.1) (5.3.5) Requirement already satisfied: tornado>=4.0 in /usr/local/lib/python3.7/dist-packages (from ipykernel>=4.5.1->ipywidgets->ssm==0.1) (5.1.1) Requirement already satisfied: pexpect in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (4.8.0) Requirement already satisfied: prompt-toolkit<2.0.0,>=1.0.4 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (1.0.18) Requirement already satisfied: pickleshare in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (0.7.5) Requirement already satisfied: decorator in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (4.4.2) Requirement already satisfied: simplegeneric>0.8 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (0.8.1) Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (57.4.0) Requirement already satisfied: pygments in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (2.6.1) Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /usr/local/lib/python3.7/dist-packages (from nbformat>=4.2.0->ipywidgets->ssm==0.1) (4.3.3) Requirement already satisfied: jupyter-core in /usr/local/lib/python3.7/dist-packages (from nbformat>=4.2.0->ipywidgets->ssm==0.1) (4.9.1) Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (21.4.0) Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (3.10.0.2) Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (4.10.1) Requirement already satisfied: importlib-resources>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (5.4.0) Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (0.18.1) Requirement already satisfied: zipp>=3.1.0 in /usr/local/lib/python3.7/dist-packages (from importlib-resources>=1.4.0->jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (3.7.0) Requirement already satisfied: wcwidth in /usr/local/lib/python3.7/dist-packages (from prompt-toolkit<2.0.0,>=1.0.4->ipython>=4.0.0->ipywidgets->ssm==0.1) (0.2.5) Requirement already satisfied: notebook>=4.4.1 in /usr/local/lib/python3.7/dist-packages (from widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (5.3.1) Requirement already satisfied: Send2Trash in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (1.8.0) Requirement already satisfied: terminado>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.13.1) Requirement already satisfied: nbconvert in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (5.6.1) Requirement already satisfied: jinja2 in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (2.11.3) Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from jupyter-client->ipykernel>=4.5.1->ipywidgets->ssm==0.1) (2.8.2) Requirement already satisfied: pyzmq>=13 in /usr/local/lib/python3.7/dist-packages (from jupyter-client->ipykernel>=4.5.1->ipywidgets->ssm==0.1) (22.3.0) Requirement already satisfied: ptyprocess in /usr/local/lib/python3.7/dist-packages (from terminado>=0.8.1->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.7.0) Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib->ssm==0.1) (2.0) Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (2.0.1) Requirement already satisfied: qtconsole in /usr/local/lib/python3.7/dist-packages (from jupyter->ssm==0.1) (5.2.2) Requirement already satisfied: jupyter-console in /usr/local/lib/python3.7/dist-packages (from jupyter->ssm==0.1) (5.2.0) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->ssm==0.1) (0.11.0) Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->ssm==0.1) (3.0.7) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->ssm==0.1) (1.3.2) Requirement already satisfied: testpath in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.5.0) Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.8.4) Requirement already satisfied: defusedxml in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.7.1) Requirement already satisfied: bleach in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (4.1.0) Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (1.5.0) Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.3) Requirement already satisfied: webencodings in /usr/local/lib/python3.7/dist-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.5.1) Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (21.3) Requirement already satisfied: qtpy in /usr/local/lib/python3.7/dist-packages (from qtconsole->jupyter->ssm==0.1) (2.0.0) Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->ssm==0.1) (1.1.0) Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->ssm==0.1) (3.0.0) Requirement already satisfied: pandas>=0.23 in /usr/local/lib/python3.7/dist-packages (from seaborn->ssm==0.1) (1.3.5) Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.23->seaborn->ssm==0.1) (2018.9) Requirement already satisfied: gast>=0.3.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability->ssm==0.1) (0.4.0) Requirement already satisfied: cloudpickle>=1.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability->ssm==0.1) (1.3.0) Requirement already satisfied: dm-tree in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability->ssm==0.1) (0.1.6) Building wheels for collected packages: ssm, jax Building wheel for ssm (setup.py) ... done Created wheel for ssm: filename=ssm-0.1-py3-none-any.whl size=75282 sha256=ae294a7f3473a7150d60d49a22a9c2a242dfc3708741d5ed31fae98a9127bb78 Stored in directory: /tmp/pip-ephem-wheel-cache-tdl_6bl2/wheels/78/93/24/866323c03bb6444c9ad2485bc0abe61ad5e6828d66c2c2fda3 Building wheel for jax (setup.py) ... done Created wheel for jax: filename=jax-0.2.21-py3-none-any.whl size=869303 sha256=0614b317ff4f589fc01904ee293583c774a413c5f073e19f603fa9456afaa169 Stored in directory: /root/.cache/pip/wheels/5c/69/0d/3784dd6d281be0837d8cef1db0c8b37d108c8bff727b961178 Successfully built ssm jax Installing collected packages: jax, ssm Attempting uninstall: jax Found existing installation: jax 0.2.25 Uninstalling jax-0.2.25: Successfully uninstalled jax-0.2.25 Successfully installed jax-0.2.21 ssm-0.1

Imports and Plotting Functions

import jax.numpy as np import jax.random as jr import jax.experimental.optimizers as optimizers from jax import jit, value_and_grad, vmap from tqdm.auto import trange import matplotlib.pyplot as plt from tensorflow_probability.substrates import jax as tfp from ssm.lds.models import GaussianLDS, PoissonLDS from ssm.distributions.linreg import GaussianLinearRegression from ssm.utils import random_rotation from ssm.plots import plot_dynamics_2d from matplotlib.gridspec import GridSpec def plot_emissions(states, data): latent_dim = states.shape[-1] emissions_dim = data.shape[-1] num_timesteps = data.shape[0] plt.figure(figsize=(8, 6)) gs = GridSpec(2, 1, height_ratios=(1, emissions_dim / latent_dim)) # Plot the continuous latent states lim = abs(states).max() plt.subplot(gs[0]) for d in range(latent_dim): plt.plot(states[:, d] + lim * d, "-") plt.yticks(np.arange(latent_dim) * lim, ["$x_{}$".format(d + 1) for d in range(latent_dim)]) plt.xticks([]) plt.xlim(0, num_timesteps) plt.title("Sampled Latent States") lim = abs(data).max() plt.subplot(gs[1]) for n in range(emissions_dim): plt.plot(data[:, n] - lim * n, "-k") plt.yticks(-np.arange(emissions_dim) * lim, ["$y_{{ {} }}$".format(n + 1) for n in range(emissions_dim)]) plt.xlabel("time") plt.xlim(0, num_timesteps) plt.title("Sampled Emissions") plt.tight_layout() def plot_emissions_poisson(states, data): latent_dim = states.shape[-1] emissions_dim = data.shape[-1] num_timesteps = data.shape[0] plt.figure(figsize=(8, 6)) gs = GridSpec(2, 1, height_ratios=(1, emissions_dim / latent_dim)) # Plot the continuous latent states lim = abs(states).max() plt.subplot(gs[0]) for d in range(latent_dim): plt.plot(states[:, d] + lim * d, "-") plt.yticks(np.arange(latent_dim) * lim, ["$z_{}$".format(d + 1) for d in range(latent_dim)]) plt.xticks([]) plt.xlim(0, time_bins) plt.title("Sampled Latent States") lim = abs(data).max() plt.subplot(gs[1]) plt.imshow(data.T, aspect="auto", interpolation="none") plt.xlabel("time") plt.xlim(0, time_bins) plt.yticks(ticks=np.arange(emissions_dim)) # plt.ylabel("Neuron") plt.title("Sampled Emissions (Counts / Time Bin)") plt.tight_layout() plt.colorbar() def plot_dynamics(lds, states): q = plot_dynamics_2d( lds._dynamics.weights, bias_vector=lds._dynamics.bias, mins=states.min(axis=0), maxs=states.max(axis=0), color="blue", ) plt.plot(states[:, 0], states[:, 1], lw=2, label="Latent State") plt.plot(states[0, 0], states[0, 1], "*r", markersize=10, label="Initial State") plt.xlabel("$z_1$") plt.ylabel("$z_2$") plt.title("Latent States & Dynamics") plt.legend(bbox_to_anchor=(1, 1)) # plt.show() def extract_trial_stats(trial_idx, posterior, all_data, all_states, fitted_lds, true_lds): # Posterior Mean Ex = posterior.mean()[trial_idx] states = all_states[trial_idx] data = all_data[trial_idx] # Compute the data predictions C = fitted_lds.emissions_matrix d = fitted_lds.emissions_bias Ey = Ex @ C.T + d Covy = C @ posterior.covariance()[trial_idx] @ C.T # basically recover the "true" input to the Poisson GLM Ey_true = states @ true_lds.emissions_matrix.T + true_lds.emissions_bias return states, data, Ex, Ey, Covy, Ey_true def compare_dynamics(Ex, states, data): # Plot fig, axs = plt.subplots(1, 2, figsize=(8, 4)) q = plot_dynamics_2d( true_lds._dynamics.weights, bias_vector=true_lds._dynamics.bias, mins=states.min(axis=0), maxs=states.max(axis=0), color="blue", axis=axs[0], ) axs[0].plot(states[:, 0], states[:, 1], lw=2) axs[0].plot(states[0, 0], states[0, 1], "*r", markersize=10, label="$z_{init}$") axs[0].set_xlabel("$z_1$") axs[0].set_ylabel("$z_2$") axs[0].set_title("True Latent States & Dynamics") q = plot_dynamics_2d( fitted_lds._dynamics.weights, bias_vector=fitted_lds._dynamics.bias, mins=Ex.min(axis=0), maxs=Ex.max(axis=0), color="red", axis=axs[1], ) axs[1].plot(Ex[:, 0], Ex[:, 1], lw=2) axs[1].plot(Ex[0, 0], Ex[0, 1], "*r", markersize=10, label="$z_{init}$") axs[1].set_xlabel("$z_1$") axs[1].set_ylabel("$z_2$") axs[1].set_title("Simulated Latent States & Dynamics") plt.tight_layout() # plt.show() def compare_smoothened_predictions(Ey, Ey_true, Covy, data): data_dim = data.shape[-1] plt.figure(figsize=(15, 6)) plt.plot(Ey_true + 10 * np.arange(data_dim)) plt.plot(Ey + 10 * np.arange(data_dim), "--k") for i in range(data_dim): plt.fill_between( np.arange(len(data)), 10 * i + Ey[:, i] - 2 * np.sqrt(Covy[:, i, i]), 10 * i + Ey[:, i] + 2 * np.sqrt(Covy[:, i, i]), color="k", alpha=0.25, ) plt.xlabel("time") plt.ylabel("data and predictions (for each neuron)") plt.plot([0], "--k", label="Predicted") # dummy trace for legend plt.plot([0], "-k", label="True") plt.legend(loc="upper right") # plt.show()
# Some parameters to define our model emissions_dim = 5 # num_observations latent_dim = 2 seed = jr.PRNGKey(0) # Initialize our true Poisson LDS model true_lds = PoissonLDS(num_latent_dims=latent_dim, num_emission_dims=emissions_dim, seed=seed)

Sample some synthetic data from the Poisson LDS

import warnings num_trials = 5 time_bins = 200 # catch annoying warnings of tfp Poisson sampling rng = jr.PRNGKey(0) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) all_states, all_data = true_lds.sample(key=rng, num_steps=time_bins, num_samples=num_trials)
plot_emissions_poisson(all_states[0], all_data[0]) plt.savefig("poisson-hmm-data.pdf") plt.savefig("poisson-hmm-data.png")
Image in a Jupyter notebook

Inference: let's fit a Poisson LDS to our data

Since we have a Poisson emissions model, we can no longer perform exact EM.

Instead, we perform Laplace EM, in which we approximate the posterior using a Laplace (Gaussian) approximation.

latent_dim = 2 seed = jr.PRNGKey(32) # NOTE: different seed! test_lds = PoissonLDS(num_emission_dims=emissions_dim, num_latent_dims=latent_dim, seed=seed)
rng = jr.PRNGKey(10) elbos, fitted_lds, posteriors = test_lds.fit(all_data, method="laplace_em", rng=rng, num_iters=25) # NOTE: you could also call the laplace_em routine directly like this # from ssm.inference.laplace_em import laplace_em # elbos, fitted_lds, posteriors = laplace_em(rng, # test_lds, # all_data, # num_iters=25, # laplace_mode_fit_method="BFGS")
0%| | 0/25 [00:00<?, ?it/s]
plt.plot(elbos) plt.show()
Image in a Jupyter notebook
num_emissions_channels_to_view = 5 num_trials_to_view = 1 # Ex is expected hidden states, Ey is expected visible values for trial_idx in range(num_trials_to_view): states, data, Ex, Ey, Covy, Ey_true = extract_trial_stats( trial_idx, posteriors, all_data, all_states, fitted_lds, true_lds ) compare_dynamics(Ex, states, data) plt.savefig("poisson-hmm-dynamics.pdf") plt.savefig("poisson-hmm-dynamics.png") plt.show() compare_smoothened_predictions( Ey[:, :num_emissions_channels_to_view], Ey_true[:, :num_emissions_channels_to_view], Covy, data[:, :num_emissions_channels_to_view], ) plt.savefig("poisson-hmm-trajectory.pdf") plt.savefig("poisson-hmm-trajectory.png") plt.show()
Image in a Jupyter notebookImage in a Jupyter notebook