Path: blob/master/notebooks/book2/12/rjmcmc_rbf/rjdemo1.m
1193 views
% PURPOSE : To approximate a noisy nolinear function with RBFs, where the number1% of parameters and parameter values are estimated via reversible jump2% Markov Chain Monte Carlo (MCMC) simulation.34% AUTHOR : Nando de Freitas - Thanks for the acknowledgement :-)5% DATE : 21-01-9967clear;8echo off;910% INITIALISATION AND PARAMETERS:11% =============================1213N = 100; % Number of time steps.14t = 1:1:N; % Time.15chainLength = 2000; % Length of the Markov chain simulation.16burnIn = 1000; % Burn In period.17bFunction = 'rjGaussian' % Type of basis function.18par.doPlot = 0; % 1 plot. 0 don't plot.19par.a = 2; % Hyperparameter for delta.20par.b = 10; % Hyperparameter for delta.21par.e1 = 0.0001; % Hyperparameter for nabla.22par.e2 = 0.0001; % Hyperparameter for nabla.23par.v = 0.1; %0; % Hyperparameter for sigma24par.gamma = 0.1; % 0; % Hyperparameter for sigma.25par.kMax = 50; % Maximum number of basis.26par.arbC = 0.25; % Constant for birth and death moves.27par.merge = .1; % Split-Merge parameter.28par.Lambda = .5; % Hybrid Metropolis decision parameter.29par.sRW = .001; % Variance of noise in the random walk.30par.walkPer = 0.1; % Percentange of random walk interval.3132% GENERATE THE DATA:33% =================34noiseVar = 0.5;35x = 4*rand(N,1)-2; % Input data - uniform in [-2,2].36u = randn(N,1);37noise = sqrt(noiseVar)*u; % Measurement noise38varianceN=var(noise)39y = x + 2*exp(-16*(x.^(2))) + 2*exp(-16*((x-.7).^(2))) + noise; % Output data.40x=(x+2)/4; % Rescaling to [0,1].41ynn = y-noise;42xv = 4*rand(N,1)-2; % Input data - uniform in [-2,2].43uv = randn(N,1);44noisev = sqrt(noiseVar)*uv;45yv = xv + 2*exp(-16*(xv.^(2))) + 2*exp(-16*((xv-.7).^(2))) + noisev; % Output data.46xv=(xv+2)/4;47yvnn = yv-noisev;4849figure(1)50subplot(211)51plot(x,y,'b+');52ylabel('Output data','fontsize',15);53xlabel('Input data','fontsize',15);54%axis([0 1 -3 3]);55subplot(212)56plot(noise)57ylabel('Measurement noise','fontsize',15);58xlabel('Time','fontsize',15);59606162% PERFORM REVERSE JUMP MCMC WITH RADIAL BASIS:63% ===========================================64[k,mu,alpha,sigma,nabla,delta,yp,ypv,post] = rjnn(x,y,chainLength,N,bFunction,par,xv,yv);6566% COMPUTE CENTROID, MAP AND VARIANCE ESTIMATES:67% ============================================6869[l,m]=size(mu{1});70[Nv,d]=size(xv);71l=chainLength-burnIn;72muvec=zeros(l,m);73alphavec=zeros(m+d+1,l);74ypred = zeros(N,l+1);75ypredv = zeros(Nv,l+1);76for i=1:N;77ypred(i,:) = yp(i,1,burnIn:chainLength);78end;79for i=1:Nv;80ypredv(i,:) = ypv(i,1,burnIn:chainLength);81end;82ypred = mean(ypred');83ypredv = mean(ypredv');84fevTrain =(y-ypred')'*(y-ypred')*inv((y-mean(y)*ones(size(y)))'*(y-mean(y)*ones(size(y))))85fevTest = (yv-ypredv')'*(yv-ypredv')*inv((yv-mean(yv)*ones(size(yv)))'*(yv-mean(yv)*ones(size(yv))))8687% PLOTS:88% =====89figure;90[xv,i]=sort(xv);91yvnn=yvnn(i);92ypredv=ypredv(i);93yv=yv(i);94[x,i]=sort(x);95ynn=ynn(i);96ypred=ypred(i);97y=y(i);98plot(x,ynn,'k:',x,y,'b+',x,ypred,'r','linewidth',3)99ylabel('Train output','fontsize',18)100xlabel('Train input','fontsize',18)101print(gcf, '-dpdf', 'rjmcmc_train');102103104105figure106plot(xv,yvnn,'k:',xv,yv,'b+',xv,ypredv,'r','linewidth',3)107ylabel('Test output','fontsize',18)108xlabel('Test input','fontsize',18)109legend('True function','Test data','Prediction');110print(gcf, '-dpdf', 'rjmcmc_test')111112% COMPUTE THE MOST LIKELY MODES:113% =============================114pInt=2;115support=[1:1:4];116probk=zeros((chainLength)/pInt,length(support));117for p=pInt:pInt:chainLength,118[probk(p/pInt,:),kmodes]=hist(k(1:p),support);119probk(p/pInt,:)=probk(p/pInt,:)/p;120end;121figure;122plot(pInt:pInt:chainLength,probk(:,1),'k--',...123pInt:pInt:chainLength,probk(:,2),'b:',...124pInt:pInt:chainLength,probk(:,3),'r',...125pInt:pInt:chainLength,probk(:,4),'g-.','linewidth',3);126xlabel('Chain length','fontsize',15)127ylabel('p(k|y)','fontsize',15)128legend('k=1','k=2','k=3','k=4')129modes = probk(chainLength/2,:);130print(gcf, '-dpdf', 'rjmcmc_K_vs_time')131132%KPM133figure;134bar([probk(end,1) probk(end,2) probk(end,3) probk(end,4)])135title('p(k|data)')136print(gcf, '-dpdf', 'rjmcmc_K_hist')137138139% HISTOGRAMS:140% ==========141figure;142subplot(321)143hist(delta(burnIn:chainLength),80)144ylabel('Regularisation parameter','fontsize',15);145subplot(322)146plot(delta)147ylabel('Regularisation parameter','fontsize',15);148subplot(323)149hist(sigma(burnIn:chainLength),80)150ylabel('Noise variance','fontsize',15);151subplot(324)152plot(sigma)153ylabel('Noise variance','fontsize',15);154subplot(325)155hist(nabla(burnIn:chainLength),80)156ylabel('Poisson parameter','fontsize',15);157subplot(326)158plot(nabla)159ylabel('Poisson parameter','fontsize',15);160161162163164165166167168169170171172173174175176177