Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/scripts/beta_binom_post_pred.py
1192 views
1
import superimport
2
3
import arviz as az
4
import matplotlib.pyplot as plt
5
import numpy as np
6
from scipy.special import comb, beta
7
from scipy.stats import binom
8
from scipy import stats
9
10
np.random.seed(0)
11
12
a_prior, b_prior = 1, 1
13
14
Y = stats.bernoulli(0.7).rvs(20)
15
16
N1, N0 = Y.sum(), len(Y) - Y.sum()
17
18
a_post = a_prior + N1
19
b_post = b_prior + N0
20
21
prior_pred_dist, post_pred_dist = [], []
22
N = 20
23
24
for k in range(N + 1):
25
post_pred_dist.append(comb(N, k) * beta(k + a_post, N - k + b_post) / beta(a_post, b_post))
26
prior_pred_dist.append(comb(N, k) * beta(k + a_prior, N - k + b_prior) / beta(a_prior, b_prior))
27
28
fig, ax = plt.subplots()
29
ax.bar(np.arange(N + 1), prior_pred_dist, align='center', color='grey')
30
ax.set_title(f"Prior predictive distribution", fontweight='bold')
31
ax.set_xlim(-1, 21)
32
ax.set_xticks(list(range(N + 1)))
33
ax.set_xticklabels(list(range(N + 1)))
34
ax.set_ylim(0, 0.15)
35
ax.set_xlabel("number of success")
36
37
fig, ax = plt.subplots()
38
ax.bar(np.arange(N + 1), post_pred_dist, align='center', color='grey')
39
ax.set_title(f"Posterior predictive distribution", fontweight='bold')
40
ax.set_xlim(-1, 21)
41
ax.set_xticks(list(range(N + 1)))
42
ax.set_xticklabels(list(range(N + 1)))
43
ax.set_ylim(0, 0.15)
44
ax.set_xlabel("number of success")
45
46
fig, ax = plt.subplots()
47
az.plot_dist(np.random.beta(a_prior, b_prior, 10000), plot_kwargs={"color": "0.5"},
48
fill_kwargs={'alpha': 1})
49
ax.set_title("Prior distribution", fontweight='bold')
50
ax.set_xlim(0, 1)
51
ax.set_ylim(0, 4)
52
ax.tick_params(axis='both', pad=7)
53
ax.set_xlabel("θ")
54
55
fig, ax = plt.subplots()
56
az.plot_dist(np.random.beta(a_post, b_post, 10000), plot_kwargs={"color": "0.5"},
57
fill_kwargs={'alpha': 1})
58
ax.set_title("Posterior distribution", fontweight='bold')
59
#ax.set_xlim(0, 1)
60
#ax.set_ylim(0, 4)
61
ax.tick_params(axis='both', pad=7)
62
ax.set_xlabel("θ")
63
64
plt.show()
65
66