Path: blob/master/deprecated/notebooks/arhmm_example.ipynb
1192 views
Kernel: Python 3 (ipykernel)
Autoregressive (AR) HMM Demo
Modified from
https://github.com/lindermanlab/ssm-jax-refactor/blob/main/notebooks/arhmm-example.ipynb
This notebook illustrates the use of the auto_regression observation model. Let denote the observation at time . Let denote the corresponding discrete latent state.
The autoregressive hidden Markov model has the following likelihood, (Technically, higher-order autoregressive processes with extra linear terms from inputs are also implemented.)
In [4]:
Out[4]:
Collecting git+git://github.com/lindermanlab/ssm-jax-refactor.git
Cloning git://github.com/lindermanlab/ssm-jax-refactor.git to /tmp/pip-req-build-2x6xx8nv
Running command git clone -q git://github.com/lindermanlab/ssm-jax-refactor.git /tmp/pip-req-build-2x6xx8nv
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (1.19.5)
Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (1.4.1)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (3.2.2)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (1.0.2)
Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (4.62.3)
Requirement already satisfied: seaborn in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (0.11.2)
Collecting jax==0.2.21
Downloading jax-0.2.21.tar.gz (756 kB)
|████████████████████████████████| 756 kB 15.4 MB/s
Requirement already satisfied: jaxlib in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (0.1.71+cuda111)
Requirement already satisfied: h5py in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (3.1.0)
Requirement already satisfied: jupyter in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (1.0.0)
Requirement already satisfied: ipywidgets in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (7.6.5)
Requirement already satisfied: tensorflow-probability in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (0.15.0)
Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax==0.2.21->ssm==0.1) (1.0.0)
Requirement already satisfied: opt_einsum in /usr/local/lib/python3.7/dist-packages (from jax==0.2.21->ssm==0.1) (3.3.0)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax==0.2.21->ssm==0.1) (1.15.0)
Requirement already satisfied: cached-property in /usr/local/lib/python3.7/dist-packages (from h5py->ssm==0.1) (1.5.2)
Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (1.0.2)
Requirement already satisfied: ipython-genutils~=0.2.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (0.2.0)
Requirement already satisfied: ipython>=4.0.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (5.5.0)
Requirement already satisfied: widgetsnbextension~=3.5.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (3.5.2)
Requirement already satisfied: nbformat>=4.2.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (5.1.3)
Requirement already satisfied: traitlets>=4.3.1 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (5.1.1)
Requirement already satisfied: ipykernel>=4.5.1 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (4.10.1)
Requirement already satisfied: jupyter-client in /usr/local/lib/python3.7/dist-packages (from ipykernel>=4.5.1->ipywidgets->ssm==0.1) (5.3.5)
Requirement already satisfied: tornado>=4.0 in /usr/local/lib/python3.7/dist-packages (from ipykernel>=4.5.1->ipywidgets->ssm==0.1) (5.1.1)
Requirement already satisfied: simplegeneric>0.8 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (0.8.1)
Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (57.4.0)
Requirement already satisfied: prompt-toolkit<2.0.0,>=1.0.4 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (1.0.18)
Requirement already satisfied: pygments in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (2.6.1)
Requirement already satisfied: decorator in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (4.4.2)
Requirement already satisfied: pexpect in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (4.8.0)
Requirement already satisfied: pickleshare in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (0.7.5)
Requirement already satisfied: jupyter-core in /usr/local/lib/python3.7/dist-packages (from nbformat>=4.2.0->ipywidgets->ssm==0.1) (4.9.1)
Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /usr/local/lib/python3.7/dist-packages (from nbformat>=4.2.0->ipywidgets->ssm==0.1) (4.3.3)
Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (21.4.0)
Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (0.18.1)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (4.10.1)
Requirement already satisfied: importlib-resources>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (5.4.0)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (3.10.0.2)
Requirement already satisfied: zipp>=3.1.0 in /usr/local/lib/python3.7/dist-packages (from importlib-resources>=1.4.0->jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (3.7.0)
Requirement already satisfied: wcwidth in /usr/local/lib/python3.7/dist-packages (from prompt-toolkit<2.0.0,>=1.0.4->ipython>=4.0.0->ipywidgets->ssm==0.1) (0.2.5)
Requirement already satisfied: notebook>=4.4.1 in /usr/local/lib/python3.7/dist-packages (from widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (5.3.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (2.11.3)
Requirement already satisfied: nbconvert in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (5.6.1)
Requirement already satisfied: Send2Trash in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (1.8.0)
Requirement already satisfied: terminado>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.13.1)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from jupyter-client->ipykernel>=4.5.1->ipywidgets->ssm==0.1) (2.8.2)
Requirement already satisfied: pyzmq>=13 in /usr/local/lib/python3.7/dist-packages (from jupyter-client->ipykernel>=4.5.1->ipywidgets->ssm==0.1) (22.3.0)
Requirement already satisfied: ptyprocess in /usr/local/lib/python3.7/dist-packages (from terminado>=0.8.1->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.7.0)
Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib->ssm==0.1) (2.0)
Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (2.0.1)
Requirement already satisfied: jupyter-console in /usr/local/lib/python3.7/dist-packages (from jupyter->ssm==0.1) (5.2.0)
Requirement already satisfied: qtconsole in /usr/local/lib/python3.7/dist-packages (from jupyter->ssm==0.1) (5.2.2)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->ssm==0.1) (1.3.2)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->ssm==0.1) (3.0.7)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->ssm==0.1) (0.11.0)
Requirement already satisfied: defusedxml in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.7.1)
Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.8.4)
Requirement already satisfied: testpath in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.5.0)
Requirement already satisfied: bleach in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (4.1.0)
Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.3)
Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (1.5.0)
Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (21.3)
Requirement already satisfied: webencodings in /usr/local/lib/python3.7/dist-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.5.1)
Requirement already satisfied: qtpy in /usr/local/lib/python3.7/dist-packages (from qtconsole->jupyter->ssm==0.1) (2.0.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->ssm==0.1) (3.0.0)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->ssm==0.1) (1.1.0)
Requirement already satisfied: pandas>=0.23 in /usr/local/lib/python3.7/dist-packages (from seaborn->ssm==0.1) (1.3.5)
Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.23->seaborn->ssm==0.1) (2018.9)
Requirement already satisfied: cloudpickle>=1.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability->ssm==0.1) (1.3.0)
Requirement already satisfied: gast>=0.3.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability->ssm==0.1) (0.4.0)
Requirement already satisfied: dm-tree in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability->ssm==0.1) (0.1.6)
Building wheels for collected packages: ssm, jax
Building wheel for ssm (setup.py) ... done
Created wheel for ssm: filename=ssm-0.1-py3-none-any.whl size=75282 sha256=bde077b314fa6664e8403f9bed8b69a2ae93fde31ed7843c500b456c71f7a819
Stored in directory: /tmp/pip-ephem-wheel-cache-s91lg85w/wheels/78/93/24/866323c03bb6444c9ad2485bc0abe61ad5e6828d66c2c2fda3
Building wheel for jax (setup.py) ... done
Created wheel for jax: filename=jax-0.2.21-py3-none-any.whl size=869303 sha256=dc33be4778a51583c3656c00dad98a420e13b2905773a4f7fd321de9ac478eb7
Stored in directory: /root/.cache/pip/wheels/5c/69/0d/3784dd6d281be0837d8cef1db0c8b37d108c8bff727b961178
Successfully built ssm jax
Installing collected packages: jax, ssm
Attempting uninstall: jax
Found existing installation: jax 0.2.25
Uninstalling jax-0.2.25:
Successfully uninstalled jax-0.2.25
Successfully installed jax-0.2.21 ssm-0.1
In [5]:
In [6]:
In [7]:
In [11]:
Out[11]:
In [9]:
In [34]:
Out[34]:
7.2
72.0
Plot dynamics functions
In [60]:
Out[60]:
In [62]:
Out[62]:
In [61]:
Out[61]:
[[ 3.9209812 -1.9056123 ]
[ 3.0239954 3.140209 ]
[-2.0520504 3.8463693 ]
[-4.2922325 -0.7630231 ]
[-0.60069436 -4.3179426 ]]
Sample data from the ARHMM
In [13]:
In [37]:
Out[37]:
In [53]:
Out[53]:
In [57]:
Out[57]:
[[ 1. 0. ]
[ 0.30901697 0.95105654]
[-0.80901706 0.58778524]
[-0.80901694 -0.5877853 ]
[ 0.30901712 -0.9510565 ]]
In [56]:
Out[56]:
[[ 3.9209812 -1.9056123 ]
[ 3.0239954 3.140209 ]
[-2.0520504 3.8463693 ]
[-4.2922325 -0.7630231 ]
[-0.60069436 -4.3179426 ]]
In [58]:
Out[58]:
Below, we visualize each component of of the observation variable as a time series. The colors correspond to the latent state. The dotted lines represent the stationary point of the the corresponding AR state while the solid lines are the actual observations sampled from the HMM.
In [52]:
Out[52]:
DeviceArray(4.7118726, dtype=float32)
In [51]:
Out[51]:
In [59]:
Out[59]:
(10000, 2)
In [55]:
Out[55]:
DeviceArray([[-0.8169615 , 0.55239207],
[-1.423961 , 1.1161395 ],
[-1.8636721 , 1.6114323 ],
[-2.194045 , 2.0294168 ],
[-2.3697448 , 2.4288142 ],
[-2.4682324 , 2.7841942 ],
[-2.4781866 , 2.9988554 ],
[-2.4822226 , 3.1525304 ],
[-2.489001 , 3.346456 ],
[-2.4774575 , 3.4129188 ]], dtype=float32)
Fit an ARHMM
In [16]:
Out[16]:
Initializing...
Done.
0%| | 0/100 [00:00<?, ?it/s]
In [17]:
Out[17]:
In [ ]:
In [18]:
Out[18]:
In [22]:
Out[22]:
In [24]:
Out[24]:
In [25]:
In [63]:
Out[63]: