Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/ml/test/test_mltests2.cpp
16354 views
1
/*M///////////////////////////////////////////////////////////////////////////////////////
2
//
3
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4
//
5
// By downloading, copying, installing or using the software you agree to this license.
6
// If you do not agree to this license, do not download, install,
7
// copy or use the software.
8
//
9
//
10
// Intel License Agreement
11
// For Open Source Computer Vision Library
12
//
13
// Copyright (C) 2000, Intel Corporation, all rights reserved.
14
// Third party copyrights are property of their respective owners.
15
//
16
// Redistribution and use in source and binary forms, with or without modification,
17
// are permitted provided that the following conditions are met:
18
//
19
// * Redistribution's of source code must retain the above copyright notice,
20
// this list of conditions and the following disclaimer.
21
//
22
// * Redistribution's in binary form must reproduce the above copyright notice,
23
// this list of conditions and the following disclaimer in the documentation
24
// and/or other materials provided with the distribution.
25
//
26
// * The name of Intel Corporation may not be used to endorse or promote products
27
// derived from this software without specific prior written permission.
28
//
29
// This software is provided by the copyright holders and contributors "as is" and
30
// any express or implied warranties, including, but not limited to, the implied
31
// warranties of merchantability and fitness for a particular purpose are disclaimed.
32
// In no event shall the Intel Corporation or contributors be liable for any direct,
33
// indirect, incidental, special, exemplary, or consequential damages
34
// (including, but not limited to, procurement of substitute goods or services;
35
// loss of use, data, or profits; or business interruption) however caused
36
// and on any theory of liability, whether in contract, strict liability,
37
// or tort (including negligence or otherwise) arising in any way out of
38
// the use of this software, even if advised of the possibility of such damage.
39
//
40
//M*/
41
42
#include "test_precomp.hpp"
43
44
//#define GENERATE_TESTDATA
45
46
namespace opencv_test { namespace {
47
48
int str_to_svm_type(String& str)
49
{
50
if( !str.compare("C_SVC") )
51
return SVM::C_SVC;
52
if( !str.compare("NU_SVC") )
53
return SVM::NU_SVC;
54
if( !str.compare("ONE_CLASS") )
55
return SVM::ONE_CLASS;
56
if( !str.compare("EPS_SVR") )
57
return SVM::EPS_SVR;
58
if( !str.compare("NU_SVR") )
59
return SVM::NU_SVR;
60
CV_Error( CV_StsBadArg, "incorrect svm type string" );
61
}
62
int str_to_svm_kernel_type( String& str )
63
{
64
if( !str.compare("LINEAR") )
65
return SVM::LINEAR;
66
if( !str.compare("POLY") )
67
return SVM::POLY;
68
if( !str.compare("RBF") )
69
return SVM::RBF;
70
if( !str.compare("SIGMOID") )
71
return SVM::SIGMOID;
72
CV_Error( CV_StsBadArg, "incorrect svm type string" );
73
}
74
75
// 4. em
76
// 5. ann
77
int str_to_ann_train_method( String& str )
78
{
79
if( !str.compare("BACKPROP") )
80
return ANN_MLP::BACKPROP;
81
if (!str.compare("RPROP"))
82
return ANN_MLP::RPROP;
83
if (!str.compare("ANNEAL"))
84
return ANN_MLP::ANNEAL;
85
CV_Error( CV_StsBadArg, "incorrect ann train method string" );
86
}
87
88
#if 0
89
int str_to_ann_activation_function(String& str)
90
{
91
if (!str.compare("IDENTITY"))
92
return ANN_MLP::IDENTITY;
93
if (!str.compare("SIGMOID_SYM"))
94
return ANN_MLP::SIGMOID_SYM;
95
if (!str.compare("GAUSSIAN"))
96
return ANN_MLP::GAUSSIAN;
97
if (!str.compare("RELU"))
98
return ANN_MLP::RELU;
99
if (!str.compare("LEAKYRELU"))
100
return ANN_MLP::LEAKYRELU;
101
CV_Error(CV_StsBadArg, "incorrect ann activation function string");
102
}
103
#endif
104
105
void ann_check_data( Ptr<TrainData> _data )
106
{
107
CV_TRACE_FUNCTION();
108
Mat values = _data->getSamples();
109
Mat var_idx = _data->getVarIdx();
110
int nvars = (int)var_idx.total();
111
if( nvars != 0 && nvars != values.cols )
112
CV_Error( CV_StsBadArg, "var_idx is not supported" );
113
if( !_data->getMissing().empty() )
114
CV_Error( CV_StsBadArg, "missing values are not supported" );
115
}
116
117
// unroll the categorical responses to binary vectors
118
Mat ann_get_new_responses( Ptr<TrainData> _data, map<int, int>& cls_map )
119
{
120
CV_TRACE_FUNCTION();
121
Mat train_sidx = _data->getTrainSampleIdx();
122
int* train_sidx_ptr = train_sidx.ptr<int>();
123
Mat responses = _data->getResponses();
124
int cls_count = 0;
125
// construct cls_map
126
cls_map.clear();
127
int nresponses = (int)responses.total();
128
int si, n = !train_sidx.empty() ? (int)train_sidx.total() : nresponses;
129
130
for( si = 0; si < n; si++ )
131
{
132
int sidx = train_sidx_ptr ? train_sidx_ptr[si] : si;
133
int r = cvRound(responses.at<float>(sidx));
134
CV_DbgAssert( fabs(responses.at<float>(sidx) - r) < FLT_EPSILON );
135
map<int,int>::iterator it = cls_map.find(r);
136
if( it == cls_map.end() )
137
cls_map[r] = cls_count++;
138
}
139
Mat new_responses = Mat::zeros( nresponses, cls_count, CV_32F );
140
for( si = 0; si < n; si++ )
141
{
142
int sidx = train_sidx_ptr ? train_sidx_ptr[si] : si;
143
int r = cvRound(responses.at<float>(sidx));
144
int cidx = cls_map[r];
145
new_responses.at<float>(sidx, cidx) = 1.f;
146
}
147
return new_responses;
148
}
149
150
float ann_calc_error( Ptr<StatModel> ann, Ptr<TrainData> _data, map<int, int>& cls_map, int type, vector<float> *resp_labels )
151
{
152
CV_TRACE_FUNCTION();
153
float err = 0;
154
Mat samples = _data->getSamples();
155
Mat responses = _data->getResponses();
156
Mat sample_idx = (type == CV_TEST_ERROR) ? _data->getTestSampleIdx() : _data->getTrainSampleIdx();
157
int* sidx = !sample_idx.empty() ? sample_idx.ptr<int>() : 0;
158
ann_check_data( _data );
159
int sample_count = (int)sample_idx.total();
160
sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? samples.rows : sample_count;
161
float* pred_resp = 0;
162
vector<float> innresp;
163
if( sample_count > 0 )
164
{
165
if( resp_labels )
166
{
167
resp_labels->resize( sample_count );
168
pred_resp = &((*resp_labels)[0]);
169
}
170
else
171
{
172
innresp.resize( sample_count );
173
pred_resp = &(innresp[0]);
174
}
175
}
176
int cls_count = (int)cls_map.size();
177
Mat output( 1, cls_count, CV_32FC1 );
178
179
for( int i = 0; i < sample_count; i++ )
180
{
181
int si = sidx ? sidx[i] : i;
182
Mat sample = samples.row(si);
183
ann->predict( sample, output );
184
Point best_cls;
185
minMaxLoc(output, 0, 0, 0, &best_cls, 0);
186
int r = cvRound(responses.at<float>(si));
187
CV_DbgAssert( fabs(responses.at<float>(si) - r) < FLT_EPSILON );
188
r = cls_map[r];
189
int d = best_cls.x == r ? 0 : 1;
190
err += d;
191
pred_resp[i] = (float)best_cls.x;
192
}
193
err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
194
return err;
195
}
196
197
TEST(ML_ANN, ActivationFunction)
198
{
199
String folder = string(cvtest::TS::ptr()->get_data_path());
200
String original_path = folder + "waveform.data";
201
String dataname = folder + "waveform";
202
203
Ptr<TrainData> tdata = TrainData::loadFromCSV(original_path, 0);
204
205
ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << original_path;
206
RNG& rng = theRNG();
207
rng.state = 1027401484159173092;
208
tdata->setTrainTestSplit(500);
209
210
vector<int> activationType;
211
activationType.push_back(ml::ANN_MLP::IDENTITY);
212
activationType.push_back(ml::ANN_MLP::SIGMOID_SYM);
213
activationType.push_back(ml::ANN_MLP::GAUSSIAN);
214
activationType.push_back(ml::ANN_MLP::RELU);
215
activationType.push_back(ml::ANN_MLP::LEAKYRELU);
216
vector<String> activationName;
217
activationName.push_back("_identity");
218
activationName.push_back("_sigmoid_sym");
219
activationName.push_back("_gaussian");
220
activationName.push_back("_relu");
221
activationName.push_back("_leakyrelu");
222
for (size_t i = 0; i < activationType.size(); i++)
223
{
224
Ptr<ml::ANN_MLP> x = ml::ANN_MLP::create();
225
Mat_<int> layerSizes(1, 4);
226
layerSizes(0, 0) = tdata->getNVars();
227
layerSizes(0, 1) = 100;
228
layerSizes(0, 2) = 100;
229
layerSizes(0, 3) = tdata->getResponses().cols;
230
x->setLayerSizes(layerSizes);
231
x->setActivationFunction(activationType[i]);
232
x->setTrainMethod(ml::ANN_MLP::RPROP, 0.01, 0.1);
233
x->setTermCriteria(TermCriteria(TermCriteria::COUNT, 300, 0.01));
234
x->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE);
235
ASSERT_TRUE(x->isTrained()) << "Could not train networks with " << activationName[i];
236
#ifdef GENERATE_TESTDATA
237
x->save(dataname + activationName[i] + ".yml");
238
#else
239
Ptr<ml::ANN_MLP> y = Algorithm::load<ANN_MLP>(dataname + activationName[i] + ".yml");
240
ASSERT_TRUE(y) << "Could not load " << dataname + activationName[i] + ".yml";
241
Mat testSamples = tdata->getTestSamples();
242
Mat rx, ry, dst;
243
x->predict(testSamples, rx);
244
y->predict(testSamples, ry);
245
double n = cvtest::norm(rx, ry, NORM_INF);
246
EXPECT_LT(n,FLT_EPSILON) << "Predict are not equal for " << dataname + activationName[i] + ".yml and " << activationName[i];
247
#endif
248
}
249
}
250
251
CV_ENUM(ANN_MLP_METHOD, ANN_MLP::RPROP, ANN_MLP::ANNEAL)
252
253
typedef tuple<ANN_MLP_METHOD, string, int> ML_ANN_METHOD_Params;
254
typedef TestWithParam<ML_ANN_METHOD_Params> ML_ANN_METHOD;
255
256
TEST_P(ML_ANN_METHOD, Test)
257
{
258
int methodType = get<0>(GetParam());
259
string methodName = get<1>(GetParam());
260
int N = get<2>(GetParam());
261
262
String folder = string(cvtest::TS::ptr()->get_data_path());
263
String original_path = folder + "waveform.data";
264
String dataname = folder + "waveform" + '_' + methodName;
265
266
Ptr<TrainData> tdata2 = TrainData::loadFromCSV(original_path, 0);
267
Mat samples = tdata2->getSamples()(Range(0, N), Range::all());
268
Mat responses(N, 3, CV_32FC1, Scalar(0));
269
for (int i = 0; i < N; i++)
270
responses.at<float>(i, static_cast<int>(tdata2->getResponses().at<float>(i, 0))) = 1;
271
Ptr<TrainData> tdata = TrainData::create(samples, ml::ROW_SAMPLE, responses);
272
273
ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << original_path;
274
RNG& rng = theRNG();
275
rng.state = 0;
276
tdata->setTrainTestSplitRatio(0.8);
277
278
Mat testSamples = tdata->getTestSamples();
279
280
#ifdef GENERATE_TESTDATA
281
{
282
Ptr<ml::ANN_MLP> xx = ml::ANN_MLP::create();
283
Mat_<int> layerSizesXX(1, 4);
284
layerSizesXX(0, 0) = tdata->getNVars();
285
layerSizesXX(0, 1) = 30;
286
layerSizesXX(0, 2) = 30;
287
layerSizesXX(0, 3) = tdata->getResponses().cols;
288
xx->setLayerSizes(layerSizesXX);
289
xx->setActivationFunction(ml::ANN_MLP::SIGMOID_SYM);
290
xx->setTrainMethod(ml::ANN_MLP::RPROP);
291
xx->setTermCriteria(TermCriteria(TermCriteria::COUNT, 1, 0.01));
292
xx->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE + ml::ANN_MLP::NO_INPUT_SCALE);
293
FileStorage fs;
294
fs.open(dataname + "_init_weight.yml.gz", FileStorage::WRITE + FileStorage::BASE64);
295
xx->write(fs);
296
fs.release();
297
}
298
#endif
299
{
300
FileStorage fs;
301
fs.open(dataname + "_init_weight.yml.gz", FileStorage::READ);
302
Ptr<ml::ANN_MLP> x = ml::ANN_MLP::create();
303
x->read(fs.root());
304
x->setTrainMethod(methodType);
305
if (methodType == ml::ANN_MLP::ANNEAL)
306
{
307
x->setAnnealEnergyRNG(RNG(CV_BIG_INT(0xffffffff)));
308
x->setAnnealInitialT(12);
309
x->setAnnealFinalT(0.15);
310
x->setAnnealCoolingRatio(0.96);
311
x->setAnnealItePerStep(11);
312
}
313
x->setTermCriteria(TermCriteria(TermCriteria::COUNT, 100, 0.01));
314
x->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE + ml::ANN_MLP::NO_INPUT_SCALE + ml::ANN_MLP::UPDATE_WEIGHTS);
315
ASSERT_TRUE(x->isTrained()) << "Could not train networks with " << methodName;
316
string filename = dataname + ".yml.gz";
317
Mat r_gold;
318
#ifdef GENERATE_TESTDATA
319
x->save(filename);
320
x->predict(testSamples, r_gold);
321
{
322
FileStorage fs_response(dataname + "_response.yml.gz", FileStorage::WRITE + FileStorage::BASE64);
323
fs_response << "response" << r_gold;
324
}
325
#else
326
{
327
FileStorage fs_response(dataname + "_response.yml.gz", FileStorage::READ);
328
fs_response["response"] >> r_gold;
329
}
330
#endif
331
ASSERT_FALSE(r_gold.empty());
332
Ptr<ml::ANN_MLP> y = Algorithm::load<ANN_MLP>(filename);
333
ASSERT_TRUE(y) << "Could not load " << filename;
334
Mat rx, ry;
335
for (int j = 0; j < 4; j++)
336
{
337
rx = x->getWeights(j);
338
ry = y->getWeights(j);
339
double n = cvtest::norm(rx, ry, NORM_INF);
340
EXPECT_LT(n, FLT_EPSILON) << "Weights are not equal for layer: " << j;
341
}
342
x->predict(testSamples, rx);
343
y->predict(testSamples, ry);
344
double n = cvtest::norm(ry, rx, NORM_INF);
345
EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal to result of the saved model";
346
n = cvtest::norm(r_gold, rx, NORM_INF);
347
EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal to 'gold' response";
348
}
349
}
350
351
INSTANTIATE_TEST_CASE_P(/*none*/, ML_ANN_METHOD,
352
testing::Values(
353
make_tuple<ANN_MLP_METHOD, string, int>(ml::ANN_MLP::RPROP, "rprop", 5000),
354
make_tuple<ANN_MLP_METHOD, string, int>(ml::ANN_MLP::ANNEAL, "anneal", 1000)
355
//make_pair<ANN_MLP_METHOD, string>(ml::ANN_MLP::BACKPROP, "backprop", 5000); -----> NO BACKPROP TEST
356
)
357
);
358
359
360
// 6. dtree
361
// 7. boost
362
int str_to_boost_type( String& str )
363
{
364
if ( !str.compare("DISCRETE") )
365
return Boost::DISCRETE;
366
if ( !str.compare("REAL") )
367
return Boost::REAL;
368
if ( !str.compare("LOGIT") )
369
return Boost::LOGIT;
370
if ( !str.compare("GENTLE") )
371
return Boost::GENTLE;
372
CV_Error( CV_StsBadArg, "incorrect boost type string" );
373
}
374
375
// 8. rtrees
376
// 9. ertrees
377
378
int str_to_svmsgd_type( String& str )
379
{
380
if ( !str.compare("SGD") )
381
return SVMSGD::SGD;
382
if ( !str.compare("ASGD") )
383
return SVMSGD::ASGD;
384
CV_Error( CV_StsBadArg, "incorrect svmsgd type string" );
385
}
386
387
int str_to_margin_type( String& str )
388
{
389
if ( !str.compare("SOFT_MARGIN") )
390
return SVMSGD::SOFT_MARGIN;
391
if ( !str.compare("HARD_MARGIN") )
392
return SVMSGD::HARD_MARGIN;
393
CV_Error( CV_StsBadArg, "incorrect svmsgd margin type string" );
394
}
395
396
}
397
// ---------------------------------- MLBaseTest ---------------------------------------------------
398
399
CV_MLBaseTest::CV_MLBaseTest(const char* _modelName)
400
{
401
int64 seeds[] = { CV_BIG_INT(0x00009fff4f9c8d52),
402
CV_BIG_INT(0x0000a17166072c7c),
403
CV_BIG_INT(0x0201b32115cd1f9a),
404
CV_BIG_INT(0x0513cb37abcd1234),
405
CV_BIG_INT(0x0001a2b3c4d5f678)
406
};
407
408
int seedCount = sizeof(seeds)/sizeof(seeds[0]);
409
RNG& rng = theRNG();
410
411
initSeed = rng.state;
412
rng.state = seeds[rng(seedCount)];
413
414
modelName = _modelName;
415
}
416
417
CV_MLBaseTest::~CV_MLBaseTest()
418
{
419
if( validationFS.isOpened() )
420
validationFS.release();
421
theRNG().state = initSeed;
422
}
423
424
int CV_MLBaseTest::read_params( CvFileStorage* __fs )
425
{
426
CV_TRACE_FUNCTION();
427
FileStorage _fs(__fs, false);
428
if( !_fs.isOpened() )
429
test_case_count = -1;
430
else
431
{
432
FileNode fn = _fs.getFirstTopLevelNode()["run_params"][modelName];
433
test_case_count = (int)fn.size();
434
if( test_case_count <= 0 )
435
test_case_count = -1;
436
if( test_case_count > 0 )
437
{
438
dataSetNames.resize( test_case_count );
439
FileNodeIterator it = fn.begin();
440
for( int i = 0; i < test_case_count; i++, ++it )
441
{
442
dataSetNames[i] = (string)*it;
443
}
444
}
445
}
446
return cvtest::TS::OK;;
447
}
448
449
void CV_MLBaseTest::run( int )
450
{
451
CV_TRACE_FUNCTION();
452
string filename = ts->get_data_path();
453
filename += get_validation_filename();
454
validationFS.open( filename, FileStorage::READ );
455
read_params( *validationFS );
456
457
int code = cvtest::TS::OK;
458
for (int i = 0; i < test_case_count; i++)
459
{
460
CV_TRACE_REGION("iteration");
461
int temp_code = run_test_case( i );
462
if (temp_code == cvtest::TS::OK)
463
temp_code = validate_test_results( i );
464
if (temp_code != cvtest::TS::OK)
465
code = temp_code;
466
}
467
if ( test_case_count <= 0)
468
{
469
ts->printf( cvtest::TS::LOG, "validation file is not determined or not correct" );
470
code = cvtest::TS::FAIL_INVALID_TEST_DATA;
471
}
472
ts->set_failed_test_info( code );
473
}
474
475
int CV_MLBaseTest::prepare_test_case( int test_case_idx )
476
{
477
CV_TRACE_FUNCTION();
478
clear();
479
480
string dataPath = ts->get_data_path();
481
if ( dataPath.empty() )
482
{
483
ts->printf( cvtest::TS::LOG, "data path is empty" );
484
return cvtest::TS::FAIL_INVALID_TEST_DATA;
485
}
486
487
string dataName = dataSetNames[test_case_idx],
488
filename = dataPath + dataName + ".data";
489
490
FileNode dataParamsNode = validationFS.getFirstTopLevelNode()["validation"][modelName][dataName]["data_params"];
491
CV_DbgAssert( !dataParamsNode.empty() );
492
493
CV_DbgAssert( !dataParamsNode["LS"].empty() );
494
int trainSampleCount = (int)dataParamsNode["LS"];
495
496
CV_DbgAssert( !dataParamsNode["resp_idx"].empty() );
497
int respIdx = (int)dataParamsNode["resp_idx"];
498
499
CV_DbgAssert( !dataParamsNode["types"].empty() );
500
String varTypes = (String)dataParamsNode["types"];
501
502
data = TrainData::loadFromCSV(filename, 0, respIdx, respIdx+1, varTypes);
503
if( data.empty() )
504
{
505
ts->printf( cvtest::TS::LOG, "file %s can not be read\n", filename.c_str() );
506
return cvtest::TS::FAIL_INVALID_TEST_DATA;
507
}
508
509
data->setTrainTestSplit(trainSampleCount);
510
return cvtest::TS::OK;
511
}
512
513
string& CV_MLBaseTest::get_validation_filename()
514
{
515
return validationFN;
516
}
517
518
int CV_MLBaseTest::train( int testCaseIdx )
519
{
520
CV_TRACE_FUNCTION();
521
bool is_trained = false;
522
FileNode modelParamsNode =
523
validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"];
524
525
if( modelName == CV_NBAYES )
526
model = NormalBayesClassifier::create();
527
else if( modelName == CV_KNEAREST )
528
{
529
model = KNearest::create();
530
}
531
else if( modelName == CV_SVM )
532
{
533
String svm_type_str, kernel_type_str;
534
modelParamsNode["svm_type"] >> svm_type_str;
535
modelParamsNode["kernel_type"] >> kernel_type_str;
536
Ptr<SVM> m = SVM::create();
537
m->setType(str_to_svm_type( svm_type_str ));
538
m->setKernel(str_to_svm_kernel_type( kernel_type_str ));
539
m->setDegree(modelParamsNode["degree"]);
540
m->setGamma(modelParamsNode["gamma"]);
541
m->setCoef0(modelParamsNode["coef0"]);
542
m->setC(modelParamsNode["C"]);
543
m->setNu(modelParamsNode["nu"]);
544
m->setP(modelParamsNode["p"]);
545
model = m;
546
}
547
else if( modelName == CV_EM )
548
{
549
assert( 0 );
550
}
551
else if( modelName == CV_ANN )
552
{
553
String train_method_str;
554
double param1, param2;
555
modelParamsNode["train_method"] >> train_method_str;
556
modelParamsNode["param1"] >> param1;
557
modelParamsNode["param2"] >> param2;
558
Mat new_responses = ann_get_new_responses( data, cls_map );
559
// binarize the responses
560
data = TrainData::create(data->getSamples(), data->getLayout(), new_responses,
561
data->getVarIdx(), data->getTrainSampleIdx());
562
int layer_sz[] = { data->getNAllVars(), 100, 100, (int)cls_map.size() };
563
Mat layer_sizes( 1, (int)(sizeof(layer_sz)/sizeof(layer_sz[0])), CV_32S, layer_sz );
564
Ptr<ANN_MLP> m = ANN_MLP::create();
565
m->setLayerSizes(layer_sizes);
566
m->setActivationFunction(ANN_MLP::SIGMOID_SYM, 0, 0);
567
m->setTermCriteria(TermCriteria(TermCriteria::COUNT,300,0.01));
568
m->setTrainMethod(str_to_ann_train_method(train_method_str), param1, param2);
569
model = m;
570
571
}
572
else if( modelName == CV_DTREE )
573
{
574
int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS;
575
float REG_ACCURACY = 0;
576
bool USE_SURROGATE = false, IS_PRUNED;
577
modelParamsNode["max_depth"] >> MAX_DEPTH;
578
modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
579
//modelParamsNode["use_surrogate"] >> USE_SURROGATE;
580
modelParamsNode["max_categories"] >> MAX_CATEGORIES;
581
modelParamsNode["cv_folds"] >> CV_FOLDS;
582
modelParamsNode["is_pruned"] >> IS_PRUNED;
583
584
Ptr<DTrees> m = DTrees::create();
585
m->setMaxDepth(MAX_DEPTH);
586
m->setMinSampleCount(MIN_SAMPLE_COUNT);
587
m->setRegressionAccuracy(REG_ACCURACY);
588
m->setUseSurrogates(USE_SURROGATE);
589
m->setMaxCategories(MAX_CATEGORIES);
590
m->setCVFolds(CV_FOLDS);
591
m->setUse1SERule(false);
592
m->setTruncatePrunedTree(IS_PRUNED);
593
m->setPriors(Mat());
594
model = m;
595
}
596
else if( modelName == CV_BOOST )
597
{
598
int BOOST_TYPE, WEAK_COUNT, MAX_DEPTH;
599
float WEIGHT_TRIM_RATE;
600
bool USE_SURROGATE = false;
601
String typeStr;
602
modelParamsNode["type"] >> typeStr;
603
BOOST_TYPE = str_to_boost_type( typeStr );
604
modelParamsNode["weak_count"] >> WEAK_COUNT;
605
modelParamsNode["weight_trim_rate"] >> WEIGHT_TRIM_RATE;
606
modelParamsNode["max_depth"] >> MAX_DEPTH;
607
//modelParamsNode["use_surrogate"] >> USE_SURROGATE;
608
609
Ptr<Boost> m = Boost::create();
610
m->setBoostType(BOOST_TYPE);
611
m->setWeakCount(WEAK_COUNT);
612
m->setWeightTrimRate(WEIGHT_TRIM_RATE);
613
m->setMaxDepth(MAX_DEPTH);
614
m->setUseSurrogates(USE_SURROGATE);
615
m->setPriors(Mat());
616
model = m;
617
}
618
else if( modelName == CV_RTREES )
619
{
620
int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS, NACTIVE_VARS, MAX_TREES_NUM;
621
float REG_ACCURACY = 0, OOB_EPS = 0.0;
622
bool USE_SURROGATE = false, IS_PRUNED;
623
modelParamsNode["max_depth"] >> MAX_DEPTH;
624
modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
625
//modelParamsNode["use_surrogate"] >> USE_SURROGATE;
626
modelParamsNode["max_categories"] >> MAX_CATEGORIES;
627
modelParamsNode["cv_folds"] >> CV_FOLDS;
628
modelParamsNode["is_pruned"] >> IS_PRUNED;
629
modelParamsNode["nactive_vars"] >> NACTIVE_VARS;
630
modelParamsNode["max_trees_num"] >> MAX_TREES_NUM;
631
632
Ptr<RTrees> m = RTrees::create();
633
m->setMaxDepth(MAX_DEPTH);
634
m->setMinSampleCount(MIN_SAMPLE_COUNT);
635
m->setRegressionAccuracy(REG_ACCURACY);
636
m->setUseSurrogates(USE_SURROGATE);
637
m->setMaxCategories(MAX_CATEGORIES);
638
m->setPriors(Mat());
639
m->setCalculateVarImportance(true);
640
m->setActiveVarCount(NACTIVE_VARS);
641
m->setTermCriteria(TermCriteria(TermCriteria::COUNT, MAX_TREES_NUM, OOB_EPS));
642
model = m;
643
}
644
645
else if( modelName == CV_SVMSGD )
646
{
647
String svmsgdTypeStr;
648
modelParamsNode["svmsgdType"] >> svmsgdTypeStr;
649
650
Ptr<SVMSGD> m = SVMSGD::create();
651
int svmsgdType = str_to_svmsgd_type( svmsgdTypeStr );
652
m->setSvmsgdType(svmsgdType);
653
654
String marginTypeStr;
655
modelParamsNode["marginType"] >> marginTypeStr;
656
int marginType = str_to_margin_type( marginTypeStr );
657
m->setMarginType(marginType);
658
659
m->setMarginRegularization(modelParamsNode["marginRegularization"]);
660
m->setInitialStepSize(modelParamsNode["initialStepSize"]);
661
m->setStepDecreasingPower(modelParamsNode["stepDecreasingPower"]);
662
m->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 10000, 0.00001));
663
model = m;
664
}
665
666
if( !model.empty() )
667
is_trained = model->train(data, 0);
668
669
if( !is_trained )
670
{
671
ts->printf( cvtest::TS::LOG, "in test case %d model training was failed", testCaseIdx );
672
return cvtest::TS::FAIL_INVALID_OUTPUT;
673
}
674
return cvtest::TS::OK;
675
}
676
677
float CV_MLBaseTest::get_test_error( int /*testCaseIdx*/, vector<float> *resp )
678
{
679
CV_TRACE_FUNCTION();
680
int type = CV_TEST_ERROR;
681
float err = 0;
682
Mat _resp;
683
if( modelName == CV_EM )
684
assert( 0 );
685
else if( modelName == CV_ANN )
686
err = ann_calc_error( model, data, cls_map, type, resp );
687
else if( modelName == CV_DTREE || modelName == CV_BOOST || modelName == CV_RTREES ||
688
modelName == CV_SVM || modelName == CV_NBAYES || modelName == CV_KNEAREST || modelName == CV_SVMSGD )
689
err = model->calcError( data, true, _resp );
690
if( !_resp.empty() && resp )
691
_resp.convertTo(*resp, CV_32F);
692
return err;
693
}
694
695
void CV_MLBaseTest::save( const char* filename )
696
{
697
CV_TRACE_FUNCTION();
698
model->save( filename );
699
}
700
701
void CV_MLBaseTest::load( const char* filename )
702
{
703
CV_TRACE_FUNCTION();
704
if( modelName == CV_NBAYES )
705
model = Algorithm::load<NormalBayesClassifier>( filename );
706
else if( modelName == CV_KNEAREST )
707
model = Algorithm::load<KNearest>( filename );
708
else if( modelName == CV_SVM )
709
model = Algorithm::load<SVM>( filename );
710
else if( modelName == CV_ANN )
711
model = Algorithm::load<ANN_MLP>( filename );
712
else if( modelName == CV_DTREE )
713
model = Algorithm::load<DTrees>( filename );
714
else if( modelName == CV_BOOST )
715
model = Algorithm::load<Boost>( filename );
716
else if( modelName == CV_RTREES )
717
model = Algorithm::load<RTrees>( filename );
718
else if( modelName == CV_SVMSGD )
719
model = Algorithm::load<SVMSGD>( filename );
720
else
721
CV_Error( CV_StsNotImplemented, "invalid stat model name");
722
}
723
724
725
726
TEST(TrainDataGet, layout_ROW_SAMPLE) // Details: #12236
727
{
728
cv::Mat test = cv::Mat::ones(150, 30, CV_32FC1) * 2;
729
test.col(3) += Scalar::all(3);
730
cv::Mat labels = cv::Mat::ones(150, 3, CV_32SC1) * 5;
731
labels.col(1) += 1;
732
cv::Ptr<cv::ml::TrainData> train_data = cv::ml::TrainData::create(test, cv::ml::ROW_SAMPLE, labels);
733
train_data->setTrainTestSplitRatio(0.9);
734
735
Mat tidx = train_data->getTestSampleIdx();
736
EXPECT_EQ((size_t)15, tidx.total());
737
738
Mat tresp = train_data->getTestResponses();
739
EXPECT_EQ(15, tresp.rows);
740
EXPECT_EQ(labels.cols, tresp.cols);
741
EXPECT_EQ(5, tresp.at<int>(0, 0)) << tresp;
742
EXPECT_EQ(6, tresp.at<int>(0, 1)) << tresp;
743
EXPECT_EQ(6, tresp.at<int>(14, 1)) << tresp;
744
EXPECT_EQ(5, tresp.at<int>(14, 2)) << tresp;
745
746
Mat tsamples = train_data->getTestSamples();
747
EXPECT_EQ(15, tsamples.rows);
748
EXPECT_EQ(test.cols, tsamples.cols);
749
EXPECT_EQ(2, tsamples.at<float>(0, 0)) << tsamples;
750
EXPECT_EQ(5, tsamples.at<float>(0, 3)) << tsamples;
751
EXPECT_EQ(2, tsamples.at<float>(14, test.cols - 1)) << tsamples;
752
EXPECT_EQ(5, tsamples.at<float>(14, 3)) << tsamples;
753
}
754
755
TEST(TrainDataGet, layout_COL_SAMPLE) // Details: #12236
756
{
757
cv::Mat test = cv::Mat::ones(30, 150, CV_32FC1) * 3;
758
test.row(3) += Scalar::all(3);
759
cv::Mat labels = cv::Mat::ones(3, 150, CV_32SC1) * 5;
760
labels.row(1) += 1;
761
cv::Ptr<cv::ml::TrainData> train_data = cv::ml::TrainData::create(test, cv::ml::COL_SAMPLE, labels);
762
train_data->setTrainTestSplitRatio(0.9);
763
764
Mat tidx = train_data->getTestSampleIdx();
765
EXPECT_EQ((size_t)15, tidx.total());
766
767
Mat tresp = train_data->getTestResponses(); // always row-based, transposed
768
EXPECT_EQ(15, tresp.rows);
769
EXPECT_EQ(labels.rows, tresp.cols);
770
EXPECT_EQ(5, tresp.at<int>(0, 0)) << tresp;
771
EXPECT_EQ(6, tresp.at<int>(0, 1)) << tresp;
772
EXPECT_EQ(6, tresp.at<int>(14, 1)) << tresp;
773
EXPECT_EQ(5, tresp.at<int>(14, 2)) << tresp;
774
775
776
Mat tsamples = train_data->getTestSamples();
777
EXPECT_EQ(15, tsamples.cols);
778
EXPECT_EQ(test.rows, tsamples.rows);
779
EXPECT_EQ(3, tsamples.at<float>(0, 0)) << tsamples;
780
EXPECT_EQ(6, tsamples.at<float>(3, 0)) << tsamples;
781
EXPECT_EQ(6, tsamples.at<float>(3, 14)) << tsamples;
782
EXPECT_EQ(3, tsamples.at<float>(test.rows - 1, 14)) << tsamples;
783
}
784
785
786
787
} // namespace
788
/* End of file. */
789
790