Path: blob/master/deprecated/scripts/bayes_linreg_2d_demo.py
1192 views
1#Bayesian inference for simple linear regression with known noise variance2#The goal is to reproduce fig 3.7 from Bishop's book.3#We fit the linear model f(x,w) = w0 + w1*x and plot the posterior over w.45import superimport67import numpy as np8import matplotlib.pyplot as plt9import pyprobml_utils as pml101112from scipy.stats import uniform, norm, multivariate_normal1314np.random.seed(0)1516#Number of samples to draw from posterior distribution of parameters.17NSamples = 101819#Each of these corresponds to a row in the graphic and an amount of data the posterior will reflect.20#First one must be zero, for the prior.21DataIndices = [0,1,2,100]2223#True regression parameters that we wish to recover. Do not set these outside the range of [-1,1]24a0 = -0.325a1 = 0.52627NPoints = 100 #Number of (x,y) training points28noiseSD = 0.2 #True noise standard deviation29priorPrecision = 2.0 #Fix the prior precision, alpha. We will use a zero-mean isotropic Gaussian.30likelihoodSD = noiseSD # Assume the likelihood precision, beta, is known.31likelihoodPrecision = 1.0/(likelihoodSD**2)3233#Because of how axises are set up, x and y values should be in the same range as the coefficients.34x = 2*uniform().rvs(NPoints) - 135y = a0 + a1*x + norm(0, noiseSD).rvs(NPoints)3637def MeanCovPost(x, y):38#Given data vectors x and y, this returns the posterior mean and covariance.39X = np.array([[1,x1] for x1 in x])40Precision = np.diag([priorPrecision]*2) + likelihoodPrecision*X.T.dot(X)41Cov = np.linalg.inv(Precision)42Mean = likelihoodPrecision*Cov.dot(X.T.dot(y))43return {'Mean':Mean,'Cov':Cov}4445def GaussPdfMaker(mean,cov):46#For a given (mean, cov) pair, this returns a vectorized pdf function.47def out(w1,w2):48return multivariate_normal.pdf([w1,w2],mean=mean,cov=cov)49return np.vectorize(out)5051def LikeFMaker(x0,y0):52#For a given (x,y) pair, this returns a vectorized likelhood function.53def out(w1,w2):54err = y0 - (w1 + w2*x0)55return norm.pdf(err,loc=0,scale=likelihoodSD)56return np.vectorize(out)5758#Grid space for which values will be determined, which is shared between the coefficient space and data space.59grid = np.linspace(-1,1,50)60Xg = np.array([[1,g] for g in grid])61G1, G2 = np.meshgrid(grid,grid)6263#If we have many samples of lines, we make them a bit transparent.64alph = 5.0/NSamples if NSamples>50 else 1.06566#A function to make some common adjustments to our subplots.67def adjustgraph(whitemark):68if whitemark:69plt.ylabel(r'$w_1$')70plt.xlabel(r'$w_0$')71plt.scatter(a0,a1,marker='+',color='white',s=100)72else:73plt.ylabel('y')74plt.xlabel('x')75plt.ylim([-1,1])76plt.xlim([-1,1])77plt.xticks([-1,0,1])78plt.yticks([-1,0,1])79return None8081figcounter = 182fig = plt.figure(figsize=(10,10))8384#Top left plot only has a title.85ax = fig.add_subplot(len(DataIndices),3,figcounter)86ax.set_title('likelihood')87plt.axis('off')8889#This builds the graph one row at a time.90for di in DataIndices:91if di == 0:92postM = [0,0]93postCov = np.diag([1.0/priorPrecision]*2)94else:95Post = MeanCovPost(x[:di],y[:di])96postM = Post['Mean']97postCov = Post['Cov']9899#Left graph100figcounter += 1101fig.add_subplot(len(DataIndices),3,figcounter)102likfunc = LikeFMaker(x[di-1],y[di-1])103plt.contourf(G1, G2, likfunc(G1,G2), 100)104adjustgraph(True)105106#Middle graph107postfunc = GaussPdfMaker(postM,postCov)108figcounter += 1109ax = fig.add_subplot(len(DataIndices),3,figcounter)110plt.contourf(G1, G2, postfunc(G1,G2), 100)111adjustgraph(True)112#Set title if this is the top middle graph113if figcounter == 2:114ax.set_title('prior/posterior')115116#Right graph117Samples = multivariate_normal(postM,postCov).rvs(NSamples)118Lines = Xg.dot(Samples.T)119figcounter += 1120ax = fig.add_subplot(len(DataIndices),3,figcounter)121if di != 0:122plt.scatter(x[:di],y[:di], s=140, facecolors='none', edgecolors='b')123for j in range(Lines.shape[1]):124plt.plot(grid,Lines[:,j],linewidth=2,color='r',alpha=alph)125#Set title if this is the top right graph126if figcounter == 3:127ax.set_title('data space')128adjustgraph(False)129130fig.tight_layout()131plt.show()132pml.savefig('bayesLinRegPlot2dB.pdf')133134