Path: blob/master/deprecated/notebooks/bernoulli_hmm_example.ipynb
1192 views
Kernel: Python 3 (ipykernel)
Bernoulli HMM Example Notebook
Modified from https://github.com/lindermanlab/ssm-jax-refactor/blob/main/notebooks/bernoulli-hmm-example.ipynb
In [2]:
Out[2]:
Collecting git+git://github.com/lindermanlab/ssm-jax-refactor.git
Cloning git://github.com/lindermanlab/ssm-jax-refactor.git to /tmp/pip-req-build-j0n1k4xi
Running command git clone -q git://github.com/lindermanlab/ssm-jax-refactor.git /tmp/pip-req-build-j0n1k4xi
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (1.19.5)
Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (1.4.1)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (3.2.2)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (1.0.2)
Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (4.62.3)
Requirement already satisfied: seaborn in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (0.11.2)
Collecting jax==0.2.21
Downloading jax-0.2.21.tar.gz (756 kB)
|████████████████████████████████| 756 kB 13.7 MB/s
Requirement already satisfied: jaxlib in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (0.1.71+cuda111)
Requirement already satisfied: h5py in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (3.1.0)
Requirement already satisfied: jupyter in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (1.0.0)
Requirement already satisfied: ipywidgets in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (7.6.5)
Requirement already satisfied: tensorflow-probability in /usr/local/lib/python3.7/dist-packages (from ssm==0.1) (0.15.0)
Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax==0.2.21->ssm==0.1) (1.0.0)
Requirement already satisfied: opt_einsum in /usr/local/lib/python3.7/dist-packages (from jax==0.2.21->ssm==0.1) (3.3.0)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax==0.2.21->ssm==0.1) (1.15.0)
Requirement already satisfied: cached-property in /usr/local/lib/python3.7/dist-packages (from h5py->ssm==0.1) (1.5.2)
Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (1.0.2)
Requirement already satisfied: ipython-genutils~=0.2.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (0.2.0)
Requirement already satisfied: ipython>=4.0.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (5.5.0)
Requirement already satisfied: nbformat>=4.2.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (5.1.3)
Requirement already satisfied: ipykernel>=4.5.1 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (4.10.1)
Requirement already satisfied: traitlets>=4.3.1 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (5.1.1)
Requirement already satisfied: widgetsnbextension~=3.5.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->ssm==0.1) (3.5.2)
Requirement already satisfied: jupyter-client in /usr/local/lib/python3.7/dist-packages (from ipykernel>=4.5.1->ipywidgets->ssm==0.1) (5.3.5)
Requirement already satisfied: tornado>=4.0 in /usr/local/lib/python3.7/dist-packages (from ipykernel>=4.5.1->ipywidgets->ssm==0.1) (5.1.1)
Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (57.4.0)
Requirement already satisfied: simplegeneric>0.8 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (0.8.1)
Requirement already satisfied: pickleshare in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (0.7.5)
Requirement already satisfied: pygments in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (2.6.1)
Requirement already satisfied: prompt-toolkit<2.0.0,>=1.0.4 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (1.0.18)
Requirement already satisfied: decorator in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (4.4.2)
Requirement already satisfied: pexpect in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipywidgets->ssm==0.1) (4.8.0)
Requirement already satisfied: jupyter-core in /usr/local/lib/python3.7/dist-packages (from nbformat>=4.2.0->ipywidgets->ssm==0.1) (4.9.1)
Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /usr/local/lib/python3.7/dist-packages (from nbformat>=4.2.0->ipywidgets->ssm==0.1) (4.3.3)
Requirement already satisfied: importlib-resources>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (5.4.0)
Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (21.4.0)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (4.10.1)
Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (0.18.1)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (3.10.0.2)
Requirement already satisfied: zipp>=3.1.0 in /usr/local/lib/python3.7/dist-packages (from importlib-resources>=1.4.0->jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->ssm==0.1) (3.7.0)
Requirement already satisfied: wcwidth in /usr/local/lib/python3.7/dist-packages (from prompt-toolkit<2.0.0,>=1.0.4->ipython>=4.0.0->ipywidgets->ssm==0.1) (0.2.5)
Requirement already satisfied: notebook>=4.4.1 in /usr/local/lib/python3.7/dist-packages (from widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (5.3.1)
Requirement already satisfied: nbconvert in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (5.6.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (2.11.3)
Requirement already satisfied: terminado>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.13.1)
Requirement already satisfied: Send2Trash in /usr/local/lib/python3.7/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (1.8.0)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from jupyter-client->ipykernel>=4.5.1->ipywidgets->ssm==0.1) (2.8.2)
Requirement already satisfied: pyzmq>=13 in /usr/local/lib/python3.7/dist-packages (from jupyter-client->ipykernel>=4.5.1->ipywidgets->ssm==0.1) (22.3.0)
Requirement already satisfied: ptyprocess in /usr/local/lib/python3.7/dist-packages (from terminado>=0.8.1->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.7.0)
Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib->ssm==0.1) (2.0)
Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (2.0.1)
Requirement already satisfied: jupyter-console in /usr/local/lib/python3.7/dist-packages (from jupyter->ssm==0.1) (5.2.0)
Requirement already satisfied: qtconsole in /usr/local/lib/python3.7/dist-packages (from jupyter->ssm==0.1) (5.2.2)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->ssm==0.1) (0.11.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->ssm==0.1) (1.3.2)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->ssm==0.1) (3.0.7)
Requirement already satisfied: bleach in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (4.1.0)
Requirement already satisfied: testpath in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.5.0)
Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (1.5.0)
Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.3)
Requirement already satisfied: defusedxml in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.7.1)
Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.8.4)
Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (21.3)
Requirement already satisfied: webencodings in /usr/local/lib/python3.7/dist-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets->ssm==0.1) (0.5.1)
Requirement already satisfied: qtpy in /usr/local/lib/python3.7/dist-packages (from qtconsole->jupyter->ssm==0.1) (2.0.0)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->ssm==0.1) (1.1.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->ssm==0.1) (3.0.0)
Requirement already satisfied: pandas>=0.23 in /usr/local/lib/python3.7/dist-packages (from seaborn->ssm==0.1) (1.3.5)
Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.23->seaborn->ssm==0.1) (2018.9)
Requirement already satisfied: dm-tree in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability->ssm==0.1) (0.1.6)
Requirement already satisfied: cloudpickle>=1.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability->ssm==0.1) (1.3.0)
Requirement already satisfied: gast>=0.3.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability->ssm==0.1) (0.4.0)
Building wheels for collected packages: ssm, jax
Building wheel for ssm (setup.py) ... done
Created wheel for ssm: filename=ssm-0.1-py3-none-any.whl size=75282 sha256=07a1ee07356d240c4bf185042fad0cc0216b5f6a00fb7ba950c498cc897b251e
Stored in directory: /tmp/pip-ephem-wheel-cache-1b2n7kks/wheels/78/93/24/866323c03bb6444c9ad2485bc0abe61ad5e6828d66c2c2fda3
Building wheel for jax (setup.py) ... done
Created wheel for jax: filename=jax-0.2.21-py3-none-any.whl size=869303 sha256=ea52c5240af54eab126396cb09cc3730651e7b07fd3cef4ded61073ccd7e50fd
Stored in directory: /root/.cache/pip/wheels/5c/69/0d/3784dd6d281be0837d8cef1db0c8b37d108c8bff727b961178
Successfully built ssm jax
Installing collected packages: jax, ssm
Attempting uninstall: jax
Found existing installation: jax 0.2.25
Uninstalling jax-0.2.25:
Successfully uninstalled jax-0.2.25
Successfully installed jax-0.2.21 ssm-0.1
In [3]:
Imports and Plotting Functions
In [18]:
Bernoulli HMM
Let's create a true model
In [5]:
In [6]:
Out[6]:
From the true model, we can sample synthetic data
In [7]:
Let's view the synthetic data
In [20]:
Out[20]:
Fit HMM using exact EM update
In [9]:
Out[9]:
Initializing...
Done.
0%| | 0/100 [00:00<?, ?it/s]
In [10]:
Out[10]:
Text(0, 0.5, 'log likelihood')
In [11]:
Out[11]:
DeviceArray([[0.84915197, 0.04904636, 0.00188023, 0.07304415, 0.02687721],
[0.00975706, 0.9331732 , 0.02004754, 0.02332079, 0.01370143],
[0.00132545, 0.02912631, 0.9396162 , 0.0288132 , 0.00111887],
[0.06562801, 0.03183502, 0.00170273, 0.850693 , 0.0501412 ],
[0.01220625, 0.03129839, 0.00972043, 0.01064176, 0.93613315]], dtype=float32)
In [12]:
Out[12]:
<Figure size 432x288 with 0 Axes>
In [21]:
Out[21]:
Fit Bernoulli Over Multiple Trials
In [ ]:
In [ ]:
(5, 500)
(5, 500, 10)
In [ ]:
Initializing...
Done.
LP: -13234.113: 100%|██████████| 100/100 [00:01<00:00, 53.80it/s]
In [ ]:
In [ ]:
===== Trial: 0 =====
===== Trial: 1 =====
===== Trial: 2 =====