Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/scripts/beta_binom_post_plot.py
1192 views
1
import superimport
2
import jax.numpy as jnp
3
import matplotlib.pyplot as plt
4
import pyprobml_utils as pml
5
from jax.scipy.stats import beta, bernoulli
6
7
# Points where we evaluate the pdf
8
x = jnp.linspace(0.001, 0.999, 100)
9
10
11
# Forms graph given the parameters of the prior, likelihood and posterior:
12
def make_graph(data, save_name):
13
prior = beta.pdf(x, a=data["prior"]["a"], b=data["prior"]["b"])
14
n_0 = data["likelihood"]["n_0"]
15
n_1 = data["likelihood"]["n_1"]
16
samples = jnp.concatenate([jnp.zeros(n_0), jnp.ones(n_1)])
17
likelihood_function = jnp.vectorize(
18
lambda p: jnp.exp(bernoulli.logpmf(samples, p).sum())
19
)
20
likelihood = likelihood_function(x)
21
posterior = beta.pdf(x, a=data["posterior"]["a"], b=data["posterior"]["b"])
22
23
fig, ax = plt.subplots()
24
axt = ax.twinx()
25
fig1 = ax.plot(
26
x,
27
prior,
28
"k",
29
label=f"prior Beta({data['prior']['a']}, {data['prior']['b']})",
30
linewidth=2.0,
31
)
32
fig2 = axt.plot(x, likelihood, "r:", label=f"likelihood Bernoulli", linewidth=2.0)
33
fig3 = ax.plot(
34
x,
35
posterior,
36
"b-.",
37
label=f"posterior Beta({data['posterior']['a']}, {data['posterior']['b']})",
38
linewidth=2.0,
39
)
40
fig_list = fig1 + fig2 + fig3
41
labels = [fig.get_label() for fig in fig_list]
42
ax.legend(fig_list, labels, loc="upper left", shadow=True)
43
axt.set_ylabel("Likelihood")
44
ax.set_ylabel("Prior/Posterior")
45
ax.set_title(f"$N_0$:{n_0}, $N_1$:{n_1}")
46
pml.savefig(save_name)
47
48
49
data1 = {
50
"prior": {"a": 1, "b": 1},
51
"likelihood": {"n_0": 1, "n_1": 4},
52
"posterior": {"a": 5, "b": 2},
53
}
54
make_graph(data1, "betaPostUninfSmallSample.pdf")
55
56
data2 = {
57
"prior": {"a": 1, "b": 1},
58
"likelihood": {"n_0": 10, "n_1": 40},
59
"posterior": {"a": 41, "b": 11},
60
}
61
make_graph(data2, "betaPostUninfLargeSample.pdf")
62
63
data3 = {
64
"prior": {"a": 2, "b": 2},
65
"likelihood": {"n_0": 1, "n_1": 4},
66
"posterior": {"a": 6, "b": 3},
67
}
68
make_graph(data3, "betaPostInfSmallSample.pdf")
69
70
data4 = {
71
"prior": {"a": 2, "b": 2},
72
"likelihood": {"n_0": 10, "n_1": 40},
73
"posterior": {"a": 42, "b": 12},
74
}
75
make_graph(data4, "betaPostInfLargeSample.pdf")
76
77
78
plt.show()
79
80