Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/ml/src/precomp.hpp
16337 views
1
/*M///////////////////////////////////////////////////////////////////////////////////////
2
//
3
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4
//
5
// By downloading, copying, installing or using the software you agree to this license.
6
// If you do not agree to this license, do not download, install,
7
// copy or use the software.
8
//
9
//
10
// Intel License Agreement
11
//
12
// Copyright (C) 2000, Intel Corporation, all rights reserved.
13
// Third party copyrights are property of their respective owners.
14
//
15
// Redistribution and use in source and binary forms, with or without modification,
16
// are permitted provided that the following conditions are met:
17
//
18
// * Redistribution's of source code must retain the above copyright notice,
19
// this list of conditions and the following disclaimer.
20
//
21
// * Redistribution's in binary form must reproduce the above copyright notice,
22
// this list of conditions and the following disclaimer in the documentation
23
// and/or other materials provided with the distribution.
24
//
25
// * The name of Intel Corporation may not be used to endorse or promote products
26
// derived from this software without specific prior written permission.
27
//
28
// This software is provided by the copyright holders and contributors "as is" and
29
// any express or implied warranties, including, but not limited to, the implied
30
// warranties of merchantability and fitness for a particular purpose are disclaimed.
31
// In no event shall the Intel Corporation or contributors be liable for any direct,
32
// indirect, incidental, special, exemplary, or consequential damages
33
// (including, but not limited to, procurement of substitute goods or services;
34
// loss of use, data, or profits; or business interruption) however caused
35
// and on any theory of liability, whether in contract, strict liability,
36
// or tort (including negligence or otherwise) arising in any way out of
37
// the use of this software, even if advised of the possibility of such damage.
38
//
39
//M*/
40
41
#ifndef __OPENCV_ML_PRECOMP_HPP__
42
#define __OPENCV_ML_PRECOMP_HPP__
43
44
#include "opencv2/core.hpp"
45
#include "opencv2/ml.hpp"
46
#include "opencv2/core/core_c.h"
47
#include "opencv2/core/utility.hpp"
48
49
#include "opencv2/core/private.hpp"
50
51
#include <assert.h>
52
#include <float.h>
53
#include <limits.h>
54
#include <math.h>
55
#include <stdlib.h>
56
#include <stdio.h>
57
#include <string.h>
58
#include <time.h>
59
#include <vector>
60
61
/****************************************************************************************\
62
* Main struct definitions *
63
\****************************************************************************************/
64
65
/* log(2*PI) */
66
#define CV_LOG2PI (1.8378770664093454835606594728112)
67
68
namespace cv
69
{
70
namespace ml
71
{
72
using std::vector;
73
74
#define CV_DTREE_CAT_DIR(idx,subset) \
75
(2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
76
77
template<typename _Tp> struct cmp_lt_idx
78
{
79
cmp_lt_idx(const _Tp* _arr) : arr(_arr) {}
80
bool operator ()(int a, int b) const { return arr[a] < arr[b]; }
81
const _Tp* arr;
82
};
83
84
template<typename _Tp> struct cmp_lt_ptr
85
{
86
cmp_lt_ptr() {}
87
bool operator ()(const _Tp* a, const _Tp* b) const { return *a < *b; }
88
};
89
90
static inline void setRangeVector(std::vector<int>& vec, int n)
91
{
92
vec.resize(n);
93
for( int i = 0; i < n; i++ )
94
vec[i] = i;
95
}
96
97
static inline void writeTermCrit(FileStorage& fs, const TermCriteria& termCrit)
98
{
99
if( (termCrit.type & TermCriteria::EPS) != 0 )
100
fs << "epsilon" << termCrit.epsilon;
101
if( (termCrit.type & TermCriteria::COUNT) != 0 )
102
fs << "iterations" << termCrit.maxCount;
103
}
104
105
static inline TermCriteria readTermCrit(const FileNode& fn)
106
{
107
TermCriteria termCrit;
108
double epsilon = (double)fn["epsilon"];
109
if( epsilon > 0 )
110
{
111
termCrit.type |= TermCriteria::EPS;
112
termCrit.epsilon = epsilon;
113
}
114
int iters = (int)fn["iterations"];
115
if( iters > 0 )
116
{
117
termCrit.type |= TermCriteria::COUNT;
118
termCrit.maxCount = iters;
119
}
120
return termCrit;
121
}
122
123
struct TreeParams
124
{
125
TreeParams();
126
TreeParams( int maxDepth, int minSampleCount,
127
double regressionAccuracy, bool useSurrogates,
128
int maxCategories, int CVFolds,
129
bool use1SERule, bool truncatePrunedTree,
130
const Mat& priors );
131
132
inline void setMaxCategories(int val)
133
{
134
if( val < 2 )
135
CV_Error( CV_StsOutOfRange, "max_categories should be >= 2" );
136
maxCategories = std::min(val, 15 );
137
}
138
inline void setMaxDepth(int val)
139
{
140
if( val < 0 )
141
CV_Error( CV_StsOutOfRange, "max_depth should be >= 0" );
142
maxDepth = std::min( val, 25 );
143
}
144
inline void setMinSampleCount(int val)
145
{
146
minSampleCount = std::max(val, 1);
147
}
148
inline void setCVFolds(int val)
149
{
150
if( val < 0 )
151
CV_Error( CV_StsOutOfRange,
152
"params.CVFolds should be =0 (the tree is not pruned) "
153
"or n>0 (tree is pruned using n-fold cross-validation)" );
154
if(val > 1)
155
CV_Error( CV_StsNotImplemented,
156
"tree pruning using cross-validation is not implemented."
157
"Set CVFolds to 1");
158
159
if( val == 1 )
160
val = 0;
161
CVFolds = val;
162
}
163
inline void setRegressionAccuracy(float val)
164
{
165
if( val < 0 )
166
CV_Error( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
167
regressionAccuracy = val;
168
}
169
170
inline int getMaxCategories() const { return maxCategories; }
171
inline int getMaxDepth() const { return maxDepth; }
172
inline int getMinSampleCount() const { return minSampleCount; }
173
inline int getCVFolds() const { return CVFolds; }
174
inline float getRegressionAccuracy() const { return regressionAccuracy; }
175
176
inline bool getUseSurrogates() const { return useSurrogates; }
177
inline void setUseSurrogates(bool val) { useSurrogates = val; }
178
inline bool getUse1SERule() const { return use1SERule; }
179
inline void setUse1SERule(bool val) { use1SERule = val; }
180
inline bool getTruncatePrunedTree() const { return truncatePrunedTree; }
181
inline void setTruncatePrunedTree(bool val) { truncatePrunedTree = val; }
182
inline cv::Mat getPriors() const { return priors; }
183
inline void setPriors(const cv::Mat& val) { priors = val; }
184
185
public:
186
bool useSurrogates;
187
bool use1SERule;
188
bool truncatePrunedTree;
189
Mat priors;
190
191
protected:
192
int maxCategories;
193
int maxDepth;
194
int minSampleCount;
195
int CVFolds;
196
float regressionAccuracy;
197
};
198
199
struct RTreeParams
200
{
201
RTreeParams();
202
RTreeParams(bool calcVarImportance, int nactiveVars, TermCriteria termCrit );
203
bool calcVarImportance;
204
int nactiveVars;
205
TermCriteria termCrit;
206
};
207
208
struct BoostTreeParams
209
{
210
BoostTreeParams();
211
BoostTreeParams(int boostType, int weakCount, double weightTrimRate);
212
int boostType;
213
int weakCount;
214
double weightTrimRate;
215
};
216
217
class DTreesImpl : public DTrees
218
{
219
public:
220
struct WNode
221
{
222
WNode()
223
{
224
class_idx = sample_count = depth = complexity = 0;
225
parent = left = right = split = defaultDir = -1;
226
Tn = INT_MAX;
227
value = maxlr = alpha = node_risk = tree_risk = tree_error = 0.;
228
}
229
230
int class_idx;
231
double Tn;
232
double value;
233
234
int parent;
235
int left;
236
int right;
237
int defaultDir;
238
239
int split;
240
241
int sample_count;
242
int depth;
243
double maxlr;
244
245
// global pruning data
246
int complexity;
247
double alpha;
248
double node_risk, tree_risk, tree_error;
249
};
250
251
struct WSplit
252
{
253
WSplit()
254
{
255
varIdx = next = 0;
256
inversed = false;
257
quality = c = 0.f;
258
subsetOfs = -1;
259
}
260
261
int varIdx;
262
bool inversed;
263
float quality;
264
int next;
265
float c;
266
int subsetOfs;
267
};
268
269
struct WorkData
270
{
271
WorkData(const Ptr<TrainData>& _data);
272
273
Ptr<TrainData> data;
274
vector<WNode> wnodes;
275
vector<WSplit> wsplits;
276
vector<int> wsubsets;
277
vector<double> cv_Tn;
278
vector<double> cv_node_risk;
279
vector<double> cv_node_error;
280
vector<int> cv_labels;
281
vector<double> sample_weights;
282
vector<int> cat_responses;
283
vector<double> ord_responses;
284
vector<int> sidx;
285
int maxSubsetSize;
286
};
287
288
inline int getMaxCategories() const CV_OVERRIDE { return params.getMaxCategories(); }
289
inline void setMaxCategories(int val) CV_OVERRIDE { params.setMaxCategories(val); }
290
inline int getMaxDepth() const CV_OVERRIDE { return params.getMaxDepth(); }
291
inline void setMaxDepth(int val) CV_OVERRIDE { params.setMaxDepth(val); }
292
inline int getMinSampleCount() const CV_OVERRIDE { return params.getMinSampleCount(); }
293
inline void setMinSampleCount(int val) CV_OVERRIDE { params.setMinSampleCount(val); }
294
inline int getCVFolds() const CV_OVERRIDE { return params.getCVFolds(); }
295
inline void setCVFolds(int val) CV_OVERRIDE { params.setCVFolds(val); }
296
inline bool getUseSurrogates() const CV_OVERRIDE { return params.getUseSurrogates(); }
297
inline void setUseSurrogates(bool val) CV_OVERRIDE { params.setUseSurrogates(val); }
298
inline bool getUse1SERule() const CV_OVERRIDE { return params.getUse1SERule(); }
299
inline void setUse1SERule(bool val) CV_OVERRIDE { params.setUse1SERule(val); }
300
inline bool getTruncatePrunedTree() const CV_OVERRIDE { return params.getTruncatePrunedTree(); }
301
inline void setTruncatePrunedTree(bool val) CV_OVERRIDE { params.setTruncatePrunedTree(val); }
302
inline float getRegressionAccuracy() const CV_OVERRIDE { return params.getRegressionAccuracy(); }
303
inline void setRegressionAccuracy(float val) CV_OVERRIDE { params.setRegressionAccuracy(val); }
304
inline cv::Mat getPriors() const CV_OVERRIDE { return params.getPriors(); }
305
inline void setPriors(const cv::Mat& val) CV_OVERRIDE { params.setPriors(val); }
306
307
DTreesImpl();
308
virtual ~DTreesImpl() CV_OVERRIDE;
309
virtual void clear() CV_OVERRIDE;
310
311
String getDefaultName() const CV_OVERRIDE { return "opencv_ml_dtree"; }
312
bool isTrained() const CV_OVERRIDE { return !roots.empty(); }
313
bool isClassifier() const CV_OVERRIDE { return _isClassifier; }
314
int getVarCount() const CV_OVERRIDE { return varType.empty() ? 0 : (int)(varType.size() - 1); }
315
int getCatCount(int vi) const { return catOfs[vi][1] - catOfs[vi][0]; }
316
int getSubsetSize(int vi) const { return (getCatCount(vi) + 31)/32; }
317
318
virtual void setDParams(const TreeParams& _params);
319
virtual void startTraining( const Ptr<TrainData>& trainData, int flags );
320
virtual void endTraining();
321
virtual void initCompVarIdx();
322
virtual bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE;
323
324
virtual int addTree( const vector<int>& sidx );
325
virtual int addNodeAndTrySplit( int parent, const vector<int>& sidx );
326
virtual const vector<int>& getActiveVars();
327
virtual int findBestSplit( const vector<int>& _sidx );
328
virtual void calcValue( int nidx, const vector<int>& _sidx );
329
330
virtual WSplit findSplitOrdClass( int vi, const vector<int>& _sidx, double initQuality );
331
332
// simple k-means, slightly modified to take into account the "weight" (L1-norm) of each vector.
333
virtual void clusterCategories( const double* vectors, int n, int m, double* csums, int k, int* labels );
334
virtual WSplit findSplitCatClass( int vi, const vector<int>& _sidx, double initQuality, int* subset );
335
336
virtual WSplit findSplitOrdReg( int vi, const vector<int>& _sidx, double initQuality );
337
virtual WSplit findSplitCatReg( int vi, const vector<int>& _sidx, double initQuality, int* subset );
338
339
virtual int calcDir( int splitidx, const vector<int>& _sidx, vector<int>& _sleft, vector<int>& _sright );
340
virtual int pruneCV( int root );
341
342
virtual double updateTreeRNC( int root, double T, int fold );
343
virtual bool cutTree( int root, double T, int fold, double min_alpha );
344
virtual float predictTrees( const Range& range, const Mat& sample, int flags ) const;
345
virtual float predict( InputArray inputs, OutputArray outputs, int flags ) const CV_OVERRIDE;
346
347
virtual void writeTrainingParams( FileStorage& fs ) const;
348
virtual void writeParams( FileStorage& fs ) const;
349
virtual void writeSplit( FileStorage& fs, int splitidx ) const;
350
virtual void writeNode( FileStorage& fs, int nidx, int depth ) const;
351
virtual void writeTree( FileStorage& fs, int root ) const;
352
virtual void write( FileStorage& fs ) const CV_OVERRIDE;
353
354
virtual void readParams( const FileNode& fn );
355
virtual int readSplit( const FileNode& fn );
356
virtual int readNode( const FileNode& fn );
357
virtual int readTree( const FileNode& fn );
358
virtual void read( const FileNode& fn ) CV_OVERRIDE;
359
360
virtual const std::vector<int>& getRoots() const CV_OVERRIDE { return roots; }
361
virtual const std::vector<Node>& getNodes() const CV_OVERRIDE { return nodes; }
362
virtual const std::vector<Split>& getSplits() const CV_OVERRIDE { return splits; }
363
virtual const std::vector<int>& getSubsets() const CV_OVERRIDE { return subsets; }
364
365
TreeParams params;
366
367
vector<int> varIdx;
368
vector<int> compVarIdx;
369
vector<uchar> varType;
370
vector<Vec2i> catOfs;
371
vector<int> catMap;
372
vector<int> roots;
373
vector<Node> nodes;
374
vector<Split> splits;
375
vector<int> subsets;
376
vector<int> classLabels;
377
vector<float> missingSubst;
378
vector<int> varMapping;
379
bool _isClassifier;
380
381
Ptr<WorkData> w;
382
};
383
384
template <typename T>
385
static inline void readVectorOrMat(const FileNode & node, std::vector<T> & v)
386
{
387
if (node.type() == FileNode::MAP)
388
{
389
Mat m;
390
node >> m;
391
m.copyTo(v);
392
}
393
else if (node.type() == FileNode::SEQ)
394
{
395
node >> v;
396
}
397
}
398
399
}}
400
401
#endif /* __OPENCV_ML_PRECOMP_HPP__ */
402
403