Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/apps/traincascade/old_ml.hpp
16337 views
1
/*M///////////////////////////////////////////////////////////////////////////////////////
2
//
3
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4
//
5
// By downloading, copying, installing or using the software you agree to this license.
6
// If you do not agree to this license, do not download, install,
7
// copy or use the software.
8
//
9
//
10
// Intel License Agreement
11
//
12
// Copyright (C) 2000, Intel Corporation, all rights reserved.
13
// Third party copyrights are property of their respective owners.
14
//
15
// Redistribution and use in source and binary forms, with or without modification,
16
// are permitted provided that the following conditions are met:
17
//
18
// * Redistribution's of source code must retain the above copyright notice,
19
// this list of conditions and the following disclaimer.
20
//
21
// * Redistribution's in binary form must reproduce the above copyright notice,
22
// this list of conditions and the following disclaimer in the documentation
23
// and/or other materials provided with the distribution.
24
//
25
// * The name of Intel Corporation may not be used to endorse or promote products
26
// derived from this software without specific prior written permission.
27
//
28
// This software is provided by the copyright holders and contributors "as is" and
29
// any express or implied warranties, including, but not limited to, the implied
30
// warranties of merchantability and fitness for a particular purpose are disclaimed.
31
// In no event shall the Intel Corporation or contributors be liable for any direct,
32
// indirect, incidental, special, exemplary, or consequential damages
33
// (including, but not limited to, procurement of substitute goods or services;
34
// loss of use, data, or profits; or business interruption) however caused
35
// and on any theory of liability, whether in contract, strict liability,
36
// or tort (including negligence or otherwise) arising in any way out of
37
// the use of this software, even if advised of the possibility of such damage.
38
//
39
//M*/
40
41
#ifndef OPENCV_OLD_ML_HPP
42
#define OPENCV_OLD_ML_HPP
43
44
#ifdef __cplusplus
45
# include "opencv2/core.hpp"
46
#endif
47
48
#include "opencv2/core/core_c.h"
49
#include <limits.h>
50
51
#ifdef __cplusplus
52
53
#include <map>
54
#include <iostream>
55
56
// Apple defines a check() macro somewhere in the debug headers
57
// that interferes with a method definition in this header
58
#undef check
59
60
/****************************************************************************************\
61
* Main struct definitions *
62
\****************************************************************************************/
63
64
/* log(2*PI) */
65
#define CV_LOG2PI (1.8378770664093454835606594728112)
66
67
/* columns of <trainData> matrix are training samples */
68
#define CV_COL_SAMPLE 0
69
70
/* rows of <trainData> matrix are training samples */
71
#define CV_ROW_SAMPLE 1
72
73
#define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
74
75
struct CvVectors
76
{
77
int type;
78
int dims, count;
79
CvVectors* next;
80
union
81
{
82
uchar** ptr;
83
float** fl;
84
double** db;
85
} data;
86
};
87
88
#if 0
89
/* A structure, representing the lattice range of statmodel parameters.
90
It is used for optimizing statmodel parameters by cross-validation method.
91
The lattice is logarithmic, so <step> must be greater then 1. */
92
typedef struct CvParamLattice
93
{
94
double min_val;
95
double max_val;
96
double step;
97
}
98
CvParamLattice;
99
100
CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
101
double log_step )
102
{
103
CvParamLattice pl;
104
pl.min_val = MIN( min_val, max_val );
105
pl.max_val = MAX( min_val, max_val );
106
pl.step = MAX( log_step, 1. );
107
return pl;
108
}
109
110
CV_INLINE CvParamLattice cvDefaultParamLattice( void )
111
{
112
CvParamLattice pl = {0,0,0};
113
return pl;
114
}
115
#endif
116
117
/* Variable type */
118
#define CV_VAR_NUMERICAL 0
119
#define CV_VAR_ORDERED 0
120
#define CV_VAR_CATEGORICAL 1
121
122
#define CV_TYPE_NAME_ML_SVM "opencv-ml-svm"
123
#define CV_TYPE_NAME_ML_KNN "opencv-ml-knn"
124
#define CV_TYPE_NAME_ML_NBAYES "opencv-ml-bayesian"
125
#define CV_TYPE_NAME_ML_BOOSTING "opencv-ml-boost-tree"
126
#define CV_TYPE_NAME_ML_TREE "opencv-ml-tree"
127
#define CV_TYPE_NAME_ML_ANN_MLP "opencv-ml-ann-mlp"
128
#define CV_TYPE_NAME_ML_CNN "opencv-ml-cnn"
129
#define CV_TYPE_NAME_ML_RTREES "opencv-ml-random-trees"
130
#define CV_TYPE_NAME_ML_ERTREES "opencv-ml-extremely-randomized-trees"
131
#define CV_TYPE_NAME_ML_GBT "opencv-ml-gradient-boosting-trees"
132
133
#define CV_TRAIN_ERROR 0
134
#define CV_TEST_ERROR 1
135
136
class CvStatModel
137
{
138
public:
139
CvStatModel();
140
virtual ~CvStatModel();
141
142
virtual void clear();
143
144
CV_WRAP virtual void save( const char* filename, const char* name=0 ) const;
145
CV_WRAP virtual void load( const char* filename, const char* name=0 );
146
147
virtual void write( CvFileStorage* storage, const char* name ) const;
148
virtual void read( CvFileStorage* storage, CvFileNode* node );
149
150
protected:
151
const char* default_model_name;
152
};
153
154
/****************************************************************************************\
155
* Normal Bayes Classifier *
156
\****************************************************************************************/
157
158
/* The structure, representing the grid range of statmodel parameters.
159
It is used for optimizing statmodel accuracy by varying model parameters,
160
the accuracy estimate being computed by cross-validation.
161
The grid is logarithmic, so <step> must be greater then 1. */
162
163
class CvMLData;
164
165
struct CvParamGrid
166
{
167
// SVM params type
168
enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
169
170
CvParamGrid()
171
{
172
min_val = max_val = step = 0;
173
}
174
175
CvParamGrid( double min_val, double max_val, double log_step );
176
//CvParamGrid( int param_id );
177
bool check() const;
178
179
CV_PROP_RW double min_val;
180
CV_PROP_RW double max_val;
181
CV_PROP_RW double step;
182
};
183
184
inline CvParamGrid::CvParamGrid( double _min_val, double _max_val, double _log_step )
185
{
186
min_val = _min_val;
187
max_val = _max_val;
188
step = _log_step;
189
}
190
191
class CvNormalBayesClassifier : public CvStatModel
192
{
193
public:
194
CV_WRAP CvNormalBayesClassifier();
195
virtual ~CvNormalBayesClassifier();
196
197
CvNormalBayesClassifier( const CvMat* trainData, const CvMat* responses,
198
const CvMat* varIdx=0, const CvMat* sampleIdx=0 );
199
200
virtual bool train( const CvMat* trainData, const CvMat* responses,
201
const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );
202
203
virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0, CV_OUT CvMat* results_prob=0 ) const;
204
CV_WRAP virtual void clear();
205
206
CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,
207
const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() );
208
CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
209
const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
210
bool update=false );
211
CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0, CV_OUT cv::Mat* results_prob=0 ) const;
212
213
virtual void write( CvFileStorage* storage, const char* name ) const;
214
virtual void read( CvFileStorage* storage, CvFileNode* node );
215
216
protected:
217
int var_count, var_all;
218
CvMat* var_idx;
219
CvMat* cls_labels;
220
CvMat** count;
221
CvMat** sum;
222
CvMat** productsum;
223
CvMat** avg;
224
CvMat** inv_eigen_values;
225
CvMat** cov_rotate_mats;
226
CvMat* c;
227
};
228
229
230
/****************************************************************************************\
231
* K-Nearest Neighbour Classifier *
232
\****************************************************************************************/
233
234
// k Nearest Neighbors
235
class CvKNearest : public CvStatModel
236
{
237
public:
238
239
CV_WRAP CvKNearest();
240
virtual ~CvKNearest();
241
242
CvKNearest( const CvMat* trainData, const CvMat* responses,
243
const CvMat* sampleIdx=0, bool isRegression=false, int max_k=32 );
244
245
virtual bool train( const CvMat* trainData, const CvMat* responses,
246
const CvMat* sampleIdx=0, bool is_regression=false,
247
int maxK=32, bool updateBase=false );
248
249
virtual float find_nearest( const CvMat* samples, int k, CV_OUT CvMat* results=0,
250
const float** neighbors=0, CV_OUT CvMat* neighborResponses=0, CV_OUT CvMat* dist=0 ) const;
251
252
CV_WRAP CvKNearest( const cv::Mat& trainData, const cv::Mat& responses,
253
const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false, int max_k=32 );
254
255
CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
256
const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false,
257
int maxK=32, bool updateBase=false );
258
259
virtual float find_nearest( const cv::Mat& samples, int k, cv::Mat* results=0,
260
const float** neighbors=0, cv::Mat* neighborResponses=0,
261
cv::Mat* dist=0 ) const;
262
CV_WRAP virtual float find_nearest( const cv::Mat& samples, int k, CV_OUT cv::Mat& results,
263
CV_OUT cv::Mat& neighborResponses, CV_OUT cv::Mat& dists) const;
264
265
virtual void clear();
266
int get_max_k() const;
267
int get_var_count() const;
268
int get_sample_count() const;
269
bool is_regression() const;
270
271
virtual float write_results( int k, int k1, int start, int end,
272
const float* neighbor_responses, const float* dist, CvMat* _results,
273
CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
274
275
virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
276
float* neighbor_responses, const float** neighbors, float* dist ) const;
277
278
protected:
279
280
int max_k, var_count;
281
int total;
282
bool regression;
283
CvVectors* samples;
284
};
285
286
/****************************************************************************************\
287
* Support Vector Machines *
288
\****************************************************************************************/
289
290
// SVM training parameters
291
struct CvSVMParams
292
{
293
CvSVMParams();
294
CvSVMParams( int svm_type, int kernel_type,
295
double degree, double gamma, double coef0,
296
double Cvalue, double nu, double p,
297
CvMat* class_weights, CvTermCriteria term_crit );
298
299
CV_PROP_RW int svm_type;
300
CV_PROP_RW int kernel_type;
301
CV_PROP_RW double degree; // for poly
302
CV_PROP_RW double gamma; // for poly/rbf/sigmoid/chi2
303
CV_PROP_RW double coef0; // for poly/sigmoid
304
305
CV_PROP_RW double C; // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
306
CV_PROP_RW double nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
307
CV_PROP_RW double p; // for CV_SVM_EPS_SVR
308
CvMat* class_weights; // for CV_SVM_C_SVC
309
CV_PROP_RW CvTermCriteria term_crit; // termination criteria
310
};
311
312
313
struct CvSVMKernel
314
{
315
typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
316
const float* another, float* results );
317
CvSVMKernel();
318
CvSVMKernel( const CvSVMParams* params, Calc _calc_func );
319
virtual bool create( const CvSVMParams* params, Calc _calc_func );
320
virtual ~CvSVMKernel();
321
322
virtual void clear();
323
virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
324
325
const CvSVMParams* params;
326
Calc calc_func;
327
328
virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
329
const float* another, float* results,
330
double alpha, double beta );
331
virtual void calc_intersec( int vcount, int var_count, const float** vecs,
332
const float* another, float* results );
333
virtual void calc_chi2( int vec_count, int vec_size, const float** vecs,
334
const float* another, float* results );
335
virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
336
const float* another, float* results );
337
virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
338
const float* another, float* results );
339
virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
340
const float* another, float* results );
341
virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
342
const float* another, float* results );
343
};
344
345
346
struct CvSVMKernelRow
347
{
348
CvSVMKernelRow* prev;
349
CvSVMKernelRow* next;
350
float* data;
351
};
352
353
354
struct CvSVMSolutionInfo
355
{
356
double obj;
357
double rho;
358
double upper_bound_p;
359
double upper_bound_n;
360
double r; // for Solver_NU
361
};
362
363
class CvSVMSolver
364
{
365
public:
366
typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
367
typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
368
typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
369
370
CvSVMSolver();
371
372
CvSVMSolver( int count, int var_count, const float** samples, schar* y,
373
int alpha_count, double* alpha, double Cp, double Cn,
374
CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
375
SelectWorkingSet select_working_set, CalcRho calc_rho );
376
virtual bool create( int count, int var_count, const float** samples, schar* y,
377
int alpha_count, double* alpha, double Cp, double Cn,
378
CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
379
SelectWorkingSet select_working_set, CalcRho calc_rho );
380
virtual ~CvSVMSolver();
381
382
virtual void clear();
383
virtual bool solve_generic( CvSVMSolutionInfo& si );
384
385
virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
386
double Cp, double Cn, CvMemStorage* storage,
387
CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
388
virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
389
CvMemStorage* storage, CvSVMKernel* kernel,
390
double* alpha, CvSVMSolutionInfo& si );
391
virtual bool solve_one_class( int count, int var_count, const float** samples,
392
CvMemStorage* storage, CvSVMKernel* kernel,
393
double* alpha, CvSVMSolutionInfo& si );
394
395
virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
396
CvMemStorage* storage, CvSVMKernel* kernel,
397
double* alpha, CvSVMSolutionInfo& si );
398
399
virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
400
CvMemStorage* storage, CvSVMKernel* kernel,
401
double* alpha, CvSVMSolutionInfo& si );
402
403
virtual float* get_row_base( int i, bool* _existed );
404
virtual float* get_row( int i, float* dst );
405
406
int sample_count;
407
int var_count;
408
int cache_size;
409
int cache_line_size;
410
const float** samples;
411
const CvSVMParams* params;
412
CvMemStorage* storage;
413
CvSVMKernelRow lru_list;
414
CvSVMKernelRow* rows;
415
416
int alpha_count;
417
418
double* G;
419
double* alpha;
420
421
// -1 - lower bound, 0 - free, 1 - upper bound
422
schar* alpha_status;
423
424
schar* y;
425
double* b;
426
float* buf[2];
427
double eps;
428
int max_iter;
429
double C[2]; // C[0] == Cn, C[1] == Cp
430
CvSVMKernel* kernel;
431
432
SelectWorkingSet select_working_set_func;
433
CalcRho calc_rho_func;
434
GetRow get_row_func;
435
436
virtual bool select_working_set( int& i, int& j );
437
virtual bool select_working_set_nu_svm( int& i, int& j );
438
virtual void calc_rho( double& rho, double& r );
439
virtual void calc_rho_nu_svm( double& rho, double& r );
440
441
virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
442
virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
443
virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
444
};
445
446
447
struct CvSVMDecisionFunc
448
{
449
double rho;
450
int sv_count;
451
double* alpha;
452
int* sv_index;
453
};
454
455
456
// SVM model
457
class CvSVM : public CvStatModel
458
{
459
public:
460
// SVM type
461
enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
462
463
// SVM kernel type
464
enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3, CHI2=4, INTER=5 };
465
466
// SVM params type
467
enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
468
469
CV_WRAP CvSVM();
470
virtual ~CvSVM();
471
472
CvSVM( const CvMat* trainData, const CvMat* responses,
473
const CvMat* varIdx=0, const CvMat* sampleIdx=0,
474
CvSVMParams params=CvSVMParams() );
475
476
virtual bool train( const CvMat* trainData, const CvMat* responses,
477
const CvMat* varIdx=0, const CvMat* sampleIdx=0,
478
CvSVMParams params=CvSVMParams() );
479
480
virtual bool train_auto( const CvMat* trainData, const CvMat* responses,
481
const CvMat* varIdx, const CvMat* sampleIdx, CvSVMParams params,
482
int kfold = 10,
483
CvParamGrid Cgrid = get_default_grid(CvSVM::C),
484
CvParamGrid gammaGrid = get_default_grid(CvSVM::GAMMA),
485
CvParamGrid pGrid = get_default_grid(CvSVM::P),
486
CvParamGrid nuGrid = get_default_grid(CvSVM::NU),
487
CvParamGrid coeffGrid = get_default_grid(CvSVM::COEF),
488
CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE),
489
bool balanced=false );
490
491
virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
492
virtual float predict( const CvMat* samples, CV_OUT CvMat* results, bool returnDFVal=false ) const;
493
494
CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
495
const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
496
CvSVMParams params=CvSVMParams() );
497
498
CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
499
const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
500
CvSVMParams params=CvSVMParams() );
501
502
CV_WRAP virtual bool train_auto( const cv::Mat& trainData, const cv::Mat& responses,
503
const cv::Mat& varIdx, const cv::Mat& sampleIdx, CvSVMParams params,
504
int k_fold = 10,
505
CvParamGrid Cgrid = CvSVM::get_default_grid(CvSVM::C),
506
CvParamGrid gammaGrid = CvSVM::get_default_grid(CvSVM::GAMMA),
507
CvParamGrid pGrid = CvSVM::get_default_grid(CvSVM::P),
508
CvParamGrid nuGrid = CvSVM::get_default_grid(CvSVM::NU),
509
CvParamGrid coeffGrid = CvSVM::get_default_grid(CvSVM::COEF),
510
CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
511
bool balanced=false);
512
CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
513
CV_WRAP_AS(predict_all) virtual void predict( cv::InputArray samples, cv::OutputArray results ) const;
514
515
CV_WRAP virtual int get_support_vector_count() const;
516
virtual const float* get_support_vector(int i) const;
517
virtual CvSVMParams get_params() const { return params; }
518
CV_WRAP virtual void clear();
519
520
virtual const CvSVMDecisionFunc* get_decision_function() const { return decision_func; }
521
522
static CvParamGrid get_default_grid( int param_id );
523
524
virtual void write( CvFileStorage* storage, const char* name ) const;
525
virtual void read( CvFileStorage* storage, CvFileNode* node );
526
CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
527
528
protected:
529
530
virtual bool set_params( const CvSVMParams& params );
531
virtual bool train1( int sample_count, int var_count, const float** samples,
532
const void* responses, double Cp, double Cn,
533
CvMemStorage* _storage, double* alpha, double& rho );
534
virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
535
const CvMat* responses, CvMemStorage* _storage, double* alpha );
536
virtual void create_kernel();
537
virtual void create_solver();
538
539
virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const;
540
541
virtual void write_params( CvFileStorage* fs ) const;
542
virtual void read_params( CvFileStorage* fs, CvFileNode* node );
543
544
void optimize_linear_svm();
545
546
CvSVMParams params;
547
CvMat* class_labels;
548
int var_all;
549
float** sv;
550
int sv_total;
551
CvMat* var_idx;
552
CvMat* class_weights;
553
CvSVMDecisionFunc* decision_func;
554
CvMemStorage* storage;
555
556
CvSVMSolver* solver;
557
CvSVMKernel* kernel;
558
559
private:
560
CvSVM(const CvSVM&);
561
CvSVM& operator = (const CvSVM&);
562
};
563
564
/****************************************************************************************\
565
* Decision Tree *
566
\****************************************************************************************/\
567
struct CvPair16u32s
568
{
569
unsigned short* u;
570
int* i;
571
};
572
573
574
#define CV_DTREE_CAT_DIR(idx,subset) \
575
(2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
576
577
struct CvDTreeSplit
578
{
579
int var_idx;
580
int condensed_idx;
581
int inversed;
582
float quality;
583
CvDTreeSplit* next;
584
union
585
{
586
int subset[2];
587
struct
588
{
589
float c;
590
int split_point;
591
}
592
ord;
593
};
594
};
595
596
struct CvDTreeNode
597
{
598
int class_idx;
599
int Tn;
600
double value;
601
602
CvDTreeNode* parent;
603
CvDTreeNode* left;
604
CvDTreeNode* right;
605
606
CvDTreeSplit* split;
607
608
int sample_count;
609
int depth;
610
int* num_valid;
611
int offset;
612
int buf_idx;
613
double maxlr;
614
615
// global pruning data
616
int complexity;
617
double alpha;
618
double node_risk, tree_risk, tree_error;
619
620
// cross-validation pruning data
621
int* cv_Tn;
622
double* cv_node_risk;
623
double* cv_node_error;
624
625
int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
626
void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
627
};
628
629
630
struct CvDTreeParams
631
{
632
CV_PROP_RW int max_categories;
633
CV_PROP_RW int max_depth;
634
CV_PROP_RW int min_sample_count;
635
CV_PROP_RW int cv_folds;
636
CV_PROP_RW bool use_surrogates;
637
CV_PROP_RW bool use_1se_rule;
638
CV_PROP_RW bool truncate_pruned_tree;
639
CV_PROP_RW float regression_accuracy;
640
const float* priors;
641
642
CvDTreeParams();
643
CvDTreeParams( int max_depth, int min_sample_count,
644
float regression_accuracy, bool use_surrogates,
645
int max_categories, int cv_folds,
646
bool use_1se_rule, bool truncate_pruned_tree,
647
const float* priors );
648
};
649
650
651
struct CvDTreeTrainData
652
{
653
CvDTreeTrainData();
654
CvDTreeTrainData( const CvMat* trainData, int tflag,
655
const CvMat* responses, const CvMat* varIdx=0,
656
const CvMat* sampleIdx=0, const CvMat* varType=0,
657
const CvMat* missingDataMask=0,
658
const CvDTreeParams& params=CvDTreeParams(),
659
bool _shared=false, bool _add_labels=false );
660
virtual ~CvDTreeTrainData();
661
662
virtual void set_data( const CvMat* trainData, int tflag,
663
const CvMat* responses, const CvMat* varIdx=0,
664
const CvMat* sampleIdx=0, const CvMat* varType=0,
665
const CvMat* missingDataMask=0,
666
const CvDTreeParams& params=CvDTreeParams(),
667
bool _shared=false, bool _add_labels=false,
668
bool _update_data=false );
669
virtual void do_responses_copy();
670
671
virtual void get_vectors( const CvMat* _subsample_idx,
672
float* values, uchar* missing, float* responses, bool get_class_idx=false );
673
674
virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
675
676
virtual void write_params( CvFileStorage* fs ) const;
677
virtual void read_params( CvFileStorage* fs, CvFileNode* node );
678
679
// release all the data
680
virtual void clear();
681
682
int get_num_classes() const;
683
int get_var_type(int vi) const;
684
int get_work_var_count() const {return work_var_count;}
685
686
virtual const float* get_ord_responses( CvDTreeNode* n, float* values_buf, int* sample_indices_buf );
687
virtual const int* get_class_labels( CvDTreeNode* n, int* labels_buf );
688
virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
689
virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
690
virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
691
virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
692
const float** ord_values, const int** sorted_indices, int* sample_indices_buf );
693
virtual int get_child_buf_idx( CvDTreeNode* n );
694
695
////////////////////////////////////
696
697
virtual bool set_params( const CvDTreeParams& params );
698
virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
699
int storage_idx, int offset );
700
701
virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
702
int split_point, int inversed, float quality );
703
virtual CvDTreeSplit* new_split_cat( int vi, float quality );
704
virtual void free_node_data( CvDTreeNode* node );
705
virtual void free_train_data();
706
virtual void free_node( CvDTreeNode* node );
707
708
int sample_count, var_all, var_count, max_c_count;
709
int ord_var_count, cat_var_count, work_var_count;
710
bool have_labels, have_priors;
711
bool is_classifier;
712
int tflag;
713
714
const CvMat* train_data;
715
const CvMat* responses;
716
CvMat* responses_copy; // used in Boosting
717
718
int buf_count, buf_size; // buf_size is obsolete, please do not use it, use expression ((int64)buf->rows * (int64)buf->cols / buf_count) instead
719
bool shared;
720
int is_buf_16u;
721
722
CvMat* cat_count;
723
CvMat* cat_ofs;
724
CvMat* cat_map;
725
726
CvMat* counts;
727
CvMat* buf;
728
inline size_t get_length_subbuf() const
729
{
730
size_t res = (size_t)(work_var_count + 1) * (size_t)sample_count;
731
return res;
732
}
733
734
CvMat* direction;
735
CvMat* split_buf;
736
737
CvMat* var_idx;
738
CvMat* var_type; // i-th element =
739
// k<0 - ordered
740
// k>=0 - categorical, see k-th element of cat_* arrays
741
CvMat* priors;
742
CvMat* priors_mult;
743
744
CvDTreeParams params;
745
746
CvMemStorage* tree_storage;
747
CvMemStorage* temp_storage;
748
749
CvDTreeNode* data_root;
750
751
CvSet* node_heap;
752
CvSet* split_heap;
753
CvSet* cv_heap;
754
CvSet* nv_heap;
755
756
cv::RNG* rng;
757
};
758
759
class CvDTree;
760
class CvForestTree;
761
762
namespace cv
763
{
764
struct DTreeBestSplitFinder;
765
struct ForestTreeBestSplitFinder;
766
}
767
768
class CvDTree : public CvStatModel
769
{
770
public:
771
CV_WRAP CvDTree();
772
virtual ~CvDTree();
773
774
virtual bool train( const CvMat* trainData, int tflag,
775
const CvMat* responses, const CvMat* varIdx=0,
776
const CvMat* sampleIdx=0, const CvMat* varType=0,
777
const CvMat* missingDataMask=0,
778
CvDTreeParams params=CvDTreeParams() );
779
780
virtual bool train( CvMLData* trainData, CvDTreeParams params=CvDTreeParams() );
781
782
// type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
783
virtual float calc_error( CvMLData* trainData, int type, std::vector<float> *resp = 0 );
784
785
virtual bool train( CvDTreeTrainData* trainData, const CvMat* subsampleIdx );
786
787
virtual CvDTreeNode* predict( const CvMat* sample, const CvMat* missingDataMask=0,
788
bool preprocessedInput=false ) const;
789
790
CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
791
const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
792
const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
793
const cv::Mat& missingDataMask=cv::Mat(),
794
CvDTreeParams params=CvDTreeParams() );
795
796
CV_WRAP virtual CvDTreeNode* predict( const cv::Mat& sample, const cv::Mat& missingDataMask=cv::Mat(),
797
bool preprocessedInput=false ) const;
798
CV_WRAP virtual cv::Mat getVarImportance();
799
800
virtual const CvMat* get_var_importance();
801
CV_WRAP virtual void clear();
802
803
virtual void read( CvFileStorage* fs, CvFileNode* node );
804
virtual void write( CvFileStorage* fs, const char* name ) const;
805
806
// special read & write methods for trees in the tree ensembles
807
virtual void read( CvFileStorage* fs, CvFileNode* node,
808
CvDTreeTrainData* data );
809
virtual void write( CvFileStorage* fs ) const;
810
811
const CvDTreeNode* get_root() const;
812
int get_pruned_tree_idx() const;
813
CvDTreeTrainData* get_data();
814
815
protected:
816
friend struct cv::DTreeBestSplitFinder;
817
818
virtual bool do_train( const CvMat* _subsample_idx );
819
820
virtual void try_split_node( CvDTreeNode* n );
821
virtual void split_node_data( CvDTreeNode* n );
822
virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
823
virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
824
float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
825
virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
826
float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
827
virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
828
float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
829
virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
830
float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
831
virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
832
virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
833
virtual double calc_node_dir( CvDTreeNode* node );
834
virtual void complete_node_dir( CvDTreeNode* node );
835
virtual void cluster_categories( const int* vectors, int vector_count,
836
int var_count, int* sums, int k, int* cluster_labels );
837
838
virtual void calc_node_value( CvDTreeNode* node );
839
840
virtual void prune_cv();
841
virtual double update_tree_rnc( int T, int fold );
842
virtual int cut_tree( int T, int fold, double min_alpha );
843
virtual void free_prune_data(bool cut_tree);
844
virtual void free_tree();
845
846
virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const;
847
virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const;
848
virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
849
virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
850
virtual void write_tree_nodes( CvFileStorage* fs ) const;
851
virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
852
853
CvDTreeNode* root;
854
CvMat* var_importance;
855
CvDTreeTrainData* data;
856
CvMat train_data_hdr, responses_hdr;
857
cv::Mat train_data_mat, responses_mat;
858
859
public:
860
int pruned_tree_idx;
861
};
862
863
864
/****************************************************************************************\
865
* Random Trees Classifier *
866
\****************************************************************************************/
867
868
class CvRTrees;
869
870
class CvForestTree: public CvDTree
871
{
872
public:
873
CvForestTree();
874
virtual ~CvForestTree();
875
876
virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx, CvRTrees* forest );
877
878
virtual int get_var_count() const {return data ? data->var_count : 0;}
879
virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
880
881
/* dummy methods to avoid warnings: BEGIN */
882
virtual bool train( const CvMat* trainData, int tflag,
883
const CvMat* responses, const CvMat* varIdx=0,
884
const CvMat* sampleIdx=0, const CvMat* varType=0,
885
const CvMat* missingDataMask=0,
886
CvDTreeParams params=CvDTreeParams() );
887
888
virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
889
virtual void read( CvFileStorage* fs, CvFileNode* node );
890
virtual void read( CvFileStorage* fs, CvFileNode* node,
891
CvDTreeTrainData* data );
892
/* dummy methods to avoid warnings: END */
893
894
protected:
895
friend struct cv::ForestTreeBestSplitFinder;
896
897
virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
898
CvRTrees* forest;
899
};
900
901
902
struct CvRTParams : public CvDTreeParams
903
{
904
//Parameters for the forest
905
CV_PROP_RW bool calc_var_importance; // true <=> RF processes variable importance
906
CV_PROP_RW int nactive_vars;
907
CV_PROP_RW CvTermCriteria term_crit;
908
909
CvRTParams();
910
CvRTParams( int max_depth, int min_sample_count,
911
float regression_accuracy, bool use_surrogates,
912
int max_categories, const float* priors, bool calc_var_importance,
913
int nactive_vars, int max_num_of_trees_in_the_forest,
914
float forest_accuracy, int termcrit_type );
915
};
916
917
918
class CvRTrees : public CvStatModel
919
{
920
public:
921
CV_WRAP CvRTrees();
922
virtual ~CvRTrees();
923
virtual bool train( const CvMat* trainData, int tflag,
924
const CvMat* responses, const CvMat* varIdx=0,
925
const CvMat* sampleIdx=0, const CvMat* varType=0,
926
const CvMat* missingDataMask=0,
927
CvRTParams params=CvRTParams() );
928
929
virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
930
virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
931
virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;
932
933
CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
934
const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
935
const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
936
const cv::Mat& missingDataMask=cv::Mat(),
937
CvRTParams params=CvRTParams() );
938
CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
939
CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
940
CV_WRAP virtual cv::Mat getVarImportance();
941
942
CV_WRAP virtual void clear();
943
944
virtual const CvMat* get_var_importance();
945
virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
946
const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
947
948
virtual float calc_error( CvMLData* data, int type , std::vector<float>* resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
949
950
virtual float get_train_error();
951
952
virtual void read( CvFileStorage* fs, CvFileNode* node );
953
virtual void write( CvFileStorage* fs, const char* name ) const;
954
955
CvMat* get_active_var_mask();
956
CvRNG* get_rng();
957
958
int get_tree_count() const;
959
CvForestTree* get_tree(int i) const;
960
961
protected:
962
virtual cv::String getName() const;
963
964
virtual bool grow_forest( const CvTermCriteria term_crit );
965
966
// array of the trees of the forest
967
CvForestTree** trees;
968
CvDTreeTrainData* data;
969
CvMat train_data_hdr, responses_hdr;
970
cv::Mat train_data_mat, responses_mat;
971
int ntrees;
972
int nclasses;
973
double oob_error;
974
CvMat* var_importance;
975
int nsamples;
976
977
cv::RNG* rng;
978
CvMat* active_var_mask;
979
};
980
981
/****************************************************************************************\
982
* Extremely randomized trees Classifier *
983
\****************************************************************************************/
984
struct CvERTreeTrainData : public CvDTreeTrainData
985
{
986
virtual void set_data( const CvMat* trainData, int tflag,
987
const CvMat* responses, const CvMat* varIdx=0,
988
const CvMat* sampleIdx=0, const CvMat* varType=0,
989
const CvMat* missingDataMask=0,
990
const CvDTreeParams& params=CvDTreeParams(),
991
bool _shared=false, bool _add_labels=false,
992
bool _update_data=false );
993
virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
994
const float** ord_values, const int** missing, int* sample_buf = 0 );
995
virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
996
virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
997
virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
998
virtual void get_vectors( const CvMat* _subsample_idx, float* values, uchar* missing,
999
float* responses, bool get_class_idx=false );
1000
virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
1001
const CvMat* missing_mask;
1002
};
1003
1004
class CvForestERTree : public CvForestTree
1005
{
1006
protected:
1007
virtual double calc_node_dir( CvDTreeNode* node );
1008
virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
1009
float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1010
virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
1011
float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1012
virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
1013
float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1014
virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
1015
float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1016
virtual void split_node_data( CvDTreeNode* n );
1017
};
1018
1019
class CvERTrees : public CvRTrees
1020
{
1021
public:
1022
CV_WRAP CvERTrees();
1023
virtual ~CvERTrees();
1024
virtual bool train( const CvMat* trainData, int tflag,
1025
const CvMat* responses, const CvMat* varIdx=0,
1026
const CvMat* sampleIdx=0, const CvMat* varType=0,
1027
const CvMat* missingDataMask=0,
1028
CvRTParams params=CvRTParams());
1029
CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1030
const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1031
const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1032
const cv::Mat& missingDataMask=cv::Mat(),
1033
CvRTParams params=CvRTParams());
1034
virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
1035
protected:
1036
virtual cv::String getName() const;
1037
virtual bool grow_forest( const CvTermCriteria term_crit );
1038
};
1039
1040
1041
/****************************************************************************************\
1042
* Boosted tree classifier *
1043
\****************************************************************************************/
1044
1045
struct CvBoostParams : public CvDTreeParams
1046
{
1047
CV_PROP_RW int boost_type;
1048
CV_PROP_RW int weak_count;
1049
CV_PROP_RW int split_criteria;
1050
CV_PROP_RW double weight_trim_rate;
1051
1052
CvBoostParams();
1053
CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
1054
int max_depth, bool use_surrogates, const float* priors );
1055
};
1056
1057
1058
class CvBoost;
1059
1060
class CvBoostTree: public CvDTree
1061
{
1062
public:
1063
CvBoostTree();
1064
virtual ~CvBoostTree();
1065
1066
virtual bool train( CvDTreeTrainData* trainData,
1067
const CvMat* subsample_idx, CvBoost* ensemble );
1068
1069
virtual void scale( double s );
1070
virtual void read( CvFileStorage* fs, CvFileNode* node,
1071
CvBoost* ensemble, CvDTreeTrainData* _data );
1072
virtual void clear();
1073
1074
/* dummy methods to avoid warnings: BEGIN */
1075
virtual bool train( const CvMat* trainData, int tflag,
1076
const CvMat* responses, const CvMat* varIdx=0,
1077
const CvMat* sampleIdx=0, const CvMat* varType=0,
1078
const CvMat* missingDataMask=0,
1079
CvDTreeParams params=CvDTreeParams() );
1080
virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
1081
1082
virtual void read( CvFileStorage* fs, CvFileNode* node );
1083
virtual void read( CvFileStorage* fs, CvFileNode* node,
1084
CvDTreeTrainData* data );
1085
/* dummy methods to avoid warnings: END */
1086
1087
protected:
1088
1089
virtual void try_split_node( CvDTreeNode* n );
1090
virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
1091
virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
1092
virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
1093
float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1094
virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
1095
float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1096
virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
1097
float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1098
virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
1099
float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1100
virtual void calc_node_value( CvDTreeNode* n );
1101
virtual double calc_node_dir( CvDTreeNode* n );
1102
1103
CvBoost* ensemble;
1104
};
1105
1106
1107
class CvBoost : public CvStatModel
1108
{
1109
public:
1110
// Boosting type
1111
enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
1112
1113
// Splitting criteria
1114
enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
1115
1116
CV_WRAP CvBoost();
1117
virtual ~CvBoost();
1118
1119
CvBoost( const CvMat* trainData, int tflag,
1120
const CvMat* responses, const CvMat* varIdx=0,
1121
const CvMat* sampleIdx=0, const CvMat* varType=0,
1122
const CvMat* missingDataMask=0,
1123
CvBoostParams params=CvBoostParams() );
1124
1125
virtual bool train( const CvMat* trainData, int tflag,
1126
const CvMat* responses, const CvMat* varIdx=0,
1127
const CvMat* sampleIdx=0, const CvMat* varType=0,
1128
const CvMat* missingDataMask=0,
1129
CvBoostParams params=CvBoostParams(),
1130
bool update=false );
1131
1132
virtual bool train( CvMLData* data,
1133
CvBoostParams params=CvBoostParams(),
1134
bool update=false );
1135
1136
virtual float predict( const CvMat* sample, const CvMat* missing=0,
1137
CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
1138
bool raw_mode=false, bool return_sum=false ) const;
1139
1140
CV_WRAP CvBoost( const cv::Mat& trainData, int tflag,
1141
const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1142
const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1143
const cv::Mat& missingDataMask=cv::Mat(),
1144
CvBoostParams params=CvBoostParams() );
1145
1146
CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1147
const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1148
const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1149
const cv::Mat& missingDataMask=cv::Mat(),
1150
CvBoostParams params=CvBoostParams(),
1151
bool update=false );
1152
1153
CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
1154
const cv::Range& slice=cv::Range::all(), bool rawMode=false,
1155
bool returnSum=false ) const;
1156
1157
virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
1158
1159
CV_WRAP virtual void prune( CvSlice slice );
1160
1161
CV_WRAP virtual void clear();
1162
1163
virtual void write( CvFileStorage* storage, const char* name ) const;
1164
virtual void read( CvFileStorage* storage, CvFileNode* node );
1165
virtual const CvMat* get_active_vars(bool absolute_idx=true);
1166
1167
CvSeq* get_weak_predictors();
1168
1169
CvMat* get_weights();
1170
CvMat* get_subtree_weights();
1171
CvMat* get_weak_response();
1172
const CvBoostParams& get_params() const;
1173
const CvDTreeTrainData* get_data() const;
1174
1175
protected:
1176
1177
virtual bool set_params( const CvBoostParams& params );
1178
virtual void update_weights( CvBoostTree* tree );
1179
virtual void trim_weights();
1180
virtual void write_params( CvFileStorage* fs ) const;
1181
virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1182
1183
virtual void initialize_weights(double (&p)[2]);
1184
1185
CvDTreeTrainData* data;
1186
CvMat train_data_hdr, responses_hdr;
1187
cv::Mat train_data_mat, responses_mat;
1188
CvBoostParams params;
1189
CvSeq* weak;
1190
1191
CvMat* active_vars;
1192
CvMat* active_vars_abs;
1193
bool have_active_cat_vars;
1194
1195
CvMat* orig_response;
1196
CvMat* sum_response;
1197
CvMat* weak_eval;
1198
CvMat* subsample_mask;
1199
CvMat* weights;
1200
CvMat* subtree_weights;
1201
bool have_subsample;
1202
};
1203
1204
1205
/****************************************************************************************\
1206
* Gradient Boosted Trees *
1207
\****************************************************************************************/
1208
1209
// DataType: STRUCT CvGBTreesParams
1210
// Parameters of GBT (Gradient Boosted trees model), including single
1211
// tree settings and ensemble parameters.
1212
//
1213
// weak_count - count of trees in the ensemble
1214
// loss_function_type - loss function used for ensemble training
1215
// subsample_portion - portion of whole training set used for
1216
// every single tree training.
1217
// subsample_portion value is in (0.0, 1.0].
1218
// subsample_portion == 1.0 when whole dataset is
1219
// used on each step. Count of sample used on each
1220
// step is computed as
1221
// int(total_samples_count * subsample_portion).
1222
// shrinkage - regularization parameter.
1223
// Each tree prediction is multiplied on shrinkage value.
1224
1225
1226
struct CvGBTreesParams : public CvDTreeParams
1227
{
1228
CV_PROP_RW int weak_count;
1229
CV_PROP_RW int loss_function_type;
1230
CV_PROP_RW float subsample_portion;
1231
CV_PROP_RW float shrinkage;
1232
1233
CvGBTreesParams();
1234
CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage,
1235
float subsample_portion, int max_depth, bool use_surrogates );
1236
};
1237
1238
// DataType: CLASS CvGBTrees
1239
// Gradient Boosting Trees (GBT) algorithm implementation.
1240
//
1241
// data - training dataset
1242
// params - parameters of the CvGBTrees
1243
// weak - array[0..(class_count-1)] of CvSeq
1244
// for storing tree ensembles
1245
// orig_response - original responses of the training set samples
1246
// sum_response - predictions of the current model on the training dataset.
1247
// this matrix is updated on every iteration.
1248
// sum_response_tmp - predictions of the model on the training set on the next
1249
// step. On every iteration values of sum_responses_tmp are
1250
// computed via sum_responses values. When the current
1251
// step is complete sum_response values become equal to
1252
// sum_responses_tmp.
1253
// sampleIdx - indices of samples used for training the ensemble.
1254
// CvGBTrees training procedure takes a set of samples
1255
// (train_data) and a set of responses (responses).
1256
// Only pairs (train_data[i], responses[i]), where i is
1257
// in sample_idx are used for training the ensemble.
1258
// subsample_train - indices of samples used for training a single decision
1259
// tree on the current step. This indices are countered
1260
// relatively to the sample_idx, so that pairs
1261
// (train_data[sample_idx[i]], responses[sample_idx[i]])
1262
// are used for training a decision tree.
1263
// Training set is randomly splited
1264
// in two parts (subsample_train and subsample_test)
1265
// on every iteration accordingly to the portion parameter.
1266
// subsample_test - relative indices of samples from the training set,
1267
// which are not used for training a tree on the current
1268
// step.
1269
// missing - mask of the missing values in the training set. This
1270
// matrix has the same size as train_data. 1 - missing
1271
// value, 0 - not a missing value.
1272
// class_labels - output class labels map.
1273
// rng - random number generator. Used for splitting the
1274
// training set.
1275
// class_count - count of output classes.
1276
// class_count == 1 in the case of regression,
1277
// and > 1 in the case of classification.
1278
// delta - Huber loss function parameter.
1279
// base_value - start point of the gradient descent procedure.
1280
// model prediction is
1281
// f(x) = f_0 + sum_{i=1..weak_count-1}(f_i(x)), where
1282
// f_0 is the base value.
1283
1284
1285
1286
class CvGBTrees : public CvStatModel
1287
{
1288
public:
1289
1290
/*
1291
// DataType: ENUM
1292
// Loss functions implemented in CvGBTrees.
1293
//
1294
// SQUARED_LOSS
1295
// problem: regression
1296
// loss = (x - x')^2
1297
//
1298
// ABSOLUTE_LOSS
1299
// problem: regression
1300
// loss = abs(x - x')
1301
//
1302
// HUBER_LOSS
1303
// problem: regression
1304
// loss = delta*( abs(x - x') - delta/2), if abs(x - x') > delta
1305
// 1/2*(x - x')^2, if abs(x - x') <= delta,
1306
// where delta is the alpha-quantile of pseudo responses from
1307
// the training set.
1308
//
1309
// DEVIANCE_LOSS
1310
// problem: classification
1311
//
1312
*/
1313
enum {SQUARED_LOSS=0, ABSOLUTE_LOSS, HUBER_LOSS=3, DEVIANCE_LOSS};
1314
1315
1316
/*
1317
// Default constructor. Creates a model only (without training).
1318
// Should be followed by one form of the train(...) function.
1319
//
1320
// API
1321
// CvGBTrees();
1322
1323
// INPUT
1324
// OUTPUT
1325
// RESULT
1326
*/
1327
CV_WRAP CvGBTrees();
1328
1329
1330
/*
1331
// Full form constructor. Creates a gradient boosting model and does the
1332
// train.
1333
//
1334
// API
1335
// CvGBTrees( const CvMat* trainData, int tflag,
1336
const CvMat* responses, const CvMat* varIdx=0,
1337
const CvMat* sampleIdx=0, const CvMat* varType=0,
1338
const CvMat* missingDataMask=0,
1339
CvGBTreesParams params=CvGBTreesParams() );
1340
1341
// INPUT
1342
// trainData - a set of input feature vectors.
1343
// size of matrix is
1344
// <count of samples> x <variables count>
1345
// or <variables count> x <count of samples>
1346
// depending on the tflag parameter.
1347
// matrix values are float.
1348
// tflag - a flag showing how do samples stored in the
1349
// trainData matrix row by row (tflag=CV_ROW_SAMPLE)
1350
// or column by column (tflag=CV_COL_SAMPLE).
1351
// responses - a vector of responses corresponding to the samples
1352
// in trainData.
1353
// varIdx - indices of used variables. zero value means that all
1354
// variables are active.
1355
// sampleIdx - indices of used samples. zero value means that all
1356
// samples from trainData are in the training set.
1357
// varType - vector of <variables count> length. gives every
1358
// variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
1359
// varType = 0 means all variables are numerical.
1360
// missingDataMask - a mask of misiing values in trainData.
1361
// missingDataMask = 0 means that there are no missing
1362
// values.
1363
// params - parameters of GTB algorithm.
1364
// OUTPUT
1365
// RESULT
1366
*/
1367
CvGBTrees( const CvMat* trainData, int tflag,
1368
const CvMat* responses, const CvMat* varIdx=0,
1369
const CvMat* sampleIdx=0, const CvMat* varType=0,
1370
const CvMat* missingDataMask=0,
1371
CvGBTreesParams params=CvGBTreesParams() );
1372
1373
1374
/*
1375
// Destructor.
1376
*/
1377
virtual ~CvGBTrees();
1378
1379
1380
/*
1381
// Gradient tree boosting model training
1382
//
1383
// API
1384
// virtual bool train( const CvMat* trainData, int tflag,
1385
const CvMat* responses, const CvMat* varIdx=0,
1386
const CvMat* sampleIdx=0, const CvMat* varType=0,
1387
const CvMat* missingDataMask=0,
1388
CvGBTreesParams params=CvGBTreesParams(),
1389
bool update=false );
1390
1391
// INPUT
1392
// trainData - a set of input feature vectors.
1393
// size of matrix is
1394
// <count of samples> x <variables count>
1395
// or <variables count> x <count of samples>
1396
// depending on the tflag parameter.
1397
// matrix values are float.
1398
// tflag - a flag showing how do samples stored in the
1399
// trainData matrix row by row (tflag=CV_ROW_SAMPLE)
1400
// or column by column (tflag=CV_COL_SAMPLE).
1401
// responses - a vector of responses corresponding to the samples
1402
// in trainData.
1403
// varIdx - indices of used variables. zero value means that all
1404
// variables are active.
1405
// sampleIdx - indices of used samples. zero value means that all
1406
// samples from trainData are in the training set.
1407
// varType - vector of <variables count> length. gives every
1408
// variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
1409
// varType = 0 means all variables are numerical.
1410
// missingDataMask - a mask of misiing values in trainData.
1411
// missingDataMask = 0 means that there are no missing
1412
// values.
1413
// params - parameters of GTB algorithm.
1414
// update - is not supported now. (!)
1415
// OUTPUT
1416
// RESULT
1417
// Error state.
1418
*/
1419
virtual bool train( const CvMat* trainData, int tflag,
1420
const CvMat* responses, const CvMat* varIdx=0,
1421
const CvMat* sampleIdx=0, const CvMat* varType=0,
1422
const CvMat* missingDataMask=0,
1423
CvGBTreesParams params=CvGBTreesParams(),
1424
bool update=false );
1425
1426
1427
/*
1428
// Gradient tree boosting model training
1429
//
1430
// API
1431
// virtual bool train( CvMLData* data,
1432
CvGBTreesParams params=CvGBTreesParams(),
1433
bool update=false ) {return false;}
1434
1435
// INPUT
1436
// data - training set.
1437
// params - parameters of GTB algorithm.
1438
// update - is not supported now. (!)
1439
// OUTPUT
1440
// RESULT
1441
// Error state.
1442
*/
1443
virtual bool train( CvMLData* data,
1444
CvGBTreesParams params=CvGBTreesParams(),
1445
bool update=false );
1446
1447
1448
/*
1449
// Response value prediction
1450
//
1451
// API
1452
// virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
1453
CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
1454
int k=-1 ) const;
1455
1456
// INPUT
1457
// sample - input sample of the same type as in the training set.
1458
// missing - missing values mask. missing=0 if there are no
1459
// missing values in sample vector.
1460
// weak_responses - predictions of all of the trees.
1461
// not implemented (!)
1462
// slice - part of the ensemble used for prediction.
1463
// slice = CV_WHOLE_SEQ when all trees are used.
1464
// k - number of ensemble used.
1465
// k is in {-1,0,1,..,<count of output classes-1>}.
1466
// in the case of classification problem
1467
// <count of output classes-1> ensembles are built.
1468
// If k = -1 ordinary prediction is the result,
1469
// otherwise function gives the prediction of the
1470
// k-th ensemble only.
1471
// OUTPUT
1472
// RESULT
1473
// Predicted value.
1474
*/
1475
virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
1476
CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
1477
int k=-1 ) const;
1478
1479
/*
1480
// Response value prediction.
1481
// Parallel version (in the case of TBB existence)
1482
//
1483
// API
1484
// virtual float predict( const CvMat* sample, const CvMat* missing=0,
1485
CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
1486
int k=-1 ) const;
1487
1488
// INPUT
1489
// sample - input sample of the same type as in the training set.
1490
// missing - missing values mask. missing=0 if there are no
1491
// missing values in sample vector.
1492
// weak_responses - predictions of all of the trees.
1493
// not implemented (!)
1494
// slice - part of the ensemble used for prediction.
1495
// slice = CV_WHOLE_SEQ when all trees are used.
1496
// k - number of ensemble used.
1497
// k is in {-1,0,1,..,<count of output classes-1>}.
1498
// in the case of classification problem
1499
// <count of output classes-1> ensembles are built.
1500
// If k = -1 ordinary prediction is the result,
1501
// otherwise function gives the prediction of the
1502
// k-th ensemble only.
1503
// OUTPUT
1504
// RESULT
1505
// Predicted value.
1506
*/
1507
virtual float predict( const CvMat* sample, const CvMat* missing=0,
1508
CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
1509
int k=-1 ) const;
1510
1511
/*
1512
// Deletes all the data.
1513
//
1514
// API
1515
// virtual void clear();
1516
1517
// INPUT
1518
// OUTPUT
1519
// delete data, weak, orig_response, sum_response,
1520
// weak_eval, subsample_train, subsample_test,
1521
// sample_idx, missing, lass_labels
1522
// delta = 0.0
1523
// RESULT
1524
*/
1525
CV_WRAP virtual void clear();
1526
1527
/*
1528
// Compute error on the train/test set.
1529
//
1530
// API
1531
// virtual float calc_error( CvMLData* _data, int type,
1532
// std::vector<float> *resp = 0 );
1533
//
1534
// INPUT
1535
// data - dataset
1536
// type - defines which error is to compute: train (CV_TRAIN_ERROR) or
1537
// test (CV_TEST_ERROR).
1538
// OUTPUT
1539
// resp - vector of predictions
1540
// RESULT
1541
// Error value.
1542
*/
1543
virtual float calc_error( CvMLData* _data, int type,
1544
std::vector<float> *resp = 0 );
1545
1546
/*
1547
//
1548
// Write parameters of the gtb model and data. Write learned model.
1549
//
1550
// API
1551
// virtual void write( CvFileStorage* fs, const char* name ) const;
1552
//
1553
// INPUT
1554
// fs - file storage to read parameters from.
1555
// name - model name.
1556
// OUTPUT
1557
// RESULT
1558
*/
1559
virtual void write( CvFileStorage* fs, const char* name ) const;
1560
1561
1562
/*
1563
//
1564
// Read parameters of the gtb model and data. Read learned model.
1565
//
1566
// API
1567
// virtual void read( CvFileStorage* fs, CvFileNode* node );
1568
//
1569
// INPUT
1570
// fs - file storage to read parameters from.
1571
// node - file node.
1572
// OUTPUT
1573
// RESULT
1574
*/
1575
virtual void read( CvFileStorage* fs, CvFileNode* node );
1576
1577
1578
// new-style C++ interface
1579
CV_WRAP CvGBTrees( const cv::Mat& trainData, int tflag,
1580
const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1581
const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1582
const cv::Mat& missingDataMask=cv::Mat(),
1583
CvGBTreesParams params=CvGBTreesParams() );
1584
1585
CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1586
const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1587
const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1588
const cv::Mat& missingDataMask=cv::Mat(),
1589
CvGBTreesParams params=CvGBTreesParams(),
1590
bool update=false );
1591
1592
CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
1593
const cv::Range& slice = cv::Range::all(),
1594
int k=-1 ) const;
1595
1596
protected:
1597
1598
/*
1599
// Compute the gradient vector components.
1600
//
1601
// API
1602
// virtual void find_gradient( const int k = 0);
1603
1604
// INPUT
1605
// k - used for classification problem, determining current
1606
// tree ensemble.
1607
// OUTPUT
1608
// changes components of data->responses
1609
// which correspond to samples used for training
1610
// on the current step.
1611
// RESULT
1612
*/
1613
virtual void find_gradient( const int k = 0);
1614
1615
1616
/*
1617
//
1618
// Change values in tree leaves according to the used loss function.
1619
//
1620
// API
1621
// virtual void change_values(CvDTree* tree, const int k = 0);
1622
//
1623
// INPUT
1624
// tree - decision tree to change.
1625
// k - used for classification problem, determining current
1626
// tree ensemble.
1627
// OUTPUT
1628
// changes 'value' fields of the trees' leaves.
1629
// changes sum_response_tmp.
1630
// RESULT
1631
*/
1632
virtual void change_values(CvDTree* tree, const int k = 0);
1633
1634
1635
/*
1636
//
1637
// Find optimal constant prediction value according to the used loss
1638
// function.
1639
// The goal is to find a constant which gives the minimal summary loss
1640
// on the _Idx samples.
1641
//
1642
// API
1643
// virtual float find_optimal_value( const CvMat* _Idx );
1644
//
1645
// INPUT
1646
// _Idx - indices of the samples from the training set.
1647
// OUTPUT
1648
// RESULT
1649
// optimal constant value.
1650
*/
1651
virtual float find_optimal_value( const CvMat* _Idx );
1652
1653
1654
/*
1655
//
1656
// Randomly split the whole training set in two parts according
1657
// to params.portion.
1658
//
1659
// API
1660
// virtual void do_subsample();
1661
//
1662
// INPUT
1663
// OUTPUT
1664
// subsample_train - indices of samples used for training
1665
// subsample_test - indices of samples used for test
1666
// RESULT
1667
*/
1668
virtual void do_subsample();
1669
1670
1671
/*
1672
//
1673
// Internal recursive function giving an array of subtree tree leaves.
1674
//
1675
// API
1676
// void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
1677
//
1678
// INPUT
1679
// node - current leaf.
1680
// OUTPUT
1681
// count - count of leaves in the subtree.
1682
// leaves - array of pointers to leaves.
1683
// RESULT
1684
*/
1685
void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
1686
1687
1688
/*
1689
//
1690
// Get leaves of the tree.
1691
//
1692
// API
1693
// CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
1694
//
1695
// INPUT
1696
// dtree - decision tree.
1697
// OUTPUT
1698
// len - count of the leaves.
1699
// RESULT
1700
// CvDTreeNode** - array of pointers to leaves.
1701
*/
1702
CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
1703
1704
1705
/*
1706
//
1707
// Is it a regression or a classification.
1708
//
1709
// API
1710
// bool problem_type();
1711
//
1712
// INPUT
1713
// OUTPUT
1714
// RESULT
1715
// false if it is a classification problem,
1716
// true - if regression.
1717
*/
1718
virtual bool problem_type() const;
1719
1720
1721
/*
1722
//
1723
// Write parameters of the gtb model.
1724
//
1725
// API
1726
// virtual void write_params( CvFileStorage* fs ) const;
1727
//
1728
// INPUT
1729
// fs - file storage to write parameters to.
1730
// OUTPUT
1731
// RESULT
1732
*/
1733
virtual void write_params( CvFileStorage* fs ) const;
1734
1735
1736
/*
1737
//
1738
// Read parameters of the gtb model and data.
1739
//
1740
// API
1741
// virtual void read_params( CvFileStorage* fs );
1742
//
1743
// INPUT
1744
// fs - file storage to read parameters from.
1745
// OUTPUT
1746
// params - parameters of the gtb model.
1747
// data - contains information about the structure
1748
// of the data set (count of variables,
1749
// their types, etc.).
1750
// class_labels - output class labels map.
1751
// RESULT
1752
*/
1753
virtual void read_params( CvFileStorage* fs, CvFileNode* fnode );
1754
int get_len(const CvMat* mat) const;
1755
1756
1757
CvDTreeTrainData* data;
1758
CvGBTreesParams params;
1759
1760
CvSeq** weak;
1761
CvMat* orig_response;
1762
CvMat* sum_response;
1763
CvMat* sum_response_tmp;
1764
CvMat* sample_idx;
1765
CvMat* subsample_train;
1766
CvMat* subsample_test;
1767
CvMat* missing;
1768
CvMat* class_labels;
1769
1770
cv::RNG* rng;
1771
1772
int class_count;
1773
float delta;
1774
float base_value;
1775
1776
};
1777
1778
1779
1780
/****************************************************************************************\
1781
* Artificial Neural Networks (ANN) *
1782
\****************************************************************************************/
1783
1784
/////////////////////////////////// Multi-Layer Perceptrons //////////////////////////////
1785
1786
struct CvANN_MLP_TrainParams
1787
{
1788
CvANN_MLP_TrainParams();
1789
CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
1790
double param1, double param2=0 );
1791
~CvANN_MLP_TrainParams();
1792
1793
enum { BACKPROP=0, RPROP=1 };
1794
1795
CV_PROP_RW CvTermCriteria term_crit;
1796
CV_PROP_RW int train_method;
1797
1798
// backpropagation parameters
1799
CV_PROP_RW double bp_dw_scale, bp_moment_scale;
1800
1801
// rprop parameters
1802
CV_PROP_RW double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
1803
};
1804
1805
1806
class CvANN_MLP : public CvStatModel
1807
{
1808
public:
1809
CV_WRAP CvANN_MLP();
1810
CvANN_MLP( const CvMat* layerSizes,
1811
int activateFunc=CvANN_MLP::SIGMOID_SYM,
1812
double fparam1=0, double fparam2=0 );
1813
1814
virtual ~CvANN_MLP();
1815
1816
virtual void create( const CvMat* layerSizes,
1817
int activateFunc=CvANN_MLP::SIGMOID_SYM,
1818
double fparam1=0, double fparam2=0 );
1819
1820
virtual int train( const CvMat* inputs, const CvMat* outputs,
1821
const CvMat* sampleWeights, const CvMat* sampleIdx=0,
1822
CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
1823
int flags=0 );
1824
virtual float predict( const CvMat* inputs, CV_OUT CvMat* outputs ) const;
1825
1826
CV_WRAP CvANN_MLP( const cv::Mat& layerSizes,
1827
int activateFunc=CvANN_MLP::SIGMOID_SYM,
1828
double fparam1=0, double fparam2=0 );
1829
1830
CV_WRAP virtual void create( const cv::Mat& layerSizes,
1831
int activateFunc=CvANN_MLP::SIGMOID_SYM,
1832
double fparam1=0, double fparam2=0 );
1833
1834
CV_WRAP virtual int train( const cv::Mat& inputs, const cv::Mat& outputs,
1835
const cv::Mat& sampleWeights, const cv::Mat& sampleIdx=cv::Mat(),
1836
CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
1837
int flags=0 );
1838
1839
CV_WRAP virtual float predict( const cv::Mat& inputs, CV_OUT cv::Mat& outputs ) const;
1840
1841
CV_WRAP virtual void clear();
1842
1843
// possible activation functions
1844
enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
1845
1846
// available training flags
1847
enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
1848
1849
virtual void read( CvFileStorage* fs, CvFileNode* node );
1850
virtual void write( CvFileStorage* storage, const char* name ) const;
1851
1852
int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
1853
const CvMat* get_layer_sizes() { return layer_sizes; }
1854
double* get_weights(int layer)
1855
{
1856
return layer_sizes && weights &&
1857
(unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
1858
}
1859
1860
virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
1861
1862
protected:
1863
1864
virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
1865
const CvMat* _sample_weights, const CvMat* sampleIdx,
1866
CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
1867
1868
// sequential random backpropagation
1869
virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1870
1871
// RPROP algorithm
1872
virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1873
1874
virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
1875
virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
1876
double _f_param1=0, double _f_param2=0 );
1877
virtual void init_weights();
1878
virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
1879
virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
1880
virtual void calc_input_scale( const CvVectors* vecs, int flags );
1881
virtual void calc_output_scale( const CvVectors* vecs, int flags );
1882
1883
virtual void write_params( CvFileStorage* fs ) const;
1884
virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1885
1886
CvMat* layer_sizes;
1887
CvMat* wbuf;
1888
CvMat* sample_weights;
1889
double** weights;
1890
double f_param1, f_param2;
1891
double min_val, max_val, min_val1, max_val1;
1892
int activ_func;
1893
int max_count, max_buf_sz;
1894
CvANN_MLP_TrainParams params;
1895
cv::RNG* rng;
1896
};
1897
1898
/****************************************************************************************\
1899
* Data *
1900
\****************************************************************************************/
1901
1902
#define CV_COUNT 0
1903
#define CV_PORTION 1
1904
1905
struct CvTrainTestSplit
1906
{
1907
CvTrainTestSplit();
1908
CvTrainTestSplit( int train_sample_count, bool mix = true);
1909
CvTrainTestSplit( float train_sample_portion, bool mix = true);
1910
1911
union
1912
{
1913
int count;
1914
float portion;
1915
} train_sample_part;
1916
int train_sample_part_mode;
1917
1918
bool mix;
1919
};
1920
1921
class CvMLData
1922
{
1923
public:
1924
CvMLData();
1925
virtual ~CvMLData();
1926
1927
// returns:
1928
// 0 - OK
1929
// -1 - file can not be opened or is not correct
1930
int read_csv( const char* filename );
1931
1932
const CvMat* get_values() const;
1933
const CvMat* get_responses();
1934
const CvMat* get_missing() const;
1935
1936
void set_header_lines_number( int n );
1937
int get_header_lines_number() const;
1938
1939
void set_response_idx( int idx ); // old response become predictors, new response_idx = idx
1940
// if idx < 0 there will be no response
1941
int get_response_idx() const;
1942
1943
void set_train_test_split( const CvTrainTestSplit * spl );
1944
const CvMat* get_train_sample_idx() const;
1945
const CvMat* get_test_sample_idx() const;
1946
void mix_train_and_test_idx();
1947
1948
const CvMat* get_var_idx();
1949
void chahge_var_idx( int vi, bool state ); // misspelled (saved for back compitability),
1950
// use change_var_idx
1951
void change_var_idx( int vi, bool state ); // state == true to set vi-variable as predictor
1952
1953
const CvMat* get_var_types();
1954
int get_var_type( int var_idx ) const;
1955
// following 2 methods enable to change vars type
1956
// use these methods to assign CV_VAR_CATEGORICAL type for categorical variable
1957
// with numerical labels; in the other cases var types are correctly determined automatically
1958
void set_var_types( const char* str ); // str examples:
1959
// "ord[0-17],cat[18]", "ord[0,2,4,10-12], cat[1,3,5-9,13,14]",
1960
// "cat", "ord" (all vars are categorical/ordered)
1961
void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL }
1962
1963
void set_delimiter( char ch );
1964
char get_delimiter() const;
1965
1966
void set_miss_ch( char ch );
1967
char get_miss_ch() const;
1968
1969
const std::map<cv::String, int>& get_class_labels_map() const;
1970
1971
protected:
1972
virtual void clear();
1973
1974
void str_to_flt_elem( const char* token, float& flt_elem, int& type);
1975
void free_train_test_idx();
1976
1977
char delimiter;
1978
char miss_ch;
1979
//char flt_separator;
1980
1981
CvMat* values;
1982
CvMat* missing;
1983
CvMat* var_types;
1984
CvMat* var_idx_mask;
1985
1986
CvMat* response_out; // header
1987
CvMat* var_idx_out; // mat
1988
CvMat* var_types_out; // mat
1989
1990
int header_lines_number;
1991
1992
int response_idx;
1993
1994
int train_sample_count;
1995
bool mix;
1996
1997
int total_class_count;
1998
std::map<cv::String, int> class_map;
1999
2000
CvMat* train_sample_idx;
2001
CvMat* test_sample_idx;
2002
int* sample_idx; // data of train_sample_idx and test_sample_idx
2003
2004
cv::RNG* rng;
2005
};
2006
2007
2008
namespace cv
2009
{
2010
2011
typedef CvStatModel StatModel;
2012
typedef CvParamGrid ParamGrid;
2013
typedef CvNormalBayesClassifier NormalBayesClassifier;
2014
typedef CvKNearest KNearest;
2015
typedef CvSVMParams SVMParams;
2016
typedef CvSVMKernel SVMKernel;
2017
typedef CvSVMSolver SVMSolver;
2018
typedef CvSVM SVM;
2019
typedef CvDTreeParams DTreeParams;
2020
typedef CvMLData TrainData;
2021
typedef CvDTree DecisionTree;
2022
typedef CvForestTree ForestTree;
2023
typedef CvRTParams RandomTreeParams;
2024
typedef CvRTrees RandomTrees;
2025
typedef CvERTreeTrainData ERTreeTRainData;
2026
typedef CvForestERTree ERTree;
2027
typedef CvERTrees ERTrees;
2028
typedef CvBoostParams BoostParams;
2029
typedef CvBoostTree BoostTree;
2030
typedef CvBoost Boost;
2031
typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams;
2032
typedef CvANN_MLP NeuralNet_MLP;
2033
typedef CvGBTreesParams GradientBoostingTreeParams;
2034
typedef CvGBTrees GradientBoostingTrees;
2035
2036
template<> struct DefaultDeleter<CvDTreeSplit>{ void operator ()(CvDTreeSplit* obj) const; };
2037
2038
}
2039
2040
#endif // __cplusplus
2041
#endif // OPENCV_OLD_ML_HPP
2042
2043
/* End of file. */
2044
2045