Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/ml/test/test_lr.cpp
16354 views
1
///////////////////////////////////////////////////////////////////////////////////////
2
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
3
4
// By downloading, copying, installing or using the software you agree to this license.
5
// If you do not agree to this license, do not download, install,
6
// copy or use the software.
7
8
// This is a implementation of the Logistic Regression algorithm in C++ in OpenCV.
9
10
// AUTHOR:
11
// Rahul Kavi rahulkavi[at]live[at]com
12
//
13
14
// contains a subset of data from the popular Iris Dataset (taken from "http://archive.ics.uci.edu/ml/datasets/Iris")
15
16
// # You are free to use, change, or redistribute the code in any way you wish for
17
// # non-commercial purposes, but please maintain the name of the original author.
18
// # This code comes with no warranty of any kind.
19
20
// #
21
// # You are free to use, change, or redistribute the code in any way you wish for
22
// # non-commercial purposes, but please maintain the name of the original author.
23
// # This code comes with no warranty of any kind.
24
25
// # Logistic Regression ALGORITHM
26
27
28
// License Agreement
29
// For Open Source Computer Vision Library
30
31
// Copyright (C) 2000-2008, Intel Corporation, all rights reserved.
32
// Copyright (C) 2008-2011, Willow Garage Inc., all rights reserved.
33
// Third party copyrights are property of their respective owners.
34
35
// Redistribution and use in source and binary forms, with or without modification,
36
// are permitted provided that the following conditions are met:
37
38
// * Redistributions of source code must retain the above copyright notice,
39
// this list of conditions and the following disclaimer.
40
41
// * Redistributions in binary form must reproduce the above copyright notice,
42
// this list of conditions and the following disclaimer in the documentation
43
// and/or other materials provided with the distribution.
44
45
// * The name of the copyright holders may not be used to endorse or promote products
46
// derived from this software without specific prior written permission.
47
48
// This software is provided by the copyright holders and contributors "as is" and
49
// any express or implied warranties, including, but not limited to, the implied
50
// warranties of merchantability and fitness for a particular purpose are disclaimed.
51
// In no event shall the Intel Corporation or contributors be liable for any direct,
52
// indirect, incidental, special, exemplary, or consequential damages
53
// (including, but not limited to, procurement of substitute goods or services;
54
// loss of use, data, or profits; or business interruption) however caused
55
// and on any theory of liability, whether in contract, strict liability,
56
// or tort (including negligence or otherwise) arising in any way out of
57
// the use of this software, even if advised of the possibility of such damage.
58
59
#include "test_precomp.hpp"
60
61
namespace opencv_test { namespace {
62
63
bool calculateError( const Mat& _p_labels, const Mat& _o_labels, float& error)
64
{
65
CV_TRACE_FUNCTION();
66
error = 0.0f;
67
float accuracy = 0.0f;
68
Mat _p_labels_temp;
69
Mat _o_labels_temp;
70
_p_labels.convertTo(_p_labels_temp, CV_32S);
71
_o_labels.convertTo(_o_labels_temp, CV_32S);
72
73
CV_Assert(_p_labels_temp.total() == _o_labels_temp.total());
74
CV_Assert(_p_labels_temp.rows == _o_labels_temp.rows);
75
76
accuracy = (float)countNonZero(_p_labels_temp == _o_labels_temp)/_p_labels_temp.rows;
77
error = 1 - accuracy;
78
return true;
79
}
80
81
//--------------------------------------------------------------------------------------------
82
83
class CV_LRTest : public cvtest::BaseTest
84
{
85
public:
86
CV_LRTest() {}
87
protected:
88
virtual void run( int start_from );
89
};
90
91
void CV_LRTest::run( int /*start_from*/ )
92
{
93
CV_TRACE_FUNCTION();
94
// initialize variables from the popular Iris Dataset
95
string dataFileName = ts->get_data_path() + "iris.data";
96
Ptr<TrainData> tdata = TrainData::loadFromCSV(dataFileName, 0);
97
98
if (tdata.empty()) {
99
ts->set_failed_test_info(cvtest::TS::FAIL_INVALID_TEST_DATA);
100
return;
101
}
102
103
// run LR classifier train classifier
104
Ptr<LogisticRegression> p = LogisticRegression::create();
105
p->setLearningRate(1.0);
106
p->setIterations(10001);
107
p->setRegularization(LogisticRegression::REG_L2);
108
p->setTrainMethod(LogisticRegression::BATCH);
109
p->setMiniBatchSize(10);
110
p->train(tdata);
111
112
// predict using the same data
113
Mat responses;
114
p->predict(tdata->getSamples(), responses);
115
116
// calculate error
117
int test_code = cvtest::TS::OK;
118
float error = 0.0f;
119
if(!calculateError(responses, tdata->getResponses(), error))
120
{
121
ts->printf(cvtest::TS::LOG, "Bad prediction labels\n" );
122
test_code = cvtest::TS::FAIL_INVALID_OUTPUT;
123
}
124
else if(error > 0.05f)
125
{
126
ts->printf(cvtest::TS::LOG, "Bad accuracy of (%f)\n", error);
127
test_code = cvtest::TS::FAIL_BAD_ACCURACY;
128
}
129
130
{
131
FileStorage s("debug.xml", FileStorage::WRITE);
132
s << "original" << tdata->getResponses();
133
s << "predicted1" << responses;
134
s << "learnt" << p->get_learnt_thetas();
135
s << "error" << error;
136
s.release();
137
}
138
ts->set_failed_test_info(test_code);
139
}
140
141
//--------------------------------------------------------------------------------------------
142
class CV_LRTest_SaveLoad : public cvtest::BaseTest
143
{
144
public:
145
CV_LRTest_SaveLoad(){}
146
protected:
147
virtual void run(int start_from);
148
};
149
150
151
void CV_LRTest_SaveLoad::run( int /*start_from*/ )
152
{
153
CV_TRACE_FUNCTION();
154
int code = cvtest::TS::OK;
155
156
// initialize variables from the popular Iris Dataset
157
string dataFileName = ts->get_data_path() + "iris.data";
158
Ptr<TrainData> tdata = TrainData::loadFromCSV(dataFileName, 0);
159
160
Mat responses1, responses2;
161
Mat learnt_mat1, learnt_mat2;
162
163
// train and save the classifier
164
String filename = tempfile(".xml");
165
try
166
{
167
// run LR classifier train classifier
168
Ptr<LogisticRegression> lr1 = LogisticRegression::create();
169
lr1->setLearningRate(1.0);
170
lr1->setIterations(10001);
171
lr1->setRegularization(LogisticRegression::REG_L2);
172
lr1->setTrainMethod(LogisticRegression::BATCH);
173
lr1->setMiniBatchSize(10);
174
lr1->train(tdata);
175
lr1->predict(tdata->getSamples(), responses1);
176
learnt_mat1 = lr1->get_learnt_thetas();
177
lr1->save(filename);
178
}
179
catch(...)
180
{
181
ts->printf(cvtest::TS::LOG, "Crash in write method.\n" );
182
ts->set_failed_test_info(cvtest::TS::FAIL_EXCEPTION);
183
}
184
185
// and load to another
186
try
187
{
188
Ptr<LogisticRegression> lr2 = Algorithm::load<LogisticRegression>(filename);
189
lr2->predict(tdata->getSamples(), responses2);
190
learnt_mat2 = lr2->get_learnt_thetas();
191
}
192
catch(...)
193
{
194
ts->printf(cvtest::TS::LOG, "Crash in write method.\n" );
195
ts->set_failed_test_info(cvtest::TS::FAIL_EXCEPTION);
196
}
197
198
CV_Assert(responses1.rows == responses2.rows);
199
200
// compare difference in learnt matrices before and after loading from disk
201
Mat comp_learnt_mats;
202
comp_learnt_mats = (learnt_mat1 == learnt_mat2);
203
comp_learnt_mats = comp_learnt_mats.reshape(1, comp_learnt_mats.rows*comp_learnt_mats.cols);
204
comp_learnt_mats.convertTo(comp_learnt_mats, CV_32S);
205
comp_learnt_mats = comp_learnt_mats/255;
206
207
// compare difference in prediction outputs and stored inputs
208
// check if there is any difference between computed learnt mat and retrieved mat
209
210
float errorCount = 0.0;
211
errorCount += 1 - (float)countNonZero(responses1 == responses2)/responses1.rows;
212
errorCount += 1 - (float)sum(comp_learnt_mats)[0]/comp_learnt_mats.rows;
213
214
if(errorCount>0)
215
{
216
ts->printf( cvtest::TS::LOG, "Different prediction results before writing and after reading (errorCount=%d).\n", errorCount );
217
code = cvtest::TS::FAIL_BAD_ACCURACY;
218
}
219
220
remove( filename.c_str() );
221
222
ts->set_failed_test_info( code );
223
}
224
225
TEST(ML_LR, accuracy) { CV_LRTest test; test.safe_run(); }
226
TEST(ML_LR, save_load) { CV_LRTest_SaveLoad test; test.safe_run(); }
227
228
}} // namespace
229
230