Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/12/rjmcmc_rbf/rjdemo1.m
1193 views
1
% PURPOSE : To approximate a noisy nolinear function with RBFs, where the number
2
% of parameters and parameter values are estimated via reversible jump
3
% Markov Chain Monte Carlo (MCMC) simulation.
4
5
% AUTHOR : Nando de Freitas - Thanks for the acknowledgement :-)
6
% DATE : 21-01-99
7
8
clear;
9
echo off;
10
11
% INITIALISATION AND PARAMETERS:
12
% =============================
13
14
N = 100; % Number of time steps.
15
t = 1:1:N; % Time.
16
chainLength = 2000; % Length of the Markov chain simulation.
17
burnIn = 1000; % Burn In period.
18
bFunction = 'rjGaussian' % Type of basis function.
19
par.doPlot = 0; % 1 plot. 0 don't plot.
20
par.a = 2; % Hyperparameter for delta.
21
par.b = 10; % Hyperparameter for delta.
22
par.e1 = 0.0001; % Hyperparameter for nabla.
23
par.e2 = 0.0001; % Hyperparameter for nabla.
24
par.v = 0.1; %0; % Hyperparameter for sigma
25
par.gamma = 0.1; % 0; % Hyperparameter for sigma.
26
par.kMax = 50; % Maximum number of basis.
27
par.arbC = 0.25; % Constant for birth and death moves.
28
par.merge = .1; % Split-Merge parameter.
29
par.Lambda = .5; % Hybrid Metropolis decision parameter.
30
par.sRW = .001; % Variance of noise in the random walk.
31
par.walkPer = 0.1; % Percentange of random walk interval.
32
33
% GENERATE THE DATA:
34
% =================
35
noiseVar = 0.5;
36
x = 4*rand(N,1)-2; % Input data - uniform in [-2,2].
37
u = randn(N,1);
38
noise = sqrt(noiseVar)*u; % Measurement noise
39
varianceN=var(noise)
40
y = x + 2*exp(-16*(x.^(2))) + 2*exp(-16*((x-.7).^(2))) + noise; % Output data.
41
x=(x+2)/4; % Rescaling to [0,1].
42
ynn = y-noise;
43
xv = 4*rand(N,1)-2; % Input data - uniform in [-2,2].
44
uv = randn(N,1);
45
noisev = sqrt(noiseVar)*uv;
46
yv = xv + 2*exp(-16*(xv.^(2))) + 2*exp(-16*((xv-.7).^(2))) + noisev; % Output data.
47
xv=(xv+2)/4;
48
yvnn = yv-noisev;
49
50
figure(1)
51
subplot(211)
52
plot(x,y,'b+');
53
ylabel('Output data','fontsize',15);
54
xlabel('Input data','fontsize',15);
55
%axis([0 1 -3 3]);
56
subplot(212)
57
plot(noise)
58
ylabel('Measurement noise','fontsize',15);
59
xlabel('Time','fontsize',15);
60
61
62
63
% PERFORM REVERSE JUMP MCMC WITH RADIAL BASIS:
64
% ===========================================
65
[k,mu,alpha,sigma,nabla,delta,yp,ypv,post] = rjnn(x,y,chainLength,N,bFunction,par,xv,yv);
66
67
% COMPUTE CENTROID, MAP AND VARIANCE ESTIMATES:
68
% ============================================
69
70
[l,m]=size(mu{1});
71
[Nv,d]=size(xv);
72
l=chainLength-burnIn;
73
muvec=zeros(l,m);
74
alphavec=zeros(m+d+1,l);
75
ypred = zeros(N,l+1);
76
ypredv = zeros(Nv,l+1);
77
for i=1:N;
78
ypred(i,:) = yp(i,1,burnIn:chainLength);
79
end;
80
for i=1:Nv;
81
ypredv(i,:) = ypv(i,1,burnIn:chainLength);
82
end;
83
ypred = mean(ypred');
84
ypredv = mean(ypredv');
85
fevTrain =(y-ypred')'*(y-ypred')*inv((y-mean(y)*ones(size(y)))'*(y-mean(y)*ones(size(y))))
86
fevTest = (yv-ypredv')'*(yv-ypredv')*inv((yv-mean(yv)*ones(size(yv)))'*(yv-mean(yv)*ones(size(yv))))
87
88
% PLOTS:
89
% =====
90
figure;
91
[xv,i]=sort(xv);
92
yvnn=yvnn(i);
93
ypredv=ypredv(i);
94
yv=yv(i);
95
[x,i]=sort(x);
96
ynn=ynn(i);
97
ypred=ypred(i);
98
y=y(i);
99
plot(x,ynn,'k:',x,y,'b+',x,ypred,'r','linewidth',3)
100
ylabel('Train output','fontsize',18)
101
xlabel('Train input','fontsize',18)
102
print(gcf, '-dpdf', 'rjmcmc_train');
103
104
105
106
figure
107
plot(xv,yvnn,'k:',xv,yv,'b+',xv,ypredv,'r','linewidth',3)
108
ylabel('Test output','fontsize',18)
109
xlabel('Test input','fontsize',18)
110
legend('True function','Test data','Prediction');
111
print(gcf, '-dpdf', 'rjmcmc_test')
112
113
% COMPUTE THE MOST LIKELY MODES:
114
% =============================
115
pInt=2;
116
support=[1:1:4];
117
probk=zeros((chainLength)/pInt,length(support));
118
for p=pInt:pInt:chainLength,
119
[probk(p/pInt,:),kmodes]=hist(k(1:p),support);
120
probk(p/pInt,:)=probk(p/pInt,:)/p;
121
end;
122
figure;
123
plot(pInt:pInt:chainLength,probk(:,1),'k--',...
124
pInt:pInt:chainLength,probk(:,2),'b:',...
125
pInt:pInt:chainLength,probk(:,3),'r',...
126
pInt:pInt:chainLength,probk(:,4),'g-.','linewidth',3);
127
xlabel('Chain length','fontsize',15)
128
ylabel('p(k|y)','fontsize',15)
129
legend('k=1','k=2','k=3','k=4')
130
modes = probk(chainLength/2,:);
131
print(gcf, '-dpdf', 'rjmcmc_K_vs_time')
132
133
%KPM
134
figure;
135
bar([probk(end,1) probk(end,2) probk(end,3) probk(end,4)])
136
title('p(k|data)')
137
print(gcf, '-dpdf', 'rjmcmc_K_hist')
138
139
140
% HISTOGRAMS:
141
% ==========
142
figure;
143
subplot(321)
144
hist(delta(burnIn:chainLength),80)
145
ylabel('Regularisation parameter','fontsize',15);
146
subplot(322)
147
plot(delta)
148
ylabel('Regularisation parameter','fontsize',15);
149
subplot(323)
150
hist(sigma(burnIn:chainLength),80)
151
ylabel('Noise variance','fontsize',15);
152
subplot(324)
153
plot(sigma)
154
ylabel('Noise variance','fontsize',15);
155
subplot(325)
156
hist(nabla(burnIn:chainLength),80)
157
ylabel('Poisson parameter','fontsize',15);
158
subplot(326)
159
plot(nabla)
160
ylabel('Poisson parameter','fontsize',15);
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177