Path: blob/master/Week 7/Programming Assignment - 6/ex6/svmTrain.m
863 views
function [model] = svmTrain(X, Y, C, kernelFunction, ...1tol, max_passes)2%SVMTRAIN Trains an SVM classifier using a simplified version of the SMO3%algorithm.4% [model] = SVMTRAIN(X, Y, C, kernelFunction, tol, max_passes) trains an5% SVM classifier and returns trained model. X is the matrix of training6% examples. Each row is a training example, and the jth column holds the7% jth feature. Y is a column matrix containing 1 for positive examples8% and 0 for negative examples. C is the standard SVM regularization9% parameter. tol is a tolerance value used for determining equality of10% floating point numbers. max_passes controls the number of iterations11% over the dataset (without changes to alpha) before the algorithm quits.12%13% Note: This is a simplified version of the SMO algorithm for training14% SVMs. In practice, if you want to train an SVM classifier, we15% recommend using an optimized package such as:16%17% LIBSVM (http://www.csie.ntu.edu.tw/~cjlin/libsvm/)18% SVMLight (http://svmlight.joachims.org/)19%20%2122if ~exist('tol', 'var') || isempty(tol)23tol = 1e-3;24end2526if ~exist('max_passes', 'var') || isempty(max_passes)27max_passes = 5;28end2930% Data parameters31m = size(X, 1);32n = size(X, 2);3334% Map 0 to -135Y(Y==0) = -1;3637% Variables38alphas = zeros(m, 1);39b = 0;40E = zeros(m, 1);41passes = 0;42eta = 0;43L = 0;44H = 0;4546% Pre-compute the Kernel Matrix since our dataset is small47% (in practice, optimized SVM packages that handle large datasets48% gracefully will _not_ do this)49%50% We have implemented optimized vectorized version of the Kernels here so51% that the svm training will run faster.52if strcmp(func2str(kernelFunction), 'linearKernel')53% Vectorized computation for the Linear Kernel54% This is equivalent to computing the kernel on every pair of examples55K = X*X';56elseif strfind(func2str(kernelFunction), 'gaussianKernel')57% Vectorized RBF Kernel58% This is equivalent to computing the kernel on every pair of examples59X2 = sum(X.^2, 2);60K = bsxfun(@plus, X2, bsxfun(@plus, X2', - 2 * (X * X')));61K = kernelFunction(1, 0) .^ K;62else63% Pre-compute the Kernel Matrix64% The following can be slow due to the lack of vectorization65K = zeros(m);66for i = 1:m67for j = i:m68K(i,j) = kernelFunction(X(i,:)', X(j,:)');69K(j,i) = K(i,j); %the matrix is symmetric70end71end72end7374% Train75fprintf('\nTraining ...');76dots = 12;77while passes < max_passes,7879num_changed_alphas = 0;80for i = 1:m,8182% Calculate Ei = f(x(i)) - y(i) using (2).83% E(i) = b + sum (X(i, :) * (repmat(alphas.*Y,1,n).*X)') - Y(i);84E(i) = b + sum (alphas.*Y.*K(:,i)) - Y(i);8586if ((Y(i)*E(i) < -tol && alphas(i) < C) || (Y(i)*E(i) > tol && alphas(i) > 0)),8788% In practice, there are many heuristics one can use to select89% the i and j. In this simplified code, we select them randomly.90j = ceil(m * rand());91while j == i, % Make sure i \neq j92j = ceil(m * rand());93end9495% Calculate Ej = f(x(j)) - y(j) using (2).96E(j) = b + sum (alphas.*Y.*K(:,j)) - Y(j);9798% Save old alphas99alpha_i_old = alphas(i);100alpha_j_old = alphas(j);101102% Compute L and H by (10) or (11).103if (Y(i) == Y(j)),104L = max(0, alphas(j) + alphas(i) - C);105H = min(C, alphas(j) + alphas(i));106else107L = max(0, alphas(j) - alphas(i));108H = min(C, C + alphas(j) - alphas(i));109end110111if (L == H),112% continue to next i.113continue;114end115116% Compute eta by (14).117eta = 2 * K(i,j) - K(i,i) - K(j,j);118if (eta >= 0),119% continue to next i.120continue;121end122123% Compute and clip new value for alpha j using (12) and (15).124alphas(j) = alphas(j) - (Y(j) * (E(i) - E(j))) / eta;125126% Clip127alphas(j) = min (H, alphas(j));128alphas(j) = max (L, alphas(j));129130% Check if change in alpha is significant131if (abs(alphas(j) - alpha_j_old) < tol),132% continue to next i.133% replace anyway134alphas(j) = alpha_j_old;135continue;136end137138% Determine value for alpha i using (16).139alphas(i) = alphas(i) + Y(i)*Y(j)*(alpha_j_old - alphas(j));140141% Compute b1 and b2 using (17) and (18) respectively.142b1 = b - E(i) ...143- Y(i) * (alphas(i) - alpha_i_old) * K(i,j)' ...144- Y(j) * (alphas(j) - alpha_j_old) * K(i,j)';145b2 = b - E(j) ...146- Y(i) * (alphas(i) - alpha_i_old) * K(i,j)' ...147- Y(j) * (alphas(j) - alpha_j_old) * K(j,j)';148149% Compute b by (19).150if (0 < alphas(i) && alphas(i) < C),151b = b1;152elseif (0 < alphas(j) && alphas(j) < C),153b = b2;154else155b = (b1+b2)/2;156end157158num_changed_alphas = num_changed_alphas + 1;159160end161162end163164if (num_changed_alphas == 0),165passes = passes + 1;166else167passes = 0;168end169170fprintf('.');171dots = dots + 1;172if dots > 78173dots = 0;174fprintf('\n');175end176if exist('OCTAVE_VERSION')177fflush(stdout);178end179end180fprintf(' Done! \n\n');181182% Save the model183idx = alphas > 0;184model.X= X(idx,:);185model.y= Y(idx);186model.kernelFunction = kernelFunction;187model.b= b;188model.alphas= alphas(idx);189model.w = ((alphas.*Y)'*X)';190191end192193194