Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/ml/src/svmsgd.cpp
16337 views
1
/*M///////////////////////////////////////////////////////////////////////////////////////
2
//
3
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4
//
5
// By downloading, copying, installing or using the software you agree to this license.
6
// If you do not agree to this license, do not download, install,
7
// copy or use the software.
8
//
9
//
10
// License Agreement
11
// For Open Source Computer Vision Library
12
//
13
// Copyright (C) 2000, Intel Corporation, all rights reserved.
14
// Copyright (C) 2016, Itseez Inc, all rights reserved.
15
// Third party copyrights are property of their respective owners.
16
//
17
// Redistribution and use in source and binary forms, with or without modification,
18
// are permitted provided that the following conditions are met:
19
//
20
// * Redistribution's of source code must retain the above copyright notice,
21
// this list of conditions and the following disclaimer.
22
//
23
// * Redistribution's in binary form must reproduce the above copyright notice,
24
// this list of conditions and the following disclaimer in the documentation
25
// and/or other materials provided with the distribution.
26
//
27
// * The name of the copyright holders may not be used to endorse or promote products
28
// derived from this software without specific prior written permission.
29
//
30
// This software is provided by the copyright holders and contributors "as is" and
31
// any express or implied warranties, including, but not limited to, the implied
32
// warranties of merchantability and fitness for a particular purpose are disclaimed.
33
// In no event shall the Intel Corporation or contributors be liable for any direct,
34
// indirect, incidental, special, exemplary, or consequential damages
35
// (including, but not limited to, procurement of substitute goods or services;
36
// loss of use, data, or profits; or business interruption) however caused
37
// and on any theory of liability, whether in contract, strict liability,
38
// or tort (including negligence or otherwise) arising in any way out of
39
// the use of this software, even if advised of the possibility of such damage.
40
//
41
//M*/
42
43
#include "precomp.hpp"
44
#include "limits"
45
46
#include <iostream>
47
48
using std::cout;
49
using std::endl;
50
51
/****************************************************************************************\
52
* Stochastic Gradient Descent SVM Classifier *
53
\****************************************************************************************/
54
55
namespace cv
56
{
57
namespace ml
58
{
59
60
class SVMSGDImpl CV_FINAL : public SVMSGD
61
{
62
63
public:
64
SVMSGDImpl();
65
66
virtual ~SVMSGDImpl() {}
67
68
virtual bool train(const Ptr<TrainData>& data, int) CV_OVERRIDE;
69
70
virtual float predict( InputArray samples, OutputArray results=noArray(), int flags = 0 ) const CV_OVERRIDE;
71
72
virtual bool isClassifier() const CV_OVERRIDE;
73
74
virtual bool isTrained() const CV_OVERRIDE;
75
76
virtual void clear() CV_OVERRIDE;
77
78
virtual void write(FileStorage &fs) const CV_OVERRIDE;
79
80
virtual void read(const FileNode &fn) CV_OVERRIDE;
81
82
virtual Mat getWeights() CV_OVERRIDE { return weights_; }
83
84
virtual float getShift() CV_OVERRIDE { return shift_; }
85
86
virtual int getVarCount() const CV_OVERRIDE { return weights_.cols; }
87
88
virtual String getDefaultName() const CV_OVERRIDE {return "opencv_ml_svmsgd";}
89
90
virtual void setOptimalParameters(int svmsgdType = ASGD, int marginType = SOFT_MARGIN) CV_OVERRIDE;
91
92
inline int getSvmsgdType() const CV_OVERRIDE { return params.svmsgdType; }
93
inline void setSvmsgdType(int val) CV_OVERRIDE { params.svmsgdType = val; }
94
inline int getMarginType() const CV_OVERRIDE { return params.marginType; }
95
inline void setMarginType(int val) CV_OVERRIDE { params.marginType = val; }
96
inline float getMarginRegularization() const CV_OVERRIDE { return params.marginRegularization; }
97
inline void setMarginRegularization(float val) CV_OVERRIDE { params.marginRegularization = val; }
98
inline float getInitialStepSize() const CV_OVERRIDE { return params.initialStepSize; }
99
inline void setInitialStepSize(float val) CV_OVERRIDE { params.initialStepSize = val; }
100
inline float getStepDecreasingPower() const CV_OVERRIDE { return params.stepDecreasingPower; }
101
inline void setStepDecreasingPower(float val) CV_OVERRIDE { params.stepDecreasingPower = val; }
102
inline cv::TermCriteria getTermCriteria() const CV_OVERRIDE { return params.termCrit; }
103
inline void setTermCriteria(const cv::TermCriteria& val) CV_OVERRIDE { params.termCrit = val; }
104
105
private:
106
void updateWeights(InputArray sample, bool positive, float stepSize, Mat &weights);
107
108
void writeParams( FileStorage &fs ) const;
109
110
void readParams( const FileNode &fn );
111
112
static inline bool isPositive(float val) { return val > 0; }
113
114
static void normalizeSamples(Mat &matrix, Mat &average, float &multiplier);
115
116
float calcShift(InputArray _samples, InputArray _responses) const;
117
118
static void makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier);
119
120
// Vector with SVM weights
121
Mat weights_;
122
float shift_;
123
124
// Parameters for learning
125
struct SVMSGDParams
126
{
127
float marginRegularization;
128
float initialStepSize;
129
float stepDecreasingPower;
130
TermCriteria termCrit;
131
int svmsgdType;
132
int marginType;
133
};
134
135
SVMSGDParams params;
136
};
137
138
Ptr<SVMSGD> SVMSGD::create()
139
{
140
return makePtr<SVMSGDImpl>();
141
}
142
143
Ptr<SVMSGD> SVMSGD::load(const String& filepath, const String& nodeName)
144
{
145
return Algorithm::load<SVMSGD>(filepath, nodeName);
146
}
147
148
149
void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &average, float &multiplier)
150
{
151
int featuresCount = samples.cols;
152
int samplesCount = samples.rows;
153
154
average = Mat(1, featuresCount, samples.type());
155
CV_Assert(average.type() == CV_32FC1);
156
for (int featureIndex = 0; featureIndex < featuresCount; featureIndex++)
157
{
158
average.at<float>(featureIndex) = static_cast<float>(mean(samples.col(featureIndex))[0]);
159
}
160
161
for (int sampleIndex = 0; sampleIndex < samplesCount; sampleIndex++)
162
{
163
samples.row(sampleIndex) -= average;
164
}
165
166
double normValue = norm(samples);
167
168
multiplier = static_cast<float>(sqrt(static_cast<double>(samples.total())) / normValue);
169
170
samples *= multiplier;
171
}
172
173
void SVMSGDImpl::makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier)
174
{
175
Mat normalizedTrainSamples = trainSamples.clone();
176
int samplesCount = normalizedTrainSamples.rows;
177
178
normalizeSamples(normalizedTrainSamples, average, multiplier);
179
180
Mat onesCol = Mat::ones(samplesCount, 1, CV_32F);
181
cv::hconcat(normalizedTrainSamples, onesCol, extendedTrainSamples);
182
}
183
184
void SVMSGDImpl::updateWeights(InputArray _sample, bool positive, float stepSize, Mat& weights)
185
{
186
Mat sample = _sample.getMat();
187
188
int response = positive ? 1 : -1; // ensure that trainResponses are -1 or 1
189
190
if ( sample.dot(weights) * response > 1)
191
{
192
// Not a support vector, only apply weight decay
193
weights *= (1.f - stepSize * params.marginRegularization);
194
}
195
else
196
{
197
// It's a support vector, add it to the weights
198
weights -= (stepSize * params.marginRegularization) * weights - (stepSize * response) * sample;
199
}
200
}
201
202
float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
203
{
204
float margin[2] = { std::numeric_limits<float>::max(), std::numeric_limits<float>::max() };
205
206
Mat trainSamples = _samples.getMat();
207
int trainSamplesCount = trainSamples.rows;
208
209
Mat trainResponses = _responses.getMat();
210
211
CV_Assert(trainResponses.type() == CV_32FC1);
212
for (int samplesIndex = 0; samplesIndex < trainSamplesCount; samplesIndex++)
213
{
214
Mat currentSample = trainSamples.row(samplesIndex);
215
float dotProduct = static_cast<float>(currentSample.dot(weights_));
216
217
bool positive = isPositive(trainResponses.at<float>(samplesIndex));
218
int index = positive ? 0 : 1;
219
float signToMul = positive ? 1.f : -1.f;
220
float curMargin = dotProduct * signToMul;
221
222
if (curMargin < margin[index])
223
{
224
margin[index] = curMargin;
225
}
226
}
227
228
return -(margin[0] - margin[1]) / 2.f;
229
}
230
231
bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
232
{
233
clear();
234
CV_Assert( isClassifier() ); //toDo: consider
235
236
Mat trainSamples = data->getTrainSamples();
237
238
int featureCount = trainSamples.cols;
239
Mat trainResponses = data->getTrainResponses(); // (trainSamplesCount x 1) matrix
240
241
CV_Assert(trainResponses.rows == trainSamples.rows);
242
243
if (trainResponses.empty())
244
{
245
return false;
246
}
247
248
int positiveCount = countNonZero(trainResponses >= 0);
249
int negativeCount = countNonZero(trainResponses < 0);
250
251
if ( positiveCount <= 0 || negativeCount <= 0 )
252
{
253
weights_ = Mat::zeros(1, featureCount, CV_32F);
254
shift_ = (positiveCount > 0) ? 1.f : -1.f;
255
return true;
256
}
257
258
Mat extendedTrainSamples;
259
Mat average;
260
float multiplier = 0;
261
makeExtendedTrainSamples(trainSamples, extendedTrainSamples, average, multiplier);
262
263
int extendedTrainSamplesCount = extendedTrainSamples.rows;
264
int extendedFeatureCount = extendedTrainSamples.cols;
265
266
Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
267
Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
268
Mat averageExtendedWeights;
269
if (params.svmsgdType == ASGD)
270
{
271
averageExtendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
272
}
273
274
RNG rng(0);
275
276
CV_Assert (params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS);
277
int maxCount = (params.termCrit.type & TermCriteria::COUNT) ? params.termCrit.maxCount : INT_MAX;
278
double epsilon = (params.termCrit.type & TermCriteria::EPS) ? params.termCrit.epsilon : 0;
279
280
double err = DBL_MAX;
281
CV_Assert (trainResponses.type() == CV_32FC1);
282
// Stochastic gradient descent SVM
283
for (int iter = 0; (iter < maxCount) && (err > epsilon); iter++)
284
{
285
int randomNumber = rng.uniform(0, extendedTrainSamplesCount); //generate sample number
286
287
Mat currentSample = extendedTrainSamples.row(randomNumber);
288
289
float stepSize = params.initialStepSize * std::pow((1 + params.marginRegularization * params.initialStepSize * (float)iter), (-params.stepDecreasingPower)); //update stepSize
290
291
updateWeights( currentSample, isPositive(trainResponses.at<float>(randomNumber)), stepSize, extendedWeights );
292
293
//average weights (only for ASGD model)
294
if (params.svmsgdType == ASGD)
295
{
296
averageExtendedWeights = ((float)iter/ (1 + (float)iter)) * averageExtendedWeights + extendedWeights / (1 + (float) iter);
297
err = norm(averageExtendedWeights - previousWeights);
298
averageExtendedWeights.copyTo(previousWeights);
299
}
300
else
301
{
302
err = norm(extendedWeights - previousWeights);
303
extendedWeights.copyTo(previousWeights);
304
}
305
}
306
307
if (params.svmsgdType == ASGD)
308
{
309
extendedWeights = averageExtendedWeights;
310
}
311
312
Rect roi(0, 0, featureCount, 1);
313
weights_ = extendedWeights(roi);
314
weights_ *= multiplier;
315
316
CV_Assert((params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN) && (extendedWeights.type() == CV_32FC1));
317
318
if (params.marginType == SOFT_MARGIN)
319
{
320
shift_ = extendedWeights.at<float>(featureCount) - static_cast<float>(weights_.dot(average));
321
}
322
else
323
{
324
shift_ = calcShift(trainSamples, trainResponses);
325
}
326
327
return true;
328
}
329
330
float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) const
331
{
332
float result = 0;
333
cv::Mat samples = _samples.getMat();
334
int nSamples = samples.rows;
335
cv::Mat results;
336
337
CV_Assert( samples.cols == weights_.cols && samples.type() == CV_32FC1);
338
339
if( _results.needed() )
340
{
341
_results.create( nSamples, 1, samples.type() );
342
results = _results.getMat();
343
}
344
else
345
{
346
CV_Assert( nSamples == 1 );
347
results = Mat(1, 1, CV_32FC1, &result);
348
}
349
350
for (int sampleIndex = 0; sampleIndex < nSamples; sampleIndex++)
351
{
352
Mat currentSample = samples.row(sampleIndex);
353
float criterion = static_cast<float>(currentSample.dot(weights_)) + shift_;
354
results.at<float>(sampleIndex) = (criterion >= 0) ? 1.f : -1.f;
355
}
356
357
return result;
358
}
359
360
bool SVMSGDImpl::isClassifier() const
361
{
362
return (params.svmsgdType == SGD || params.svmsgdType == ASGD)
363
&&
364
(params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN)
365
&&
366
(params.marginRegularization > 0) && (params.initialStepSize > 0) && (params.stepDecreasingPower >= 0);
367
}
368
369
bool SVMSGDImpl::isTrained() const
370
{
371
return !weights_.empty();
372
}
373
374
void SVMSGDImpl::write(FileStorage& fs) const
375
{
376
if( !isTrained() )
377
CV_Error( CV_StsParseError, "SVMSGD model data is invalid, it hasn't been trained" );
378
379
writeFormat(fs);
380
writeParams( fs );
381
382
fs << "weights" << weights_;
383
fs << "shift" << shift_;
384
}
385
386
void SVMSGDImpl::writeParams( FileStorage& fs ) const
387
{
388
String SvmsgdTypeStr;
389
390
switch (params.svmsgdType)
391
{
392
case SGD:
393
SvmsgdTypeStr = "SGD";
394
break;
395
case ASGD:
396
SvmsgdTypeStr = "ASGD";
397
break;
398
default:
399
SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType);
400
}
401
402
fs << "svmsgdType" << SvmsgdTypeStr;
403
404
String marginTypeStr;
405
406
switch (params.marginType)
407
{
408
case SOFT_MARGIN:
409
marginTypeStr = "SOFT_MARGIN";
410
break;
411
case HARD_MARGIN:
412
marginTypeStr = "HARD_MARGIN";
413
break;
414
default:
415
marginTypeStr = format("Unknown_%d", params.marginType);
416
}
417
418
fs << "marginType" << marginTypeStr;
419
420
fs << "marginRegularization" << params.marginRegularization;
421
fs << "initialStepSize" << params.initialStepSize;
422
fs << "stepDecreasingPower" << params.stepDecreasingPower;
423
424
fs << "term_criteria" << "{:";
425
if( params.termCrit.type & TermCriteria::EPS )
426
fs << "epsilon" << params.termCrit.epsilon;
427
if( params.termCrit.type & TermCriteria::COUNT )
428
fs << "iterations" << params.termCrit.maxCount;
429
fs << "}";
430
}
431
void SVMSGDImpl::readParams( const FileNode& fn )
432
{
433
String svmsgdTypeStr = (String)fn["svmsgdType"];
434
int svmsgdType =
435
svmsgdTypeStr == "SGD" ? SGD :
436
svmsgdTypeStr == "ASGD" ? ASGD : -1;
437
438
if( svmsgdType < 0 )
439
CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" );
440
441
params.svmsgdType = svmsgdType;
442
443
String marginTypeStr = (String)fn["marginType"];
444
int marginType =
445
marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN :
446
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : -1;
447
448
if( marginType < 0 )
449
CV_Error( CV_StsParseError, "Missing or invalid margin type" );
450
451
params.marginType = marginType;
452
453
CV_Assert ( fn["marginRegularization"].isReal() );
454
params.marginRegularization = (float)fn["marginRegularization"];
455
456
CV_Assert ( fn["initialStepSize"].isReal() );
457
params.initialStepSize = (float)fn["initialStepSize"];
458
459
CV_Assert ( fn["stepDecreasingPower"].isReal() );
460
params.stepDecreasingPower = (float)fn["stepDecreasingPower"];
461
462
FileNode tcnode = fn["term_criteria"];
463
CV_Assert(!tcnode.empty());
464
params.termCrit.epsilon = (double)tcnode["epsilon"];
465
params.termCrit.maxCount = (int)tcnode["iterations"];
466
params.termCrit.type = (params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) +
467
(params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0);
468
CV_Assert ((params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS));
469
}
470
471
void SVMSGDImpl::read(const FileNode& fn)
472
{
473
clear();
474
475
readParams(fn);
476
477
fn["weights"] >> weights_;
478
fn["shift"] >> shift_;
479
}
480
481
void SVMSGDImpl::clear()
482
{
483
weights_.release();
484
shift_ = 0;
485
}
486
487
488
SVMSGDImpl::SVMSGDImpl()
489
{
490
clear();
491
setOptimalParameters();
492
}
493
494
void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
495
{
496
switch (svmsgdType)
497
{
498
case SGD:
499
params.svmsgdType = SGD;
500
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
501
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
502
params.marginRegularization = 0.0001f;
503
params.initialStepSize = 0.05f;
504
params.stepDecreasingPower = 1.f;
505
params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001);
506
break;
507
508
case ASGD:
509
params.svmsgdType = ASGD;
510
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
511
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
512
params.marginRegularization = 0.00001f;
513
params.initialStepSize = 0.05f;
514
params.stepDecreasingPower = 0.75f;
515
params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001);
516
break;
517
518
default:
519
CV_Error( CV_StsParseError, "SVMSGD model data is invalid" );
520
}
521
}
522
} //ml
523
} //cv
524
525