Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/12/rjmcmc_rbf/rjnn.m
1193 views
1
function [k,mu,alpha,sigma,nabla,delta,ypred,ypredv,post] = rjnn(x,y,chainLength,Ndata,bFunction,par,xv,yv);
2
% PURPOSE : Computes the parameters and number of parameters of a radial basis function (RBF)
3
% network using the reversible jump MCMC algorithm. Please have a
4
% look at the paper first.
5
% INPUTS : - x : Input data.
6
% - y : Target data.
7
% - chainLength: Number of iterations of the Markov chain.
8
% - Ndata: Number of time steps in the training data set.
9
% - bFunction: Type of basis function.
10
% - par: Record of simulation parameters (see defaults).
11
% - {xv,yv}: Validation data (optional).
12
% OUTPUTS : - k : Model order.
13
% - mu: Basis centres.
14
% - alpha: Coefficients + linear weights (see paper).
15
% - sigma: Measurement noise variance.
16
% - nabla: Hyperparameter for k.
17
% - delta: Signal to noise ratio.
18
% - ypred: Prediction on the train set.
19
% - ypredv: Prediction on the test set.
20
% - post: Log of the joint posterior density.
21
22
% AUTHOR : Nando de Freitas - Thanks for the acknowledgement :-)
23
% DATE : 21-01-99
24
25
% CHECK INPUTS AND SET DEFAULTS:
26
% =============================
27
if nargin < 5, error('Not enough input arguments.'); end;
28
if ((nargin==5) | (nargin==7)),
29
if nargin == 5
30
Validation = 0;
31
else
32
Validation = 1;
33
end;
34
hyper.a = 2; % Hyperparameter for delta.
35
hyper.b = 10; % Hyperparameter for delta.
36
hyper.e1 = 0.0001; % Hyperparameter for nabla.
37
hyper.e2 = 0.0001; % Hyperparameter for nabla.
38
hyper.v = 0; % Hyperparameter for sigma
39
hyper.gamma = 0; % Hyperparameter for sigma.
40
kMax = 50; % Maximum number of basis.
41
arbC = 0.5; % Constant for birth and death moves.
42
doPlot = 1; % To plot or not to plot? Thats ...
43
sigStar = .1; % Merge-split parameter.
44
sWalk = .001;
45
Lambda = .5;
46
walkPer = 0.1;
47
elseif ((nargin==6) | (nargin==8))
48
if nargin == 6
49
Validation = 0;
50
else
51
Validation = 1;
52
end;
53
hyper.a = par.a;
54
hyper.b = par.b;
55
hyper.e1 = par.e1;
56
hyper.e2 = par.e2;
57
hyper.v = par.v;
58
hyper.gamma = par.gamma;
59
kMax = par.kMax;
60
arbC = par.arbC;
61
doPlot = par.doPlot;
62
sigStar = par.merge;
63
sWalk = par.sRW;
64
Lambda = par.Lambda;
65
walkPer = par.walkPer;
66
else
67
error('Wrong Number of input arguments.');
68
end;
69
if Validation,
70
[Nv,dv] = size(xv); % Nv = number of test data, dv = dimension of xv.
71
end;
72
[N,d] = size(x); % N = number of train data, d = dimension of x.
73
[N,c] = size(y); % c = dimension of y, i.e. number of outputs.
74
if Ndata ~= N, error('input must me N by d and output N by c.'); end;
75
76
% INITIALISATION:
77
% ==============
78
post = ones(chainLength,1); % p(centres,k|y).
79
if Validation,
80
ypredv = zeros(Nv,c,chainLength); % Output fit (test set).
81
end;
82
ypred = zeros(N,c,chainLength); % Output fit (train set).
83
nabla = zeros(chainLength,1); % Poisson parameter.
84
delta = zeros(chainLength,c); % Regularisation parameter.
85
k = ones(chainLength,1); % Model order - number of basis.
86
sigma = ones(chainLength,c); % Output noise variance.
87
mu = cell(chainLength,1); % Radial basis centres.
88
alpha = cell(chainLength,c); % Radial basis coefficients.
89
90
% DEFINE WALK INTERVAL FOR MU:
91
% ===========================
92
walk = walkPer*(max(x)-min(x));
93
walkInt=zeros(d,1);
94
for i=1:d,
95
walkInt(i,1) = (max(x(:,i))-min(x(:,i))) + 2*walk(i);
96
end;
97
98
% SAMPLE INITIAL CONDITIONS FROM THEIR PRIORS:
99
% ===========================================
100
nabla(1) = gengamma(0.5 + hyper.e1,hyper.e2);
101
k(1) = poissrnd(nabla(1));
102
k(1) = 40; % TEMPORARY: for demo1 comparison.
103
k(1) = max(k(1),1);
104
k(1) = min(k(1),kMax);
105
for i=1:c
106
delta(1,i) = inv(gengamma(hyper.a,hyper.b));
107
sigma(1,i) = inv(gengamma(hyper.v/2,hyper.gamma/2));
108
alpha{1,i} = mvnrnd(zeros(1,k(1)+d+1),sigma(1,i)*delta(1,i)*eye(k(1)+d+1),1)';
109
end;
110
111
% DRAW THE INITIAL RADIAL CENTRES:
112
% ===============================
113
mu{1}=zeros(k(1),d);
114
for i=1:d,
115
mu{1}(:,i)= (min(x(:,i))-walk(i))*ones(k(1),1) + ((max(x(:,i))+walk(i))-(min(x(:,i))-walk(i)))*rand(k(1),1);
116
end;
117
118
% FILL THE REGRESSION MATRIX:
119
% ==========================
120
M=zeros(N,k(1)+d+1);
121
M(:,1) = ones(N,1);
122
M(:,2:d+1) = x;
123
for j=d+2:k(1)+d+1,
124
M(:,j) = feval(bFunction,mu{1}(j-d-1,:),x);
125
end;
126
for i=1:c,
127
ypred(:,i,1) = M*alpha{1,i};
128
end;
129
if Validation
130
Mv=zeros(Nv,k(1)+d+1);
131
Mv(:,1) = ones(Nv,1);
132
Mv(:,2:d+1) = xv;
133
for j=d+2:k(1)+d+1,
134
Mv(:,j) = feval(bFunction,mu{1}(j-d-1,:),xv);
135
end;
136
for i=1:c,
137
ypredv(:,i,1) = Mv*alpha{1,i};
138
end;
139
end;
140
141
% INITIALISE COUNTERS:
142
% ===================
143
aUpdate=0;
144
rUpdate=0;
145
aBirth=0;
146
rBirth=0;
147
aDeath=0;
148
rDeath=0;
149
aMerge=0;
150
rMerge=0;
151
aSplit=0;
152
rSplit=0;
153
aRW=0;
154
rRW=0;
155
match=0;
156
if doPlot
157
figure(3)
158
clf;
159
end;
160
161
% ITERATE THE MARKOV CHAIN:
162
% ========================
163
for t=1:chainLength-1,
164
iteration=t
165
% COMPUTE THE CENTRES AND DIMENSION WITH METROPOLIS, BIRTH AND DEATH MOVES:
166
% ========================================================================
167
decision=rand(1);
168
birth=arbC*min(1,(nabla(t)/(k(t)+1)));
169
death=arbC*min(1,((k(t)+1)/nabla(t)));
170
if ((decision <= birth) & (k(t)<kMax)),
171
[k,mu,M,match,aBirth,rBirth] = radialBirth(match,aBirth,rBirth,k,mu,M,delta,x,y,hyper,t,bFunction,walkInt,walk);
172
elseif ((decision <= birth+death) & (k(t)>0)),
173
[k,mu,M,aDeath,rDeath] = radialDeath(aDeath,rDeath,k,mu,M,delta,x,y,hyper,t,nabla);
174
elseif ((decision <= 2*birth+death) & (k(t)<kMax) & (k(t)>1)),
175
[k,mu,M,aSplit,rSplit] = radialSplit(aSplit,rSplit,k,mu,M,delta,x,y,hyper,t,bFunction,sigStar,walkInt,walk);
176
elseif ((decision <= 2*birth+2*death) & (k(t)>1)),
177
[k,mu,M,aMerge,rMerge] = radialMerge(aMerge,rMerge,k,mu,M,delta,x,y,hyper,t,bFunction,sigStar,walkInt);
178
else
179
uLambda = rand(1);
180
if ((uLambda>Lambda) & (k(t)>0))
181
[k,mu,M,match,aRW,rRW] = radialRW(match,aRW,rRW,k,mu,M,delta,x,y,hyper,t,bFunction,sWalk,walk);
182
else
183
[k,mu,M,match,aUpdate,rUpdate] = radialUpdate(match,aUpdate,rUpdate,k,mu,M,delta,x,y,hyper,t,bFunction,walkInt,walk);
184
end;
185
end;
186
187
% UPDATE OTHER PARAMETERS WITH GIBBS:
188
% ==================================
189
H=zeros(k(t+1)+1+d,k(t+1)+1+d,c);
190
F=zeros(k(t+1)+1+d,c);
191
P=zeros(N,N,c);
192
for i=1:c,
193
H(:,:,i) = inv(M'*M + (1/delta(t,i))*eye(k(t+1)+1+d));
194
F(:,i) = H(:,:,i)*M'*y(:,i);
195
P(:,:,i) = eye(N) - M*H(:,:,i)*M';
196
sigma(t+1,i) = inv(gengamma((hyper.v+N)/2,(hyper.gamma+y(:,i)'*P(:,:,i)*y(:,i))/2));
197
alpha{t+1,i} = mvnrnd(F(:,i),sigma(t+1,i)*H(:,:,i),1)';
198
delta(t+1,i) = inv(gengamma(hyper.a+(k(t+1)+d+1)/2,hyper.b+inv(2*sigma(t+1,i))*alpha{t+1,i}'*alpha{t+1,i}));
199
end;
200
nabla(t+1) = gengamma(0.5+hyper.e1+k(t+1),1+hyper.e2);
201
202
% COMPUTE THE POSTERIOR FOR MONITORING:
203
% ====================================
204
posterior =exp(-nabla(t+1)) * delta(t+1,1)^(-(d+k(t+1)+1)/2) * inv(prod(1:k(t+1)) * prod(walkInt)^(k(t+1))) * nabla(t+1)^(k(t+1)) * sqrt(det(H(:,:,1))) * (hyper.gamma+y(:,1)'*P(:,:,1)*y(:,1))^(-(hyper.v+N)/2);
205
for i=2:c,
206
newpost = delta(t+1,i)^(-(d+k(t+1)+1)/2) * sqrt(det(H(:,:,i))) * (hyper.gamma+y(:,i)'*P(:,:,i)*y(:,i))^(-(hyper.v+N)/2);
207
posterior = posterior * newpost;
208
end;
209
post(t+1) = log(posterior);
210
211
% PLOT FOR FUN AND MONITORING:
212
% ============================
213
for i=1:c,
214
ypred(:,i,t+1) = M*alpha{t+1,i};
215
end;
216
msError = inv(N) * trace((y-ypred(:,:,t+1))'*(y-ypred(:,:,t+1)));
217
% NRMSE = sqrt((y-ypred(:,:,t+1))'*(y-ypred(:,:,t+1))*inv((y-mean(y)*ones(size(y)))'*(y-mean(y)*ones(size(y)))))
218
219
if Validation,
220
% FILL THE VALIDATION REGRESSION MATRIX:
221
% ======================================
222
Mv=zeros(Nv,k(t+1)+d+1);
223
Mv(:,1) = ones(Nv,1);
224
Mv(:,2:d+1) = xv;
225
for j=d+2:k(t+1)+d+1,
226
Mv(:,j) = feval(bFunction,mu{t+1}(j-d-1,:),xv);
227
end;
228
for i=1:c,
229
ypredv(:,i,t+1) = Mv*alpha{t+1,i};
230
end;
231
msErrorv = inv(Nv) * trace((yv-ypredv(:,:,t+1))'*(yv-ypredv(:,:,t+1)));
232
end;
233
234
if doPlot,
235
figure(1)
236
clf
237
if (c==2),
238
plot(x(:,1),y(:,1),'b+',x(:,2),y(:,2),'r+',x(:,1),ypred(:,1,t+1),'bo',x(:,2),ypred(:,2,t+1),'ro');
239
elseif c==1,
240
plot(x,y,'b+',x,ypred(:,:,t+1),'ro');
241
end;
242
errorv = sum(abs(yv-ypredv(:,:,t+1)))*100*inv(Nv);
243
ylabel('Output','fontsize',15)
244
xlabel('Input','fontsize',15)
245
figure(3)
246
subplot(511);
247
hold on;
248
plot(t,k(t),'*');
249
ylabel('k','fontsize',15);
250
subplot(512);
251
hold on;
252
plot(t,post(t+1),'*');
253
ylabel('p(k,mu|y)','fontsize',15);
254
subplot(513);
255
hold on;
256
plot(t,msError,'r*');
257
ylabel('Train error','fontsize',15);
258
subplot(514);
259
hold on;
260
plot(t,msErrorv,'r*');
261
ylabel('Test error','fontsize',15);
262
subplot(515);
263
hold on;
264
bar([1 2 3 4 5 6 7 8 9 10 11 12 13],[match aUpdate rUpdate aBirth rBirth aDeath rDeath aMerge rMerge aSplit rSplit aRW rRW]);
265
ylabel('Acceptance','fontsize',15);
266
xlabel('match aU rU aB rB aD rD aM rM aS rS aRW rRW','fontsize',15)
267
end;
268
end;
269
270
271
272
273
274
275
276
277
278
279