Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/ml/test/test_precomp.hpp
16339 views
1
#ifndef __OPENCV_TEST_PRECOMP_HPP__
2
#define __OPENCV_TEST_PRECOMP_HPP__
3
4
#include "opencv2/ts.hpp"
5
#include "opencv2/ml.hpp"
6
#include "opencv2/core/core_c.h"
7
8
namespace opencv_test {
9
using namespace cv::ml;
10
11
#define CV_NBAYES "nbayes"
12
#define CV_KNEAREST "knearest"
13
#define CV_SVM "svm"
14
#define CV_EM "em"
15
#define CV_ANN "ann"
16
#define CV_DTREE "dtree"
17
#define CV_BOOST "boost"
18
#define CV_RTREES "rtrees"
19
#define CV_ERTREES "ertrees"
20
#define CV_SVMSGD "svmsgd"
21
22
enum { CV_TRAIN_ERROR=0, CV_TEST_ERROR=1 };
23
24
using cv::Ptr;
25
using cv::ml::StatModel;
26
using cv::ml::TrainData;
27
using cv::ml::NormalBayesClassifier;
28
using cv::ml::SVM;
29
using cv::ml::KNearest;
30
using cv::ml::ParamGrid;
31
using cv::ml::ANN_MLP;
32
using cv::ml::DTrees;
33
using cv::ml::Boost;
34
using cv::ml::RTrees;
35
using cv::ml::SVMSGD;
36
37
class CV_MLBaseTest : public cvtest::BaseTest
38
{
39
public:
40
CV_MLBaseTest( const char* _modelName );
41
virtual ~CV_MLBaseTest();
42
protected:
43
virtual int read_params( CvFileStorage* fs );
44
virtual void run( int startFrom );
45
virtual int prepare_test_case( int testCaseIdx );
46
virtual std::string& get_validation_filename();
47
virtual int run_test_case( int testCaseIdx ) = 0;
48
virtual int validate_test_results( int testCaseIdx ) = 0;
49
50
int train( int testCaseIdx );
51
float get_test_error( int testCaseIdx, std::vector<float> *resp = 0 );
52
void save( const char* filename );
53
void load( const char* filename );
54
55
Ptr<TrainData> data;
56
std::string modelName, validationFN;
57
std::vector<std::string> dataSetNames;
58
cv::FileStorage validationFS;
59
60
Ptr<StatModel> model;
61
62
std::map<int, int> cls_map;
63
64
int64 initSeed;
65
};
66
67
class CV_AMLTest : public CV_MLBaseTest
68
{
69
public:
70
CV_AMLTest( const char* _modelName );
71
virtual ~CV_AMLTest() {}
72
protected:
73
virtual int run_test_case( int testCaseIdx );
74
virtual int validate_test_results( int testCaseIdx );
75
};
76
77
class CV_SLMLTest : public CV_MLBaseTest
78
{
79
public:
80
CV_SLMLTest( const char* _modelName );
81
virtual ~CV_SLMLTest() {}
82
protected:
83
virtual int run_test_case( int testCaseIdx );
84
virtual int validate_test_results( int testCaseIdx );
85
86
std::vector<float> test_resps1, test_resps2; // predicted responses for test data
87
std::string fname1, fname2;
88
};
89
90
} // namespace
91
92
#endif
93
94