Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/ml/src/tree.cpp
16337 views
1
/*M///////////////////////////////////////////////////////////////////////////////////////
2
//
3
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4
//
5
// By downloading, copying, installing or using the software you agree to this license.
6
// If you do not agree to this license, do not download, install,
7
// copy or use the software.
8
//
9
//
10
// License Agreement
11
// For Open Source Computer Vision Library
12
//
13
// Copyright (C) 2000, Intel Corporation, all rights reserved.
14
// Copyright (C) 2014, Itseez Inc, all rights reserved.
15
// Third party copyrights are property of their respective owners.
16
//
17
// Redistribution and use in source and binary forms, with or without modification,
18
// are permitted provided that the following conditions are met:
19
//
20
// * Redistribution's of source code must retain the above copyright notice,
21
// this list of conditions and the following disclaimer.
22
//
23
// * Redistribution's in binary form must reproduce the above copyright notice,
24
// this list of conditions and the following disclaimer in the documentation
25
// and/or other materials provided with the distribution.
26
//
27
// * The name of the copyright holders may not be used to endorse or promote products
28
// derived from this software without specific prior written permission.
29
//
30
// This software is provided by the copyright holders and contributors "as is" and
31
// any express or implied warranties, including, but not limited to, the implied
32
// warranties of merchantability and fitness for a particular purpose are disclaimed.
33
// In no event shall the Intel Corporation or contributors be liable for any direct,
34
// indirect, incidental, special, exemplary, or consequential damages
35
// (including, but not limited to, procurement of substitute goods or services;
36
// loss of use, data, or profits; or business interruption) however caused
37
// and on any theory of liability, whether in contract, strict liability,
38
// or tort (including negligence or otherwise) arising in any way out of
39
// the use of this software, even if advised of the possibility of such damage.
40
//
41
//M*/
42
43
#include "precomp.hpp"
44
#include <ctype.h>
45
46
namespace cv {
47
namespace ml {
48
49
using std::vector;
50
51
TreeParams::TreeParams()
52
{
53
maxDepth = INT_MAX;
54
minSampleCount = 10;
55
regressionAccuracy = 0.01f;
56
useSurrogates = false;
57
maxCategories = 10;
58
CVFolds = 10;
59
use1SERule = true;
60
truncatePrunedTree = true;
61
priors = Mat();
62
}
63
64
TreeParams::TreeParams(int _maxDepth, int _minSampleCount,
65
double _regressionAccuracy, bool _useSurrogates,
66
int _maxCategories, int _CVFolds,
67
bool _use1SERule, bool _truncatePrunedTree,
68
const Mat& _priors)
69
{
70
maxDepth = _maxDepth;
71
minSampleCount = _minSampleCount;
72
regressionAccuracy = (float)_regressionAccuracy;
73
useSurrogates = _useSurrogates;
74
maxCategories = _maxCategories;
75
CVFolds = _CVFolds;
76
use1SERule = _use1SERule;
77
truncatePrunedTree = _truncatePrunedTree;
78
priors = _priors;
79
}
80
81
DTrees::Node::Node()
82
{
83
classIdx = 0;
84
value = 0;
85
parent = left = right = split = defaultDir = -1;
86
}
87
88
DTrees::Split::Split()
89
{
90
varIdx = 0;
91
inversed = false;
92
quality = 0.f;
93
next = -1;
94
c = 0.f;
95
subsetOfs = 0;
96
}
97
98
99
DTreesImpl::WorkData::WorkData(const Ptr<TrainData>& _data)
100
{
101
data = _data;
102
vector<int> subsampleIdx;
103
Mat sidx0 = _data->getTrainSampleIdx();
104
if( !sidx0.empty() )
105
{
106
sidx0.copyTo(sidx);
107
std::sort(sidx.begin(), sidx.end());
108
}
109
else
110
{
111
int n = _data->getNSamples();
112
setRangeVector(sidx, n);
113
}
114
115
maxSubsetSize = 0;
116
}
117
118
DTreesImpl::DTreesImpl() : _isClassifier(false) {}
119
DTreesImpl::~DTreesImpl() {}
120
void DTreesImpl::clear()
121
{
122
varIdx.clear();
123
compVarIdx.clear();
124
varType.clear();
125
catOfs.clear();
126
catMap.clear();
127
roots.clear();
128
nodes.clear();
129
splits.clear();
130
subsets.clear();
131
classLabels.clear();
132
133
w.release();
134
_isClassifier = false;
135
}
136
137
void DTreesImpl::startTraining( const Ptr<TrainData>& data, int )
138
{
139
clear();
140
w = makePtr<WorkData>(data);
141
142
Mat vtype = data->getVarType();
143
vtype.copyTo(varType);
144
145
data->getCatOfs().copyTo(catOfs);
146
data->getCatMap().copyTo(catMap);
147
data->getDefaultSubstValues().copyTo(missingSubst);
148
149
int nallvars = data->getNAllVars();
150
151
Mat vidx0 = data->getVarIdx();
152
if( !vidx0.empty() )
153
vidx0.copyTo(varIdx);
154
else
155
setRangeVector(varIdx, nallvars);
156
157
initCompVarIdx();
158
159
w->maxSubsetSize = 0;
160
161
int i, nvars = (int)varIdx.size();
162
for( i = 0; i < nvars; i++ )
163
w->maxSubsetSize = std::max(w->maxSubsetSize, getCatCount(varIdx[i]));
164
165
w->maxSubsetSize = std::max((w->maxSubsetSize + 31)/32, 1);
166
167
data->getSampleWeights().copyTo(w->sample_weights);
168
169
_isClassifier = data->getResponseType() == VAR_CATEGORICAL;
170
171
if( _isClassifier )
172
{
173
data->getNormCatResponses().copyTo(w->cat_responses);
174
data->getClassLabels().copyTo(classLabels);
175
int nclasses = (int)classLabels.size();
176
177
Mat class_weights = params.priors;
178
if( !class_weights.empty() )
179
{
180
if( class_weights.type() != CV_64F || !class_weights.isContinuous() )
181
{
182
Mat temp;
183
class_weights.convertTo(temp, CV_64F);
184
class_weights = temp;
185
}
186
CV_Assert( class_weights.checkVector(1, CV_64F) == nclasses );
187
188
int nsamples = (int)w->cat_responses.size();
189
const double* cw = class_weights.ptr<double>();
190
CV_Assert( (int)w->sample_weights.size() == nsamples );
191
192
for( i = 0; i < nsamples; i++ )
193
{
194
int ci = w->cat_responses[i];
195
CV_Assert( 0 <= ci && ci < nclasses );
196
w->sample_weights[i] *= cw[ci];
197
}
198
}
199
}
200
else
201
data->getResponses().copyTo(w->ord_responses);
202
}
203
204
205
void DTreesImpl::initCompVarIdx()
206
{
207
int nallvars = (int)varType.size();
208
compVarIdx.assign(nallvars, -1);
209
int i, nvars = (int)varIdx.size(), prevIdx = -1;
210
for( i = 0; i < nvars; i++ )
211
{
212
int vi = varIdx[i];
213
CV_Assert( 0 <= vi && vi < nallvars && vi > prevIdx );
214
prevIdx = vi;
215
compVarIdx[vi] = i;
216
}
217
}
218
219
void DTreesImpl::endTraining()
220
{
221
w.release();
222
}
223
224
bool DTreesImpl::train( const Ptr<TrainData>& trainData, int flags )
225
{
226
startTraining(trainData, flags);
227
bool ok = addTree( w->sidx ) >= 0;
228
w.release();
229
endTraining();
230
return ok;
231
}
232
233
const vector<int>& DTreesImpl::getActiveVars()
234
{
235
return varIdx;
236
}
237
238
int DTreesImpl::addTree(const vector<int>& sidx )
239
{
240
size_t n = (params.getMaxDepth() > 0 ? (1 << params.getMaxDepth()) : 1024) + w->wnodes.size();
241
242
w->wnodes.reserve(n);
243
w->wsplits.reserve(n);
244
w->wsubsets.reserve(n*w->maxSubsetSize);
245
w->wnodes.clear();
246
w->wsplits.clear();
247
w->wsubsets.clear();
248
249
int cv_n = params.getCVFolds();
250
251
if( cv_n > 0 )
252
{
253
w->cv_Tn.resize(n*cv_n);
254
w->cv_node_error.resize(n*cv_n);
255
w->cv_node_risk.resize(n*cv_n);
256
}
257
258
// build the tree recursively
259
int w_root = addNodeAndTrySplit(-1, sidx);
260
int maxdepth = INT_MAX;//pruneCV(root);
261
262
int w_nidx = w_root, pidx = -1, depth = 0;
263
int root = (int)nodes.size();
264
265
for(;;)
266
{
267
const WNode& wnode = w->wnodes[w_nidx];
268
Node node;
269
node.parent = pidx;
270
node.classIdx = wnode.class_idx;
271
node.value = wnode.value;
272
node.defaultDir = wnode.defaultDir;
273
274
int wsplit_idx = wnode.split;
275
if( wsplit_idx >= 0 )
276
{
277
const WSplit& wsplit = w->wsplits[wsplit_idx];
278
Split split;
279
split.c = wsplit.c;
280
split.quality = wsplit.quality;
281
split.inversed = wsplit.inversed;
282
split.varIdx = wsplit.varIdx;
283
split.subsetOfs = -1;
284
if( wsplit.subsetOfs >= 0 )
285
{
286
int ssize = getSubsetSize(split.varIdx);
287
split.subsetOfs = (int)subsets.size();
288
subsets.resize(split.subsetOfs + ssize);
289
// This check verifies that subsets index is in the correct range
290
// as in case ssize == 0 no real resize performed.
291
// Thus memory kept safe.
292
// Also this skips useless memcpy call when size parameter is zero
293
if(ssize > 0)
294
{
295
memcpy(&subsets[split.subsetOfs], &w->wsubsets[wsplit.subsetOfs], ssize*sizeof(int));
296
}
297
}
298
node.split = (int)splits.size();
299
splits.push_back(split);
300
}
301
int nidx = (int)nodes.size();
302
nodes.push_back(node);
303
if( pidx >= 0 )
304
{
305
int w_pidx = w->wnodes[w_nidx].parent;
306
if( w->wnodes[w_pidx].left == w_nidx )
307
{
308
nodes[pidx].left = nidx;
309
}
310
else
311
{
312
CV_Assert(w->wnodes[w_pidx].right == w_nidx);
313
nodes[pidx].right = nidx;
314
}
315
}
316
317
if( wnode.left >= 0 && depth+1 < maxdepth )
318
{
319
w_nidx = wnode.left;
320
pidx = nidx;
321
depth++;
322
}
323
else
324
{
325
int w_pidx = wnode.parent;
326
while( w_pidx >= 0 && w->wnodes[w_pidx].right == w_nidx )
327
{
328
w_nidx = w_pidx;
329
w_pidx = w->wnodes[w_pidx].parent;
330
nidx = pidx;
331
pidx = nodes[pidx].parent;
332
depth--;
333
}
334
335
if( w_pidx < 0 )
336
break;
337
338
w_nidx = w->wnodes[w_pidx].right;
339
CV_Assert( w_nidx >= 0 );
340
}
341
}
342
roots.push_back(root);
343
return root;
344
}
345
346
void DTreesImpl::setDParams(const TreeParams& _params)
347
{
348
params = _params;
349
}
350
351
int DTreesImpl::addNodeAndTrySplit( int parent, const vector<int>& sidx )
352
{
353
w->wnodes.push_back(WNode());
354
int nidx = (int)(w->wnodes.size() - 1);
355
WNode& node = w->wnodes.back();
356
357
node.parent = parent;
358
node.depth = parent >= 0 ? w->wnodes[parent].depth + 1 : 0;
359
int nfolds = params.getCVFolds();
360
361
if( nfolds > 0 )
362
{
363
w->cv_Tn.resize((nidx+1)*nfolds);
364
w->cv_node_error.resize((nidx+1)*nfolds);
365
w->cv_node_risk.resize((nidx+1)*nfolds);
366
}
367
368
int i, n = node.sample_count = (int)sidx.size();
369
bool can_split = true;
370
vector<int> sleft, sright;
371
372
calcValue( nidx, sidx );
373
374
if( n <= params.getMinSampleCount() || node.depth >= params.getMaxDepth() )
375
can_split = false;
376
else if( _isClassifier )
377
{
378
const int* responses = &w->cat_responses[0];
379
const int* s = &sidx[0];
380
int first = responses[s[0]];
381
for( i = 1; i < n; i++ )
382
if( responses[s[i]] != first )
383
break;
384
if( i == n )
385
can_split = false;
386
}
387
else
388
{
389
if( sqrt(node.node_risk) < params.getRegressionAccuracy() )
390
can_split = false;
391
}
392
393
if( can_split )
394
node.split = findBestSplit( sidx );
395
396
//printf("depth=%d, nidx=%d, parent=%d, n=%d, %s, value=%.1f, risk=%.1f\n", node.depth, nidx, node.parent, n, (node.split < 0 ? "leaf" : varType[w->wsplits[node.split].varIdx] == VAR_CATEGORICAL ? "cat" : "ord"), node.value, node.node_risk);
397
398
if( node.split >= 0 )
399
{
400
node.defaultDir = calcDir( node.split, sidx, sleft, sright );
401
if( params.useSurrogates )
402
CV_Error( CV_StsNotImplemented, "surrogate splits are not implemented yet");
403
404
int left = addNodeAndTrySplit( nidx, sleft );
405
int right = addNodeAndTrySplit( nidx, sright );
406
w->wnodes[nidx].left = left;
407
w->wnodes[nidx].right = right;
408
CV_Assert( w->wnodes[nidx].left > 0 && w->wnodes[nidx].right > 0 );
409
}
410
411
return nidx;
412
}
413
414
int DTreesImpl::findBestSplit( const vector<int>& _sidx )
415
{
416
const vector<int>& activeVars = getActiveVars();
417
int splitidx = -1;
418
int vi_, nv = (int)activeVars.size();
419
AutoBuffer<int> buf(w->maxSubsetSize*2);
420
int *subset = buf.data(), *best_subset = subset + w->maxSubsetSize;
421
WSplit split, best_split;
422
best_split.quality = 0.;
423
424
for( vi_ = 0; vi_ < nv; vi_++ )
425
{
426
int vi = activeVars[vi_];
427
if( varType[vi] == VAR_CATEGORICAL )
428
{
429
if( _isClassifier )
430
split = findSplitCatClass(vi, _sidx, 0, subset);
431
else
432
split = findSplitCatReg(vi, _sidx, 0, subset);
433
}
434
else
435
{
436
if( _isClassifier )
437
split = findSplitOrdClass(vi, _sidx, 0);
438
else
439
split = findSplitOrdReg(vi, _sidx, 0);
440
}
441
if( split.quality > best_split.quality )
442
{
443
best_split = split;
444
std::swap(subset, best_subset);
445
}
446
}
447
448
if( best_split.quality > 0 )
449
{
450
int best_vi = best_split.varIdx;
451
CV_Assert( compVarIdx[best_split.varIdx] >= 0 && best_vi >= 0 );
452
int i, prevsz = (int)w->wsubsets.size(), ssize = getSubsetSize(best_vi);
453
w->wsubsets.resize(prevsz + ssize);
454
for( i = 0; i < ssize; i++ )
455
w->wsubsets[prevsz + i] = best_subset[i];
456
best_split.subsetOfs = prevsz;
457
w->wsplits.push_back(best_split);
458
splitidx = (int)(w->wsplits.size()-1);
459
}
460
461
return splitidx;
462
}
463
464
void DTreesImpl::calcValue( int nidx, const vector<int>& _sidx )
465
{
466
WNode* node = &w->wnodes[nidx];
467
int i, j, k, n = (int)_sidx.size(), cv_n = params.getCVFolds();
468
int m = (int)classLabels.size();
469
470
cv::AutoBuffer<double> buf(std::max(m, 3)*(cv_n+1));
471
472
if( cv_n > 0 )
473
{
474
size_t sz = w->cv_Tn.size();
475
w->cv_Tn.resize(sz + cv_n);
476
w->cv_node_risk.resize(sz + cv_n);
477
w->cv_node_error.resize(sz + cv_n);
478
}
479
480
if( _isClassifier )
481
{
482
// in case of classification tree:
483
// * node value is the label of the class that has the largest weight in the node.
484
// * node risk is the weighted number of misclassified samples,
485
// * j-th cross-validation fold value and risk are calculated as above,
486
// but using the samples with cv_labels(*)!=j.
487
// * j-th cross-validation fold error is calculated as the weighted number of
488
// misclassified samples with cv_labels(*)==j.
489
490
// compute the number of instances of each class
491
double* cls_count = buf.data();
492
double* cv_cls_count = cls_count + m;
493
494
double max_val = -1, total_weight = 0;
495
int max_k = -1;
496
497
for( k = 0; k < m; k++ )
498
cls_count[k] = 0;
499
500
if( cv_n == 0 )
501
{
502
for( i = 0; i < n; i++ )
503
{
504
int si = _sidx[i];
505
cls_count[w->cat_responses[si]] += w->sample_weights[si];
506
}
507
}
508
else
509
{
510
for( j = 0; j < cv_n; j++ )
511
for( k = 0; k < m; k++ )
512
cv_cls_count[j*m + k] = 0;
513
514
for( i = 0; i < n; i++ )
515
{
516
int si = _sidx[i];
517
j = w->cv_labels[si]; k = w->cat_responses[si];
518
cv_cls_count[j*m + k] += w->sample_weights[si];
519
}
520
521
for( j = 0; j < cv_n; j++ )
522
for( k = 0; k < m; k++ )
523
cls_count[k] += cv_cls_count[j*m + k];
524
}
525
526
for( k = 0; k < m; k++ )
527
{
528
double val = cls_count[k];
529
total_weight += val;
530
if( max_val < val )
531
{
532
max_val = val;
533
max_k = k;
534
}
535
}
536
537
node->class_idx = max_k;
538
node->value = classLabels[max_k];
539
node->node_risk = total_weight - max_val;
540
541
for( j = 0; j < cv_n; j++ )
542
{
543
double sum_k = 0, sum = 0, max_val_k = 0;
544
max_val = -1; max_k = -1;
545
546
for( k = 0; k < m; k++ )
547
{
548
double val_k = cv_cls_count[j*m + k];
549
double val = cls_count[k] - val_k;
550
sum_k += val_k;
551
sum += val;
552
if( max_val < val )
553
{
554
max_val = val;
555
max_val_k = val_k;
556
max_k = k;
557
}
558
}
559
560
w->cv_Tn[nidx*cv_n + j] = INT_MAX;
561
w->cv_node_risk[nidx*cv_n + j] = sum - max_val;
562
w->cv_node_error[nidx*cv_n + j] = sum_k - max_val_k;
563
}
564
}
565
else
566
{
567
// in case of regression tree:
568
// * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
569
// n is the number of samples in the node.
570
// * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
571
// * j-th cross-validation fold value and risk are calculated as above,
572
// but using the samples with cv_labels(*)!=j.
573
// * j-th cross-validation fold error is calculated
574
// using samples with cv_labels(*)==j as the test subset:
575
// error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
576
// where node_value_j is the node value calculated
577
// as described in the previous bullet, and summation is done
578
// over the samples with cv_labels(*)==j.
579
double sum = 0, sum2 = 0, sumw = 0;
580
581
if( cv_n == 0 )
582
{
583
for( i = 0; i < n; i++ )
584
{
585
int si = _sidx[i];
586
double wval = w->sample_weights[si];
587
double t = w->ord_responses[si];
588
sum += t*wval;
589
sum2 += t*t*wval;
590
sumw += wval;
591
}
592
}
593
else
594
{
595
double *cv_sum = buf.data(), *cv_sum2 = cv_sum + cv_n;
596
double* cv_count = (double*)(cv_sum2 + cv_n);
597
598
for( j = 0; j < cv_n; j++ )
599
{
600
cv_sum[j] = cv_sum2[j] = 0.;
601
cv_count[j] = 0;
602
}
603
604
for( i = 0; i < n; i++ )
605
{
606
int si = _sidx[i];
607
j = w->cv_labels[si];
608
double wval = w->sample_weights[si];
609
double t = w->ord_responses[si];
610
cv_sum[j] += t*wval;
611
cv_sum2[j] += t*t*wval;
612
cv_count[j] += wval;
613
}
614
615
for( j = 0; j < cv_n; j++ )
616
{
617
sum += cv_sum[j];
618
sum2 += cv_sum2[j];
619
sumw += cv_count[j];
620
}
621
622
for( j = 0; j < cv_n; j++ )
623
{
624
double s = sum - cv_sum[j], si = sum - s;
625
double s2 = sum2 - cv_sum2[j], s2i = sum2 - s2;
626
double c = cv_count[j], ci = sumw - c;
627
double r = si/std::max(ci, DBL_EPSILON);
628
w->cv_node_risk[nidx*cv_n + j] = s2i - r*r*ci;
629
w->cv_node_error[nidx*cv_n + j] = s2 - 2*r*s + c*r*r;
630
w->cv_Tn[nidx*cv_n + j] = INT_MAX;
631
}
632
}
633
CV_Assert(fabs(sumw) > 0);
634
node->node_risk = sum2 - (sum/sumw)*sum;
635
node->node_risk /= sumw;
636
node->value = sum/sumw;
637
}
638
}
639
640
DTreesImpl::WSplit DTreesImpl::findSplitOrdClass( int vi, const vector<int>& _sidx, double initQuality )
641
{
642
int n = (int)_sidx.size();
643
int m = (int)classLabels.size();
644
645
cv::AutoBuffer<uchar> buf(n*(sizeof(float) + sizeof(int)) + m*2*sizeof(double));
646
const int* sidx = &_sidx[0];
647
const int* responses = &w->cat_responses[0];
648
const double* weights = &w->sample_weights[0];
649
double* lcw = (double*)buf.data();
650
double* rcw = lcw + m;
651
float* values = (float*)(rcw + m);
652
int* sorted_idx = (int*)(values + n);
653
int i, best_i = -1;
654
double best_val = initQuality;
655
656
for( i = 0; i < m; i++ )
657
lcw[i] = rcw[i] = 0.;
658
659
w->data->getValues( vi, _sidx, values );
660
661
for( i = 0; i < n; i++ )
662
{
663
sorted_idx[i] = i;
664
int si = sidx[i];
665
rcw[responses[si]] += weights[si];
666
}
667
668
std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values));
669
670
double L = 0, R = 0, lsum2 = 0, rsum2 = 0;
671
for( i = 0; i < m; i++ )
672
{
673
double wval = rcw[i];
674
R += wval;
675
rsum2 += wval*wval;
676
}
677
678
for( i = 0; i < n - 1; i++ )
679
{
680
int curr = sorted_idx[i];
681
int next = sorted_idx[i+1];
682
int si = sidx[curr];
683
double wval = weights[si], w2 = wval*wval;
684
L += wval; R -= wval;
685
int idx = responses[si];
686
double lv = lcw[idx], rv = rcw[idx];
687
lsum2 += 2*lv*wval + w2;
688
rsum2 -= 2*rv*wval - w2;
689
lcw[idx] = lv + wval; rcw[idx] = rv - wval;
690
691
float value_between = (values[next] + values[curr]) * 0.5f;
692
if( value_between > values[curr] && value_between < values[next] )
693
{
694
double val = (lsum2*R + rsum2*L)/(L*R);
695
if( best_val < val )
696
{
697
best_val = val;
698
best_i = i;
699
}
700
}
701
}
702
703
WSplit split;
704
if( best_i >= 0 )
705
{
706
split.varIdx = vi;
707
split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f;
708
split.inversed = false;
709
split.quality = (float)best_val;
710
}
711
return split;
712
}
713
714
// simple k-means, slightly modified to take into account the "weight" (L1-norm) of each vector.
715
void DTreesImpl::clusterCategories( const double* vectors, int n, int m, double* csums, int k, int* labels )
716
{
717
int iters = 0, max_iters = 100;
718
int i, j, idx;
719
cv::AutoBuffer<double> buf(n + k);
720
double *v_weights = buf.data(), *c_weights = buf.data() + n;
721
bool modified = true;
722
RNG r((uint64)-1);
723
724
// assign labels randomly
725
for( i = 0; i < n; i++ )
726
{
727
double sum = 0;
728
const double* v = vectors + i*m;
729
labels[i] = i < k ? i : r.uniform(0, k);
730
731
// compute weight of each vector
732
for( j = 0; j < m; j++ )
733
sum += v[j];
734
v_weights[i] = sum ? 1./sum : 0.;
735
}
736
737
for( i = 0; i < n; i++ )
738
{
739
int i1 = r.uniform(0, n);
740
int i2 = r.uniform(0, n);
741
std::swap( labels[i1], labels[i2] );
742
}
743
744
for( iters = 0; iters <= max_iters; iters++ )
745
{
746
// calculate csums
747
for( i = 0; i < k; i++ )
748
{
749
for( j = 0; j < m; j++ )
750
csums[i*m + j] = 0;
751
}
752
753
for( i = 0; i < n; i++ )
754
{
755
const double* v = vectors + i*m;
756
double* s = csums + labels[i]*m;
757
for( j = 0; j < m; j++ )
758
s[j] += v[j];
759
}
760
761
// exit the loop here, when we have up-to-date csums
762
if( iters == max_iters || !modified )
763
break;
764
765
modified = false;
766
767
// calculate weight of each cluster
768
for( i = 0; i < k; i++ )
769
{
770
const double* s = csums + i*m;
771
double sum = 0;
772
for( j = 0; j < m; j++ )
773
sum += s[j];
774
c_weights[i] = sum ? 1./sum : 0;
775
}
776
777
// now for each vector determine the closest cluster
778
for( i = 0; i < n; i++ )
779
{
780
const double* v = vectors + i*m;
781
double alpha = v_weights[i];
782
double min_dist2 = DBL_MAX;
783
int min_idx = -1;
784
785
for( idx = 0; idx < k; idx++ )
786
{
787
const double* s = csums + idx*m;
788
double dist2 = 0., beta = c_weights[idx];
789
for( j = 0; j < m; j++ )
790
{
791
double t = v[j]*alpha - s[j]*beta;
792
dist2 += t*t;
793
}
794
if( min_dist2 > dist2 )
795
{
796
min_dist2 = dist2;
797
min_idx = idx;
798
}
799
}
800
801
if( min_idx != labels[i] )
802
modified = true;
803
labels[i] = min_idx;
804
}
805
}
806
}
807
808
DTreesImpl::WSplit DTreesImpl::findSplitCatClass( int vi, const vector<int>& _sidx,
809
double initQuality, int* subset )
810
{
811
int _mi = getCatCount(vi), mi = _mi;
812
int n = (int)_sidx.size();
813
int m = (int)classLabels.size();
814
815
int base_size = m*(3 + mi) + mi + 1;
816
if( m > 2 && mi > params.getMaxCategories() )
817
base_size += m*std::min(params.getMaxCategories(), n) + mi;
818
else
819
base_size += mi;
820
AutoBuffer<double> buf(base_size + n);
821
822
double* lc = buf.data();
823
double* rc = lc + m;
824
double* _cjk = rc + m*2, *cjk = _cjk;
825
double* c_weights = cjk + m*mi;
826
827
int* labels = (int*)(buf.data() + base_size);
828
w->data->getNormCatValues(vi, _sidx, labels);
829
const int* responses = &w->cat_responses[0];
830
const double* weights = &w->sample_weights[0];
831
832
int* cluster_labels = 0;
833
double** dbl_ptr = 0;
834
int i, j, k, si, idx;
835
double L = 0, R = 0;
836
double best_val = initQuality;
837
int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
838
839
// init array of counters:
840
// c_{jk} - number of samples that have vi-th input variable = j and response = k.
841
for( j = -1; j < mi; j++ )
842
for( k = 0; k < m; k++ )
843
cjk[j*m + k] = 0;
844
845
for( i = 0; i < n; i++ )
846
{
847
si = _sidx[i];
848
j = labels[i];
849
k = responses[si];
850
cjk[j*m + k] += weights[si];
851
}
852
853
if( m > 2 )
854
{
855
if( mi > params.getMaxCategories() )
856
{
857
mi = std::min(params.getMaxCategories(), n);
858
cjk = c_weights + _mi;
859
cluster_labels = (int*)(cjk + m*mi);
860
clusterCategories( _cjk, _mi, m, cjk, mi, cluster_labels );
861
}
862
subset_i = 1;
863
subset_n = 1 << mi;
864
}
865
else
866
{
867
assert( m == 2 );
868
dbl_ptr = (double**)(c_weights + _mi);
869
for( j = 0; j < mi; j++ )
870
dbl_ptr[j] = cjk + j*2 + 1;
871
std::sort(dbl_ptr, dbl_ptr + mi, cmp_lt_ptr<double>());
872
subset_i = 0;
873
subset_n = mi;
874
}
875
876
for( k = 0; k < m; k++ )
877
{
878
double sum = 0;
879
for( j = 0; j < mi; j++ )
880
sum += cjk[j*m + k];
881
CV_Assert(sum > 0);
882
rc[k] = sum;
883
lc[k] = 0;
884
}
885
886
for( j = 0; j < mi; j++ )
887
{
888
double sum = 0;
889
for( k = 0; k < m; k++ )
890
sum += cjk[j*m + k];
891
c_weights[j] = sum;
892
R += c_weights[j];
893
}
894
895
for( ; subset_i < subset_n; subset_i++ )
896
{
897
double lsum2 = 0, rsum2 = 0;
898
899
if( m == 2 )
900
idx = (int)(dbl_ptr[subset_i] - cjk)/2;
901
else
902
{
903
int graycode = (subset_i>>1)^subset_i;
904
int diff = graycode ^ prevcode;
905
906
// determine index of the changed bit.
907
Cv32suf u;
908
idx = diff >= (1 << 16) ? 16 : 0;
909
u.f = (float)(((diff >> 16) | diff) & 65535);
910
idx += (u.i >> 23) - 127;
911
subtract = graycode < prevcode;
912
prevcode = graycode;
913
}
914
915
double* crow = cjk + idx*m;
916
double weight = c_weights[idx];
917
if( weight < FLT_EPSILON )
918
continue;
919
920
if( !subtract )
921
{
922
for( k = 0; k < m; k++ )
923
{
924
double t = crow[k];
925
double lval = lc[k] + t;
926
double rval = rc[k] - t;
927
lsum2 += lval*lval;
928
rsum2 += rval*rval;
929
lc[k] = lval; rc[k] = rval;
930
}
931
L += weight;
932
R -= weight;
933
}
934
else
935
{
936
for( k = 0; k < m; k++ )
937
{
938
double t = crow[k];
939
double lval = lc[k] - t;
940
double rval = rc[k] + t;
941
lsum2 += lval*lval;
942
rsum2 += rval*rval;
943
lc[k] = lval; rc[k] = rval;
944
}
945
L -= weight;
946
R += weight;
947
}
948
949
if( L > FLT_EPSILON && R > FLT_EPSILON )
950
{
951
double val = (lsum2*R + rsum2*L)/(L*R);
952
if( best_val < val )
953
{
954
best_val = val;
955
best_subset = subset_i;
956
}
957
}
958
}
959
960
WSplit split;
961
if( best_subset >= 0 )
962
{
963
split.varIdx = vi;
964
split.quality = (float)best_val;
965
memset( subset, 0, getSubsetSize(vi) * sizeof(int) );
966
if( m == 2 )
967
{
968
for( i = 0; i <= best_subset; i++ )
969
{
970
idx = (int)(dbl_ptr[i] - cjk) >> 1;
971
subset[idx >> 5] |= 1 << (idx & 31);
972
}
973
}
974
else
975
{
976
for( i = 0; i < _mi; i++ )
977
{
978
idx = cluster_labels ? cluster_labels[i] : i;
979
if( best_subset & (1 << idx) )
980
subset[i >> 5] |= 1 << (i & 31);
981
}
982
}
983
}
984
return split;
985
}
986
987
DTreesImpl::WSplit DTreesImpl::findSplitOrdReg( int vi, const vector<int>& _sidx, double initQuality )
988
{
989
const double* weights = &w->sample_weights[0];
990
int n = (int)_sidx.size();
991
992
AutoBuffer<uchar> buf(n*(sizeof(int) + sizeof(float)));
993
994
float* values = (float*)buf.data();
995
int* sorted_idx = (int*)(values + n);
996
w->data->getValues(vi, _sidx, values);
997
const double* responses = &w->ord_responses[0];
998
999
int i, si, best_i = -1;
1000
double L = 0, R = 0;
1001
double best_val = initQuality, lsum = 0, rsum = 0;
1002
1003
for( i = 0; i < n; i++ )
1004
{
1005
sorted_idx[i] = i;
1006
si = _sidx[i];
1007
R += weights[si];
1008
rsum += weights[si]*responses[si];
1009
}
1010
1011
std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values));
1012
1013
// find the optimal split
1014
for( i = 0; i < n - 1; i++ )
1015
{
1016
int curr = sorted_idx[i];
1017
int next = sorted_idx[i+1];
1018
si = _sidx[curr];
1019
double wval = weights[si];
1020
double t = responses[si]*wval;
1021
L += wval; R -= wval;
1022
lsum += t; rsum -= t;
1023
1024
float value_between = (values[next] + values[curr]) * 0.5f;
1025
if( value_between > values[curr] && value_between < values[next] )
1026
{
1027
double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
1028
if( best_val < val )
1029
{
1030
best_val = val;
1031
best_i = i;
1032
}
1033
}
1034
}
1035
1036
WSplit split;
1037
if( best_i >= 0 )
1038
{
1039
split.varIdx = vi;
1040
split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f;
1041
split.inversed = false;
1042
split.quality = (float)best_val;
1043
}
1044
return split;
1045
}
1046
1047
DTreesImpl::WSplit DTreesImpl::findSplitCatReg( int vi, const vector<int>& _sidx,
1048
double initQuality, int* subset )
1049
{
1050
const double* weights = &w->sample_weights[0];
1051
const double* responses = &w->ord_responses[0];
1052
int n = (int)_sidx.size();
1053
int mi = getCatCount(vi);
1054
1055
AutoBuffer<double> buf(3*mi + 3 + n);
1056
double* sum = buf.data() + 1;
1057
double* counts = sum + mi + 1;
1058
double** sum_ptr = (double**)(counts + mi);
1059
int* cat_labels = (int*)(sum_ptr + mi);
1060
1061
w->data->getNormCatValues(vi, _sidx, cat_labels);
1062
1063
double L = 0, R = 0, best_val = initQuality, lsum = 0, rsum = 0;
1064
int i, si, best_subset = -1, subset_i;
1065
1066
for( i = -1; i < mi; i++ )
1067
sum[i] = counts[i] = 0;
1068
1069
// calculate sum response and weight of each category of the input var
1070
for( i = 0; i < n; i++ )
1071
{
1072
int idx = cat_labels[i];
1073
si = _sidx[i];
1074
double wval = weights[si];
1075
sum[idx] += responses[si]*wval;
1076
counts[idx] += wval;
1077
}
1078
1079
// calculate average response in each category
1080
for( i = 0; i < mi; i++ )
1081
{
1082
R += counts[i];
1083
rsum += sum[i];
1084
sum[i] = fabs(counts[i]) > DBL_EPSILON ? sum[i]/counts[i] : 0;
1085
sum_ptr[i] = sum + i;
1086
}
1087
1088
std::sort(sum_ptr, sum_ptr + mi, cmp_lt_ptr<double>());
1089
1090
// revert back to unnormalized sums
1091
// (there should be a very little loss in accuracy)
1092
for( i = 0; i < mi; i++ )
1093
sum[i] *= counts[i];
1094
1095
for( subset_i = 0; subset_i < mi-1; subset_i++ )
1096
{
1097
int idx = (int)(sum_ptr[subset_i] - sum);
1098
double ni = counts[idx];
1099
1100
if( ni > FLT_EPSILON )
1101
{
1102
double s = sum[idx];
1103
lsum += s; L += ni;
1104
rsum -= s; R -= ni;
1105
1106
if( L > FLT_EPSILON && R > FLT_EPSILON )
1107
{
1108
double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
1109
if( best_val < val )
1110
{
1111
best_val = val;
1112
best_subset = subset_i;
1113
}
1114
}
1115
}
1116
}
1117
1118
WSplit split;
1119
if( best_subset >= 0 )
1120
{
1121
split.varIdx = vi;
1122
split.quality = (float)best_val;
1123
memset( subset, 0, getSubsetSize(vi) * sizeof(int));
1124
for( i = 0; i <= best_subset; i++ )
1125
{
1126
int idx = (int)(sum_ptr[i] - sum);
1127
subset[idx >> 5] |= 1 << (idx & 31);
1128
}
1129
}
1130
return split;
1131
}
1132
1133
int DTreesImpl::calcDir( int splitidx, const vector<int>& _sidx,
1134
vector<int>& _sleft, vector<int>& _sright )
1135
{
1136
WSplit split = w->wsplits[splitidx];
1137
int i, si, n = (int)_sidx.size(), vi = split.varIdx;
1138
_sleft.reserve(n);
1139
_sright.reserve(n);
1140
_sleft.clear();
1141
_sright.clear();
1142
1143
AutoBuffer<float> buf(n);
1144
int mi = getCatCount(vi);
1145
double wleft = 0, wright = 0;
1146
const double* weights = &w->sample_weights[0];
1147
1148
if( mi <= 0 ) // split on an ordered variable
1149
{
1150
float c = split.c;
1151
float* values = buf.data();
1152
w->data->getValues(vi, _sidx, values);
1153
1154
for( i = 0; i < n; i++ )
1155
{
1156
si = _sidx[i];
1157
if( values[i] <= c )
1158
{
1159
_sleft.push_back(si);
1160
wleft += weights[si];
1161
}
1162
else
1163
{
1164
_sright.push_back(si);
1165
wright += weights[si];
1166
}
1167
}
1168
}
1169
else
1170
{
1171
const int* subset = &w->wsubsets[split.subsetOfs];
1172
int* cat_labels = (int*)buf.data();
1173
w->data->getNormCatValues(vi, _sidx, cat_labels);
1174
1175
for( i = 0; i < n; i++ )
1176
{
1177
si = _sidx[i];
1178
unsigned u = cat_labels[i];
1179
if( CV_DTREE_CAT_DIR(u, subset) < 0 )
1180
{
1181
_sleft.push_back(si);
1182
wleft += weights[si];
1183
}
1184
else
1185
{
1186
_sright.push_back(si);
1187
wright += weights[si];
1188
}
1189
}
1190
}
1191
CV_Assert( (int)_sleft.size() < n && (int)_sright.size() < n );
1192
return wleft > wright ? -1 : 1;
1193
}
1194
1195
int DTreesImpl::pruneCV( int root )
1196
{
1197
vector<double> ab;
1198
1199
// 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
1200
// 2. choose the best tree index (if need, apply 1SE rule).
1201
// 3. store the best index and cut the branches.
1202
1203
int ti, tree_count = 0, j, cv_n = params.getCVFolds(), n = w->wnodes[root].sample_count;
1204
// currently, 1SE for regression is not implemented
1205
bool use_1se = params.use1SERule != 0 && _isClassifier;
1206
double min_err = 0, min_err_se = 0;
1207
int min_idx = -1;
1208
1209
// build the main tree sequence, calculate alpha's
1210
for(;;tree_count++)
1211
{
1212
double min_alpha = updateTreeRNC(root, tree_count, -1);
1213
if( cutTree(root, tree_count, -1, min_alpha) )
1214
break;
1215
1216
ab.push_back(min_alpha);
1217
}
1218
1219
if( tree_count > 0 )
1220
{
1221
ab[0] = 0.;
1222
1223
for( ti = 1; ti < tree_count-1; ti++ )
1224
ab[ti] = std::sqrt(ab[ti]*ab[ti+1]);
1225
ab[tree_count-1] = DBL_MAX*0.5;
1226
1227
Mat err_jk(cv_n, tree_count, CV_64F);
1228
1229
for( j = 0; j < cv_n; j++ )
1230
{
1231
int tj = 0, tk = 0;
1232
for( ; tj < tree_count; tj++ )
1233
{
1234
double min_alpha = updateTreeRNC(root, tj, j);
1235
if( cutTree(root, tj, j, min_alpha) )
1236
min_alpha = DBL_MAX;
1237
1238
for( ; tk < tree_count; tk++ )
1239
{
1240
if( ab[tk] > min_alpha )
1241
break;
1242
err_jk.at<double>(j, tk) = w->wnodes[root].tree_error;
1243
}
1244
}
1245
}
1246
1247
for( ti = 0; ti < tree_count; ti++ )
1248
{
1249
double sum_err = 0;
1250
for( j = 0; j < cv_n; j++ )
1251
sum_err += err_jk.at<double>(j, ti);
1252
if( ti == 0 || sum_err < min_err )
1253
{
1254
min_err = sum_err;
1255
min_idx = ti;
1256
if( use_1se )
1257
min_err_se = sqrt( sum_err*(n - sum_err) );
1258
}
1259
else if( sum_err < min_err + min_err_se )
1260
min_idx = ti;
1261
}
1262
}
1263
1264
return min_idx;
1265
}
1266
1267
double DTreesImpl::updateTreeRNC( int root, double T, int fold )
1268
{
1269
int nidx = root, pidx = -1, cv_n = params.getCVFolds();
1270
double min_alpha = DBL_MAX;
1271
1272
for(;;)
1273
{
1274
WNode *node = 0, *parent = 0;
1275
1276
for(;;)
1277
{
1278
node = &w->wnodes[nidx];
1279
double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn;
1280
if( t <= T || node->left < 0 )
1281
{
1282
node->complexity = 1;
1283
node->tree_risk = node->node_risk;
1284
node->tree_error = 0.;
1285
if( fold >= 0 )
1286
{
1287
node->tree_risk = w->cv_node_risk[nidx*cv_n + fold];
1288
node->tree_error = w->cv_node_error[nidx*cv_n + fold];
1289
}
1290
break;
1291
}
1292
nidx = node->left;
1293
}
1294
1295
for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx;
1296
nidx = pidx, pidx = w->wnodes[pidx].parent )
1297
{
1298
node = &w->wnodes[nidx];
1299
parent = &w->wnodes[pidx];
1300
parent->complexity += node->complexity;
1301
parent->tree_risk += node->tree_risk;
1302
parent->tree_error += node->tree_error;
1303
1304
parent->alpha = ((fold >= 0 ? w->cv_node_risk[pidx*cv_n + fold] : parent->node_risk)
1305
- parent->tree_risk)/(parent->complexity - 1);
1306
min_alpha = std::min( min_alpha, parent->alpha );
1307
}
1308
1309
if( pidx < 0 )
1310
break;
1311
1312
node = &w->wnodes[nidx];
1313
parent = &w->wnodes[pidx];
1314
parent->complexity = node->complexity;
1315
parent->tree_risk = node->tree_risk;
1316
parent->tree_error = node->tree_error;
1317
nidx = parent->right;
1318
}
1319
1320
return min_alpha;
1321
}
1322
1323
bool DTreesImpl::cutTree( int root, double T, int fold, double min_alpha )
1324
{
1325
int cv_n = params.getCVFolds(), nidx = root, pidx = -1;
1326
WNode* node = &w->wnodes[root];
1327
if( node->left < 0 )
1328
return true;
1329
1330
for(;;)
1331
{
1332
for(;;)
1333
{
1334
node = &w->wnodes[nidx];
1335
double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn;
1336
if( t <= T || node->left < 0 )
1337
break;
1338
if( node->alpha <= min_alpha + FLT_EPSILON )
1339
{
1340
if( fold >= 0 )
1341
w->cv_Tn[nidx*cv_n + fold] = T;
1342
else
1343
node->Tn = T;
1344
if( nidx == root )
1345
return true;
1346
break;
1347
}
1348
nidx = node->left;
1349
}
1350
1351
for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx;
1352
nidx = pidx, pidx = w->wnodes[pidx].parent )
1353
;
1354
1355
if( pidx < 0 )
1356
break;
1357
1358
nidx = w->wnodes[pidx].right;
1359
}
1360
1361
return false;
1362
}
1363
1364
float DTreesImpl::predictTrees( const Range& range, const Mat& sample, int flags ) const
1365
{
1366
CV_Assert( sample.type() == CV_32F );
1367
1368
int predictType = flags & PREDICT_MASK;
1369
int nvars = (int)varIdx.size();
1370
if( nvars == 0 )
1371
nvars = (int)varType.size();
1372
int i, ncats = (int)catOfs.size(), nclasses = (int)classLabels.size();
1373
int catbufsize = ncats > 0 ? nvars : 0;
1374
AutoBuffer<int> buf(nclasses + catbufsize + 1);
1375
int* votes = buf.data();
1376
int* catbuf = votes + nclasses;
1377
const int* cvidx = (flags & (COMPRESSED_INPUT|PREPROCESSED_INPUT)) == 0 && !varIdx.empty() ? &compVarIdx[0] : 0;
1378
const uchar* vtype = &varType[0];
1379
const Vec2i* cofs = !catOfs.empty() ? &catOfs[0] : 0;
1380
const int* cmap = !catMap.empty() ? &catMap[0] : 0;
1381
const float* psample = sample.ptr<float>();
1382
const float* missingSubstPtr = !missingSubst.empty() ? &missingSubst[0] : 0;
1383
size_t sstep = sample.isContinuous() ? 1 : sample.step/sizeof(float);
1384
double sum = 0.;
1385
int lastClassIdx = -1;
1386
const float MISSED_VAL = TrainData::missingValue();
1387
1388
for( i = 0; i < catbufsize; i++ )
1389
catbuf[i] = -1;
1390
1391
if( predictType == PREDICT_AUTO )
1392
{
1393
predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
1394
PREDICT_SUM : PREDICT_MAX_VOTE;
1395
}
1396
1397
if( predictType == PREDICT_MAX_VOTE )
1398
{
1399
for( i = 0; i < nclasses; i++ )
1400
votes[i] = 0;
1401
}
1402
1403
for( int ridx = range.start; ridx < range.end; ridx++ )
1404
{
1405
int nidx = roots[ridx], prev = nidx, c = 0;
1406
1407
for(;;)
1408
{
1409
prev = nidx;
1410
const Node& node = nodes[nidx];
1411
if( node.split < 0 )
1412
break;
1413
const Split& split = splits[node.split];
1414
int vi = split.varIdx;
1415
int ci = cvidx ? cvidx[vi] : vi;
1416
float val = psample[ci*sstep];
1417
if( val == MISSED_VAL )
1418
{
1419
if( !missingSubstPtr )
1420
{
1421
nidx = node.defaultDir < 0 ? node.left : node.right;
1422
continue;
1423
}
1424
val = missingSubstPtr[vi];
1425
}
1426
1427
if( vtype[vi] == VAR_ORDERED )
1428
nidx = val <= split.c ? node.left : node.right;
1429
else
1430
{
1431
if( flags & PREPROCESSED_INPUT )
1432
c = cvRound(val);
1433
else
1434
{
1435
c = catbuf[ci];
1436
if( c < 0 )
1437
{
1438
int a = c = cofs[vi][0];
1439
int b = cofs[vi][1];
1440
1441
int ival = cvRound(val);
1442
if( ival != val )
1443
CV_Error( CV_StsBadArg,
1444
"one of input categorical variable is not an integer" );
1445
1446
CV_Assert(cmap != NULL);
1447
while( a < b )
1448
{
1449
c = (a + b) >> 1;
1450
if( ival < cmap[c] )
1451
b = c;
1452
else if( ival > cmap[c] )
1453
a = c+1;
1454
else
1455
break;
1456
}
1457
1458
CV_Assert( c >= 0 && ival == cmap[c] );
1459
1460
c -= cofs[vi][0];
1461
catbuf[ci] = c;
1462
}
1463
const int* subset = &subsets[split.subsetOfs];
1464
unsigned u = c;
1465
nidx = CV_DTREE_CAT_DIR(u, subset) < 0 ? node.left : node.right;
1466
}
1467
}
1468
}
1469
1470
if( predictType == PREDICT_SUM )
1471
sum += nodes[prev].value;
1472
else
1473
{
1474
lastClassIdx = nodes[prev].classIdx;
1475
votes[lastClassIdx]++;
1476
}
1477
}
1478
1479
if( predictType == PREDICT_MAX_VOTE )
1480
{
1481
int best_idx = lastClassIdx;
1482
if( range.end - range.start > 1 )
1483
{
1484
best_idx = 0;
1485
for( i = 1; i < nclasses; i++ )
1486
if( votes[best_idx] < votes[i] )
1487
best_idx = i;
1488
}
1489
sum = (flags & RAW_OUTPUT) ? (float)best_idx : classLabels[best_idx];
1490
}
1491
1492
return (float)sum;
1493
}
1494
1495
1496
float DTreesImpl::predict( InputArray _samples, OutputArray _results, int flags ) const
1497
{
1498
CV_Assert( !roots.empty() );
1499
Mat samples = _samples.getMat(), results;
1500
int i, nsamples = samples.rows;
1501
int rtype = CV_32F;
1502
bool needresults = _results.needed();
1503
float retval = 0.f;
1504
bool iscls = isClassifier();
1505
float scale = !iscls ? 1.f/(int)roots.size() : 1.f;
1506
1507
if( iscls && (flags & PREDICT_MASK) == PREDICT_MAX_VOTE )
1508
rtype = CV_32S;
1509
1510
if( needresults )
1511
{
1512
_results.create(nsamples, 1, rtype);
1513
results = _results.getMat();
1514
}
1515
else
1516
nsamples = std::min(nsamples, 1);
1517
1518
for( i = 0; i < nsamples; i++ )
1519
{
1520
float val = predictTrees( Range(0, (int)roots.size()), samples.row(i), flags )*scale;
1521
if( needresults )
1522
{
1523
if( rtype == CV_32F )
1524
results.at<float>(i) = val;
1525
else
1526
results.at<int>(i) = cvRound(val);
1527
}
1528
if( i == 0 )
1529
retval = val;
1530
}
1531
return retval;
1532
}
1533
1534
void DTreesImpl::writeTrainingParams(FileStorage& fs) const
1535
{
1536
fs << "use_surrogates" << (params.useSurrogates ? 1 : 0);
1537
fs << "max_categories" << params.getMaxCategories();
1538
fs << "regression_accuracy" << params.getRegressionAccuracy();
1539
1540
fs << "max_depth" << params.getMaxDepth();
1541
fs << "min_sample_count" << params.getMinSampleCount();
1542
fs << "cross_validation_folds" << params.getCVFolds();
1543
1544
if( params.getCVFolds() > 1 )
1545
fs << "use_1se_rule" << (params.use1SERule ? 1 : 0);
1546
1547
if( !params.priors.empty() )
1548
fs << "priors" << params.priors;
1549
}
1550
1551
void DTreesImpl::writeParams(FileStorage& fs) const
1552
{
1553
fs << "is_classifier" << isClassifier();
1554
fs << "var_all" << (int)varType.size();
1555
fs << "var_count" << getVarCount();
1556
1557
int ord_var_count = 0, cat_var_count = 0;
1558
int i, n = (int)varType.size();
1559
for( i = 0; i < n; i++ )
1560
if( varType[i] == VAR_ORDERED )
1561
ord_var_count++;
1562
else
1563
cat_var_count++;
1564
fs << "ord_var_count" << ord_var_count;
1565
fs << "cat_var_count" << cat_var_count;
1566
1567
fs << "training_params" << "{";
1568
writeTrainingParams(fs);
1569
1570
fs << "}";
1571
1572
if( !varIdx.empty() )
1573
{
1574
fs << "global_var_idx" << 1;
1575
fs << "var_idx" << varIdx;
1576
}
1577
1578
fs << "var_type" << varType;
1579
1580
if( !catOfs.empty() )
1581
fs << "cat_ofs" << catOfs;
1582
if( !catMap.empty() )
1583
fs << "cat_map" << catMap;
1584
if( !classLabels.empty() )
1585
fs << "class_labels" << classLabels;
1586
if( !missingSubst.empty() )
1587
fs << "missing_subst" << missingSubst;
1588
}
1589
1590
void DTreesImpl::writeSplit( FileStorage& fs, int splitidx ) const
1591
{
1592
const Split& split = splits[splitidx];
1593
1594
fs << "{:";
1595
1596
int vi = split.varIdx;
1597
fs << "var" << vi;
1598
fs << "quality" << split.quality;
1599
1600
if( varType[vi] == VAR_CATEGORICAL ) // split on a categorical var
1601
{
1602
int i, n = getCatCount(vi), to_right = 0;
1603
const int* subset = &subsets[split.subsetOfs];
1604
for( i = 0; i < n; i++ )
1605
to_right += CV_DTREE_CAT_DIR(i, subset) > 0;
1606
1607
// ad-hoc rule when to use inverse categorical split notation
1608
// to achieve more compact and clear representation
1609
int default_dir = to_right <= 1 || to_right <= std::min(3, n/2) || to_right <= n/3 ? -1 : 1;
1610
1611
fs << (default_dir*(split.inversed ? -1 : 1) > 0 ? "in" : "not_in") << "[:";
1612
1613
for( i = 0; i < n; i++ )
1614
{
1615
int dir = CV_DTREE_CAT_DIR(i, subset);
1616
if( dir*default_dir < 0 )
1617
fs << i;
1618
}
1619
1620
fs << "]";
1621
}
1622
else
1623
fs << (!split.inversed ? "le" : "gt") << split.c;
1624
1625
fs << "}";
1626
}
1627
1628
void DTreesImpl::writeNode( FileStorage& fs, int nidx, int depth ) const
1629
{
1630
const Node& node = nodes[nidx];
1631
fs << "{";
1632
fs << "depth" << depth;
1633
fs << "value" << node.value;
1634
1635
if( _isClassifier )
1636
fs << "norm_class_idx" << node.classIdx;
1637
1638
if( node.split >= 0 )
1639
{
1640
fs << "splits" << "[";
1641
1642
for( int splitidx = node.split; splitidx >= 0; splitidx = splits[splitidx].next )
1643
writeSplit( fs, splitidx );
1644
1645
fs << "]";
1646
}
1647
1648
fs << "}";
1649
}
1650
1651
void DTreesImpl::writeTree( FileStorage& fs, int root ) const
1652
{
1653
fs << "nodes" << "[";
1654
1655
int nidx = root, pidx = 0, depth = 0;
1656
const Node *node = 0;
1657
1658
// traverse the tree and save all the nodes in depth-first order
1659
for(;;)
1660
{
1661
for(;;)
1662
{
1663
writeNode( fs, nidx, depth );
1664
node = &nodes[nidx];
1665
if( node->left < 0 )
1666
break;
1667
nidx = node->left;
1668
depth++;
1669
}
1670
1671
for( pidx = node->parent; pidx >= 0 && nodes[pidx].right == nidx;
1672
nidx = pidx, pidx = nodes[pidx].parent )
1673
depth--;
1674
1675
if( pidx < 0 )
1676
break;
1677
1678
nidx = nodes[pidx].right;
1679
}
1680
1681
fs << "]";
1682
}
1683
1684
void DTreesImpl::write( FileStorage& fs ) const
1685
{
1686
writeFormat(fs);
1687
writeParams(fs);
1688
writeTree(fs, roots[0]);
1689
}
1690
1691
void DTreesImpl::readParams( const FileNode& fn )
1692
{
1693
_isClassifier = (int)fn["is_classifier"] != 0;
1694
/*int var_all = (int)fn["var_all"];
1695
int var_count = (int)fn["var_count"];
1696
int cat_var_count = (int)fn["cat_var_count"];
1697
int ord_var_count = (int)fn["ord_var_count"];*/
1698
1699
FileNode tparams_node = fn["training_params"];
1700
1701
TreeParams params0 = TreeParams();
1702
1703
if( !tparams_node.empty() ) // training parameters are not necessary
1704
{
1705
params0.useSurrogates = (int)tparams_node["use_surrogates"] != 0;
1706
params0.setMaxCategories((int)(tparams_node["max_categories"].empty() ? 16 : tparams_node["max_categories"]));
1707
params0.setRegressionAccuracy((float)tparams_node["regression_accuracy"]);
1708
params0.setMaxDepth((int)tparams_node["max_depth"]);
1709
params0.setMinSampleCount((int)tparams_node["min_sample_count"]);
1710
params0.setCVFolds((int)tparams_node["cross_validation_folds"]);
1711
1712
if( params0.getCVFolds() > 1 )
1713
{
1714
params.use1SERule = (int)tparams_node["use_1se_rule"] != 0;
1715
}
1716
1717
tparams_node["priors"] >> params0.priors;
1718
}
1719
1720
readVectorOrMat(fn["var_idx"], varIdx);
1721
fn["var_type"] >> varType;
1722
1723
int format = 0;
1724
fn["format"] >> format;
1725
bool isLegacy = format < 3;
1726
1727
int varAll = (int)fn["var_all"];
1728
if (isLegacy && (int)varType.size() <= varAll)
1729
{
1730
std::vector<uchar> extendedTypes(varAll + 1, 0);
1731
1732
int i = 0, n;
1733
if (!varIdx.empty())
1734
{
1735
n = (int)varIdx.size();
1736
for (; i < n; ++i)
1737
{
1738
int var = varIdx[i];
1739
extendedTypes[var] = varType[i];
1740
}
1741
}
1742
else
1743
{
1744
n = (int)varType.size();
1745
for (; i < n; ++i)
1746
{
1747
extendedTypes[i] = varType[i];
1748
}
1749
}
1750
extendedTypes[varAll] = (uchar)(_isClassifier ? VAR_CATEGORICAL : VAR_ORDERED);
1751
extendedTypes.swap(varType);
1752
}
1753
1754
readVectorOrMat(fn["cat_map"], catMap);
1755
1756
if (isLegacy)
1757
{
1758
// generating "catOfs" from "cat_count"
1759
catOfs.clear();
1760
classLabels.clear();
1761
std::vector<int> counts;
1762
readVectorOrMat(fn["cat_count"], counts);
1763
unsigned int i = 0, j = 0, curShift = 0, size = (int)varType.size() - 1;
1764
for (; i < size; ++i)
1765
{
1766
Vec2i newOffsets(0, 0);
1767
if (varType[i] == VAR_CATEGORICAL) // only categorical vars are represented in catMap
1768
{
1769
newOffsets[0] = curShift;
1770
curShift += counts[j];
1771
newOffsets[1] = curShift;
1772
++j;
1773
}
1774
catOfs.push_back(newOffsets);
1775
}
1776
// other elements in "catMap" are "classLabels"
1777
if (curShift < catMap.size())
1778
{
1779
classLabels.insert(classLabels.end(), catMap.begin() + curShift, catMap.end());
1780
catMap.erase(catMap.begin() + curShift, catMap.end());
1781
}
1782
}
1783
else
1784
{
1785
fn["cat_ofs"] >> catOfs;
1786
fn["missing_subst"] >> missingSubst;
1787
fn["class_labels"] >> classLabels;
1788
}
1789
1790
// init var mapping for node reading (var indexes or varIdx indexes)
1791
bool globalVarIdx = false;
1792
fn["global_var_idx"] >> globalVarIdx;
1793
if (globalVarIdx || varIdx.empty())
1794
setRangeVector(varMapping, (int)varType.size());
1795
else
1796
varMapping = varIdx;
1797
1798
initCompVarIdx();
1799
setDParams(params0);
1800
}
1801
1802
int DTreesImpl::readSplit( const FileNode& fn )
1803
{
1804
Split split;
1805
1806
int vi = (int)fn["var"];
1807
CV_Assert( 0 <= vi && vi <= (int)varType.size() );
1808
vi = varMapping[vi]; // convert to varIdx if needed
1809
split.varIdx = vi;
1810
1811
if( varType[vi] == VAR_CATEGORICAL ) // split on categorical var
1812
{
1813
int i, val, ssize = getSubsetSize(vi);
1814
split.subsetOfs = (int)subsets.size();
1815
for( i = 0; i < ssize; i++ )
1816
subsets.push_back(0);
1817
int* subset = &subsets[split.subsetOfs];
1818
FileNode fns = fn["in"];
1819
if( fns.empty() )
1820
{
1821
fns = fn["not_in"];
1822
split.inversed = true;
1823
}
1824
1825
if( fns.isInt() )
1826
{
1827
val = (int)fns;
1828
subset[val >> 5] |= 1 << (val & 31);
1829
}
1830
else
1831
{
1832
FileNodeIterator it = fns.begin();
1833
int n = (int)fns.size();
1834
for( i = 0; i < n; i++, ++it )
1835
{
1836
val = (int)*it;
1837
subset[val >> 5] |= 1 << (val & 31);
1838
}
1839
}
1840
1841
// for categorical splits we do not use inversed splits,
1842
// instead we inverse the variable set in the split
1843
if( split.inversed )
1844
{
1845
for( i = 0; i < ssize; i++ )
1846
subset[i] ^= -1;
1847
split.inversed = false;
1848
}
1849
}
1850
else
1851
{
1852
FileNode cmpNode = fn["le"];
1853
if( cmpNode.empty() )
1854
{
1855
cmpNode = fn["gt"];
1856
split.inversed = true;
1857
}
1858
split.c = (float)cmpNode;
1859
}
1860
1861
split.quality = (float)fn["quality"];
1862
splits.push_back(split);
1863
1864
return (int)(splits.size() - 1);
1865
}
1866
1867
int DTreesImpl::readNode( const FileNode& fn )
1868
{
1869
Node node;
1870
node.value = (double)fn["value"];
1871
1872
if( _isClassifier )
1873
node.classIdx = (int)fn["norm_class_idx"];
1874
1875
FileNode sfn = fn["splits"];
1876
if( !sfn.empty() )
1877
{
1878
int i, n = (int)sfn.size(), prevsplit = -1;
1879
FileNodeIterator it = sfn.begin();
1880
1881
for( i = 0; i < n; i++, ++it )
1882
{
1883
int splitidx = readSplit(*it);
1884
if( splitidx < 0 )
1885
break;
1886
if( prevsplit < 0 )
1887
node.split = splitidx;
1888
else
1889
splits[prevsplit].next = splitidx;
1890
prevsplit = splitidx;
1891
}
1892
}
1893
nodes.push_back(node);
1894
return (int)(nodes.size() - 1);
1895
}
1896
1897
int DTreesImpl::readTree( const FileNode& fn )
1898
{
1899
int i, n = (int)fn.size(), root = -1, pidx = -1;
1900
FileNodeIterator it = fn.begin();
1901
1902
for( i = 0; i < n; i++, ++it )
1903
{
1904
int nidx = readNode(*it);
1905
if( nidx < 0 )
1906
break;
1907
Node& node = nodes[nidx];
1908
node.parent = pidx;
1909
if( pidx < 0 )
1910
root = nidx;
1911
else
1912
{
1913
Node& parent = nodes[pidx];
1914
if( parent.left < 0 )
1915
parent.left = nidx;
1916
else
1917
parent.right = nidx;
1918
}
1919
if( node.split >= 0 )
1920
pidx = nidx;
1921
else
1922
{
1923
while( pidx >= 0 && nodes[pidx].right >= 0 )
1924
pidx = nodes[pidx].parent;
1925
}
1926
}
1927
roots.push_back(root);
1928
return root;
1929
}
1930
1931
void DTreesImpl::read( const FileNode& fn )
1932
{
1933
clear();
1934
readParams(fn);
1935
1936
FileNode fnodes = fn["nodes"];
1937
CV_Assert( !fnodes.empty() );
1938
readTree(fnodes);
1939
}
1940
1941
Ptr<DTrees> DTrees::create()
1942
{
1943
return makePtr<DTreesImpl>();
1944
}
1945
1946
Ptr<DTrees> DTrees::load(const String& filepath, const String& nodeName)
1947
{
1948
return Algorithm::load<DTrees>(filepath, nodeName);
1949
}
1950
1951
1952
}
1953
}
1954
1955
/* End of file. */
1956
1957