Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/ml/src/boost.cpp
16337 views
1
/*M///////////////////////////////////////////////////////////////////////////////////////
2
//
3
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4
//
5
// By downloading, copying, installing or using the software you agree to this license.
6
// If you do not agree to this license, do not download, install,
7
// copy or use the software.
8
//
9
//
10
// License Agreement
11
// For Open Source Computer Vision Library
12
//
13
// Copyright (C) 2000, Intel Corporation, all rights reserved.
14
// Copyright (C) 2014, Itseez Inc, all rights reserved.
15
// Third party copyrights are property of their respective owners.
16
//
17
// Redistribution and use in source and binary forms, with or without modification,
18
// are permitted provided that the following conditions are met:
19
//
20
// * Redistribution's of source code must retain the above copyright notice,
21
// this list of conditions and the following disclaimer.
22
//
23
// * Redistribution's in binary form must reproduce the above copyright notice,
24
// this list of conditions and the following disclaimer in the documentation
25
// and/or other materials provided with the distribution.
26
//
27
// * The name of the copyright holders may not be used to endorse or promote products
28
// derived from this software without specific prior written permission.
29
//
30
// This software is provided by the copyright holders and contributors "as is" and
31
// any express or implied warranties, including, but not limited to, the implied
32
// warranties of merchantability and fitness for a particular purpose are disclaimed.
33
// In no event shall the Intel Corporation or contributors be liable for any direct,
34
// indirect, incidental, special, exemplary, or consequential damages
35
// (including, but not limited to, procurement of substitute goods or services;
36
// loss of use, data, or profits; or business interruption) however caused
37
// and on any theory of liability, whether in contract, strict liability,
38
// or tort (including negligence or otherwise) arising in any way out of
39
// the use of this software, even if advised of the possibility of such damage.
40
//
41
//M*/
42
43
#include "precomp.hpp"
44
45
namespace cv { namespace ml {
46
47
static inline double
48
log_ratio( double val )
49
{
50
const double eps = 1e-5;
51
val = std::max( val, eps );
52
val = std::min( val, 1. - eps );
53
return log( val/(1. - val) );
54
}
55
56
57
BoostTreeParams::BoostTreeParams()
58
{
59
boostType = Boost::REAL;
60
weakCount = 100;
61
weightTrimRate = 0.95;
62
}
63
64
BoostTreeParams::BoostTreeParams( int _boostType, int _weak_count,
65
double _weightTrimRate)
66
{
67
boostType = _boostType;
68
weakCount = _weak_count;
69
weightTrimRate = _weightTrimRate;
70
}
71
72
class DTreesImplForBoost CV_FINAL : public DTreesImpl
73
{
74
public:
75
DTreesImplForBoost()
76
{
77
params.setCVFolds(0);
78
params.setMaxDepth(1);
79
}
80
virtual ~DTreesImplForBoost() {}
81
82
bool isClassifier() const CV_OVERRIDE { return true; }
83
84
void clear() CV_OVERRIDE
85
{
86
DTreesImpl::clear();
87
}
88
89
void startTraining( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
90
{
91
DTreesImpl::startTraining(trainData, flags);
92
sumResult.assign(w->sidx.size(), 0.);
93
94
if( bparams.boostType != Boost::DISCRETE )
95
{
96
_isClassifier = false;
97
int i, n = (int)w->cat_responses.size();
98
w->ord_responses.resize(n);
99
100
double a = -1, b = 1;
101
if( bparams.boostType == Boost::LOGIT )
102
{
103
a = -2, b = 2;
104
}
105
for( i = 0; i < n; i++ )
106
w->ord_responses[i] = w->cat_responses[i] > 0 ? b : a;
107
}
108
109
normalizeWeights();
110
}
111
112
void normalizeWeights()
113
{
114
int i, n = (int)w->sidx.size();
115
double sumw = 0, a, b;
116
for( i = 0; i < n; i++ )
117
sumw += w->sample_weights[w->sidx[i]];
118
if( sumw > DBL_EPSILON )
119
{
120
a = 1./sumw;
121
b = 0;
122
}
123
else
124
{
125
a = 0;
126
b = 1;
127
}
128
for( i = 0; i < n; i++ )
129
{
130
double& wval = w->sample_weights[w->sidx[i]];
131
wval = wval*a + b;
132
}
133
}
134
135
void endTraining() CV_OVERRIDE
136
{
137
DTreesImpl::endTraining();
138
vector<double> e;
139
std::swap(sumResult, e);
140
}
141
142
void scaleTree( int root, double scale )
143
{
144
int nidx = root, pidx = 0;
145
Node *node = 0;
146
147
// traverse the tree and save all the nodes in depth-first order
148
for(;;)
149
{
150
for(;;)
151
{
152
node = &nodes[nidx];
153
node->value *= scale;
154
if( node->left < 0 )
155
break;
156
nidx = node->left;
157
}
158
159
for( pidx = node->parent; pidx >= 0 && nodes[pidx].right == nidx;
160
nidx = pidx, pidx = nodes[pidx].parent )
161
;
162
163
if( pidx < 0 )
164
break;
165
166
nidx = nodes[pidx].right;
167
}
168
}
169
170
void calcValue( int nidx, const vector<int>& _sidx ) CV_OVERRIDE
171
{
172
DTreesImpl::calcValue(nidx, _sidx);
173
WNode* node = &w->wnodes[nidx];
174
if( bparams.boostType == Boost::DISCRETE )
175
{
176
node->value = node->class_idx == 0 ? -1 : 1;
177
}
178
else if( bparams.boostType == Boost::REAL )
179
{
180
double p = (node->value+1)*0.5;
181
node->value = 0.5*log_ratio(p);
182
}
183
}
184
185
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
186
{
187
startTraining(trainData, flags);
188
int treeidx, ntrees = bparams.weakCount >= 0 ? bparams.weakCount : 10000;
189
vector<int> sidx = w->sidx;
190
191
for( treeidx = 0; treeidx < ntrees; treeidx++ )
192
{
193
int root = addTree( sidx );
194
if( root < 0 )
195
return false;
196
updateWeightsAndTrim( treeidx, sidx );
197
}
198
endTraining();
199
return true;
200
}
201
202
void updateWeightsAndTrim( int treeidx, vector<int>& sidx )
203
{
204
int i, n = (int)w->sidx.size();
205
int nvars = (int)varIdx.size();
206
double sumw = 0., C = 1.;
207
cv::AutoBuffer<double> buf(n + nvars);
208
double* result = buf.data();
209
float* sbuf = (float*)(result + n);
210
Mat sample(1, nvars, CV_32F, sbuf);
211
int predictFlags = bparams.boostType == Boost::DISCRETE ? (PREDICT_MAX_VOTE | RAW_OUTPUT) : PREDICT_SUM;
212
predictFlags |= COMPRESSED_INPUT;
213
214
for( i = 0; i < n; i++ )
215
{
216
w->data->getSample(varIdx, w->sidx[i], sbuf );
217
result[i] = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags);
218
}
219
220
// now update weights and other parameters for each type of boosting
221
if( bparams.boostType == Boost::DISCRETE )
222
{
223
// Discrete AdaBoost:
224
// weak_eval[i] (=f(x_i)) is in {-1,1}
225
// err = sum(w_i*(f(x_i) != y_i))/sum(w_i)
226
// C = log((1-err)/err)
227
// w_i *= exp(C*(f(x_i) != y_i))
228
double err = 0.;
229
230
for( i = 0; i < n; i++ )
231
{
232
int si = w->sidx[i];
233
double wval = w->sample_weights[si];
234
sumw += wval;
235
err += wval*(result[i] != w->cat_responses[si]);
236
}
237
238
if( sumw != 0 )
239
err /= sumw;
240
C = -log_ratio( err );
241
double scale = std::exp(C);
242
243
sumw = 0;
244
for( i = 0; i < n; i++ )
245
{
246
int si = w->sidx[i];
247
double wval = w->sample_weights[si];
248
if( result[i] != w->cat_responses[si] )
249
wval *= scale;
250
sumw += wval;
251
w->sample_weights[si] = wval;
252
}
253
254
scaleTree(roots[treeidx], C);
255
}
256
else if( bparams.boostType == Boost::REAL || bparams.boostType == Boost::GENTLE )
257
{
258
// Real AdaBoost:
259
// weak_eval[i] = f(x_i) = 0.5*log(p(x_i)/(1-p(x_i))), p(x_i)=P(y=1|x_i)
260
// w_i *= exp(-y_i*f(x_i))
261
262
// Gentle AdaBoost:
263
// weak_eval[i] = f(x_i) in [-1,1]
264
// w_i *= exp(-y_i*f(x_i))
265
for( i = 0; i < n; i++ )
266
{
267
int si = w->sidx[i];
268
CV_Assert( std::abs(w->ord_responses[si]) == 1 );
269
double wval = w->sample_weights[si]*std::exp(-result[i]*w->ord_responses[si]);
270
sumw += wval;
271
w->sample_weights[si] = wval;
272
}
273
}
274
else if( bparams.boostType == Boost::LOGIT )
275
{
276
// LogitBoost:
277
// weak_eval[i] = f(x_i) in [-z_max,z_max]
278
// sum_response = F(x_i).
279
// F(x_i) += 0.5*f(x_i)
280
// p(x_i) = exp(F(x_i))/(exp(F(x_i)) + exp(-F(x_i))=1/(1+exp(-2*F(x_i)))
281
// reuse weak_eval: weak_eval[i] <- p(x_i)
282
// w_i = p(x_i)*1(1 - p(x_i))
283
// z_i = ((y_i+1)/2 - p(x_i))/(p(x_i)*(1 - p(x_i)))
284
// store z_i to the data->data_root as the new target responses
285
const double lb_weight_thresh = FLT_EPSILON;
286
const double lb_z_max = 10.;
287
288
for( i = 0; i < n; i++ )
289
{
290
int si = w->sidx[i];
291
sumResult[i] += 0.5*result[i];
292
double p = 1./(1 + std::exp(-2*sumResult[i]));
293
double wval = std::max( p*(1 - p), lb_weight_thresh ), z;
294
w->sample_weights[si] = wval;
295
sumw += wval;
296
if( w->ord_responses[si] > 0 )
297
{
298
z = 1./p;
299
w->ord_responses[si] = std::min(z, lb_z_max);
300
}
301
else
302
{
303
z = 1./(1-p);
304
w->ord_responses[si] = -std::min(z, lb_z_max);
305
}
306
}
307
}
308
else
309
CV_Error(CV_StsNotImplemented, "Unknown boosting type");
310
311
/*if( bparams.boostType != Boost::LOGIT )
312
{
313
double err = 0;
314
for( i = 0; i < n; i++ )
315
{
316
sumResult[i] += result[i]*C;
317
if( bparams.boostType != Boost::DISCRETE )
318
err += sumResult[i]*w->ord_responses[w->sidx[i]] < 0;
319
else
320
err += sumResult[i]*w->cat_responses[w->sidx[i]] < 0;
321
}
322
printf("%d trees. C=%.2f, training error=%.1f%%, working set size=%d (out of %d)\n", (int)roots.size(), C, err*100./n, (int)sidx.size(), n);
323
}*/
324
325
// renormalize weights
326
if( sumw > FLT_EPSILON )
327
normalizeWeights();
328
329
if( bparams.weightTrimRate <= 0. || bparams.weightTrimRate >= 1. )
330
return;
331
332
for( i = 0; i < n; i++ )
333
result[i] = w->sample_weights[w->sidx[i]];
334
std::sort(result, result + n);
335
336
// as weight trimming occurs immediately after updating the weights,
337
// where they are renormalized, we assume that the weight sum = 1.
338
sumw = 1. - bparams.weightTrimRate;
339
340
for( i = 0; i < n; i++ )
341
{
342
double wval = result[i];
343
if( sumw <= 0 )
344
break;
345
sumw -= wval;
346
}
347
348
double threshold = i < n ? result[i] : DBL_MAX;
349
sidx.clear();
350
351
for( i = 0; i < n; i++ )
352
{
353
int si = w->sidx[i];
354
if( w->sample_weights[si] >= threshold )
355
sidx.push_back(si);
356
}
357
}
358
359
float predictTrees( const Range& range, const Mat& sample, int flags0 ) const CV_OVERRIDE
360
{
361
int flags = (flags0 & ~PREDICT_MASK) | PREDICT_SUM;
362
float val = DTreesImpl::predictTrees(range, sample, flags);
363
if( flags != flags0 )
364
{
365
int ival = (int)(val > 0);
366
if( !(flags0 & RAW_OUTPUT) )
367
ival = classLabels[ival];
368
val = (float)ival;
369
}
370
return val;
371
}
372
373
void writeTrainingParams( FileStorage& fs ) const CV_OVERRIDE
374
{
375
fs << "boosting_type" <<
376
(bparams.boostType == Boost::DISCRETE ? "DiscreteAdaboost" :
377
bparams.boostType == Boost::REAL ? "RealAdaboost" :
378
bparams.boostType == Boost::LOGIT ? "LogitBoost" :
379
bparams.boostType == Boost::GENTLE ? "GentleAdaboost" : "Unknown");
380
381
DTreesImpl::writeTrainingParams(fs);
382
fs << "weight_trimming_rate" << bparams.weightTrimRate;
383
}
384
385
void write( FileStorage& fs ) const CV_OVERRIDE
386
{
387
if( roots.empty() )
388
CV_Error( CV_StsBadArg, "RTrees have not been trained" );
389
390
writeFormat(fs);
391
writeParams(fs);
392
393
int k, ntrees = (int)roots.size();
394
395
fs << "ntrees" << ntrees
396
<< "trees" << "[";
397
398
for( k = 0; k < ntrees; k++ )
399
{
400
fs << "{";
401
writeTree(fs, roots[k]);
402
fs << "}";
403
}
404
405
fs << "]";
406
}
407
408
void readParams( const FileNode& fn ) CV_OVERRIDE
409
{
410
DTreesImpl::readParams(fn);
411
412
FileNode tparams_node = fn["training_params"];
413
// check for old layout
414
String bts = (String)(fn["boosting_type"].empty() ?
415
tparams_node["boosting_type"] : fn["boosting_type"]);
416
bparams.boostType = (bts == "DiscreteAdaboost" ? Boost::DISCRETE :
417
bts == "RealAdaboost" ? Boost::REAL :
418
bts == "LogitBoost" ? Boost::LOGIT :
419
bts == "GentleAdaboost" ? Boost::GENTLE : -1);
420
_isClassifier = bparams.boostType == Boost::DISCRETE;
421
// check for old layout
422
bparams.weightTrimRate = (double)(fn["weight_trimming_rate"].empty() ?
423
tparams_node["weight_trimming_rate"] : fn["weight_trimming_rate"]);
424
}
425
426
void read( const FileNode& fn ) CV_OVERRIDE
427
{
428
clear();
429
430
int ntrees = (int)fn["ntrees"];
431
readParams(fn);
432
433
FileNode trees_node = fn["trees"];
434
FileNodeIterator it = trees_node.begin();
435
CV_Assert( ntrees == (int)trees_node.size() );
436
437
for( int treeidx = 0; treeidx < ntrees; treeidx++, ++it )
438
{
439
FileNode nfn = (*it)["nodes"];
440
readTree(nfn);
441
}
442
}
443
444
BoostTreeParams bparams;
445
vector<double> sumResult;
446
};
447
448
449
class BoostImpl : public Boost
450
{
451
public:
452
BoostImpl() {}
453
virtual ~BoostImpl() {}
454
455
inline int getBoostType() const CV_OVERRIDE { return impl.bparams.boostType; }
456
inline void setBoostType(int val) CV_OVERRIDE { impl.bparams.boostType = val; }
457
inline int getWeakCount() const CV_OVERRIDE { return impl.bparams.weakCount; }
458
inline void setWeakCount(int val) CV_OVERRIDE { impl.bparams.weakCount = val; }
459
inline double getWeightTrimRate() const CV_OVERRIDE { return impl.bparams.weightTrimRate; }
460
inline void setWeightTrimRate(double val) CV_OVERRIDE { impl.bparams.weightTrimRate = val; }
461
462
inline int getMaxCategories() const CV_OVERRIDE { return impl.params.getMaxCategories(); }
463
inline void setMaxCategories(int val) CV_OVERRIDE { impl.params.setMaxCategories(val); }
464
inline int getMaxDepth() const CV_OVERRIDE { return impl.params.getMaxDepth(); }
465
inline void setMaxDepth(int val) CV_OVERRIDE { impl.params.setMaxDepth(val); }
466
inline int getMinSampleCount() const CV_OVERRIDE { return impl.params.getMinSampleCount(); }
467
inline void setMinSampleCount(int val) CV_OVERRIDE { impl.params.setMinSampleCount(val); }
468
inline int getCVFolds() const CV_OVERRIDE { return impl.params.getCVFolds(); }
469
inline void setCVFolds(int val) CV_OVERRIDE { impl.params.setCVFolds(val); }
470
inline bool getUseSurrogates() const CV_OVERRIDE { return impl.params.getUseSurrogates(); }
471
inline void setUseSurrogates(bool val) CV_OVERRIDE { impl.params.setUseSurrogates(val); }
472
inline bool getUse1SERule() const CV_OVERRIDE { return impl.params.getUse1SERule(); }
473
inline void setUse1SERule(bool val) CV_OVERRIDE { impl.params.setUse1SERule(val); }
474
inline bool getTruncatePrunedTree() const CV_OVERRIDE { return impl.params.getTruncatePrunedTree(); }
475
inline void setTruncatePrunedTree(bool val) CV_OVERRIDE { impl.params.setTruncatePrunedTree(val); }
476
inline float getRegressionAccuracy() const CV_OVERRIDE { return impl.params.getRegressionAccuracy(); }
477
inline void setRegressionAccuracy(float val) CV_OVERRIDE { impl.params.setRegressionAccuracy(val); }
478
inline cv::Mat getPriors() const CV_OVERRIDE { return impl.params.getPriors(); }
479
inline void setPriors(const cv::Mat& val) CV_OVERRIDE { impl.params.setPriors(val); }
480
481
String getDefaultName() const CV_OVERRIDE { return "opencv_ml_boost"; }
482
483
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
484
{
485
return impl.train(trainData, flags);
486
}
487
488
float predict( InputArray samples, OutputArray results, int flags ) const CV_OVERRIDE
489
{
490
return impl.predict(samples, results, flags);
491
}
492
493
void write( FileStorage& fs ) const CV_OVERRIDE
494
{
495
impl.write(fs);
496
}
497
498
void read( const FileNode& fn ) CV_OVERRIDE
499
{
500
impl.read(fn);
501
}
502
503
int getVarCount() const CV_OVERRIDE { return impl.getVarCount(); }
504
505
bool isTrained() const CV_OVERRIDE { return impl.isTrained(); }
506
bool isClassifier() const CV_OVERRIDE { return impl.isClassifier(); }
507
508
const vector<int>& getRoots() const CV_OVERRIDE { return impl.getRoots(); }
509
const vector<Node>& getNodes() const CV_OVERRIDE { return impl.getNodes(); }
510
const vector<Split>& getSplits() const CV_OVERRIDE { return impl.getSplits(); }
511
const vector<int>& getSubsets() const CV_OVERRIDE { return impl.getSubsets(); }
512
513
DTreesImplForBoost impl;
514
};
515
516
517
Ptr<Boost> Boost::create()
518
{
519
return makePtr<BoostImpl>();
520
}
521
522
Ptr<Boost> Boost::load(const String& filepath, const String& nodeName)
523
{
524
return Algorithm::load<Boost>(filepath, nodeName);
525
}
526
527
}}
528
529
/* End of file. */
530
531