Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/misc/casino_hmm.py
1192 views
1
"""
2
Catching an "Occasionally dishonest Casino" HMM
3
Based on https://github.com/probml/JSL/blob/main/jsl/demos/hmm_casino.py
4
"""
5
import jax.numpy as jnp
6
import jax.random as jr
7
8
import matplotlib.pyplot as plt
9
import numpy as np
10
11
from ssm_jax.hmm.models import CategoricalHMM
12
13
14
# Helper functions for plotting
15
def find_dishonest_intervals(states):
16
"""
17
Find the span of timesteps that the
18
simulated systems turns to be in state 1
19
Parameters
20
----------
21
states: array(n_timesteps)
22
Result of running the system with two
23
latent states
24
Returns
25
-------
26
list of tuples with span of values
27
"""
28
states = np.array(states)
29
changepoints = np.concatenate([[0], np.nonzero(np.diff(states))[0] + 1, [len(states)]])
30
starts, ends = changepoints[:-1], changepoints[1:]
31
32
# Return the (start, end) pairs where the start state is 1
33
dishonest = states[starts] == 1
34
return list(zip(starts[dishonest], ends[dishonest]))
35
36
37
def plot_inference(inference_values, states, ax, state=1, map_estimate=False):
38
"""
39
Plot the estimated smoothing/filtering/map of a sequence of hidden states.
40
"Vertical gray bars denote times when the hidden
41
state corresponded to state 1. Blue lines represent the
42
posterior probability of being in that state given different subsets
43
of observed data." See Markov and Hidden Markov models section for more info
44
Parameters
45
----------
46
inference_values: array(n_timesteps, state_size)
47
Result of running smoothing method
48
states: array(n_timesteps)
49
Latent simulation
50
ax: matplotlib.axes
51
state: int
52
Decide which state to highlight
53
map_estimate: bool
54
Whether to plot steps (simple plot if False)
55
"""
56
n_timesteps = len(inference_values)
57
xspan = np.arange(1, n_timesteps + 1)
58
spans = find_dishonest_intervals(states)
59
if map_estimate:
60
ax.step(xspan, inference_values, where="post")
61
else:
62
ax.plot(xspan, inference_values[:, state])
63
64
for span in spans:
65
ax.axvspan(*span, alpha=0.5, facecolor="tab:gray", edgecolor="none")
66
ax.set_xlim(1, n_timesteps)
67
# ax.set_ylim(0, 1)
68
ax.set_ylim(-0.1, 1.1)
69
ax.set_xlabel("Observation number")
70
71
72
def make_model_and_data():
73
# Construct the model
74
transition_matrix = jnp.array([[0.95, 0.05], [0.10, 0.90]])
75
emission_probs = jnp.array(
76
[
77
[1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6], # fair die
78
[1 / 10, 1 / 10, 1 / 10, 1 / 10, 1 / 10, 5 / 10], # loaded die
79
]
80
).reshape((2, 1, 6))
81
init_state_probs = jnp.array([1 / 2, 1 / 2])
82
hmm = CategoricalHMM(init_state_probs, transition_matrix, emission_probs) # old API
83
84
# Simulate data
85
n_timesteps = 300
86
true_states, emissions = hmm.sample(jr.PRNGKey(0), n_timesteps) # old API!
87
88
return hmm, true_states, emissions
89
90
91
def plot_results(true_states, emissions, posterior, most_likely_states):
92
print("Printing sample observed/latent...")
93
to_string = lambda x: "".join((np.array(x) + 1).astype(str))[:60]
94
print("hid: ", to_string(true_states)[:60])
95
print("obs: ", to_string(emissions)[:60])
96
print("Log likelihood: ", posterior.marginal_loglik)
97
98
dict_figures = {}
99
fig, ax = plt.subplots()
100
plot_inference(posterior.filtered_probs, true_states, ax)
101
ax.set_ylabel("p(loaded)")
102
ax.set_title("Filtered")
103
dict_figures["hmm_casino_filter"] = fig
104
105
fig, ax = plt.subplots()
106
plot_inference(posterior.smoothed_probs, true_states, ax)
107
ax.set_ylabel("p(loaded)")
108
ax.set_title("Smoothed")
109
dict_figures["hmm_casino_smooth"] = fig
110
111
fig, ax = plt.subplots()
112
plot_inference(most_likely_states, true_states, ax, map_estimate=True)
113
ax.set_ylabel("MAP state")
114
ax.set_title("Viterbi")
115
dict_figures["hmm_casino_map"] = fig
116
117
return dict_figures
118
119
120
def main(test_mode=False):
121
hmm, true_states, emissions = make_model_and_data()
122
posterior = hmm.smoother(emissions)
123
most_likely_states = hmm.most_likely_states(emissions)
124
if not test_mode:
125
dict_figures = plot_results(true_states, emissions, posterior, most_likely_states)
126
plt.show()
127
128
129
# Run the demo
130
if __name__ == "__main__":
131
main()
132
133