Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/apps/traincascade/old_ml_tree.cpp
16337 views
1
/*M///////////////////////////////////////////////////////////////////////////////////////
2
//
3
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4
//
5
// By downloading, copying, installing or using the software you agree to this license.
6
// If you do not agree to this license, do not download, install,
7
// copy or use the software.
8
//
9
//
10
// Intel License Agreement
11
//
12
// Copyright (C) 2000, Intel Corporation, all rights reserved.
13
// Third party copyrights are property of their respective owners.
14
//
15
// Redistribution and use in source and binary forms, with or without modification,
16
// are permitted provided that the following conditions are met:
17
//
18
// * Redistribution's of source code must retain the above copyright notice,
19
// this list of conditions and the following disclaimer.
20
//
21
// * Redistribution's in binary form must reproduce the above copyright notice,
22
// this list of conditions and the following disclaimer in the documentation
23
// and/or other materials provided with the distribution.
24
//
25
// * The name of Intel Corporation may not be used to endorse or promote products
26
// derived from this software without specific prior written permission.
27
//
28
// This software is provided by the copyright holders and contributors "as is" and
29
// any express or implied warranties, including, but not limited to, the implied
30
// warranties of merchantability and fitness for a particular purpose are disclaimed.
31
// In no event shall the Intel Corporation or contributors be liable for any direct,
32
// indirect, incidental, special, exemplary, or consequential damages
33
// (including, but not limited to, procurement of substitute goods or services;
34
// loss of use, data, or profits; or business interruption) however caused
35
// and on any theory of liability, whether in contract, strict liability,
36
// or tort (including negligence or otherwise) arising in any way out of
37
// the use of this software, even if advised of the possibility of such damage.
38
//
39
//M*/
40
41
#include "old_ml_precomp.hpp"
42
#include <ctype.h>
43
44
using namespace cv;
45
46
static const float ord_nan = FLT_MAX*0.5f;
47
static const int min_block_size = 1 << 16;
48
static const int block_size_delta = 1 << 10;
49
50
CvDTreeTrainData::CvDTreeTrainData()
51
{
52
var_idx = var_type = cat_count = cat_ofs = cat_map =
53
priors = priors_mult = counts = direction = split_buf = responses_copy = 0;
54
buf = 0;
55
tree_storage = temp_storage = 0;
56
57
clear();
58
}
59
60
61
CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,
62
const CvMat* _responses, const CvMat* _var_idx,
63
const CvMat* _sample_idx, const CvMat* _var_type,
64
const CvMat* _missing_mask, const CvDTreeParams& _params,
65
bool _shared, bool _add_labels )
66
{
67
var_idx = var_type = cat_count = cat_ofs = cat_map =
68
priors = priors_mult = counts = direction = split_buf = responses_copy = 0;
69
buf = 0;
70
71
tree_storage = temp_storage = 0;
72
73
set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,
74
_var_type, _missing_mask, _params, _shared, _add_labels );
75
}
76
77
78
CvDTreeTrainData::~CvDTreeTrainData()
79
{
80
clear();
81
}
82
83
84
bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )
85
{
86
bool ok = false;
87
88
CV_FUNCNAME( "CvDTreeTrainData::set_params" );
89
90
__BEGIN__;
91
92
// set parameters
93
params = _params;
94
95
if( params.max_categories < 2 )
96
CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );
97
params.max_categories = MIN( params.max_categories, 15 );
98
99
if( params.max_depth < 0 )
100
CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );
101
params.max_depth = MIN( params.max_depth, 25 );
102
103
params.min_sample_count = MAX(params.min_sample_count,1);
104
105
if( params.cv_folds < 0 )
106
CV_ERROR( CV_StsOutOfRange,
107
"params.cv_folds should be =0 (the tree is not pruned) "
108
"or n>0 (tree is pruned using n-fold cross-validation)" );
109
110
if( params.cv_folds == 1 )
111
params.cv_folds = 0;
112
113
if( params.regression_accuracy < 0 )
114
CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
115
116
ok = true;
117
118
__END__;
119
120
return ok;
121
}
122
123
template<typename T>
124
class LessThanPtr
125
{
126
public:
127
bool operator()(T* a, T* b) const { return *a < *b; }
128
};
129
130
template<typename T, typename Idx>
131
class LessThanIdx
132
{
133
public:
134
LessThanIdx( const T* _arr ) : arr(_arr) {}
135
bool operator()(Idx a, Idx b) const { return arr[a] < arr[b]; }
136
const T* arr;
137
};
138
139
class LessThanPairs
140
{
141
public:
142
bool operator()(const CvPair16u32s& a, const CvPair16u32s& b) const { return *a.i < *b.i; }
143
};
144
145
void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
146
const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
147
const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
148
bool _shared, bool _add_labels, bool _update_data )
149
{
150
CvMat* sample_indices = 0;
151
CvMat* var_type0 = 0;
152
CvMat* tmp_map = 0;
153
int** int_ptr = 0;
154
CvPair16u32s* pair16u32s_ptr = 0;
155
CvDTreeTrainData* data = 0;
156
float *_fdst = 0;
157
int *_idst = 0;
158
unsigned short* udst = 0;
159
int* idst = 0;
160
161
CV_FUNCNAME( "CvDTreeTrainData::set_data" );
162
163
__BEGIN__;
164
165
int sample_all = 0, r_type, cv_n;
166
int total_c_count = 0;
167
int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
168
int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
169
int vi, i, size;
170
char err[100];
171
const int *sidx = 0, *vidx = 0;
172
173
uint64 effective_buf_size = 0;
174
int effective_buf_height = 0, effective_buf_width = 0;
175
176
if( _update_data && data_root )
177
{
178
data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
179
_sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels );
180
181
// compare new and old train data
182
if( !(data->var_count == var_count &&
183
cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&
184
cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON &&
185
cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) )
186
CV_ERROR( CV_StsBadArg,
187
"The new training data must have the same types and the input and output variables "
188
"and the same categories for categorical variables" );
189
190
cvReleaseMat( &priors );
191
cvReleaseMat( &priors_mult );
192
cvReleaseMat( &buf );
193
cvReleaseMat( &direction );
194
cvReleaseMat( &split_buf );
195
cvReleaseMemStorage( &temp_storage );
196
197
priors = data->priors; data->priors = 0;
198
priors_mult = data->priors_mult; data->priors_mult = 0;
199
buf = data->buf; data->buf = 0;
200
buf_count = data->buf_count; buf_size = data->buf_size;
201
sample_count = data->sample_count;
202
203
direction = data->direction; data->direction = 0;
204
split_buf = data->split_buf; data->split_buf = 0;
205
temp_storage = data->temp_storage; data->temp_storage = 0;
206
nv_heap = data->nv_heap; cv_heap = data->cv_heap;
207
208
data_root = new_node( 0, sample_count, 0, 0 );
209
EXIT;
210
}
211
212
clear();
213
214
var_all = 0;
215
rng = &cv::theRNG();
216
217
CV_CALL( set_params( _params ));
218
219
// check parameter types and sizes
220
CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
221
222
train_data = _train_data;
223
responses = _responses;
224
225
if( _tflag == CV_ROW_SAMPLE )
226
{
227
ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
228
dv_step = 1;
229
if( _missing_mask )
230
ms_step = _missing_mask->step, mv_step = 1;
231
}
232
else
233
{
234
dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
235
ds_step = 1;
236
if( _missing_mask )
237
mv_step = _missing_mask->step, ms_step = 1;
238
}
239
tflag = _tflag;
240
241
sample_count = sample_all;
242
var_count = var_all;
243
244
if( _sample_idx )
245
{
246
CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));
247
sidx = sample_indices->data.i;
248
sample_count = sample_indices->rows + sample_indices->cols - 1;
249
}
250
251
if( _var_idx )
252
{
253
CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
254
vidx = var_idx->data.i;
255
var_count = var_idx->rows + var_idx->cols - 1;
256
}
257
258
is_buf_16u = false;
259
if ( sample_count < 65536 )
260
is_buf_16u = true;
261
262
if( !CV_IS_MAT(_responses) ||
263
(CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
264
CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
265
(_responses->rows != 1 && _responses->cols != 1) ||
266
_responses->rows + _responses->cols - 1 != sample_all )
267
CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
268
"floating-point vector containing as many elements as "
269
"the total number of samples in the training data matrix" );
270
271
r_type = CV_VAR_CATEGORICAL;
272
if( _var_type )
273
CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));
274
275
CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
276
277
cat_var_count = 0;
278
ord_var_count = -1;
279
280
is_classifier = r_type == CV_VAR_CATEGORICAL;
281
282
// step 0. calc the number of categorical vars
283
for( vi = 0; vi < var_count; vi++ )
284
{
285
char vt = var_type0 ? var_type0->data.ptr[vi] : CV_VAR_ORDERED;
286
var_type->data.i[vi] = vt == CV_VAR_CATEGORICAL ? cat_var_count++ : ord_var_count--;
287
}
288
289
ord_var_count = ~ord_var_count;
290
cv_n = params.cv_folds;
291
// set the two last elements of var_type array to be able
292
// to locate responses and cross-validation labels using
293
// the corresponding get_* functions.
294
var_type->data.i[var_count] = cat_var_count;
295
var_type->data.i[var_count+1] = cat_var_count+1;
296
297
// in case of single ordered predictor we need dummy cv_labels
298
// for safe split_node_data() operation
299
have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;
300
301
work_var_count = var_count + (is_classifier ? 1 : 0) // for responses class_labels
302
+ (have_labels ? 1 : 0); // for cv_labels
303
304
shared = _shared;
305
buf_count = shared ? 2 : 1;
306
307
buf_size = -1; // the member buf_size is obsolete
308
309
effective_buf_size = (uint64)(work_var_count + 1)*(uint64)sample_count * buf_count; // this is the total size of "CvMat buf" to be allocated
310
effective_buf_width = sample_count;
311
effective_buf_height = work_var_count+1;
312
313
if (effective_buf_width >= effective_buf_height)
314
effective_buf_height *= buf_count;
315
else
316
effective_buf_width *= buf_count;
317
318
if ((uint64)effective_buf_width * (uint64)effective_buf_height != effective_buf_size)
319
{
320
CV_Error(CV_StsBadArg, "The memory buffer cannot be allocated since its size exceeds integer fields limit");
321
}
322
323
324
325
if ( is_buf_16u )
326
{
327
CV_CALL( buf = cvCreateMat( effective_buf_height, effective_buf_width, CV_16UC1 ));
328
CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));
329
}
330
else
331
{
332
CV_CALL( buf = cvCreateMat( effective_buf_height, effective_buf_width, CV_32SC1 ));
333
CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
334
}
335
336
size = is_classifier ? (cat_var_count+1) : cat_var_count;
337
size = !size ? 1 : size;
338
CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));
339
CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));
340
341
size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;
342
size = !size ? 1 : size;
343
CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));
344
345
// now calculate the maximum size of split,
346
// create memory storage that will keep nodes and splits of the decision tree
347
// allocate root node and the buffer for the whole training data
348
max_split_size = cvAlign(sizeof(CvDTreeSplit) +
349
(MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
350
tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
351
tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
352
CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
353
CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
354
355
nv_size = var_count*sizeof(int);
356
nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));
357
358
temp_block_size = nv_size;
359
360
if( cv_n )
361
{
362
if( sample_count < cv_n*MAX(params.min_sample_count,10) )
363
CV_ERROR( CV_StsOutOfRange,
364
"The many folds in cross-validation for such a small dataset" );
365
366
cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
367
temp_block_size = MAX(temp_block_size, cv_size);
368
}
369
370
temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
371
CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
372
CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
373
if( cv_size )
374
CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
375
376
CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
377
378
max_c_count = 1;
379
380
_fdst = 0;
381
_idst = 0;
382
if (ord_var_count)
383
_fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));
384
if (is_buf_16u && (cat_var_count || is_classifier))
385
_idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));
386
387
// transform the training data to convenient representation
388
for( vi = 0; vi <= var_count; vi++ )
389
{
390
int ci;
391
const uchar* mask = 0;
392
int64 m_step = 0, step;
393
const int* idata = 0;
394
const float* fdata = 0;
395
int num_valid = 0;
396
397
if( vi < var_count ) // analyze i-th input variable
398
{
399
int vi0 = vidx ? vidx[vi] : vi;
400
ci = get_var_type(vi);
401
step = ds_step; m_step = ms_step;
402
if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
403
idata = _train_data->data.i + vi0*dv_step;
404
else
405
fdata = _train_data->data.fl + vi0*dv_step;
406
if( _missing_mask )
407
mask = _missing_mask->data.ptr + vi0*mv_step;
408
}
409
else // analyze _responses
410
{
411
ci = cat_var_count;
412
step = CV_IS_MAT_CONT(_responses->type) ?
413
1 : _responses->step / CV_ELEM_SIZE(_responses->type);
414
if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
415
idata = _responses->data.i;
416
else
417
fdata = _responses->data.fl;
418
}
419
420
if( (vi < var_count && ci>=0) ||
421
(vi == var_count && is_classifier) ) // process categorical variable or response
422
{
423
int c_count, prev_label;
424
int* c_map;
425
426
if (is_buf_16u)
427
udst = (unsigned short*)(buf->data.s + (size_t)vi*sample_count);
428
else
429
idst = buf->data.i + (size_t)vi*sample_count;
430
431
// copy data
432
for( i = 0; i < sample_count; i++ )
433
{
434
int val = INT_MAX, si = sidx ? sidx[i] : i;
435
if( !mask || !mask[(size_t)si*m_step] )
436
{
437
if( idata )
438
val = idata[(size_t)si*step];
439
else
440
{
441
float t = fdata[(size_t)si*step];
442
val = cvRound(t);
443
if( fabs(t - val) > FLT_EPSILON )
444
{
445
sprintf( err, "%d-th value of %d-th (categorical) "
446
"variable is not an integer", i, vi );
447
CV_ERROR( CV_StsBadArg, err );
448
}
449
}
450
451
if( val == INT_MAX )
452
{
453
sprintf( err, "%d-th value of %d-th (categorical) "
454
"variable is too large", i, vi );
455
CV_ERROR( CV_StsBadArg, err );
456
}
457
num_valid++;
458
}
459
if (is_buf_16u)
460
{
461
_idst[i] = val;
462
pair16u32s_ptr[i].u = udst + i;
463
pair16u32s_ptr[i].i = _idst + i;
464
}
465
else
466
{
467
idst[i] = val;
468
int_ptr[i] = idst + i;
469
}
470
}
471
472
c_count = num_valid > 0;
473
if (is_buf_16u)
474
{
475
std::sort(pair16u32s_ptr, pair16u32s_ptr + sample_count, LessThanPairs());
476
// count the categories
477
for( i = 1; i < num_valid; i++ )
478
if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)
479
c_count ++ ;
480
}
481
else
482
{
483
std::sort(int_ptr, int_ptr + sample_count, LessThanPtr<int>());
484
// count the categories
485
for( i = 1; i < num_valid; i++ )
486
c_count += *int_ptr[i] != *int_ptr[i-1];
487
}
488
489
if( vi > 0 )
490
max_c_count = MAX( max_c_count, c_count );
491
cat_count->data.i[ci] = c_count;
492
cat_ofs->data.i[ci] = total_c_count;
493
494
// resize cat_map, if need
495
if( cat_map->cols < total_c_count + c_count )
496
{
497
tmp_map = cat_map;
498
CV_CALL( cat_map = cvCreateMat( 1,
499
MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
500
for( i = 0; i < total_c_count; i++ )
501
cat_map->data.i[i] = tmp_map->data.i[i];
502
cvReleaseMat( &tmp_map );
503
}
504
505
c_map = cat_map->data.i + total_c_count;
506
total_c_count += c_count;
507
508
c_count = -1;
509
if (is_buf_16u)
510
{
511
// compact the class indices and build the map
512
prev_label = ~*pair16u32s_ptr[0].i;
513
for( i = 0; i < num_valid; i++ )
514
{
515
int cur_label = *pair16u32s_ptr[i].i;
516
if( cur_label != prev_label )
517
c_map[++c_count] = prev_label = cur_label;
518
*pair16u32s_ptr[i].u = (unsigned short)c_count;
519
}
520
// replace labels for missing values with -1
521
for( ; i < sample_count; i++ )
522
*pair16u32s_ptr[i].u = 65535;
523
}
524
else
525
{
526
// compact the class indices and build the map
527
prev_label = ~*int_ptr[0];
528
for( i = 0; i < num_valid; i++ )
529
{
530
int cur_label = *int_ptr[i];
531
if( cur_label != prev_label )
532
c_map[++c_count] = prev_label = cur_label;
533
*int_ptr[i] = c_count;
534
}
535
// replace labels for missing values with -1
536
for( ; i < sample_count; i++ )
537
*int_ptr[i] = -1;
538
}
539
}
540
else if( ci < 0 ) // process ordered variable
541
{
542
if (is_buf_16u)
543
udst = (unsigned short*)(buf->data.s + (size_t)vi*sample_count);
544
else
545
idst = buf->data.i + (size_t)vi*sample_count;
546
547
for( i = 0; i < sample_count; i++ )
548
{
549
float val = ord_nan;
550
int si = sidx ? sidx[i] : i;
551
if( !mask || !mask[(size_t)si*m_step] )
552
{
553
if( idata )
554
val = (float)idata[(size_t)si*step];
555
else
556
val = fdata[(size_t)si*step];
557
558
if( fabs(val) >= ord_nan )
559
{
560
sprintf( err, "%d-th value of %d-th (ordered) "
561
"variable (=%g) is too large", i, vi, val );
562
CV_ERROR( CV_StsBadArg, err );
563
}
564
num_valid++;
565
}
566
567
if (is_buf_16u)
568
udst[i] = (unsigned short)i; // TODO: memory corruption may be here
569
else
570
idst[i] = i;
571
_fdst[i] = val;
572
573
}
574
if (is_buf_16u)
575
std::sort(udst, udst + sample_count, LessThanIdx<float, unsigned short>(_fdst));
576
else
577
std::sort(idst, idst + sample_count, LessThanIdx<float, int>(_fdst));
578
}
579
580
if( vi < var_count )
581
data_root->set_num_valid(vi, num_valid);
582
}
583
584
// set sample labels
585
if (is_buf_16u)
586
udst = (unsigned short*)(buf->data.s + (size_t)work_var_count*sample_count);
587
else
588
idst = buf->data.i + (size_t)work_var_count*sample_count;
589
590
for (i = 0; i < sample_count; i++)
591
{
592
if (udst)
593
udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;
594
else
595
idst[i] = sidx ? sidx[i] : i;
596
}
597
598
if( cv_n )
599
{
600
unsigned short* usdst = 0;
601
int* idst2 = 0;
602
603
if (is_buf_16u)
604
{
605
usdst = (unsigned short*)(buf->data.s + (size_t)(get_work_var_count()-1)*sample_count);
606
for( i = vi = 0; i < sample_count; i++ )
607
{
608
usdst[i] = (unsigned short)vi++;
609
vi &= vi < cv_n ? -1 : 0;
610
}
611
612
for( i = 0; i < sample_count; i++ )
613
{
614
int a = (*rng)(sample_count);
615
int b = (*rng)(sample_count);
616
unsigned short unsh = (unsigned short)vi;
617
CV_SWAP( usdst[a], usdst[b], unsh );
618
}
619
}
620
else
621
{
622
idst2 = buf->data.i + (size_t)(get_work_var_count()-1)*sample_count;
623
for( i = vi = 0; i < sample_count; i++ )
624
{
625
idst2[i] = vi++;
626
vi &= vi < cv_n ? -1 : 0;
627
}
628
629
for( i = 0; i < sample_count; i++ )
630
{
631
int a = (*rng)(sample_count);
632
int b = (*rng)(sample_count);
633
CV_SWAP( idst2[a], idst2[b], vi );
634
}
635
}
636
}
637
638
if ( cat_map )
639
cat_map->cols = MAX( total_c_count, 1 );
640
641
max_split_size = cvAlign(sizeof(CvDTreeSplit) +
642
(MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
643
CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
644
645
have_priors = is_classifier && params.priors;
646
if( is_classifier )
647
{
648
int m = get_num_classes();
649
double sum = 0;
650
CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
651
for( i = 0; i < m; i++ )
652
{
653
double val = have_priors ? params.priors[i] : 1.;
654
if( val <= 0 )
655
CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
656
priors->data.db[i] = val;
657
sum += val;
658
}
659
660
// normalize weights
661
if( have_priors )
662
cvScale( priors, priors, 1./sum );
663
664
CV_CALL( priors_mult = cvCloneMat( priors ));
665
CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
666
}
667
668
669
CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
670
CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
671
672
__END__;
673
674
if( data )
675
delete data;
676
677
if (_fdst)
678
cvFree( &_fdst );
679
if (_idst)
680
cvFree( &_idst );
681
cvFree( &int_ptr );
682
cvFree( &pair16u32s_ptr);
683
cvReleaseMat( &var_type0 );
684
cvReleaseMat( &sample_indices );
685
cvReleaseMat( &tmp_map );
686
}
687
688
void CvDTreeTrainData::do_responses_copy()
689
{
690
responses_copy = cvCreateMat( responses->rows, responses->cols, responses->type );
691
cvCopy( responses, responses_copy);
692
responses = responses_copy;
693
}
694
695
CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
696
{
697
CvDTreeNode* root = 0;
698
CvMat* isubsample_idx = 0;
699
CvMat* subsample_co = 0;
700
701
bool isMakeRootCopy = true;
702
703
CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );
704
705
__BEGIN__;
706
707
if( !data_root )
708
CV_ERROR( CV_StsError, "No training data has been set" );
709
710
if( _subsample_idx )
711
{
712
CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
713
714
if( isubsample_idx->cols + isubsample_idx->rows - 1 == sample_count )
715
{
716
const int* sidx = isubsample_idx->data.i;
717
for( int i = 0; i < sample_count; i++ )
718
{
719
if( sidx[i] != i )
720
{
721
isMakeRootCopy = false;
722
break;
723
}
724
}
725
}
726
else
727
isMakeRootCopy = false;
728
}
729
730
if( isMakeRootCopy )
731
{
732
// make a copy of the root node
733
CvDTreeNode temp;
734
int i;
735
root = new_node( 0, 1, 0, 0 );
736
temp = *root;
737
*root = *data_root;
738
root->num_valid = temp.num_valid;
739
if( root->num_valid )
740
{
741
for( i = 0; i < var_count; i++ )
742
root->num_valid[i] = data_root->num_valid[i];
743
}
744
root->cv_Tn = temp.cv_Tn;
745
root->cv_node_risk = temp.cv_node_risk;
746
root->cv_node_error = temp.cv_node_error;
747
}
748
else
749
{
750
int* sidx = isubsample_idx->data.i;
751
// co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
752
int* co, cur_ofs = 0;
753
int vi, i;
754
int workVarCount = get_work_var_count();
755
int count = isubsample_idx->rows + isubsample_idx->cols - 1;
756
757
root = new_node( 0, count, 1, 0 );
758
759
CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
760
cvZero( subsample_co );
761
co = subsample_co->data.i;
762
for( i = 0; i < count; i++ )
763
co[sidx[i]*2]++;
764
for( i = 0; i < sample_count; i++ )
765
{
766
if( co[i*2] )
767
{
768
co[i*2+1] = cur_ofs;
769
cur_ofs += co[i*2];
770
}
771
else
772
co[i*2+1] = -1;
773
}
774
775
cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float)));
776
for( vi = 0; vi < workVarCount; vi++ )
777
{
778
int ci = get_var_type(vi);
779
780
if( ci >= 0 || vi >= var_count )
781
{
782
int num_valid = 0;
783
const int* src = CvDTreeTrainData::get_cat_var_data(data_root, vi, (int*)inn_buf.data());
784
785
if (is_buf_16u)
786
{
787
unsigned short* udst = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() +
788
(size_t)vi*sample_count + root->offset);
789
for( i = 0; i < count; i++ )
790
{
791
int val = src[sidx[i]];
792
udst[i] = (unsigned short)val;
793
num_valid += val >= 0;
794
}
795
}
796
else
797
{
798
int* idst = buf->data.i + root->buf_idx*get_length_subbuf() +
799
(size_t)vi*sample_count + root->offset;
800
for( i = 0; i < count; i++ )
801
{
802
int val = src[sidx[i]];
803
idst[i] = val;
804
num_valid += val >= 0;
805
}
806
}
807
808
if( vi < var_count )
809
root->set_num_valid(vi, num_valid);
810
}
811
else
812
{
813
int *src_idx_buf = (int*)inn_buf.data();
814
float *src_val_buf = (float*)(src_idx_buf + sample_count);
815
int* sample_indices_buf = (int*)(src_val_buf + sample_count);
816
const int* src_idx = 0;
817
const float* src_val = 0;
818
get_ord_var_data( data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx, sample_indices_buf );
819
int j = 0, idx, count_i;
820
int num_valid = data_root->get_num_valid(vi);
821
822
if (is_buf_16u)
823
{
824
unsigned short* udst_idx = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() +
825
(size_t)vi*sample_count + data_root->offset);
826
for( i = 0; i < num_valid; i++ )
827
{
828
idx = src_idx[i];
829
count_i = co[idx*2];
830
if( count_i )
831
for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
832
udst_idx[j] = (unsigned short)cur_ofs;
833
}
834
835
root->set_num_valid(vi, j);
836
837
for( ; i < sample_count; i++ )
838
{
839
idx = src_idx[i];
840
count_i = co[idx*2];
841
if( count_i )
842
for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
843
udst_idx[j] = (unsigned short)cur_ofs;
844
}
845
}
846
else
847
{
848
int* idst_idx = buf->data.i + root->buf_idx*get_length_subbuf() +
849
(size_t)vi*sample_count + root->offset;
850
for( i = 0; i < num_valid; i++ )
851
{
852
idx = src_idx[i];
853
count_i = co[idx*2];
854
if( count_i )
855
for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
856
idst_idx[j] = cur_ofs;
857
}
858
859
root->set_num_valid(vi, j);
860
861
for( ; i < sample_count; i++ )
862
{
863
idx = src_idx[i];
864
count_i = co[idx*2];
865
if( count_i )
866
for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
867
idst_idx[j] = cur_ofs;
868
}
869
}
870
}
871
}
872
// sample indices subsampling
873
const int* sample_idx_src = get_sample_indices(data_root, (int*)inn_buf.data());
874
if (is_buf_16u)
875
{
876
unsigned short* sample_idx_dst = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() +
877
(size_t)workVarCount*sample_count + root->offset);
878
for (i = 0; i < count; i++)
879
sample_idx_dst[i] = (unsigned short)sample_idx_src[sidx[i]];
880
}
881
else
882
{
883
int* sample_idx_dst = buf->data.i + root->buf_idx*get_length_subbuf() +
884
(size_t)workVarCount*sample_count + root->offset;
885
for (i = 0; i < count; i++)
886
sample_idx_dst[i] = sample_idx_src[sidx[i]];
887
}
888
}
889
890
__END__;
891
892
cvReleaseMat( &isubsample_idx );
893
cvReleaseMat( &subsample_co );
894
895
return root;
896
}
897
898
899
void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
900
float* values, uchar* missing,
901
float* _responses, bool get_class_idx )
902
{
903
CvMat* subsample_idx = 0;
904
CvMat* subsample_co = 0;
905
906
CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );
907
908
__BEGIN__;
909
910
int i, vi, total = sample_count, count = total, cur_ofs = 0;
911
int* sidx = 0;
912
int* co = 0;
913
914
cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float)));
915
if( _subsample_idx )
916
{
917
CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
918
sidx = subsample_idx->data.i;
919
CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
920
co = subsample_co->data.i;
921
cvZero( subsample_co );
922
count = subsample_idx->cols + subsample_idx->rows - 1;
923
for( i = 0; i < count; i++ )
924
co[sidx[i]*2]++;
925
for( i = 0; i < total; i++ )
926
{
927
int count_i = co[i*2];
928
if( count_i )
929
{
930
co[i*2+1] = cur_ofs*var_count;
931
cur_ofs += count_i;
932
}
933
}
934
}
935
936
if( missing )
937
memset( missing, 1, count*var_count );
938
939
for( vi = 0; vi < var_count; vi++ )
940
{
941
int ci = get_var_type(vi);
942
if( ci >= 0 ) // categorical
943
{
944
float* dst = values + vi;
945
uchar* m = missing ? missing + vi : 0;
946
const int* src = get_cat_var_data(data_root, vi, (int*)inn_buf.data());
947
948
for( i = 0; i < count; i++, dst += var_count )
949
{
950
int idx = sidx ? sidx[i] : i;
951
int val = src[idx];
952
*dst = (float)val;
953
if( m )
954
{
955
*m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535));
956
m += var_count;
957
}
958
}
959
}
960
else // ordered
961
{
962
float* dst = values + vi;
963
uchar* m = missing ? missing + vi : 0;
964
int count1 = data_root->get_num_valid(vi);
965
float *src_val_buf = (float*)inn_buf.data();
966
int* src_idx_buf = (int*)(src_val_buf + sample_count);
967
int* sample_indices_buf = src_idx_buf + sample_count;
968
const float *src_val = 0;
969
const int* src_idx = 0;
970
get_ord_var_data(data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx, sample_indices_buf);
971
972
for( i = 0; i < count1; i++ )
973
{
974
int idx = src_idx[i];
975
int count_i = 1;
976
if( co )
977
{
978
count_i = co[idx*2];
979
cur_ofs = co[idx*2+1];
980
}
981
else
982
cur_ofs = idx*var_count;
983
if( count_i )
984
{
985
float val = src_val[i];
986
for( ; count_i > 0; count_i--, cur_ofs += var_count )
987
{
988
dst[cur_ofs] = val;
989
if( m )
990
m[cur_ofs] = 0;
991
}
992
}
993
}
994
}
995
}
996
997
// copy responses
998
if( _responses )
999
{
1000
if( is_classifier )
1001
{
1002
const int* src = get_class_labels(data_root, (int*)inn_buf.data());
1003
for( i = 0; i < count; i++ )
1004
{
1005
int idx = sidx ? sidx[i] : i;
1006
int val = get_class_idx ? src[idx] :
1007
cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
1008
_responses[i] = (float)val;
1009
}
1010
}
1011
else
1012
{
1013
float* val_buf = (float*)inn_buf.data();
1014
int* sample_idx_buf = (int*)(val_buf + sample_count);
1015
const float* _values = get_ord_responses(data_root, val_buf, sample_idx_buf);
1016
for( i = 0; i < count; i++ )
1017
{
1018
int idx = sidx ? sidx[i] : i;
1019
_responses[i] = _values[idx];
1020
}
1021
}
1022
}
1023
1024
__END__;
1025
1026
cvReleaseMat( &subsample_idx );
1027
cvReleaseMat( &subsample_co );
1028
}
1029
1030
1031
CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,
1032
int storage_idx, int offset )
1033
{
1034
CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );
1035
1036
node->sample_count = count;
1037
node->depth = parent ? parent->depth + 1 : 0;
1038
node->parent = parent;
1039
node->left = node->right = 0;
1040
node->split = 0;
1041
node->value = 0;
1042
node->class_idx = 0;
1043
node->maxlr = 0.;
1044
1045
node->buf_idx = storage_idx;
1046
node->offset = offset;
1047
if( nv_heap )
1048
node->num_valid = (int*)cvSetNew( nv_heap );
1049
else
1050
node->num_valid = 0;
1051
node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;
1052
node->complexity = 0;
1053
1054
if( params.cv_folds > 0 && cv_heap )
1055
{
1056
int cv_n = params.cv_folds;
1057
node->Tn = INT_MAX;
1058
node->cv_Tn = (int*)cvSetNew( cv_heap );
1059
node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));
1060
node->cv_node_error = node->cv_node_risk + cv_n;
1061
}
1062
else
1063
{
1064
node->Tn = 0;
1065
node->cv_Tn = 0;
1066
node->cv_node_risk = 0;
1067
node->cv_node_error = 0;
1068
}
1069
1070
return node;
1071
}
1072
1073
1074
CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,
1075
int split_point, int inversed, float quality )
1076
{
1077
CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
1078
split->var_idx = vi;
1079
split->condensed_idx = INT_MIN;
1080
split->ord.c = cmp_val;
1081
split->ord.split_point = split_point;
1082
split->inversed = inversed;
1083
split->quality = quality;
1084
split->next = 0;
1085
1086
return split;
1087
}
1088
1089
1090
CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )
1091
{
1092
CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
1093
int i, n = (max_c_count + 31)/32;
1094
1095
split->var_idx = vi;
1096
split->condensed_idx = INT_MIN;
1097
split->inversed = 0;
1098
split->quality = quality;
1099
for( i = 0; i < n; i++ )
1100
split->subset[i] = 0;
1101
split->next = 0;
1102
1103
return split;
1104
}
1105
1106
1107
void CvDTreeTrainData::free_node( CvDTreeNode* node )
1108
{
1109
CvDTreeSplit* split = node->split;
1110
free_node_data( node );
1111
while( split )
1112
{
1113
CvDTreeSplit* next = split->next;
1114
cvSetRemoveByPtr( split_heap, split );
1115
split = next;
1116
}
1117
node->split = 0;
1118
cvSetRemoveByPtr( node_heap, node );
1119
}
1120
1121
1122
void CvDTreeTrainData::free_node_data( CvDTreeNode* node )
1123
{
1124
if( node->num_valid )
1125
{
1126
cvSetRemoveByPtr( nv_heap, node->num_valid );
1127
node->num_valid = 0;
1128
}
1129
// do not free cv_* fields, as all the cross-validation related data is released at once.
1130
}
1131
1132
1133
void CvDTreeTrainData::free_train_data()
1134
{
1135
cvReleaseMat( &counts );
1136
cvReleaseMat( &buf );
1137
cvReleaseMat( &direction );
1138
cvReleaseMat( &split_buf );
1139
cvReleaseMemStorage( &temp_storage );
1140
cvReleaseMat( &responses_copy );
1141
cv_heap = nv_heap = 0;
1142
}
1143
1144
1145
void CvDTreeTrainData::clear()
1146
{
1147
free_train_data();
1148
1149
cvReleaseMemStorage( &tree_storage );
1150
1151
cvReleaseMat( &var_idx );
1152
cvReleaseMat( &var_type );
1153
cvReleaseMat( &cat_count );
1154
cvReleaseMat( &cat_ofs );
1155
cvReleaseMat( &cat_map );
1156
cvReleaseMat( &priors );
1157
cvReleaseMat( &priors_mult );
1158
1159
node_heap = split_heap = 0;
1160
1161
sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;
1162
have_labels = have_priors = is_classifier = false;
1163
1164
buf_count = buf_size = 0;
1165
shared = false;
1166
1167
data_root = 0;
1168
1169
rng = &cv::theRNG();
1170
}
1171
1172
1173
int CvDTreeTrainData::get_num_classes() const
1174
{
1175
return is_classifier ? cat_count->data.i[cat_var_count] : 0;
1176
}
1177
1178
1179
int CvDTreeTrainData::get_var_type(int vi) const
1180
{
1181
return var_type->data.i[vi];
1182
}
1183
1184
void CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
1185
const float** ord_values, const int** sorted_indices, int* sample_indices_buf )
1186
{
1187
int vidx = var_idx ? var_idx->data.i[vi] : vi;
1188
int node_sample_count = n->sample_count;
1189
int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);
1190
1191
const int* sample_indices = get_sample_indices(n, sample_indices_buf);
1192
1193
if( !is_buf_16u )
1194
*sorted_indices = buf->data.i + n->buf_idx*get_length_subbuf() +
1195
(size_t)vi*sample_count + n->offset;
1196
else {
1197
const unsigned short* short_indices = (const unsigned short*)(buf->data.s + n->buf_idx*get_length_subbuf() +
1198
(size_t)vi*sample_count + n->offset );
1199
for( int i = 0; i < node_sample_count; i++ )
1200
sorted_indices_buf[i] = short_indices[i];
1201
*sorted_indices = sorted_indices_buf;
1202
}
1203
1204
if( tflag == CV_ROW_SAMPLE )
1205
{
1206
for( int i = 0; i < node_sample_count &&
1207
((((*sorted_indices)[i] >= 0) && !is_buf_16u) || (((*sorted_indices)[i] != 65535) && is_buf_16u)); i++ )
1208
{
1209
int idx = (*sorted_indices)[i];
1210
idx = sample_indices[idx];
1211
ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);
1212
}
1213
}
1214
else
1215
for( int i = 0; i < node_sample_count &&
1216
((((*sorted_indices)[i] >= 0) && !is_buf_16u) || (((*sorted_indices)[i] != 65535) && is_buf_16u)); i++ )
1217
{
1218
int idx = (*sorted_indices)[i];
1219
idx = sample_indices[idx];
1220
ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);
1221
}
1222
1223
*ord_values = ord_values_buf;
1224
}
1225
1226
1227
const int* CvDTreeTrainData::get_class_labels( CvDTreeNode* n, int* labels_buf )
1228
{
1229
if (is_classifier)
1230
return get_cat_var_data( n, var_count, labels_buf);
1231
return 0;
1232
}
1233
1234
const int* CvDTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf )
1235
{
1236
return get_cat_var_data( n, get_work_var_count(), indices_buf );
1237
}
1238
1239
const float* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n, float* values_buf, int*sample_indices_buf )
1240
{
1241
int _sample_count = n->sample_count;
1242
int r_step = CV_IS_MAT_CONT(responses->type) ? 1 : responses->step/CV_ELEM_SIZE(responses->type);
1243
const int* indices = get_sample_indices(n, sample_indices_buf);
1244
1245
for( int i = 0; i < _sample_count &&
1246
(((indices[i] >= 0) && !is_buf_16u) || ((indices[i] != 65535) && is_buf_16u)); i++ )
1247
{
1248
int idx = indices[i];
1249
values_buf[i] = *(responses->data.fl + idx * r_step);
1250
}
1251
1252
return values_buf;
1253
}
1254
1255
1256
const int* CvDTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf )
1257
{
1258
if (have_labels)
1259
return get_cat_var_data( n, get_work_var_count()- 1, labels_buf);
1260
return 0;
1261
}
1262
1263
1264
const int* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf)
1265
{
1266
const int* cat_values = 0;
1267
if( !is_buf_16u )
1268
cat_values = buf->data.i + n->buf_idx*get_length_subbuf() +
1269
(size_t)vi*sample_count + n->offset;
1270
else {
1271
const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*get_length_subbuf() +
1272
(size_t)vi*sample_count + n->offset);
1273
for( int i = 0; i < n->sample_count; i++ )
1274
cat_values_buf[i] = short_values[i];
1275
cat_values = cat_values_buf;
1276
}
1277
return cat_values;
1278
}
1279
1280
1281
int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )
1282
{
1283
int idx = n->buf_idx + 1;
1284
if( idx >= buf_count )
1285
idx = shared ? 1 : 0;
1286
return idx;
1287
}
1288
1289
1290
void CvDTreeTrainData::write_params( CvFileStorage* fs ) const
1291
{
1292
CV_FUNCNAME( "CvDTreeTrainData::write_params" );
1293
1294
__BEGIN__;
1295
1296
int vi, vcount = var_count;
1297
1298
cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 );
1299
cvWriteInt( fs, "var_all", var_all );
1300
cvWriteInt( fs, "var_count", var_count );
1301
cvWriteInt( fs, "ord_var_count", ord_var_count );
1302
cvWriteInt( fs, "cat_var_count", cat_var_count );
1303
1304
cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
1305
cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 );
1306
1307
if( is_classifier )
1308
{
1309
cvWriteInt( fs, "max_categories", params.max_categories );
1310
}
1311
else
1312
{
1313
cvWriteReal( fs, "regression_accuracy", params.regression_accuracy );
1314
}
1315
1316
cvWriteInt( fs, "max_depth", params.max_depth );
1317
cvWriteInt( fs, "min_sample_count", params.min_sample_count );
1318
cvWriteInt( fs, "cross_validation_folds", params.cv_folds );
1319
1320
if( params.cv_folds > 1 )
1321
{
1322
cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 );
1323
cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 );
1324
}
1325
1326
if( priors )
1327
cvWrite( fs, "priors", priors );
1328
1329
cvEndWriteStruct( fs );
1330
1331
if( var_idx )
1332
cvWrite( fs, "var_idx", var_idx );
1333
1334
cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
1335
1336
for( vi = 0; vi < vcount; vi++ )
1337
cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 );
1338
1339
cvEndWriteStruct( fs );
1340
1341
if( cat_count && (cat_var_count > 0 || is_classifier) )
1342
{
1343
CV_ASSERT( cat_count != 0 );
1344
cvWrite( fs, "cat_count", cat_count );
1345
cvWrite( fs, "cat_map", cat_map );
1346
}
1347
1348
__END__;
1349
}
1350
1351
1352
void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )
1353
{
1354
CV_FUNCNAME( "CvDTreeTrainData::read_params" );
1355
1356
__BEGIN__;
1357
1358
CvFileNode *tparams_node, *vartype_node;
1359
CvSeqReader reader;
1360
int vi, max_split_size, tree_block_size;
1361
1362
is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);
1363
var_all = cvReadIntByName( fs, node, "var_all" );
1364
var_count = cvReadIntByName( fs, node, "var_count", var_all );
1365
cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
1366
ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
1367
1368
tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
1369
1370
if( tparams_node ) // training parameters are not necessary
1371
{
1372
params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
1373
1374
if( is_classifier )
1375
{
1376
params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
1377
}
1378
else
1379
{
1380
params.regression_accuracy =
1381
(float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
1382
}
1383
1384
params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
1385
params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
1386
params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
1387
1388
if( params.cv_folds > 1 )
1389
{
1390
params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
1391
params.truncate_pruned_tree =
1392
cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
1393
}
1394
1395
priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
1396
if( priors )
1397
{
1398
if( !CV_IS_MAT(priors) )
1399
CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
1400
priors_mult = cvCloneMat( priors );
1401
}
1402
}
1403
1404
CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
1405
if( var_idx )
1406
{
1407
if( !CV_IS_MAT(var_idx) ||
1408
(var_idx->cols != 1 && var_idx->rows != 1) ||
1409
var_idx->cols + var_idx->rows - 1 != var_count ||
1410
CV_MAT_TYPE(var_idx->type) != CV_32SC1 )
1411
CV_ERROR( CV_StsParseError,
1412
"var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
1413
1414
for( vi = 0; vi < var_count; vi++ )
1415
if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all )
1416
CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
1417
}
1418
1419
////// read var type
1420
CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ));
1421
1422
cat_var_count = 0;
1423
ord_var_count = -1;
1424
vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
1425
1426
if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 )
1427
var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--;
1428
else
1429
{
1430
if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
1431
vartype_node->data.seq->total != var_count )
1432
CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
1433
1434
cvStartReadSeq( vartype_node->data.seq, &reader );
1435
1436
for( vi = 0; vi < var_count; vi++ )
1437
{
1438
CvFileNode* n = (CvFileNode*)reader.ptr;
1439
if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
1440
CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
1441
var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
1442
CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
1443
}
1444
}
1445
var_type->data.i[var_count] = cat_var_count;
1446
1447
ord_var_count = ~ord_var_count;
1448
//////
1449
1450
if( cat_var_count > 0 || is_classifier )
1451
{
1452
int ccount, total_c_count = 0;
1453
CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
1454
CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
1455
1456
if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) ||
1457
(cat_count->cols != 1 && cat_count->rows != 1) ||
1458
CV_MAT_TYPE(cat_count->type) != CV_32SC1 ||
1459
cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier ||
1460
(cat_map->cols != 1 && cat_map->rows != 1) ||
1461
CV_MAT_TYPE(cat_map->type) != CV_32SC1 )
1462
CV_ERROR( CV_StsParseError,
1463
"Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
1464
1465
ccount = cat_var_count + is_classifier;
1466
1467
CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
1468
cat_ofs->data.i[0] = 0;
1469
max_c_count = 1;
1470
1471
for( vi = 0; vi < ccount; vi++ )
1472
{
1473
int val = cat_count->data.i[vi];
1474
if( val <= 0 )
1475
CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
1476
max_c_count = MAX( max_c_count, val );
1477
cat_ofs->data.i[vi+1] = total_c_count += val;
1478
}
1479
1480
if( cat_map->cols + cat_map->rows - 1 != total_c_count )
1481
CV_ERROR( CV_StsBadSize,
1482
"cat_map vector length is not equal to the total number of categories in all categorical vars" );
1483
}
1484
1485
max_split_size = cvAlign(sizeof(CvDTreeSplit) +
1486
(MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
1487
1488
tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
1489
tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
1490
CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
1491
CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]),
1492
sizeof(CvDTreeNode), tree_storage ));
1493
CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]),
1494
max_split_size, tree_storage ));
1495
1496
__END__;
1497
}
1498
1499
/////////////////////// Decision Tree /////////////////////////
1500
CvDTreeParams::CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10),
1501
cv_folds(10), use_surrogates(true), use_1se_rule(true),
1502
truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0)
1503
{}
1504
1505
CvDTreeParams::CvDTreeParams( int _max_depth, int _min_sample_count,
1506
float _regression_accuracy, bool _use_surrogates,
1507
int _max_categories, int _cv_folds,
1508
bool _use_1se_rule, bool _truncate_pruned_tree,
1509
const float* _priors ) :
1510
max_categories(_max_categories), max_depth(_max_depth),
1511
min_sample_count(_min_sample_count), cv_folds (_cv_folds),
1512
use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule),
1513
truncate_pruned_tree(_truncate_pruned_tree),
1514
regression_accuracy(_regression_accuracy),
1515
priors(_priors)
1516
{}
1517
1518
CvDTree::CvDTree()
1519
{
1520
data = 0;
1521
var_importance = 0;
1522
default_model_name = "my_tree";
1523
1524
clear();
1525
}
1526
1527
1528
void CvDTree::clear()
1529
{
1530
cvReleaseMat( &var_importance );
1531
if( data )
1532
{
1533
if( !data->shared )
1534
delete data;
1535
else
1536
free_tree();
1537
data = 0;
1538
}
1539
root = 0;
1540
pruned_tree_idx = -1;
1541
}
1542
1543
1544
CvDTree::~CvDTree()
1545
{
1546
clear();
1547
}
1548
1549
1550
const CvDTreeNode* CvDTree::get_root() const
1551
{
1552
return root;
1553
}
1554
1555
1556
int CvDTree::get_pruned_tree_idx() const
1557
{
1558
return pruned_tree_idx;
1559
}
1560
1561
1562
CvDTreeTrainData* CvDTree::get_data()
1563
{
1564
return data;
1565
}
1566
1567
1568
bool CvDTree::train( const CvMat* _train_data, int _tflag,
1569
const CvMat* _responses, const CvMat* _var_idx,
1570
const CvMat* _sample_idx, const CvMat* _var_type,
1571
const CvMat* _missing_mask, CvDTreeParams _params )
1572
{
1573
bool result = false;
1574
1575
CV_FUNCNAME( "CvDTree::train" );
1576
1577
__BEGIN__;
1578
1579
clear();
1580
data = new CvDTreeTrainData( _train_data, _tflag, _responses,
1581
_var_idx, _sample_idx, _var_type,
1582
_missing_mask, _params, false );
1583
CV_CALL( result = do_train(0) );
1584
1585
__END__;
1586
1587
return result;
1588
}
1589
1590
bool CvDTree::train( const Mat& _train_data, int _tflag,
1591
const Mat& _responses, const Mat& _var_idx,
1592
const Mat& _sample_idx, const Mat& _var_type,
1593
const Mat& _missing_mask, CvDTreeParams _params )
1594
{
1595
train_data_hdr = cvMat(_train_data);
1596
train_data_mat = _train_data;
1597
responses_hdr = cvMat(_responses);
1598
responses_mat = _responses;
1599
1600
CvMat vidx=cvMat(_var_idx), sidx=cvMat(_sample_idx), vtype=cvMat(_var_type), mmask=cvMat(_missing_mask);
1601
1602
return train(&train_data_hdr, _tflag, &responses_hdr, vidx.data.ptr ? &vidx : 0, sidx.data.ptr ? &sidx : 0,
1603
vtype.data.ptr ? &vtype : 0, mmask.data.ptr ? &mmask : 0, _params);
1604
}
1605
1606
1607
bool CvDTree::train( CvMLData* _data, CvDTreeParams _params )
1608
{
1609
bool result = false;
1610
1611
CV_FUNCNAME( "CvDTree::train" );
1612
1613
__BEGIN__;
1614
1615
const CvMat* values = _data->get_values();
1616
const CvMat* response = _data->get_responses();
1617
const CvMat* missing = _data->get_missing();
1618
const CvMat* var_types = _data->get_var_types();
1619
const CvMat* train_sidx = _data->get_train_sample_idx();
1620
const CvMat* var_idx = _data->get_var_idx();
1621
1622
CV_CALL( result = train( values, CV_ROW_SAMPLE, response, var_idx,
1623
train_sidx, var_types, missing, _params ) );
1624
1625
__END__;
1626
1627
return result;
1628
}
1629
1630
bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )
1631
{
1632
bool result = false;
1633
1634
CV_FUNCNAME( "CvDTree::train" );
1635
1636
__BEGIN__;
1637
1638
clear();
1639
data = _data;
1640
data->shared = true;
1641
CV_CALL( result = do_train(_subsample_idx));
1642
1643
__END__;
1644
1645
return result;
1646
}
1647
1648
1649
bool CvDTree::do_train( const CvMat* _subsample_idx )
1650
{
1651
bool result = false;
1652
1653
CV_FUNCNAME( "CvDTree::do_train" );
1654
1655
__BEGIN__;
1656
1657
root = data->subsample_data( _subsample_idx );
1658
1659
CV_CALL( try_split_node(root));
1660
1661
if( root->split )
1662
{
1663
CV_Assert( root->left );
1664
CV_Assert( root->right );
1665
1666
if( data->params.cv_folds > 0 )
1667
CV_CALL( prune_cv() );
1668
1669
if( !data->shared )
1670
data->free_train_data();
1671
1672
result = true;
1673
}
1674
1675
__END__;
1676
1677
return result;
1678
}
1679
1680
1681
void CvDTree::try_split_node( CvDTreeNode* node )
1682
{
1683
CvDTreeSplit* best_split = 0;
1684
int i, n = node->sample_count, vi;
1685
bool can_split = true;
1686
double quality_scale;
1687
1688
calc_node_value( node );
1689
1690
if( node->sample_count <= data->params.min_sample_count ||
1691
node->depth >= data->params.max_depth )
1692
can_split = false;
1693
1694
if( can_split && data->is_classifier )
1695
{
1696
// check if we have a "pure" node,
1697
// we assume that cls_count is filled by calc_node_value()
1698
int* cls_count = data->counts->data.i;
1699
int nz = 0, m = data->get_num_classes();
1700
for( i = 0; i < m; i++ )
1701
nz += cls_count[i] != 0;
1702
if( nz == 1 ) // there is only one class
1703
can_split = false;
1704
}
1705
else if( can_split )
1706
{
1707
if( sqrt(node->node_risk)/n < data->params.regression_accuracy )
1708
can_split = false;
1709
}
1710
1711
if( can_split )
1712
{
1713
best_split = find_best_split(node);
1714
// TODO: check the split quality ...
1715
node->split = best_split;
1716
}
1717
if( !can_split || !best_split )
1718
{
1719
data->free_node_data(node);
1720
return;
1721
}
1722
1723
quality_scale = calc_node_dir( node );
1724
if( data->params.use_surrogates )
1725
{
1726
// find all the surrogate splits
1727
// and sort them by their similarity to the primary one
1728
for( vi = 0; vi < data->var_count; vi++ )
1729
{
1730
CvDTreeSplit* split;
1731
int ci = data->get_var_type(vi);
1732
1733
if( vi == best_split->var_idx )
1734
continue;
1735
1736
if( ci >= 0 )
1737
split = find_surrogate_split_cat( node, vi );
1738
else
1739
split = find_surrogate_split_ord( node, vi );
1740
1741
if( split )
1742
{
1743
// insert the split
1744
CvDTreeSplit* prev_split = node->split;
1745
split->quality = (float)(split->quality*quality_scale);
1746
1747
while( prev_split->next &&
1748
prev_split->next->quality > split->quality )
1749
prev_split = prev_split->next;
1750
split->next = prev_split->next;
1751
prev_split->next = split;
1752
}
1753
}
1754
}
1755
split_node_data( node );
1756
try_split_node( node->left );
1757
try_split_node( node->right );
1758
}
1759
1760
1761
// calculate direction (left(-1),right(1),missing(0))
1762
// for each sample using the best split
1763
// the function returns scale coefficients for surrogate split quality factors.
1764
// the scale is applied to normalize surrogate split quality relatively to the
1765
// best (primary) split quality. That is, if a surrogate split is absolutely
1766
// identical to the primary split, its quality will be set to the maximum value =
1767
// quality of the primary split; otherwise, it will be lower.
1768
// besides, the function compute node->maxlr,
1769
// minimum possible quality (w/o considering the above mentioned scale)
1770
// for a surrogate split. Surrogate splits with quality less than node->maxlr
1771
// are not discarded.
1772
double CvDTree::calc_node_dir( CvDTreeNode* node )
1773
{
1774
char* dir = (char*)data->direction->data.ptr;
1775
int i, n = node->sample_count, vi = node->split->var_idx;
1776
double L, R;
1777
1778
assert( !node->split->inversed );
1779
1780
if( data->get_var_type(vi) >= 0 ) // split on categorical var
1781
{
1782
cv::AutoBuffer<int> inn_buf(n*(!data->have_priors ? 1 : 2));
1783
int* labels_buf = inn_buf.data();
1784
const int* labels = data->get_cat_var_data( node, vi, labels_buf );
1785
const int* subset = node->split->subset;
1786
if( !data->have_priors )
1787
{
1788
int sum = 0, sum_abs = 0;
1789
1790
for( i = 0; i < n; i++ )
1791
{
1792
int idx = labels[i];
1793
int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?
1794
CV_DTREE_CAT_DIR(idx,subset) : 0;
1795
sum += d; sum_abs += d & 1;
1796
dir[i] = (char)d;
1797
}
1798
1799
R = (sum_abs + sum) >> 1;
1800
L = (sum_abs - sum) >> 1;
1801
}
1802
else
1803
{
1804
const double* priors = data->priors_mult->data.db;
1805
double sum = 0, sum_abs = 0;
1806
int* responses_buf = labels_buf + n;
1807
const int* responses = data->get_class_labels(node, responses_buf);
1808
1809
for( i = 0; i < n; i++ )
1810
{
1811
int idx = labels[i];
1812
double w = priors[responses[i]];
1813
int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
1814
sum += d*w; sum_abs += (d & 1)*w;
1815
dir[i] = (char)d;
1816
}
1817
1818
R = (sum_abs + sum) * 0.5;
1819
L = (sum_abs - sum) * 0.5;
1820
}
1821
}
1822
else // split on ordered var
1823
{
1824
int split_point = node->split->ord.split_point;
1825
int n1 = node->get_num_valid(vi);
1826
cv::AutoBuffer<uchar> inn_buf(n*(sizeof(int)*(data->have_priors ? 3 : 2) + sizeof(float)));
1827
float* val_buf = (float*)inn_buf.data();
1828
int* sorted_buf = (int*)(val_buf + n);
1829
int* sample_idx_buf = sorted_buf + n;
1830
const float* val = 0;
1831
const int* sorted = 0;
1832
data->get_ord_var_data( node, vi, val_buf, sorted_buf, &val, &sorted, sample_idx_buf);
1833
1834
assert( 0 <= split_point && split_point < n1-1 );
1835
1836
if( !data->have_priors )
1837
{
1838
for( i = 0; i <= split_point; i++ )
1839
dir[sorted[i]] = (char)-1;
1840
for( ; i < n1; i++ )
1841
dir[sorted[i]] = (char)1;
1842
for( ; i < n; i++ )
1843
dir[sorted[i]] = (char)0;
1844
1845
L = split_point-1;
1846
R = n1 - split_point + 1;
1847
}
1848
else
1849
{
1850
const double* priors = data->priors_mult->data.db;
1851
int* responses_buf = sample_idx_buf + n;
1852
const int* responses = data->get_class_labels(node, responses_buf);
1853
L = R = 0;
1854
1855
for( i = 0; i <= split_point; i++ )
1856
{
1857
int idx = sorted[i];
1858
double w = priors[responses[idx]];
1859
dir[idx] = (char)-1;
1860
L += w;
1861
}
1862
1863
for( ; i < n1; i++ )
1864
{
1865
int idx = sorted[i];
1866
double w = priors[responses[idx]];
1867
dir[idx] = (char)1;
1868
R += w;
1869
}
1870
1871
for( ; i < n; i++ )
1872
dir[sorted[i]] = (char)0;
1873
}
1874
}
1875
node->maxlr = MAX( L, R );
1876
return node->split->quality/(L + R);
1877
}
1878
1879
1880
namespace cv
1881
{
1882
1883
void DefaultDeleter<CvDTreeSplit>::operator ()(CvDTreeSplit* obj) const { fastFree(obj); }
1884
1885
DTreeBestSplitFinder::DTreeBestSplitFinder( CvDTree* _tree, CvDTreeNode* _node)
1886
{
1887
tree = _tree;
1888
node = _node;
1889
splitSize = tree->get_data()->split_heap->elem_size;
1890
1891
bestSplit.reset((CvDTreeSplit*)fastMalloc(splitSize));
1892
memset(bestSplit.get(), 0, splitSize);
1893
bestSplit->quality = -1;
1894
bestSplit->condensed_idx = INT_MIN;
1895
split.reset((CvDTreeSplit*)fastMalloc(splitSize));
1896
memset(split.get(), 0, splitSize);
1897
//haveSplit = false;
1898
}
1899
1900
DTreeBestSplitFinder::DTreeBestSplitFinder( const DTreeBestSplitFinder& finder, Split )
1901
{
1902
tree = finder.tree;
1903
node = finder.node;
1904
splitSize = tree->get_data()->split_heap->elem_size;
1905
1906
bestSplit.reset((CvDTreeSplit*)fastMalloc(splitSize));
1907
memcpy(bestSplit.get(), finder.bestSplit.get(), splitSize);
1908
split.reset((CvDTreeSplit*)fastMalloc(splitSize));
1909
memset(split.get(), 0, splitSize);
1910
}
1911
1912
void DTreeBestSplitFinder::operator()(const BlockedRange& range)
1913
{
1914
int vi, vi1 = range.begin(), vi2 = range.end();
1915
int n = node->sample_count;
1916
CvDTreeTrainData* data = tree->get_data();
1917
AutoBuffer<uchar> inn_buf(2*n*(sizeof(int) + sizeof(float)));
1918
1919
for( vi = vi1; vi < vi2; vi++ )
1920
{
1921
CvDTreeSplit *res;
1922
int ci = data->get_var_type(vi);
1923
if( node->get_num_valid(vi) <= 1 )
1924
continue;
1925
1926
if( data->is_classifier )
1927
{
1928
if( ci >= 0 )
1929
res = tree->find_split_cat_class( node, vi, bestSplit->quality, split, inn_buf.data() );
1930
else
1931
res = tree->find_split_ord_class( node, vi, bestSplit->quality, split, inn_buf.data() );
1932
}
1933
else
1934
{
1935
if( ci >= 0 )
1936
res = tree->find_split_cat_reg( node, vi, bestSplit->quality, split, inn_buf.data() );
1937
else
1938
res = tree->find_split_ord_reg( node, vi, bestSplit->quality, split, inn_buf.data() );
1939
}
1940
1941
if( res && bestSplit->quality < split->quality )
1942
memcpy( bestSplit.get(), split.get(), splitSize );
1943
}
1944
}
1945
1946
void DTreeBestSplitFinder::join( DTreeBestSplitFinder& rhs )
1947
{
1948
if( bestSplit->quality < rhs.bestSplit->quality )
1949
memcpy( bestSplit.get(), rhs.bestSplit.get(), splitSize );
1950
}
1951
}
1952
1953
1954
CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
1955
{
1956
DTreeBestSplitFinder finder( this, node );
1957
1958
cv::parallel_reduce(cv::BlockedRange(0, data->var_count), finder);
1959
1960
CvDTreeSplit *bestSplit = 0;
1961
if( finder.bestSplit->quality > 0 )
1962
{
1963
bestSplit = data->new_split_cat( 0, -1.0f );
1964
memcpy( bestSplit, finder.bestSplit, finder.splitSize );
1965
}
1966
1967
return bestSplit;
1968
}
1969
1970
CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi,
1971
float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
1972
{
1973
const float epsilon = FLT_EPSILON*2;
1974
int n = node->sample_count;
1975
int n1 = node->get_num_valid(vi);
1976
int m = data->get_num_classes();
1977
1978
int base_size = 2*m*sizeof(int);
1979
cv::AutoBuffer<uchar> inn_buf(base_size);
1980
if( !_ext_buf )
1981
inn_buf.allocate(base_size + n*(3*sizeof(int)+sizeof(float)));
1982
uchar* base_buf = inn_buf.data();
1983
uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
1984
float* values_buf = (float*)ext_buf;
1985
int* sorted_indices_buf = (int*)(values_buf + n);
1986
int* sample_indices_buf = sorted_indices_buf + n;
1987
const float* values = 0;
1988
const int* sorted_indices = 0;
1989
data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values,
1990
&sorted_indices, sample_indices_buf );
1991
int* responses_buf = sample_indices_buf + n;
1992
const int* responses = data->get_class_labels( node, responses_buf );
1993
1994
const int* rc0 = data->counts->data.i;
1995
int* lc = (int*)base_buf;
1996
int* rc = lc + m;
1997
int i, best_i = -1;
1998
double lsum2 = 0, rsum2 = 0, best_val = init_quality;
1999
const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
2000
2001
// init arrays of class instance counters on both sides of the split
2002
for( i = 0; i < m; i++ )
2003
{
2004
lc[i] = 0;
2005
rc[i] = rc0[i];
2006
}
2007
2008
// compensate for missing values
2009
for( i = n1; i < n; i++ )
2010
{
2011
rc[responses[sorted_indices[i]]]--;
2012
}
2013
2014
if( !priors )
2015
{
2016
int L = 0, R = n1;
2017
2018
for( i = 0; i < m; i++ )
2019
rsum2 += (double)rc[i]*rc[i];
2020
2021
for( i = 0; i < n1 - 1; i++ )
2022
{
2023
int idx = responses[sorted_indices[i]];
2024
int lv, rv;
2025
L++; R--;
2026
lv = lc[idx]; rv = rc[idx];
2027
lsum2 += lv*2 + 1;
2028
rsum2 -= rv*2 - 1;
2029
lc[idx] = lv + 1; rc[idx] = rv - 1;
2030
2031
if( values[i] + epsilon < values[i+1] )
2032
{
2033
double val = (lsum2*R + rsum2*L)/((double)L*R);
2034
if( best_val < val )
2035
{
2036
best_val = val;
2037
best_i = i;
2038
}
2039
}
2040
}
2041
}
2042
else
2043
{
2044
double L = 0, R = 0;
2045
for( i = 0; i < m; i++ )
2046
{
2047
double wv = rc[i]*priors[i];
2048
R += wv;
2049
rsum2 += wv*wv;
2050
}
2051
2052
for( i = 0; i < n1 - 1; i++ )
2053
{
2054
int idx = responses[sorted_indices[i]];
2055
int lv, rv;
2056
double p = priors[idx], p2 = p*p;
2057
L += p; R -= p;
2058
lv = lc[idx]; rv = rc[idx];
2059
lsum2 += p2*(lv*2 + 1);
2060
rsum2 -= p2*(rv*2 - 1);
2061
lc[idx] = lv + 1; rc[idx] = rv - 1;
2062
2063
if( values[i] + epsilon < values[i+1] )
2064
{
2065
double val = (lsum2*R + rsum2*L)/((double)L*R);
2066
if( best_val < val )
2067
{
2068
best_val = val;
2069
best_i = i;
2070
}
2071
}
2072
}
2073
}
2074
2075
CvDTreeSplit* split = 0;
2076
if( best_i >= 0 )
2077
{
2078
split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
2079
split->var_idx = vi;
2080
split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
2081
split->ord.split_point = best_i;
2082
split->inversed = 0;
2083
split->quality = (float)best_val;
2084
}
2085
return split;
2086
}
2087
2088
2089
void CvDTree::cluster_categories( const int* vectors, int n, int m,
2090
int* csums, int k, int* labels )
2091
{
2092
// TODO: consider adding priors (class weights) and sample weights to the clustering algorithm
2093
int iters = 0, max_iters = 100;
2094
int i, j, idx;
2095
cv::AutoBuffer<double> buf(n + k);
2096
double *v_weights = buf.data(), *c_weights = buf.data() + n;
2097
bool modified = true;
2098
RNG* r = data->rng;
2099
2100
// assign labels randomly
2101
for( i = 0; i < n; i++ )
2102
{
2103
int sum = 0;
2104
const int* v = vectors + i*m;
2105
labels[i] = i < k ? i : r->uniform(0, k);
2106
2107
// compute weight of each vector
2108
for( j = 0; j < m; j++ )
2109
sum += v[j];
2110
v_weights[i] = sum ? 1./sum : 0.;
2111
}
2112
2113
for( i = 0; i < n; i++ )
2114
{
2115
int i1 = (*r)(n);
2116
int i2 = (*r)(n);
2117
CV_SWAP( labels[i1], labels[i2], j );
2118
}
2119
2120
for( iters = 0; iters <= max_iters; iters++ )
2121
{
2122
// calculate csums
2123
for( i = 0; i < k; i++ )
2124
{
2125
for( j = 0; j < m; j++ )
2126
csums[i*m + j] = 0;
2127
}
2128
2129
for( i = 0; i < n; i++ )
2130
{
2131
const int* v = vectors + i*m;
2132
int* s = csums + labels[i]*m;
2133
for( j = 0; j < m; j++ )
2134
s[j] += v[j];
2135
}
2136
2137
// exit the loop here, when we have up-to-date csums
2138
if( iters == max_iters || !modified )
2139
break;
2140
2141
modified = false;
2142
2143
// calculate weight of each cluster
2144
for( i = 0; i < k; i++ )
2145
{
2146
const int* s = csums + i*m;
2147
int sum = 0;
2148
for( j = 0; j < m; j++ )
2149
sum += s[j];
2150
c_weights[i] = sum ? 1./sum : 0;
2151
}
2152
2153
// now for each vector determine the closest cluster
2154
for( i = 0; i < n; i++ )
2155
{
2156
const int* v = vectors + i*m;
2157
double alpha = v_weights[i];
2158
double min_dist2 = DBL_MAX;
2159
int min_idx = -1;
2160
2161
for( idx = 0; idx < k; idx++ )
2162
{
2163
const int* s = csums + idx*m;
2164
double dist2 = 0., beta = c_weights[idx];
2165
for( j = 0; j < m; j++ )
2166
{
2167
double t = v[j]*alpha - s[j]*beta;
2168
dist2 += t*t;
2169
}
2170
if( min_dist2 > dist2 )
2171
{
2172
min_dist2 = dist2;
2173
min_idx = idx;
2174
}
2175
}
2176
2177
if( min_idx != labels[i] )
2178
modified = true;
2179
labels[i] = min_idx;
2180
}
2181
}
2182
}
2183
2184
2185
CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality,
2186
CvDTreeSplit* _split, uchar* _ext_buf )
2187
{
2188
int ci = data->get_var_type(vi);
2189
int n = node->sample_count;
2190
int m = data->get_num_classes();
2191
int _mi = data->cat_count->data.i[ci], mi = _mi;
2192
2193
int base_size = m*(3 + mi)*sizeof(int) + (mi+1)*sizeof(double);
2194
if( m > 2 && mi > data->params.max_categories )
2195
base_size += (m*std::min(data->params.max_categories, n) + mi)*sizeof(int);
2196
else
2197
base_size += mi*sizeof(int*);
2198
cv::AutoBuffer<uchar> inn_buf(base_size);
2199
if( !_ext_buf )
2200
inn_buf.allocate(base_size + 2*n*sizeof(int));
2201
uchar* base_buf = inn_buf.data();
2202
uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
2203
2204
int* lc = (int*)base_buf;
2205
int* rc = lc + m;
2206
int* _cjk = rc + m*2, *cjk = _cjk;
2207
double* c_weights = (double*)alignPtr(cjk + m*mi, sizeof(double));
2208
2209
int* labels_buf = (int*)ext_buf;
2210
const int* labels = data->get_cat_var_data(node, vi, labels_buf);
2211
int* responses_buf = labels_buf + n;
2212
const int* responses = data->get_class_labels(node, responses_buf);
2213
2214
int* cluster_labels = 0;
2215
int** int_ptr = 0;
2216
int i, j, k, idx;
2217
double L = 0, R = 0;
2218
double best_val = init_quality;
2219
int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
2220
const double* priors = data->priors_mult->data.db;
2221
2222
// init array of counters:
2223
// c_{jk} - number of samples that have vi-th input variable = j and response = k.
2224
for( j = -1; j < mi; j++ )
2225
for( k = 0; k < m; k++ )
2226
cjk[j*m + k] = 0;
2227
2228
for( i = 0; i < n; i++ )
2229
{
2230
j = ( labels[i] == 65535 && data->is_buf_16u) ? -1 : labels[i];
2231
k = responses[i];
2232
cjk[j*m + k]++;
2233
}
2234
2235
if( m > 2 )
2236
{
2237
if( mi > data->params.max_categories )
2238
{
2239
mi = MIN(data->params.max_categories, n);
2240
cjk = (int*)(c_weights + _mi);
2241
cluster_labels = cjk + m*mi;
2242
cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );
2243
}
2244
subset_i = 1;
2245
subset_n = 1 << mi;
2246
}
2247
else
2248
{
2249
assert( m == 2 );
2250
int_ptr = (int**)(c_weights + _mi);
2251
for( j = 0; j < mi; j++ )
2252
int_ptr[j] = cjk + j*2 + 1;
2253
std::sort(int_ptr, int_ptr + mi, LessThanPtr<int>());
2254
subset_i = 0;
2255
subset_n = mi;
2256
}
2257
2258
for( k = 0; k < m; k++ )
2259
{
2260
int sum = 0;
2261
for( j = 0; j < mi; j++ )
2262
sum += cjk[j*m + k];
2263
rc[k] = sum;
2264
lc[k] = 0;
2265
}
2266
2267
for( j = 0; j < mi; j++ )
2268
{
2269
double sum = 0;
2270
for( k = 0; k < m; k++ )
2271
sum += cjk[j*m + k]*priors[k];
2272
c_weights[j] = sum;
2273
R += c_weights[j];
2274
}
2275
2276
for( ; subset_i < subset_n; subset_i++ )
2277
{
2278
double weight;
2279
int* crow;
2280
double lsum2 = 0, rsum2 = 0;
2281
2282
if( m == 2 )
2283
idx = (int)(int_ptr[subset_i] - cjk)/2;
2284
else
2285
{
2286
int graycode = (subset_i>>1)^subset_i;
2287
int diff = graycode ^ prevcode;
2288
2289
// determine index of the changed bit.
2290
Cv32suf u;
2291
idx = diff >= (1 << 16) ? 16 : 0;
2292
u.f = (float)(((diff >> 16) | diff) & 65535);
2293
idx += (u.i >> 23) - 127;
2294
subtract = graycode < prevcode;
2295
prevcode = graycode;
2296
}
2297
2298
crow = cjk + idx*m;
2299
weight = c_weights[idx];
2300
if( weight < FLT_EPSILON )
2301
continue;
2302
2303
if( !subtract )
2304
{
2305
for( k = 0; k < m; k++ )
2306
{
2307
int t = crow[k];
2308
int lval = lc[k] + t;
2309
int rval = rc[k] - t;
2310
double p = priors[k], p2 = p*p;
2311
lsum2 += p2*lval*lval;
2312
rsum2 += p2*rval*rval;
2313
lc[k] = lval; rc[k] = rval;
2314
}
2315
L += weight;
2316
R -= weight;
2317
}
2318
else
2319
{
2320
for( k = 0; k < m; k++ )
2321
{
2322
int t = crow[k];
2323
int lval = lc[k] - t;
2324
int rval = rc[k] + t;
2325
double p = priors[k], p2 = p*p;
2326
lsum2 += p2*lval*lval;
2327
rsum2 += p2*rval*rval;
2328
lc[k] = lval; rc[k] = rval;
2329
}
2330
L -= weight;
2331
R += weight;
2332
}
2333
2334
if( L > FLT_EPSILON && R > FLT_EPSILON )
2335
{
2336
double val = (lsum2*R + rsum2*L)/((double)L*R);
2337
if( best_val < val )
2338
{
2339
best_val = val;
2340
best_subset = subset_i;
2341
}
2342
}
2343
}
2344
2345
CvDTreeSplit* split = 0;
2346
if( best_subset >= 0 )
2347
{
2348
split = _split ? _split : data->new_split_cat( 0, -1.0f );
2349
split->var_idx = vi;
2350
split->quality = (float)best_val;
2351
memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
2352
if( m == 2 )
2353
{
2354
for( i = 0; i <= best_subset; i++ )
2355
{
2356
idx = (int)(int_ptr[i] - cjk) >> 1;
2357
split->subset[idx >> 5] |= 1 << (idx & 31);
2358
}
2359
}
2360
else
2361
{
2362
for( i = 0; i < _mi; i++ )
2363
{
2364
idx = cluster_labels ? cluster_labels[i] : i;
2365
if( best_subset & (1 << idx) )
2366
split->subset[i >> 5] |= 1 << (i & 31);
2367
}
2368
}
2369
}
2370
return split;
2371
}
2372
2373
2374
CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
2375
{
2376
const float epsilon = FLT_EPSILON*2;
2377
int n = node->sample_count;
2378
int n1 = node->get_num_valid(vi);
2379
2380
cv::AutoBuffer<uchar> inn_buf;
2381
if( !_ext_buf )
2382
inn_buf.allocate(2*n*(sizeof(int) + sizeof(float)));
2383
uchar* ext_buf = _ext_buf ? _ext_buf : inn_buf.data();
2384
float* values_buf = (float*)ext_buf;
2385
int* sorted_indices_buf = (int*)(values_buf + n);
2386
int* sample_indices_buf = sorted_indices_buf + n;
2387
const float* values = 0;
2388
const int* sorted_indices = 0;
2389
data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
2390
float* responses_buf = (float*)(sample_indices_buf + n);
2391
const float* responses = data->get_ord_responses( node, responses_buf, sample_indices_buf );
2392
2393
int i, best_i = -1;
2394
double best_val = init_quality, lsum = 0, rsum = node->value*n;
2395
int L = 0, R = n1;
2396
2397
// compensate for missing values
2398
for( i = n1; i < n; i++ )
2399
rsum -= responses[sorted_indices[i]];
2400
2401
// find the optimal split
2402
for( i = 0; i < n1 - 1; i++ )
2403
{
2404
float t = responses[sorted_indices[i]];
2405
L++; R--;
2406
lsum += t;
2407
rsum -= t;
2408
2409
if( values[i] + epsilon < values[i+1] )
2410
{
2411
double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
2412
if( best_val < val )
2413
{
2414
best_val = val;
2415
best_i = i;
2416
}
2417
}
2418
}
2419
2420
CvDTreeSplit* split = 0;
2421
if( best_i >= 0 )
2422
{
2423
split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
2424
split->var_idx = vi;
2425
split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
2426
split->ord.split_point = best_i;
2427
split->inversed = 0;
2428
split->quality = (float)best_val;
2429
}
2430
return split;
2431
}
2432
2433
CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
2434
{
2435
int ci = data->get_var_type(vi);
2436
int n = node->sample_count;
2437
int mi = data->cat_count->data.i[ci];
2438
2439
int base_size = (mi+2)*sizeof(double) + (mi+1)*(sizeof(int) + sizeof(double*));
2440
cv::AutoBuffer<uchar> inn_buf(base_size);
2441
if( !_ext_buf )
2442
inn_buf.allocate(base_size + n*(2*sizeof(int) + sizeof(float)));
2443
uchar* base_buf = inn_buf.data();
2444
uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
2445
int* labels_buf = (int*)ext_buf;
2446
const int* labels = data->get_cat_var_data(node, vi, labels_buf);
2447
float* responses_buf = (float*)(labels_buf + n);
2448
int* sample_indices_buf = (int*)(responses_buf + n);
2449
const float* responses = data->get_ord_responses(node, responses_buf, sample_indices_buf);
2450
2451
double* sum = (double*)cv::alignPtr(base_buf,sizeof(double)) + 1;
2452
int* counts = (int*)(sum + mi) + 1;
2453
double** sum_ptr = (double**)(counts + mi);
2454
int i, L = 0, R = 0;
2455
double best_val = init_quality, lsum = 0, rsum = 0;
2456
int best_subset = -1, subset_i;
2457
2458
for( i = -1; i < mi; i++ )
2459
sum[i] = counts[i] = 0;
2460
2461
// calculate sum response and weight of each category of the input var
2462
for( i = 0; i < n; i++ )
2463
{
2464
int idx = ( (labels[i] == 65535) && data->is_buf_16u ) ? -1 : labels[i];
2465
double s = sum[idx] + responses[i];
2466
int nc = counts[idx] + 1;
2467
sum[idx] = s;
2468
counts[idx] = nc;
2469
}
2470
2471
// calculate average response in each category
2472
for( i = 0; i < mi; i++ )
2473
{
2474
R += counts[i];
2475
rsum += sum[i];
2476
sum[i] /= MAX(counts[i],1);
2477
sum_ptr[i] = sum + i;
2478
}
2479
2480
std::sort(sum_ptr, sum_ptr + mi, LessThanPtr<double>());
2481
2482
// revert back to unnormalized sums
2483
// (there should be a very little loss of accuracy)
2484
for( i = 0; i < mi; i++ )
2485
sum[i] *= counts[i];
2486
2487
for( subset_i = 0; subset_i < mi-1; subset_i++ )
2488
{
2489
int idx = (int)(sum_ptr[subset_i] - sum);
2490
int ni = counts[idx];
2491
2492
if( ni )
2493
{
2494
double s = sum[idx];
2495
lsum += s; L += ni;
2496
rsum -= s; R -= ni;
2497
2498
if( L && R )
2499
{
2500
double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
2501
if( best_val < val )
2502
{
2503
best_val = val;
2504
best_subset = subset_i;
2505
}
2506
}
2507
}
2508
}
2509
2510
CvDTreeSplit* split = 0;
2511
if( best_subset >= 0 )
2512
{
2513
split = _split ? _split : data->new_split_cat( 0, -1.0f);
2514
split->var_idx = vi;
2515
split->quality = (float)best_val;
2516
memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
2517
for( i = 0; i <= best_subset; i++ )
2518
{
2519
int idx = (int)(sum_ptr[i] - sum);
2520
split->subset[idx >> 5] |= 1 << (idx & 31);
2521
}
2522
}
2523
return split;
2524
}
2525
2526
CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi, uchar* _ext_buf )
2527
{
2528
const float epsilon = FLT_EPSILON*2;
2529
const char* dir = (char*)data->direction->data.ptr;
2530
int n = node->sample_count, n1 = node->get_num_valid(vi);
2531
cv::AutoBuffer<uchar> inn_buf;
2532
if( !_ext_buf )
2533
inn_buf.allocate( n*(sizeof(int)*(data->have_priors ? 3 : 2) + sizeof(float)) );
2534
uchar* ext_buf = _ext_buf ? _ext_buf : inn_buf.data();
2535
float* values_buf = (float*)ext_buf;
2536
int* sorted_indices_buf = (int*)(values_buf + n);
2537
int* sample_indices_buf = sorted_indices_buf + n;
2538
const float* values = 0;
2539
const int* sorted_indices = 0;
2540
data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
2541
// LL - number of samples that both the primary and the surrogate splits send to the left
2542
// LR - ... primary split sends to the left and the surrogate split sends to the right
2543
// RL - ... primary split sends to the right and the surrogate split sends to the left
2544
// RR - ... both send to the right
2545
int i, best_i = -1, best_inversed = 0;
2546
double best_val;
2547
2548
if( !data->have_priors )
2549
{
2550
int LL = 0, RL = 0, LR, RR;
2551
int worst_val = cvFloor(node->maxlr), _best_val = worst_val;
2552
int sum = 0, sum_abs = 0;
2553
2554
for( i = 0; i < n1; i++ )
2555
{
2556
int d = dir[sorted_indices[i]];
2557
sum += d; sum_abs += d & 1;
2558
}
2559
2560
// sum_abs = R + L; sum = R - L
2561
RR = (sum_abs + sum) >> 1;
2562
LR = (sum_abs - sum) >> 1;
2563
2564
// initially all the samples are sent to the right by the surrogate split,
2565
// LR of them are sent to the left by primary split, and RR - to the right.
2566
// now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
2567
for( i = 0; i < n1 - 1; i++ )
2568
{
2569
int d = dir[sorted_indices[i]];
2570
2571
if( d < 0 )
2572
{
2573
LL++; LR--;
2574
if( LL + RR > _best_val && values[i] + epsilon < values[i+1] )
2575
{
2576
best_val = LL + RR;
2577
best_i = i; best_inversed = 0;
2578
}
2579
}
2580
else if( d > 0 )
2581
{
2582
RL++; RR--;
2583
if( RL + LR > _best_val && values[i] + epsilon < values[i+1] )
2584
{
2585
best_val = RL + LR;
2586
best_i = i; best_inversed = 1;
2587
}
2588
}
2589
}
2590
best_val = _best_val;
2591
}
2592
else
2593
{
2594
double LL = 0, RL = 0, LR, RR;
2595
double worst_val = node->maxlr;
2596
double sum = 0, sum_abs = 0;
2597
const double* priors = data->priors_mult->data.db;
2598
int* responses_buf = sample_indices_buf + n;
2599
const int* responses = data->get_class_labels(node, responses_buf);
2600
best_val = worst_val;
2601
2602
for( i = 0; i < n1; i++ )
2603
{
2604
int idx = sorted_indices[i];
2605
double w = priors[responses[idx]];
2606
int d = dir[idx];
2607
sum += d*w; sum_abs += (d & 1)*w;
2608
}
2609
2610
// sum_abs = R + L; sum = R - L
2611
RR = (sum_abs + sum)*0.5;
2612
LR = (sum_abs - sum)*0.5;
2613
2614
// initially all the samples are sent to the right by the surrogate split,
2615
// LR of them are sent to the left by primary split, and RR - to the right.
2616
// now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
2617
for( i = 0; i < n1 - 1; i++ )
2618
{
2619
int idx = sorted_indices[i];
2620
double w = priors[responses[idx]];
2621
int d = dir[idx];
2622
2623
if( d < 0 )
2624
{
2625
LL += w; LR -= w;
2626
if( LL + RR > best_val && values[i] + epsilon < values[i+1] )
2627
{
2628
best_val = LL + RR;
2629
best_i = i; best_inversed = 0;
2630
}
2631
}
2632
else if( d > 0 )
2633
{
2634
RL += w; RR -= w;
2635
if( RL + LR > best_val && values[i] + epsilon < values[i+1] )
2636
{
2637
best_val = RL + LR;
2638
best_i = i; best_inversed = 1;
2639
}
2640
}
2641
}
2642
}
2643
return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
2644
(values[best_i] + values[best_i+1])*0.5f, best_i, best_inversed, (float)best_val ) : 0;
2645
}
2646
2647
2648
CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi, uchar* _ext_buf )
2649
{
2650
const char* dir = (char*)data->direction->data.ptr;
2651
int n = node->sample_count;
2652
int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0;
2653
2654
int base_size = (2*(mi+1)+1)*sizeof(double) + (!data->have_priors ? 2*(mi+1)*sizeof(int) : 0);
2655
cv::AutoBuffer<uchar> inn_buf(base_size);
2656
if( !_ext_buf )
2657
inn_buf.allocate(base_size + n*(sizeof(int) + (data->have_priors ? sizeof(int) : 0)));
2658
uchar* base_buf = inn_buf.data();
2659
uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
2660
2661
int* labels_buf = (int*)ext_buf;
2662
const int* labels = data->get_cat_var_data(node, vi, labels_buf);
2663
// LL - number of samples that both the primary and the surrogate splits send to the left
2664
// LR - ... primary split sends to the left and the surrogate split sends to the right
2665
// RL - ... primary split sends to the right and the surrogate split sends to the left
2666
// RR - ... both send to the right
2667
CvDTreeSplit* split = data->new_split_cat( vi, 0 );
2668
double best_val = 0;
2669
double* lc = (double*)cv::alignPtr(base_buf,sizeof(double)) + 1;
2670
double* rc = lc + mi + 1;
2671
2672
for( i = -1; i < mi; i++ )
2673
lc[i] = rc[i] = 0;
2674
2675
// for each category calculate the weight of samples
2676
// sent to the left (lc) and to the right (rc) by the primary split
2677
if( !data->have_priors )
2678
{
2679
int* _lc = (int*)rc + 1;
2680
int* _rc = _lc + mi + 1;
2681
2682
for( i = -1; i < mi; i++ )
2683
_lc[i] = _rc[i] = 0;
2684
2685
for( i = 0; i < n; i++ )
2686
{
2687
int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];
2688
int d = dir[i];
2689
int sum = _lc[idx] + d;
2690
int sum_abs = _rc[idx] + (d & 1);
2691
_lc[idx] = sum; _rc[idx] = sum_abs;
2692
}
2693
2694
for( i = 0; i < mi; i++ )
2695
{
2696
int sum = _lc[i];
2697
int sum_abs = _rc[i];
2698
lc[i] = (sum_abs - sum) >> 1;
2699
rc[i] = (sum_abs + sum) >> 1;
2700
}
2701
}
2702
else
2703
{
2704
const double* priors = data->priors_mult->data.db;
2705
int* responses_buf = labels_buf + n;
2706
const int* responses = data->get_class_labels(node, responses_buf);
2707
2708
for( i = 0; i < n; i++ )
2709
{
2710
int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];
2711
double w = priors[responses[i]];
2712
int d = dir[i];
2713
double sum = lc[idx] + d*w;
2714
double sum_abs = rc[idx] + (d & 1)*w;
2715
lc[idx] = sum; rc[idx] = sum_abs;
2716
}
2717
2718
for( i = 0; i < mi; i++ )
2719
{
2720
double sum = lc[i];
2721
double sum_abs = rc[i];
2722
lc[i] = (sum_abs - sum) * 0.5;
2723
rc[i] = (sum_abs + sum) * 0.5;
2724
}
2725
}
2726
2727
// 2. now form the split.
2728
// in each category send all the samples to the same direction as majority
2729
for( i = 0; i < mi; i++ )
2730
{
2731
double lval = lc[i], rval = rc[i];
2732
if( lval > rval )
2733
{
2734
split->subset[i >> 5] |= 1 << (i & 31);
2735
best_val += lval;
2736
l_win++;
2737
}
2738
else
2739
best_val += rval;
2740
}
2741
2742
split->quality = (float)best_val;
2743
if( split->quality <= node->maxlr || l_win == 0 || l_win == mi )
2744
cvSetRemoveByPtr( data->split_heap, split ), split = 0;
2745
2746
return split;
2747
}
2748
2749
2750
void CvDTree::calc_node_value( CvDTreeNode* node )
2751
{
2752
int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
2753
int m = data->get_num_classes();
2754
2755
int base_size = data->is_classifier ? m*cv_n*sizeof(int) : 2*cv_n*sizeof(double)+cv_n*sizeof(int);
2756
int ext_size = n*(sizeof(int) + (data->is_classifier ? sizeof(int) : sizeof(int)+sizeof(float)));
2757
cv::AutoBuffer<uchar> inn_buf(base_size + ext_size);
2758
uchar* base_buf = inn_buf.data();
2759
uchar* ext_buf = base_buf + base_size;
2760
2761
int* cv_labels_buf = (int*)ext_buf;
2762
const int* cv_labels = data->get_cv_labels(node, cv_labels_buf);
2763
2764
if( data->is_classifier )
2765
{
2766
// in case of classification tree:
2767
// * node value is the label of the class that has the largest weight in the node.
2768
// * node risk is the weighted number of misclassified samples,
2769
// * j-th cross-validation fold value and risk are calculated as above,
2770
// but using the samples with cv_labels(*)!=j.
2771
// * j-th cross-validation fold error is calculated as the weighted number of
2772
// misclassified samples with cv_labels(*)==j.
2773
2774
// compute the number of instances of each class
2775
int* cls_count = data->counts->data.i;
2776
int* responses_buf = cv_labels_buf + n;
2777
const int* responses = data->get_class_labels(node, responses_buf);
2778
int* cv_cls_count = (int*)base_buf;
2779
double max_val = -1, total_weight = 0;
2780
int max_k = -1;
2781
double* priors = data->priors_mult->data.db;
2782
2783
for( k = 0; k < m; k++ )
2784
cls_count[k] = 0;
2785
2786
if( cv_n == 0 )
2787
{
2788
for( i = 0; i < n; i++ )
2789
cls_count[responses[i]]++;
2790
}
2791
else
2792
{
2793
for( j = 0; j < cv_n; j++ )
2794
for( k = 0; k < m; k++ )
2795
cv_cls_count[j*m + k] = 0;
2796
2797
for( i = 0; i < n; i++ )
2798
{
2799
j = cv_labels[i]; k = responses[i];
2800
cv_cls_count[j*m + k]++;
2801
}
2802
2803
for( j = 0; j < cv_n; j++ )
2804
for( k = 0; k < m; k++ )
2805
cls_count[k] += cv_cls_count[j*m + k];
2806
}
2807
2808
if( data->have_priors && node->parent == 0 )
2809
{
2810
// compute priors_mult from priors, take the sample ratio into account.
2811
double sum = 0;
2812
for( k = 0; k < m; k++ )
2813
{
2814
int n_k = cls_count[k];
2815
priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.);
2816
sum += priors[k];
2817
}
2818
sum = 1./sum;
2819
for( k = 0; k < m; k++ )
2820
priors[k] *= sum;
2821
}
2822
2823
for( k = 0; k < m; k++ )
2824
{
2825
double val = cls_count[k]*priors[k];
2826
total_weight += val;
2827
if( max_val < val )
2828
{
2829
max_val = val;
2830
max_k = k;
2831
}
2832
}
2833
2834
node->class_idx = max_k;
2835
node->value = data->cat_map->data.i[
2836
data->cat_ofs->data.i[data->cat_var_count] + max_k];
2837
node->node_risk = total_weight - max_val;
2838
2839
for( j = 0; j < cv_n; j++ )
2840
{
2841
double sum_k = 0, sum = 0, max_val_k = 0;
2842
max_val = -1; max_k = -1;
2843
2844
for( k = 0; k < m; k++ )
2845
{
2846
double w = priors[k];
2847
double val_k = cv_cls_count[j*m + k]*w;
2848
double val = cls_count[k]*w - val_k;
2849
sum_k += val_k;
2850
sum += val;
2851
if( max_val < val )
2852
{
2853
max_val = val;
2854
max_val_k = val_k;
2855
max_k = k;
2856
}
2857
}
2858
2859
node->cv_Tn[j] = INT_MAX;
2860
node->cv_node_risk[j] = sum - max_val;
2861
node->cv_node_error[j] = sum_k - max_val_k;
2862
}
2863
}
2864
else
2865
{
2866
// in case of regression tree:
2867
// * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
2868
// n is the number of samples in the node.
2869
// * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
2870
// * j-th cross-validation fold value and risk are calculated as above,
2871
// but using the samples with cv_labels(*)!=j.
2872
// * j-th cross-validation fold error is calculated
2873
// using samples with cv_labels(*)==j as the test subset:
2874
// error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
2875
// where node_value_j is the node value calculated
2876
// as described in the previous bullet, and summation is done
2877
// over the samples with cv_labels(*)==j.
2878
2879
double sum = 0, sum2 = 0;
2880
float* values_buf = (float*)(cv_labels_buf + n);
2881
int* sample_indices_buf = (int*)(values_buf + n);
2882
const float* values = data->get_ord_responses(node, values_buf, sample_indices_buf);
2883
double *cv_sum = 0, *cv_sum2 = 0;
2884
int* cv_count = 0;
2885
2886
if( cv_n == 0 )
2887
{
2888
for( i = 0; i < n; i++ )
2889
{
2890
double t = values[i];
2891
sum += t;
2892
sum2 += t*t;
2893
}
2894
}
2895
else
2896
{
2897
cv_sum = (double*)base_buf;
2898
cv_sum2 = cv_sum + cv_n;
2899
cv_count = (int*)(cv_sum2 + cv_n);
2900
2901
for( j = 0; j < cv_n; j++ )
2902
{
2903
cv_sum[j] = cv_sum2[j] = 0.;
2904
cv_count[j] = 0;
2905
}
2906
2907
for( i = 0; i < n; i++ )
2908
{
2909
j = cv_labels[i];
2910
double t = values[i];
2911
double s = cv_sum[j] + t;
2912
double s2 = cv_sum2[j] + t*t;
2913
int nc = cv_count[j] + 1;
2914
cv_sum[j] = s;
2915
cv_sum2[j] = s2;
2916
cv_count[j] = nc;
2917
}
2918
2919
for( j = 0; j < cv_n; j++ )
2920
{
2921
sum += cv_sum[j];
2922
sum2 += cv_sum2[j];
2923
}
2924
}
2925
2926
node->node_risk = sum2 - (sum/n)*sum;
2927
node->value = sum/n;
2928
2929
for( j = 0; j < cv_n; j++ )
2930
{
2931
double s = cv_sum[j], si = sum - s;
2932
double s2 = cv_sum2[j], s2i = sum2 - s2;
2933
int c = cv_count[j], ci = n - c;
2934
double r = si/MAX(ci,1);
2935
node->cv_node_risk[j] = s2i - r*r*ci;
2936
node->cv_node_error[j] = s2 - 2*r*s + c*r*r;
2937
node->cv_Tn[j] = INT_MAX;
2938
}
2939
}
2940
}
2941
2942
2943
void CvDTree::complete_node_dir( CvDTreeNode* node )
2944
{
2945
int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
2946
int nz = n - node->get_num_valid(node->split->var_idx);
2947
char* dir = (char*)data->direction->data.ptr;
2948
2949
// try to complete direction using surrogate splits
2950
if( nz && data->params.use_surrogates )
2951
{
2952
cv::AutoBuffer<uchar> inn_buf(n*(2*sizeof(int)+sizeof(float)));
2953
CvDTreeSplit* split = node->split->next;
2954
for( ; split != 0 && nz; split = split->next )
2955
{
2956
int inversed_mask = split->inversed ? -1 : 0;
2957
vi = split->var_idx;
2958
2959
if( data->get_var_type(vi) >= 0 ) // split on categorical var
2960
{
2961
int* labels_buf = (int*)inn_buf.data();
2962
const int* labels = data->get_cat_var_data(node, vi, labels_buf);
2963
const int* subset = split->subset;
2964
2965
for( i = 0; i < n; i++ )
2966
{
2967
int idx = labels[i];
2968
if( !dir[i] && ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ))
2969
2970
{
2971
int d = CV_DTREE_CAT_DIR(idx,subset);
2972
dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
2973
if( --nz )
2974
break;
2975
}
2976
}
2977
}
2978
else // split on ordered var
2979
{
2980
float* values_buf = (float*)inn_buf.data();
2981
int* sorted_indices_buf = (int*)(values_buf + n);
2982
int* sample_indices_buf = sorted_indices_buf + n;
2983
const float* values = 0;
2984
const int* sorted_indices = 0;
2985
data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
2986
int split_point = split->ord.split_point;
2987
int n1 = node->get_num_valid(vi);
2988
2989
assert( 0 <= split_point && split_point < n-1 );
2990
2991
for( i = 0; i < n1; i++ )
2992
{
2993
int idx = sorted_indices[i];
2994
if( !dir[idx] )
2995
{
2996
int d = i <= split_point ? -1 : 1;
2997
dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);
2998
if( --nz )
2999
break;
3000
}
3001
}
3002
}
3003
}
3004
}
3005
3006
// find the default direction for the rest
3007
if( nz )
3008
{
3009
for( i = nr = 0; i < n; i++ )
3010
nr += dir[i] > 0;
3011
nl = n - nr - nz;
3012
d0 = nl > nr ? -1 : nr > nl;
3013
}
3014
3015
// make sure that every sample is directed either to the left or to the right
3016
for( i = 0; i < n; i++ )
3017
{
3018
int d = dir[i];
3019
if( !d )
3020
{
3021
d = d0;
3022
if( !d )
3023
d = d1, d1 = -d1;
3024
}
3025
d = d > 0;
3026
dir[i] = (char)d; // remap (-1,1) to (0,1)
3027
}
3028
}
3029
3030
3031
void CvDTree::split_node_data( CvDTreeNode* node )
3032
{
3033
int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;
3034
char* dir = (char*)data->direction->data.ptr;
3035
CvDTreeNode *left = 0, *right = 0;
3036
int* new_idx = data->split_buf->data.i;
3037
int new_buf_idx = data->get_child_buf_idx( node );
3038
int work_var_count = data->get_work_var_count();
3039
CvMat* buf = data->buf;
3040
size_t length_buf_row = data->get_length_subbuf();
3041
cv::AutoBuffer<uchar> inn_buf(n*(3*sizeof(int) + sizeof(float)));
3042
int* temp_buf = (int*)inn_buf.data();
3043
3044
complete_node_dir(node);
3045
3046
for( i = nl = nr = 0; i < n; i++ )
3047
{
3048
int d = dir[i];
3049
// initialize new indices for splitting ordered variables
3050
new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li
3051
nr += d;
3052
nl += d^1;
3053
}
3054
3055
bool split_input_data;
3056
node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
3057
node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );
3058
3059
split_input_data = node->depth + 1 < data->params.max_depth &&
3060
(node->left->sample_count > data->params.min_sample_count ||
3061
node->right->sample_count > data->params.min_sample_count);
3062
3063
// split ordered variables, keep both halves sorted.
3064
for( vi = 0; vi < data->var_count; vi++ )
3065
{
3066
int ci = data->get_var_type(vi);
3067
3068
if( ci >= 0 || !split_input_data )
3069
continue;
3070
3071
int n1 = node->get_num_valid(vi);
3072
float* src_val_buf = (float*)(uchar*)(temp_buf + n);
3073
int* src_sorted_idx_buf = (int*)(src_val_buf + n);
3074
int* src_sample_idx_buf = src_sorted_idx_buf + n;
3075
const float* src_val = 0;
3076
const int* src_sorted_idx = 0;
3077
data->get_ord_var_data(node, vi, src_val_buf, src_sorted_idx_buf, &src_val, &src_sorted_idx, src_sample_idx_buf);
3078
3079
for(i = 0; i < n; i++)
3080
temp_buf[i] = src_sorted_idx[i];
3081
3082
if (data->is_buf_16u)
3083
{
3084
unsigned short *ldst, *rdst, *ldst0, *rdst0;
3085
//unsigned short tl, tr;
3086
ldst0 = ldst = (unsigned short*)(buf->data.s + left->buf_idx*length_buf_row +
3087
vi*scount + left->offset);
3088
rdst0 = rdst = (unsigned short*)(ldst + nl);
3089
3090
// split sorted
3091
for( i = 0; i < n1; i++ )
3092
{
3093
int idx = temp_buf[i];
3094
int d = dir[idx];
3095
idx = new_idx[idx];
3096
if (d)
3097
{
3098
*rdst = (unsigned short)idx;
3099
rdst++;
3100
}
3101
else
3102
{
3103
*ldst = (unsigned short)idx;
3104
ldst++;
3105
}
3106
}
3107
3108
left->set_num_valid(vi, (int)(ldst - ldst0));
3109
right->set_num_valid(vi, (int)(rdst - rdst0));
3110
3111
// split missing
3112
for( ; i < n; i++ )
3113
{
3114
int idx = temp_buf[i];
3115
int d = dir[idx];
3116
idx = new_idx[idx];
3117
if (d)
3118
{
3119
*rdst = (unsigned short)idx;
3120
rdst++;
3121
}
3122
else
3123
{
3124
*ldst = (unsigned short)idx;
3125
ldst++;
3126
}
3127
}
3128
}
3129
else
3130
{
3131
int *ldst0, *ldst, *rdst0, *rdst;
3132
ldst0 = ldst = buf->data.i + left->buf_idx*length_buf_row +
3133
vi*scount + left->offset;
3134
rdst0 = rdst = buf->data.i + right->buf_idx*length_buf_row +
3135
vi*scount + right->offset;
3136
3137
// split sorted
3138
for( i = 0; i < n1; i++ )
3139
{
3140
int idx = temp_buf[i];
3141
int d = dir[idx];
3142
idx = new_idx[idx];
3143
if (d)
3144
{
3145
*rdst = idx;
3146
rdst++;
3147
}
3148
else
3149
{
3150
*ldst = idx;
3151
ldst++;
3152
}
3153
}
3154
3155
left->set_num_valid(vi, (int)(ldst - ldst0));
3156
right->set_num_valid(vi, (int)(rdst - rdst0));
3157
3158
// split missing
3159
for( ; i < n; i++ )
3160
{
3161
int idx = temp_buf[i];
3162
int d = dir[idx];
3163
idx = new_idx[idx];
3164
if (d)
3165
{
3166
*rdst = idx;
3167
rdst++;
3168
}
3169
else
3170
{
3171
*ldst = idx;
3172
ldst++;
3173
}
3174
}
3175
}
3176
}
3177
3178
// split categorical vars, responses and cv_labels using new_idx relocation table
3179
for( vi = 0; vi < work_var_count; vi++ )
3180
{
3181
int ci = data->get_var_type(vi);
3182
int n1 = node->get_num_valid(vi), nr1 = 0;
3183
3184
if( ci < 0 || (vi < data->var_count && !split_input_data) )
3185
continue;
3186
3187
int *src_lbls_buf = temp_buf + n;
3188
const int* src_lbls = data->get_cat_var_data(node, vi, src_lbls_buf);
3189
3190
for(i = 0; i < n; i++)
3191
temp_buf[i] = src_lbls[i];
3192
3193
if (data->is_buf_16u)
3194
{
3195
unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*length_buf_row +
3196
vi*scount + left->offset);
3197
unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*length_buf_row +
3198
vi*scount + right->offset);
3199
3200
for( i = 0; i < n; i++ )
3201
{
3202
int d = dir[i];
3203
int idx = temp_buf[i];
3204
if (d)
3205
{
3206
*rdst = (unsigned short)idx;
3207
rdst++;
3208
nr1 += (idx != 65535 )&d;
3209
}
3210
else
3211
{
3212
*ldst = (unsigned short)idx;
3213
ldst++;
3214
}
3215
}
3216
3217
if( vi < data->var_count )
3218
{
3219
left->set_num_valid(vi, n1 - nr1);
3220
right->set_num_valid(vi, nr1);
3221
}
3222
}
3223
else
3224
{
3225
int *ldst = buf->data.i + left->buf_idx*length_buf_row +
3226
vi*scount + left->offset;
3227
int *rdst = buf->data.i + right->buf_idx*length_buf_row +
3228
vi*scount + right->offset;
3229
3230
for( i = 0; i < n; i++ )
3231
{
3232
int d = dir[i];
3233
int idx = temp_buf[i];
3234
if (d)
3235
{
3236
*rdst = idx;
3237
rdst++;
3238
nr1 += (idx >= 0)&d;
3239
}
3240
else
3241
{
3242
*ldst = idx;
3243
ldst++;
3244
}
3245
3246
}
3247
3248
if( vi < data->var_count )
3249
{
3250
left->set_num_valid(vi, n1 - nr1);
3251
right->set_num_valid(vi, nr1);
3252
}
3253
}
3254
}
3255
3256
3257
// split sample indices
3258
int *sample_idx_src_buf = temp_buf + n;
3259
const int* sample_idx_src = data->get_sample_indices(node, sample_idx_src_buf);
3260
3261
for(i = 0; i < n; i++)
3262
temp_buf[i] = sample_idx_src[i];
3263
3264
int pos = data->get_work_var_count();
3265
if (data->is_buf_16u)
3266
{
3267
unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*length_buf_row +
3268
pos*scount + left->offset);
3269
unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*length_buf_row +
3270
pos*scount + right->offset);
3271
for (i = 0; i < n; i++)
3272
{
3273
int d = dir[i];
3274
unsigned short idx = (unsigned short)temp_buf[i];
3275
if (d)
3276
{
3277
*rdst = idx;
3278
rdst++;
3279
}
3280
else
3281
{
3282
*ldst = idx;
3283
ldst++;
3284
}
3285
}
3286
}
3287
else
3288
{
3289
int* ldst = buf->data.i + left->buf_idx*length_buf_row +
3290
pos*scount + left->offset;
3291
int* rdst = buf->data.i + right->buf_idx*length_buf_row +
3292
pos*scount + right->offset;
3293
for (i = 0; i < n; i++)
3294
{
3295
int d = dir[i];
3296
int idx = temp_buf[i];
3297
if (d)
3298
{
3299
*rdst = idx;
3300
rdst++;
3301
}
3302
else
3303
{
3304
*ldst = idx;
3305
ldst++;
3306
}
3307
}
3308
}
3309
3310
// deallocate the parent node data that is not needed anymore
3311
data->free_node_data(node);
3312
}
3313
3314
float CvDTree::calc_error( CvMLData* _data, int type, std::vector<float> *resp )
3315
{
3316
float err = 0;
3317
const CvMat* values = _data->get_values();
3318
const CvMat* response = _data->get_responses();
3319
const CvMat* missing = _data->get_missing();
3320
const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
3321
const CvMat* var_types = _data->get_var_types();
3322
int* sidx = sample_idx ? sample_idx->data.i : 0;
3323
int r_step = CV_IS_MAT_CONT(response->type) ?
3324
1 : response->step / CV_ELEM_SIZE(response->type);
3325
bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
3326
int sample_count = sample_idx ? sample_idx->cols : 0;
3327
sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
3328
float* pred_resp = 0;
3329
if( resp && (sample_count > 0) )
3330
{
3331
resp->resize( sample_count );
3332
pred_resp = &((*resp)[0]);
3333
}
3334
3335
if ( is_classifier )
3336
{
3337
for( int i = 0; i < sample_count; i++ )
3338
{
3339
CvMat sample, miss;
3340
int si = sidx ? sidx[i] : i;
3341
cvGetRow( values, &sample, si );
3342
if( missing )
3343
cvGetRow( missing, &miss, si );
3344
float r = (float)predict( &sample, missing ? &miss : 0 )->value;
3345
if( pred_resp )
3346
pred_resp[i] = r;
3347
int d = fabs((double)r - response->data.fl[(size_t)si*r_step]) <= FLT_EPSILON ? 0 : 1;
3348
err += d;
3349
}
3350
err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
3351
}
3352
else
3353
{
3354
for( int i = 0; i < sample_count; i++ )
3355
{
3356
CvMat sample, miss;
3357
int si = sidx ? sidx[i] : i;
3358
cvGetRow( values, &sample, si );
3359
if( missing )
3360
cvGetRow( missing, &miss, si );
3361
float r = (float)predict( &sample, missing ? &miss : 0 )->value;
3362
if( pred_resp )
3363
pred_resp[i] = r;
3364
float d = r - response->data.fl[(size_t)si*r_step];
3365
err += d*d;
3366
}
3367
err = sample_count ? err / (float)sample_count : -FLT_MAX;
3368
}
3369
return err;
3370
}
3371
3372
void CvDTree::prune_cv()
3373
{
3374
CvMat* ab = 0;
3375
CvMat* temp = 0;
3376
CvMat* err_jk = 0;
3377
3378
// 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
3379
// 2. choose the best tree index (if need, apply 1SE rule).
3380
// 3. store the best index and cut the branches.
3381
3382
CV_FUNCNAME( "CvDTree::prune_cv" );
3383
3384
__BEGIN__;
3385
3386
int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;
3387
// currently, 1SE for regression is not implemented
3388
bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;
3389
double* err;
3390
double min_err = 0, min_err_se = 0;
3391
int min_idx = -1;
3392
3393
CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));
3394
3395
// build the main tree sequence, calculate alpha's
3396
for(;;tree_count++)
3397
{
3398
double min_alpha = update_tree_rnc(tree_count, -1);
3399
if( cut_tree(tree_count, -1, min_alpha) )
3400
break;
3401
3402
if( ab->cols <= tree_count )
3403
{
3404
CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));
3405
for( ti = 0; ti < ab->cols; ti++ )
3406
temp->data.db[ti] = ab->data.db[ti];
3407
cvReleaseMat( &ab );
3408
ab = temp;
3409
temp = 0;
3410
}
3411
3412
ab->data.db[tree_count] = min_alpha;
3413
}
3414
3415
ab->data.db[0] = 0.;
3416
3417
if( tree_count > 0 )
3418
{
3419
for( ti = 1; ti < tree_count-1; ti++ )
3420
ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);
3421
ab->data.db[tree_count-1] = DBL_MAX*0.5;
3422
3423
CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));
3424
err = err_jk->data.db;
3425
3426
for( j = 0; j < cv_n; j++ )
3427
{
3428
int tj = 0, tk = 0;
3429
for( ; tk < tree_count; tj++ )
3430
{
3431
double min_alpha = update_tree_rnc(tj, j);
3432
if( cut_tree(tj, j, min_alpha) )
3433
min_alpha = DBL_MAX;
3434
3435
for( ; tk < tree_count; tk++ )
3436
{
3437
if( ab->data.db[tk] > min_alpha )
3438
break;
3439
err[j*tree_count + tk] = root->tree_error;
3440
}
3441
}
3442
}
3443
3444
for( ti = 0; ti < tree_count; ti++ )
3445
{
3446
double sum_err = 0;
3447
for( j = 0; j < cv_n; j++ )
3448
sum_err += err[j*tree_count + ti];
3449
if( ti == 0 || sum_err < min_err )
3450
{
3451
min_err = sum_err;
3452
min_idx = ti;
3453
if( use_1se )
3454
min_err_se = sqrt( sum_err*(n - sum_err) );
3455
}
3456
else if( sum_err < min_err + min_err_se )
3457
min_idx = ti;
3458
}
3459
}
3460
3461
pruned_tree_idx = min_idx;
3462
free_prune_data(data->params.truncate_pruned_tree != 0);
3463
3464
__END__;
3465
3466
cvReleaseMat( &err_jk );
3467
cvReleaseMat( &ab );
3468
cvReleaseMat( &temp );
3469
}
3470
3471
3472
double CvDTree::update_tree_rnc( int T, int fold )
3473
{
3474
CvDTreeNode* node = root;
3475
double min_alpha = DBL_MAX;
3476
3477
for(;;)
3478
{
3479
CvDTreeNode* parent;
3480
for(;;)
3481
{
3482
int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
3483
if( t <= T || !node->left )
3484
{
3485
node->complexity = 1;
3486
node->tree_risk = node->node_risk;
3487
node->tree_error = 0.;
3488
if( fold >= 0 )
3489
{
3490
node->tree_risk = node->cv_node_risk[fold];
3491
node->tree_error = node->cv_node_error[fold];
3492
}
3493
break;
3494
}
3495
node = node->left;
3496
}
3497
3498
for( parent = node->parent; parent && parent->right == node;
3499
node = parent, parent = parent->parent )
3500
{
3501
parent->complexity += node->complexity;
3502
parent->tree_risk += node->tree_risk;
3503
parent->tree_error += node->tree_error;
3504
3505
parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)
3506
- parent->tree_risk)/(parent->complexity - 1);
3507
min_alpha = MIN( min_alpha, parent->alpha );
3508
}
3509
3510
if( !parent )
3511
break;
3512
3513
parent->complexity = node->complexity;
3514
parent->tree_risk = node->tree_risk;
3515
parent->tree_error = node->tree_error;
3516
node = parent->right;
3517
}
3518
3519
return min_alpha;
3520
}
3521
3522
3523
int CvDTree::cut_tree( int T, int fold, double min_alpha )
3524
{
3525
CvDTreeNode* node = root;
3526
if( !node->left )
3527
return 1;
3528
3529
for(;;)
3530
{
3531
CvDTreeNode* parent;
3532
for(;;)
3533
{
3534
int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
3535
if( t <= T || !node->left )
3536
break;
3537
if( node->alpha <= min_alpha + FLT_EPSILON )
3538
{
3539
if( fold >= 0 )
3540
node->cv_Tn[fold] = T;
3541
else
3542
node->Tn = T;
3543
if( node == root )
3544
return 1;
3545
break;
3546
}
3547
node = node->left;
3548
}
3549
3550
for( parent = node->parent; parent && parent->right == node;
3551
node = parent, parent = parent->parent )
3552
;
3553
3554
if( !parent )
3555
break;
3556
3557
node = parent->right;
3558
}
3559
3560
return 0;
3561
}
3562
3563
3564
void CvDTree::free_prune_data(bool _cut_tree)
3565
{
3566
CvDTreeNode* node = root;
3567
3568
for(;;)
3569
{
3570
CvDTreeNode* parent;
3571
for(;;)
3572
{
3573
// do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )
3574
// as we will clear the whole cross-validation heap at the end
3575
node->cv_Tn = 0;
3576
node->cv_node_error = node->cv_node_risk = 0;
3577
if( !node->left )
3578
break;
3579
node = node->left;
3580
}
3581
3582
for( parent = node->parent; parent && parent->right == node;
3583
node = parent, parent = parent->parent )
3584
{
3585
if( _cut_tree && parent->Tn <= pruned_tree_idx )
3586
{
3587
data->free_node( parent->left );
3588
data->free_node( parent->right );
3589
parent->left = parent->right = 0;
3590
}
3591
}
3592
3593
if( !parent )
3594
break;
3595
3596
node = parent->right;
3597
}
3598
3599
if( data->cv_heap )
3600
cvClearSet( data->cv_heap );
3601
}
3602
3603
3604
void CvDTree::free_tree()
3605
{
3606
if( root && data && data->shared )
3607
{
3608
pruned_tree_idx = INT_MIN;
3609
free_prune_data(true);
3610
data->free_node(root);
3611
root = 0;
3612
}
3613
}
3614
3615
CvDTreeNode* CvDTree::predict( const CvMat* _sample,
3616
const CvMat* _missing, bool preprocessed_input ) const
3617
{
3618
cv::AutoBuffer<int> catbuf;
3619
3620
int i, mstep = 0;
3621
const uchar* m = 0;
3622
CvDTreeNode* node = root;
3623
3624
if( !node )
3625
CV_Error( CV_StsError, "The tree has not been trained yet" );
3626
3627
if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
3628
(_sample->cols != 1 && _sample->rows != 1) ||
3629
(_sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input) ||
3630
(_sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input) )
3631
CV_Error( CV_StsBadArg,
3632
"the input sample must be 1d floating-point vector with the same "
3633
"number of elements as the total number of variables used for training" );
3634
3635
const float* sample = _sample->data.fl;
3636
int step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);
3637
3638
if( data->cat_count && !preprocessed_input ) // cache for categorical variables
3639
{
3640
int n = data->cat_count->cols;
3641
catbuf.allocate(n);
3642
for( i = 0; i < n; i++ )
3643
catbuf[i] = -1;
3644
}
3645
3646
if( _missing )
3647
{
3648
if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
3649
!CV_ARE_SIZES_EQ(_missing, _sample) )
3650
CV_Error( CV_StsBadArg,
3651
"the missing data mask must be 8-bit vector of the same size as input sample" );
3652
m = _missing->data.ptr;
3653
mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);
3654
}
3655
3656
const int* vtype = data->var_type->data.i;
3657
const int* vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;
3658
const int* cmap = data->cat_map ? data->cat_map->data.i : 0;
3659
const int* cofs = data->cat_ofs ? data->cat_ofs->data.i : 0;
3660
3661
while( node->Tn > pruned_tree_idx && node->left )
3662
{
3663
CvDTreeSplit* split = node->split;
3664
int dir = 0;
3665
for( ; !dir && split != 0; split = split->next )
3666
{
3667
int vi = split->var_idx;
3668
int ci = vtype[vi];
3669
i = vidx ? vidx[vi] : vi;
3670
float val = sample[(size_t)i*step];
3671
if( m && m[(size_t)i*mstep] )
3672
continue;
3673
if( ci < 0 ) // ordered
3674
dir = val <= split->ord.c ? -1 : 1;
3675
else // categorical
3676
{
3677
int c;
3678
if( preprocessed_input )
3679
c = cvRound(val);
3680
else
3681
{
3682
c = catbuf[ci];
3683
if( c < 0 )
3684
{
3685
int a = c = cofs[ci];
3686
int b = (ci+1 >= data->cat_ofs->cols) ? data->cat_map->cols : cofs[ci+1];
3687
3688
int ival = cvRound(val);
3689
if( ival != val )
3690
CV_Error( CV_StsBadArg,
3691
"one of input categorical variable is not an integer" );
3692
3693
int sh = 0;
3694
while( a < b )
3695
{
3696
sh++;
3697
c = (a + b) >> 1;
3698
if( ival < cmap[c] )
3699
b = c;
3700
else if( ival > cmap[c] )
3701
a = c+1;
3702
else
3703
break;
3704
}
3705
3706
if( c < 0 || ival != cmap[c] )
3707
continue;
3708
3709
catbuf[ci] = c -= cofs[ci];
3710
}
3711
}
3712
c = ( (c == 65535) && data->is_buf_16u ) ? -1 : c;
3713
dir = CV_DTREE_CAT_DIR(c, split->subset);
3714
}
3715
3716
if( split->inversed )
3717
dir = -dir;
3718
}
3719
3720
if( !dir )
3721
{
3722
double diff = node->right->sample_count - node->left->sample_count;
3723
dir = diff < 0 ? -1 : 1;
3724
}
3725
node = dir < 0 ? node->left : node->right;
3726
}
3727
3728
return node;
3729
}
3730
3731
3732
CvDTreeNode* CvDTree::predict( const Mat& _sample, const Mat& _missing, bool preprocessed_input ) const
3733
{
3734
CvMat sample = cvMat(_sample), mmask = cvMat(_missing);
3735
return predict(&sample, mmask.data.ptr ? &mmask : 0, preprocessed_input);
3736
}
3737
3738
3739
const CvMat* CvDTree::get_var_importance()
3740
{
3741
if( !var_importance )
3742
{
3743
CvDTreeNode* node = root;
3744
double* importance;
3745
if( !node )
3746
return 0;
3747
var_importance = cvCreateMat( 1, data->var_count, CV_64F );
3748
cvZero( var_importance );
3749
importance = var_importance->data.db;
3750
3751
for(;;)
3752
{
3753
CvDTreeNode* parent;
3754
for( ;; node = node->left )
3755
{
3756
CvDTreeSplit* split = node->split;
3757
3758
if( !node->left || node->Tn <= pruned_tree_idx )
3759
break;
3760
3761
for( ; split != 0; split = split->next )
3762
importance[split->var_idx] += split->quality;
3763
}
3764
3765
for( parent = node->parent; parent && parent->right == node;
3766
node = parent, parent = parent->parent )
3767
;
3768
3769
if( !parent )
3770
break;
3771
3772
node = parent->right;
3773
}
3774
3775
cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
3776
}
3777
3778
return var_importance;
3779
}
3780
3781
3782
void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split ) const
3783
{
3784
int ci;
3785
3786
cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );
3787
cvWriteInt( fs, "var", split->var_idx );
3788
cvWriteReal( fs, "quality", split->quality );
3789
3790
ci = data->get_var_type(split->var_idx);
3791
if( ci >= 0 ) // split on a categorical var
3792
{
3793
int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;
3794
for( i = 0; i < n; i++ )
3795
to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0;
3796
3797
// ad-hoc rule when to use inverse categorical split notation
3798
// to achieve more compact and clear representation
3799
default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;
3800
3801
cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?
3802
"in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );
3803
3804
for( i = 0; i < n; i++ )
3805
{
3806
int dir = CV_DTREE_CAT_DIR(i,split->subset);
3807
if( dir*default_dir < 0 )
3808
cvWriteInt( fs, 0, i );
3809
}
3810
cvEndWriteStruct( fs );
3811
}
3812
else
3813
cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );
3814
3815
cvEndWriteStruct( fs );
3816
}
3817
3818
3819
void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node ) const
3820
{
3821
CvDTreeSplit* split;
3822
3823
cvStartWriteStruct( fs, 0, CV_NODE_MAP );
3824
3825
cvWriteInt( fs, "depth", node->depth );
3826
cvWriteInt( fs, "sample_count", node->sample_count );
3827
cvWriteReal( fs, "value", node->value );
3828
3829
if( data->is_classifier )
3830
cvWriteInt( fs, "norm_class_idx", node->class_idx );
3831
3832
cvWriteInt( fs, "Tn", node->Tn );
3833
cvWriteInt( fs, "complexity", node->complexity );
3834
cvWriteReal( fs, "alpha", node->alpha );
3835
cvWriteReal( fs, "node_risk", node->node_risk );
3836
cvWriteReal( fs, "tree_risk", node->tree_risk );
3837
cvWriteReal( fs, "tree_error", node->tree_error );
3838
3839
if( node->left )
3840
{
3841
cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );
3842
3843
for( split = node->split; split != 0; split = split->next )
3844
write_split( fs, split );
3845
3846
cvEndWriteStruct( fs );
3847
}
3848
3849
cvEndWriteStruct( fs );
3850
}
3851
3852
3853
void CvDTree::write_tree_nodes( CvFileStorage* fs ) const
3854
{
3855
//CV_FUNCNAME( "CvDTree::write_tree_nodes" );
3856
3857
__BEGIN__;
3858
3859
CvDTreeNode* node = root;
3860
3861
// traverse the tree and save all the nodes in depth-first order
3862
for(;;)
3863
{
3864
CvDTreeNode* parent;
3865
for(;;)
3866
{
3867
write_node( fs, node );
3868
if( !node->left )
3869
break;
3870
node = node->left;
3871
}
3872
3873
for( parent = node->parent; parent && parent->right == node;
3874
node = parent, parent = parent->parent )
3875
;
3876
3877
if( !parent )
3878
break;
3879
3880
node = parent->right;
3881
}
3882
3883
__END__;
3884
}
3885
3886
3887
void CvDTree::write( CvFileStorage* fs, const char* name ) const
3888
{
3889
//CV_FUNCNAME( "CvDTree::write" );
3890
3891
__BEGIN__;
3892
3893
cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );
3894
3895
//get_var_importance();
3896
data->write_params( fs );
3897
//if( var_importance )
3898
//cvWrite( fs, "var_importance", var_importance );
3899
write( fs );
3900
3901
cvEndWriteStruct( fs );
3902
3903
__END__;
3904
}
3905
3906
3907
void CvDTree::write( CvFileStorage* fs ) const
3908
{
3909
//CV_FUNCNAME( "CvDTree::write" );
3910
3911
__BEGIN__;
3912
3913
cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
3914
3915
cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
3916
write_tree_nodes( fs );
3917
cvEndWriteStruct( fs );
3918
3919
__END__;
3920
}
3921
3922
3923
CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )
3924
{
3925
CvDTreeSplit* split = 0;
3926
3927
CV_FUNCNAME( "CvDTree::read_split" );
3928
3929
__BEGIN__;
3930
3931
int vi, ci;
3932
3933
if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3934
CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );
3935
3936
vi = cvReadIntByName( fs, fnode, "var", -1 );
3937
if( (unsigned)vi >= (unsigned)data->var_count )
3938
CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );
3939
3940
ci = data->get_var_type(vi);
3941
if( ci >= 0 ) // split on categorical var
3942
{
3943
int i, n = data->cat_count->data.i[ci], inversed = 0, val;
3944
CvSeqReader reader;
3945
CvFileNode* inseq;
3946
split = data->new_split_cat( vi, 0 );
3947
inseq = cvGetFileNodeByName( fs, fnode, "in" );
3948
if( !inseq )
3949
{
3950
inseq = cvGetFileNodeByName( fs, fnode, "not_in" );
3951
inversed = 1;
3952
}
3953
if( !inseq ||
3954
(CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT))
3955
CV_ERROR( CV_StsParseError,
3956
"Either 'in' or 'not_in' tags should be inside a categorical split data" );
3957
3958
if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT )
3959
{
3960
val = inseq->data.i;
3961
if( (unsigned)val >= (unsigned)n )
3962
CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3963
3964
split->subset[val >> 5] |= 1 << (val & 31);
3965
}
3966
else
3967
{
3968
cvStartReadSeq( inseq->data.seq, &reader );
3969
3970
for( i = 0; i < reader.seq->total; i++ )
3971
{
3972
CvFileNode* inode = (CvFileNode*)reader.ptr;
3973
val = inode->data.i;
3974
if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )
3975
CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3976
3977
split->subset[val >> 5] |= 1 << (val & 31);
3978
CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3979
}
3980
}
3981
3982
// for categorical splits we do not use inversed splits,
3983
// instead we inverse the variable set in the split
3984
if( inversed )
3985
for( i = 0; i < (n + 31) >> 5; i++ )
3986
split->subset[i] ^= -1;
3987
}
3988
else
3989
{
3990
CvFileNode* cmp_node;
3991
split = data->new_split_ord( vi, 0, 0, 0, 0 );
3992
3993
cmp_node = cvGetFileNodeByName( fs, fnode, "le" );
3994
if( !cmp_node )
3995
{
3996
cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );
3997
split->inversed = 1;
3998
}
3999
4000
split->ord.c = (float)cvReadReal( cmp_node );
4001
}
4002
4003
split->quality = (float)cvReadRealByName( fs, fnode, "quality" );
4004
4005
__END__;
4006
4007
return split;
4008
}
4009
4010
4011
CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )
4012
{
4013
CvDTreeNode* node = 0;
4014
4015
CV_FUNCNAME( "CvDTree::read_node" );
4016
4017
__BEGIN__;
4018
4019
CvFileNode* splits;
4020
int i, depth;
4021
4022
if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
4023
CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );
4024
4025
CV_CALL( node = data->new_node( parent, 0, 0, 0 ));
4026
depth = cvReadIntByName( fs, fnode, "depth", -1 );
4027
if( depth != node->depth )
4028
CV_ERROR( CV_StsParseError, "incorrect node depth" );
4029
4030
node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );
4031
node->value = cvReadRealByName( fs, fnode, "value" );
4032
if( data->is_classifier )
4033
node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );
4034
4035
node->Tn = cvReadIntByName( fs, fnode, "Tn" );
4036
node->complexity = cvReadIntByName( fs, fnode, "complexity" );
4037
node->alpha = cvReadRealByName( fs, fnode, "alpha" );
4038
node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );
4039
node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );
4040
node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );
4041
4042
splits = cvGetFileNodeByName( fs, fnode, "splits" );
4043
if( splits )
4044
{
4045
CvSeqReader reader;
4046
CvDTreeSplit* last_split = 0;
4047
4048
if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )
4049
CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );
4050
4051
cvStartReadSeq( splits->data.seq, &reader );
4052
for( i = 0; i < reader.seq->total; i++ )
4053
{
4054
CvDTreeSplit* split;
4055
CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));
4056
if( !last_split )
4057
node->split = last_split = split;
4058
else
4059
last_split = last_split->next = split;
4060
4061
CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
4062
}
4063
}
4064
4065
__END__;
4066
4067
return node;
4068
}
4069
4070
4071
void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )
4072
{
4073
CV_FUNCNAME( "CvDTree::read_tree_nodes" );
4074
4075
__BEGIN__;
4076
4077
CvSeqReader reader;
4078
CvDTreeNode _root;
4079
CvDTreeNode* parent = &_root;
4080
int i;
4081
parent->left = parent->right = parent->parent = 0;
4082
4083
cvStartReadSeq( fnode->data.seq, &reader );
4084
4085
for( i = 0; i < reader.seq->total; i++ )
4086
{
4087
CvDTreeNode* node;
4088
4089
CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));
4090
if( !parent->left )
4091
parent->left = node;
4092
else
4093
parent->right = node;
4094
if( node->split )
4095
parent = node;
4096
else
4097
{
4098
while( parent && parent->right )
4099
parent = parent->parent;
4100
}
4101
4102
CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
4103
}
4104
4105
root = _root.left;
4106
4107
__END__;
4108
}
4109
4110
4111
void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
4112
{
4113
CvDTreeTrainData* _data = new CvDTreeTrainData();
4114
_data->read_params( fs, fnode );
4115
4116
read( fs, fnode, _data );
4117
get_var_importance();
4118
}
4119
4120
4121
// a special entry point for reading weak decision trees from the tree ensembles
4122
void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data )
4123
{
4124
CV_FUNCNAME( "CvDTree::read" );
4125
4126
__BEGIN__;
4127
4128
CvFileNode* tree_nodes;
4129
4130
clear();
4131
data = _data;
4132
4133
tree_nodes = cvGetFileNodeByName( fs, node, "nodes" );
4134
if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )
4135
CV_ERROR( CV_StsParseError, "nodes tag is missing" );
4136
4137
pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 );
4138
read_tree_nodes( fs, tree_nodes );
4139
4140
__END__;
4141
}
4142
4143
Mat CvDTree::getVarImportance()
4144
{
4145
return cvarrToMat(get_var_importance());
4146
}
4147
4148
/* End of file. */
4149
4150