Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/scripts/bayes_linreg_2d_demo.py
1192 views
1
2
#Bayesian inference for simple linear regression with known noise variance
3
#The goal is to reproduce fig 3.7 from Bishop's book.
4
#We fit the linear model f(x,w) = w0 + w1*x and plot the posterior over w.
5
6
import superimport
7
8
import numpy as np
9
import matplotlib.pyplot as plt
10
import pyprobml_utils as pml
11
12
13
from scipy.stats import uniform, norm, multivariate_normal
14
15
np.random.seed(0)
16
17
#Number of samples to draw from posterior distribution of parameters.
18
NSamples = 10
19
20
#Each of these corresponds to a row in the graphic and an amount of data the posterior will reflect.
21
#First one must be zero, for the prior.
22
DataIndices = [0,1,2,100]
23
24
#True regression parameters that we wish to recover. Do not set these outside the range of [-1,1]
25
a0 = -0.3
26
a1 = 0.5
27
28
NPoints = 100 #Number of (x,y) training points
29
noiseSD = 0.2 #True noise standard deviation
30
priorPrecision = 2.0 #Fix the prior precision, alpha. We will use a zero-mean isotropic Gaussian.
31
likelihoodSD = noiseSD # Assume the likelihood precision, beta, is known.
32
likelihoodPrecision = 1.0/(likelihoodSD**2)
33
34
#Because of how axises are set up, x and y values should be in the same range as the coefficients.
35
x = 2*uniform().rvs(NPoints) - 1
36
y = a0 + a1*x + norm(0, noiseSD).rvs(NPoints)
37
38
def MeanCovPost(x, y):
39
#Given data vectors x and y, this returns the posterior mean and covariance.
40
X = np.array([[1,x1] for x1 in x])
41
Precision = np.diag([priorPrecision]*2) + likelihoodPrecision*X.T.dot(X)
42
Cov = np.linalg.inv(Precision)
43
Mean = likelihoodPrecision*Cov.dot(X.T.dot(y))
44
return {'Mean':Mean,'Cov':Cov}
45
46
def GaussPdfMaker(mean,cov):
47
#For a given (mean, cov) pair, this returns a vectorized pdf function.
48
def out(w1,w2):
49
return multivariate_normal.pdf([w1,w2],mean=mean,cov=cov)
50
return np.vectorize(out)
51
52
def LikeFMaker(x0,y0):
53
#For a given (x,y) pair, this returns a vectorized likelhood function.
54
def out(w1,w2):
55
err = y0 - (w1 + w2*x0)
56
return norm.pdf(err,loc=0,scale=likelihoodSD)
57
return np.vectorize(out)
58
59
#Grid space for which values will be determined, which is shared between the coefficient space and data space.
60
grid = np.linspace(-1,1,50)
61
Xg = np.array([[1,g] for g in grid])
62
G1, G2 = np.meshgrid(grid,grid)
63
64
#If we have many samples of lines, we make them a bit transparent.
65
alph = 5.0/NSamples if NSamples>50 else 1.0
66
67
#A function to make some common adjustments to our subplots.
68
def adjustgraph(whitemark):
69
if whitemark:
70
plt.ylabel(r'$w_1$')
71
plt.xlabel(r'$w_0$')
72
plt.scatter(a0,a1,marker='+',color='white',s=100)
73
else:
74
plt.ylabel('y')
75
plt.xlabel('x')
76
plt.ylim([-1,1])
77
plt.xlim([-1,1])
78
plt.xticks([-1,0,1])
79
plt.yticks([-1,0,1])
80
return None
81
82
figcounter = 1
83
fig = plt.figure(figsize=(10,10))
84
85
#Top left plot only has a title.
86
ax = fig.add_subplot(len(DataIndices),3,figcounter)
87
ax.set_title('likelihood')
88
plt.axis('off')
89
90
#This builds the graph one row at a time.
91
for di in DataIndices:
92
if di == 0:
93
postM = [0,0]
94
postCov = np.diag([1.0/priorPrecision]*2)
95
else:
96
Post = MeanCovPost(x[:di],y[:di])
97
postM = Post['Mean']
98
postCov = Post['Cov']
99
100
#Left graph
101
figcounter += 1
102
fig.add_subplot(len(DataIndices),3,figcounter)
103
likfunc = LikeFMaker(x[di-1],y[di-1])
104
plt.contourf(G1, G2, likfunc(G1,G2), 100)
105
adjustgraph(True)
106
107
#Middle graph
108
postfunc = GaussPdfMaker(postM,postCov)
109
figcounter += 1
110
ax = fig.add_subplot(len(DataIndices),3,figcounter)
111
plt.contourf(G1, G2, postfunc(G1,G2), 100)
112
adjustgraph(True)
113
#Set title if this is the top middle graph
114
if figcounter == 2:
115
ax.set_title('prior/posterior')
116
117
#Right graph
118
Samples = multivariate_normal(postM,postCov).rvs(NSamples)
119
Lines = Xg.dot(Samples.T)
120
figcounter += 1
121
ax = fig.add_subplot(len(DataIndices),3,figcounter)
122
if di != 0:
123
plt.scatter(x[:di],y[:di], s=140, facecolors='none', edgecolors='b')
124
for j in range(Lines.shape[1]):
125
plt.plot(grid,Lines[:,j],linewidth=2,color='r',alpha=alph)
126
#Set title if this is the top right graph
127
if figcounter == 3:
128
ax.set_title('data space')
129
adjustgraph(False)
130
131
fig.tight_layout()
132
plt.show()
133
pml.savefig('bayesLinRegPlot2dB.pdf')
134