Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/scripts/bayes_unigauss_2d_pyro.py
1192 views
1
# Approximate 2d posterior using pyro SVI
2
# https://www.ritchievink.com/blog/2019/06/10/bayesian-inference-how-we-are-able-to-chase-the-posterior/
3
# We use the same data and model as in posteriorGrid2d.py
4
5
import superimport
6
7
import numpy as np
8
import matplotlib.pyplot as plt
9
from scipy import stats
10
11
12
import pyro
13
import pyro.distributions as dist
14
import pyro.optim
15
from pyro.infer import SVI, Trace_ELBO
16
import torch
17
import torch.distributions.constraints as constraints
18
import numpy as np
19
20
figdir = "../figures"
21
import os
22
def save_fig(fname):
23
if figdir: plt.savefig(os.path.join(figdir, fname))
24
25
np.random.seed(0)
26
27
data = np.array([195, 182])
28
29
# lets create a grid of our two parameters
30
mu = np.linspace(150, 250)
31
sigma = np.linspace(0, 15)[::-1]
32
mm, ss = np.meshgrid(mu, sigma) # just broadcasted parameters
33
likelihood = stats.norm(mm, ss).pdf(data[0]) * stats.norm(mm, ss).pdf(data[1])
34
aspect = mm.max() / ss.max() / 3
35
extent = [mm.min(), mm.max(), ss.min(), ss.max()]
36
# extent = left right bottom top
37
38
prior = stats.norm(200, 15).pdf(mm) * stats.cauchy(0, 10).pdf(ss)
39
# Posterior - grid
40
unnormalized_posterior = prior * likelihood
41
posterior = unnormalized_posterior / np.nan_to_num(unnormalized_posterior).sum()
42
43
plt.figure()
44
plt.imshow(posterior, cmap='Blues', aspect=aspect, extent=extent)
45
plt.xlabel(r'$\mu$')
46
plt.ylabel(r'$\sigma$')
47
plt.title('Grid approximation')
48
plt.show()
49
50
def model():
51
# priors
52
mu = pyro.sample('mu', dist.Normal(loc=torch.tensor(200.),
53
scale=torch.tensor(15.)))
54
sigma = pyro.sample('sigma', dist.HalfCauchy(scale=torch.tensor(10.)))
55
56
# likelihood
57
with pyro.plate('plate', size=2):
58
pyro.sample(f'obs', dist.Normal(loc=mu, scale=sigma),
59
obs=torch.tensor([195., 185.]))
60
61
def guide():
62
# variational parameters
63
var_mu = pyro.param('var_mu', torch.tensor(180.))
64
var_mu_sig = pyro.param('var_mu_sig', torch.tensor(5.),
65
constraint=constraints.positive)
66
var_sig = pyro.param('var_sig', torch.tensor(5.))
67
68
# factorized distribution
69
pyro.sample('mu', dist.Normal(loc=var_mu, scale=var_mu_sig))
70
pyro.sample('sigma', dist.Chi2(var_sig))
71
72
pyro.clear_param_store()
73
pyro.enable_validation(True)
74
75
svi = SVI(model, guide,
76
optim=pyro.optim.ClippedAdam({"lr":0.01}),
77
loss=Trace_ELBO())
78
79
# do gradient steps
80
c = 0
81
for step in range(1000):
82
c += 1
83
loss = svi.step()
84
if step % 100 == 0:
85
print("[iteration {:>4}] loss: {:.4f}".format(c, loss))
86
87
88
sigma = dist.Chi2(pyro.param('var_sig')).sample((10000,)).numpy()
89
mu = dist.Normal(pyro.param('var_mu'), pyro.param('var_mu_sig')).sample((10000,)).numpy()
90
91
plt.figure()
92
plt.scatter(mu, sigma, alpha=0.01)
93
plt.xlim([extent[0], extent[1]])
94
plt.ylim([extent[2], extent[3]])
95
plt.ylabel('$\sigma$')
96
plt.xlabel('$\mu$')
97
plt.title('VI samples')
98
save_fig('bayes_unigauss_2d_pyro_post.pdf')
99
plt.show()
100