Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/notebooks/arhmm_example.ipynb
1192 views
Kernel: Python 3 (ipykernel)

Open In Colab

Autoregressive (AR) HMM Demo

Modified from

https://github.com/lindermanlab/ssm-jax-refactor/blob/main/notebooks/arhmm-example.ipynb

This notebook illustrates the use of the auto_regression observation model. Let xtx_t denote the observation at time tt. Let ztz_t denote the corresponding discrete latent state.

The autoregressive hidden Markov model has the following likelihood, xtxt1,ztN(Aztxt1+bzt,Qzt). \begin{align} x_t \mid x_{t-1}, z_t &\sim \mathcal{N}\left(A_{z_t} x_{t-1} + b_{z_t}, Q_{z_t} \right). \end{align} (Technically, higher-order autoregressive processes with extra linear terms from inputs are also implemented.)

!pip install git+git://github.com/lindermanlab/ssm-jax-refactor.git
Collecting git+git://github.com/lindermanlab/ssm-jax-refactor.git Cloning git://github.com/lindermanlab/ssm-jax-refactor.git to /tmp/pip-req-build-2x6xx8nv Running command git clone -q git://github.com/lindermanlab/ssm-jax-refactor.git /tmp/pip-req-build-2x6xx8nv Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (1.19.5) Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (1.4.1) Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (3.2.2) Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (1.0.2) Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (4.62.3) Requirement already satisfied: seaborn in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (0.11.2) Collecting jax==0.2.21 Downloading jax-0.2.21.tar.gz (756 kB) |████████████████████████████████| 756 kB 15.4 MB/s Requirement already satisfied: jaxlib in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (0.1.71+cuda111) Requirement already satisfied: h5py in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (3.1.0) Requirement already satisfied: jupyter in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (1.0.0) Requirement already satisfied: ipywidgets in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (7.6.5) Requirement already satisfied: tensorflow-probability in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (0.15.0) Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax==0.2.21->ssm==0.1) (1.0.0) Requirement already satisfied: opt_einsum in /usr/local/lib/python3.7/dist-packages (from jax==0.2.21->ssm==0.1) (3.3.0) Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax==0.2.21->ssm==0.1) (1.15.0) Requirement already satisfied: cached-property in /usr/local/lib/python3.7/dist-packages (from h5py->ssm==0.1) (1.5.2) Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (1.0.2) Requirement already satisfied: ipython-genutils~=0.2.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (0.2.0) Requirement already satisfied: ipython>=4.0.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (5.5.0) Requirement already satisfied: widgetsnbextension~=3.5.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (3.5.2) Requirement already satisfied: nbformat>=4.2.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (5.1.3) Requirement already satisfied: traitlets>=4.3.1 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (5.1.1) Requirement already satisfied: ipykernel>=4.5.1 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (4.10.1) Requirement already satisfied: jupyter-client in /usr/local/lib/python3.7/dist-packages (from ipykernel>=4.5.1->ipywidgets->ssm==0.1) (5.3.5) Requirement already satisfied: tornado>=4.0 in /usr/local/lib/python3.7/dist-packages (from ipykernel>=4.5.1->ipywidgets->ssm==0.1) (5.1.1) Requirement already satisfied: simplegeneric>0.8 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (0.8.1) Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (57.4.0) Requirement already satisfied: prompt-toolkit<2.0.0,>=1.0.4 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (1.0.18) Requirement already satisfied: pygments in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (2.6.1) Requirement already satisfied: decorator in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (4.4.2) Requirement already satisfied: pexpect in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (4.8.0) Requirement already satisfied: pickleshare in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (0.7.5) Requirement already satisfied: jupyter-core in /usr/local/lib/python3.7/dist-packages (from nbformat>=4.2.0->ipywidgets->ssm==0.1) (4.9.1) Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /usr/local/lib/python3.7/dist-packages (from nbformat>=4.2.0->ipywidgets->ssm==0.1) (4.3.3) Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (21.4.0) Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (0.18.1) Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (4.10.1) Requirement already satisfied: importlib-resources>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (5.4.0) Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (3.10.0.2) Requirement already satisfied: zipp>=3.1.0 in /usr/local/lib/python3.7/dist-packages (from importlib-resources>=1.4.0->jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (3.7.0) Requirement already satisfied: wcwidth in /usr/local/lib/python3.7/dist-packages (from prompt-toolkit<2.0.0,>=1.0.4->ipython>=4.0.0->ipywidgets->ssm==0.1) (0.2.5) Requirement already satisfied: notebook>=4.4.1 in /usr/local/lib/python3.7/dist-packages (from widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (5.3.1) Requirement already satisfied: jinja2 in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (2.11.3) Requirement already satisfied: nbconvert in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (5.6.1) Requirement already satisfied: Send2Trash in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (1.8.0) Requirement already satisfied: terminado>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.13.1) Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from jupyter-client->ipykernel>=4.5.1->ipywidgets->ssm==0.1) (2.8.2) Requirement already satisfied: pyzmq>=13 in /usr/local/lib/python3.7/dist-packages (from jupyter-client->ipykernel>=4.5.1->ipywidgets->ssm==0.1) (22.3.0) Requirement already satisfied: ptyprocess in /usr/local/lib/python3.7/dist-packages (from terminado>=0.8.1->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.7.0) Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib->ssm==0.1) (2.0) Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (2.0.1) Requirement already satisfied: jupyter-console in /usr/local/lib/python3.7/dist-packages (from jupyter->ssm==0.1) (5.2.0) Requirement already satisfied: qtconsole in /usr/local/lib/python3.7/dist-packages (from jupyter->ssm==0.1) (5.2.2) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->ssm==0.1) (1.3.2) Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->ssm==0.1) (3.0.7) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->ssm==0.1) (0.11.0) Requirement already satisfied: defusedxml in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.7.1) Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.8.4) Requirement already satisfied: testpath in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.5.0) Requirement already satisfied: bleach in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (4.1.0) Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.3) Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (1.5.0) Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (21.3) Requirement already satisfied: webencodings in /usr/local/lib/python3.7/dist-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.5.1) Requirement already satisfied: qtpy in /usr/local/lib/python3.7/dist-packages (from qtconsole->jupyter->ssm==0.1) (2.0.0) Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->ssm==0.1) (3.0.0) Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->ssm==0.1) (1.1.0) Requirement already satisfied: pandas>=0.23 in /usr/local/lib/python3.7/dist-packages (from seaborn->ssm==0.1) (1.3.5) Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.23->seaborn->ssm==0.1) (2018.9) Requirement already satisfied: cloudpickle>=1.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability->ssm==0.1) (1.3.0) Requirement already satisfied: gast>=0.3.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability->ssm==0.1) (0.4.0) Requirement already satisfied: dm-tree in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability->ssm==0.1) (0.1.6) Building wheels for collected packages: ssm, jax Building wheel for ssm (setup.py) ... done Created wheel for ssm: filename=ssm-0.1-py3-none-any.whl size=75282 sha256=bde077b314fa6664e8403f9bed8b69a2ae93fde31ed7843c500b456c71f7a819 Stored in directory: /tmp/pip-ephem-wheel-cache-s91lg85w/wheels/78/93/24/866323c03bb6444c9ad2485bc0abe61ad5e6828d66c2c2fda3 Building wheel for jax (setup.py) ... done Created wheel for jax: filename=jax-0.2.21-py3-none-any.whl size=869303 sha256=dc33be4778a51583c3656c00dad98a420e13b2905773a4f7fd321de9ac478eb7 Stored in directory: /root/.cache/pip/wheels/5c/69/0d/3784dd6d281be0837d8cef1db0c8b37d108c8bff727b961178 Successfully built ssm jax Installing collected packages: jax, ssm Attempting uninstall: jax Found existing installation: jax 0.2.25 Uninstalling jax-0.2.25: Successfully uninstalled jax-0.2.25 Successfully installed jax-0.2.21 ssm-0.1
import ssm
import copy import jax.numpy as np import jax.random as jr from tensorflow_probability.substrates import jax as tfp from ssm.distributions.linreg import GaussianLinearRegression from ssm.arhmm import GaussianARHMM from ssm.utils import find_permutation, random_rotation from ssm.plots import gradient_cmap # , white_to_color_cmap import matplotlib.pyplot as plt %matplotlib inline import seaborn as sns
sns.set_style("white") sns.set_context("talk") color_names = ["windows blue", "red", "amber", "faded green", "dusty purple", "orange", "brown", "pink"] colors = sns.xkcd_palette(color_names) cmap = gradient_cmap(colors)
# Make a transition matrix num_states = 5 transition_probs = (np.arange(num_states) ** 10).astype(float) transition_probs /= transition_probs.sum() transition_matrix = np.zeros((num_states, num_states)) for k, p in enumerate(transition_probs[::-1]): transition_matrix += np.roll(p * np.eye(num_states), k, axis=1) plt.imshow(transition_matrix, vmin=0, vmax=1, cmap="Greys") plt.xlabel("next state") plt.ylabel("current state") plt.title("transition matrix") plt.colorbar() plt.savefig("arhmm-transmat.pdf")
Image in a Jupyter notebook
# Make observation distributions data_dim = 2 num_lags = 1 keys = jr.split(jr.PRNGKey(0), num_states) angles = np.linspace(0, 2 * np.pi, num_states, endpoint=False) theta = np.pi / 25 # rotational frequency weights = np.array([0.8 * random_rotation(key, data_dim, theta=theta) for key in keys]) biases = np.column_stack([np.cos(angles), np.sin(angles), np.zeros((num_states, data_dim - 2))]) covariances = np.tile(0.001 * np.eye(data_dim), (num_states, 1, 1)) # Compute the stationary points stationary_points = np.linalg.solve(np.eye(data_dim) - weights, biases)
print(theta / (2 * np.pi) * 360) print(360 / 5)
7.2 72.0

Plot dynamics functions

if data_dim == 2: lim = 5 x = np.linspace(-lim, lim, 10) y = np.linspace(-lim, lim, 10) X, Y = np.meshgrid(x, y) xy = np.column_stack((X.ravel(), Y.ravel())) fig, axs = plt.subplots(1, num_states, figsize=(3 * num_states, 6)) for k in range(num_states): A, b = weights[k], biases[k] dxydt_m = xy.dot(A.T) + b - xy axs[k].quiver(xy[:, 0], xy[:, 1], dxydt_m[:, 0], dxydt_m[:, 1], color=colors[k % len(colors)]) axs[k].set_xlabel("$y_1$") # axs[k].set_xticks([]) if k == 0: axs[k].set_ylabel("$y_2$") # axs[k].set_yticks([]) axs[k].set_aspect("equal") plt.tight_layout() plt.savefig("arhmm-flow-matrices.pdf")
Image in a Jupyter notebook
colors
print(stationary_points)
[[ 3.9209812 -1.9056123 ] [ 3.0239954 3.140209 ] [-2.0520504 3.8463693 ] [-4.2922325 -0.7630231 ] [-0.60069436 -4.3179426 ]]

Sample data from the ARHMM

# Make an Autoregressive (AR) HMM true_initial_distribution = tfp.distributions.Categorical(logits=np.zeros(num_states)) true_transition_distribution = tfp.distributions.Categorical(probs=transition_matrix) true_arhmm = GaussianARHMM( num_states, transition_matrix=transition_matrix, emission_weights=weights, emission_biases=biases, emission_covariances=covariances, ) time_bins = 10000 true_states, data = true_arhmm.sample(jr.PRNGKey(0), time_bins)
fig = plt.figure(figsize=(8, 8)) for k in range(num_states): plt.plot(*data[true_states == k].T, "o", color=colors[k], alpha=0.75, markersize=3) plt.plot(*data[:1000].T, "-k", lw=0.5, alpha=0.2) plt.xlabel("$y_1$") plt.ylabel("$y_2$") # plt.gca().set_aspect("equal") plt.savefig("arhmm-samples-2d.pdf")
Image in a Jupyter notebook
fig = plt.figure(figsize=(8, 8)) for k in range(num_states): ndx = true_states == k data_k = data[ndx] T = 12 data_k = data_k[:T, :] plt.plot(data_k[:, 0], data_k[:, 1], "o", color=colors[k], alpha=0.75, markersize=3) for t in range(T): plt.text(data_k[t, 0], data_k[t, 1], t, color=colors[k], fontsize=12) # plt.plot(*data[:1000].T, '-k', lw=0.5, alpha=0.2) plt.xlabel("$y_1$") plt.ylabel("$y_2$") # plt.gca().set_aspect("equal") plt.savefig("arhmm-samples-2d-temporal.pdf")
Image in a Jupyter notebook
print(biases)
[[ 1. 0. ] [ 0.30901697 0.95105654] [-0.80901706 0.58778524] [-0.80901694 -0.5877853 ] [ 0.30901712 -0.9510565 ]]
print(stationary_points)
[[ 3.9209812 -1.9056123 ] [ 3.0239954 3.140209 ] [-2.0520504 3.8463693 ] [-4.2922325 -0.7630231 ] [-0.60069436 -4.3179426 ]]
colors

Below, we visualize each component of of the observation variable as a time series. The colors correspond to the latent state. The dotted lines represent the stationary point of the the corresponding AR state while the solid lines are the actual observations sampled from the HMM.

lim
DeviceArray(4.7118726, dtype=float32)
# Plot the data and the smoothed data plot_slice = (0, 200) lim = 1.05 * abs(data).max() plt.figure(figsize=(8, 6)) plt.imshow( true_states[None, :], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors) - 1, extent=(0, time_bins, -lim, (data_dim) * lim), ) Ey = np.array(stationary_points)[true_states] for d in range(data_dim): plt.plot(data[:, d] + lim * d, "-k") plt.plot(Ey[:, d] + lim * d, ":k") plt.xlim(plot_slice) plt.xlabel("time") # plt.yticks(lim * np.arange(data_dim), ["$y_{{{}}}$".format(d+1) for d in range(data_dim)]) plt.ylabel("observations") plt.tight_layout() plt.savefig("arhmm-samples-1d.pdf")
Image in a Jupyter notebook
data.shape
(10000, 2)
data[:10, :]
DeviceArray([[-0.8169615 , 0.55239207], [-1.423961 , 1.1161395 ], [-1.8636721 , 1.6114323 ], [-2.194045 , 2.0294168 ], [-2.3697448 , 2.4288142 ], [-2.4682324 , 2.7841942 ], [-2.4781866 , 2.9988554 ], [-2.4822226 , 3.1525304 ], [-2.489001 , 3.346456 ], [-2.4774575 , 3.4129188 ]], dtype=float32)

Fit an ARHMM

# Now fit an HMM to the data key1, key2 = jr.split(jr.PRNGKey(0), 2) test_num_states = num_states initial_distribution = tfp.distributions.Categorical(logits=np.zeros(test_num_states)) transition_distribution = tfp.distributions.Categorical(logits=np.zeros((test_num_states, test_num_states))) emission_distribution = GaussianLinearRegression( weights=np.tile(0.99 * np.eye(data_dim), (test_num_states, 1, 1)), bias=0.01 * jr.normal(key2, (test_num_states, data_dim)), scale_tril=np.tile(np.eye(data_dim), (test_num_states, 1, 1)), ) arhmm = GaussianARHMM(test_num_states, data_dim, num_lags, seed=jr.PRNGKey(0)) lps, arhmm, posterior = arhmm.fit(data, method="em")
Initializing... Done.
0%| | 0/100 [00:00<?, ?it/s]
# Plot the log likelihoods against the true likelihood, for comparison true_lp = true_arhmm.marginal_likelihood(data) plt.plot(lps, label="EM") plt.plot(true_lp * np.ones(len(lps)), ":k", label="True") plt.xlabel("EM Iteration") plt.ylabel("Log Probability") plt.legend(loc="lower right") plt.show()
Image in a Jupyter notebook
# # Find a permutation of the states that best matches the true and inferred states # most_likely_states = posterior.most_likely_states() # arhmm.permute(find_permutation(true_states[num_lags:], most_likely_states)) # posterior.update() # most_likely_states = posterior.most_likely_states()
if data_dim == 2: lim = abs(data).max() x = np.linspace(-lim, lim, 10) y = np.linspace(-lim, lim, 10) X, Y = np.meshgrid(x, y) xy = np.column_stack((X.ravel(), Y.ravel())) fig, axs = plt.subplots(2, max(num_states, test_num_states), figsize=(3 * num_states, 6)) for i, model in enumerate([true_arhmm, arhmm]): for j in range(model.num_states): dist = model._emissions._distribution[j] A, b = dist.weights, dist.bias dxydt_m = xy.dot(A.T) + b - xy axs[i, j].quiver(xy[:, 0], xy[:, 1], dxydt_m[:, 0], dxydt_m[:, 1], color=colors[j % len(colors)]) axs[i, j].set_xlabel("$x_1$") axs[i, j].set_xticks([]) if j == 0: axs[i, j].set_ylabel("$x_2$") axs[i, j].set_yticks([]) axs[i, j].set_aspect("equal") plt.tight_layout() plt.savefig("argmm-flow-matrices-true-and-estimated.pdf")
Image in a Jupyter notebook
if data_dim == 2: lim = abs(data).max() x = np.linspace(-lim, lim, 10) y = np.linspace(-lim, lim, 10) X, Y = np.meshgrid(x, y) xy = np.column_stack((X.ravel(), Y.ravel())) fig, axs = plt.subplots(1, max(num_states, test_num_states), figsize=(3 * num_states, 6)) for i, model in enumerate([arhmm]): for j in range(model.num_states): dist = model._emissions._distribution[j] A, b = dist.weights, dist.bias dxydt_m = xy.dot(A.T) + b - xy axs[j].quiver(xy[:, 0], xy[:, 1], dxydt_m[:, 0], dxydt_m[:, 1], color=colors[j % len(colors)]) axs[j].set_xlabel("$y_1$") axs[j].set_xticks([]) if j == 0: axs[j].set_ylabel("$y_2$") axs[j].set_yticks([]) axs[j].set_aspect("equal") plt.tight_layout() plt.savefig("arhmm-flow-matrices-estimated.pdf")
Image in a Jupyter notebook
# Plot the true and inferred discrete states plot_slice = (0, 1000) plt.figure(figsize=(8, 4)) plt.subplot(211) plt.imshow(true_states[None, num_lags:], aspect="auto", interpolation="none", cmap=cmap, vmin=0, vmax=len(colors) - 1) plt.xlim(plot_slice) plt.ylabel("$z_{\\mathrm{true}}$") plt.yticks([]) plt.subplot(212) # plt.imshow(most_likely_states[None,: :], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors)-1) plt.imshow(posterior.expected_states[0].T, aspect="auto", interpolation="none", cmap="Greys", vmin=0, vmax=1) plt.xlim(plot_slice) plt.ylabel("$z_{\\mathrm{inferred}}$") plt.yticks([]) plt.xlabel("time") plt.tight_layout() plt.savefig("arhmm-state-est.pdf")
Image in a Jupyter notebook
# Sample the fitted model sampled_states, sampled_data = arhmm.sample(jr.PRNGKey(0), time_bins)
fig = plt.figure(figsize=(8, 8)) for k in range(num_states): plt.plot(*sampled_data[sampled_states == k].T, "o", color=colors[k], alpha=0.75, markersize=3) # plt.plot(*sampled_data.T, '-k', lw=0.5, alpha=0.2) plt.plot(*sampled_data[:1000].T, "-k", lw=0.5, alpha=0.2) plt.xlabel("$x_1$") plt.ylabel("$x_2$") # plt.gca().set_aspect("equal") plt.savefig("arhmm-samples-2d-estimated.pdf")
Image in a Jupyter notebook