Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/ml/src/nbayes.cpp
16337 views
1
/*M///////////////////////////////////////////////////////////////////////////////////////
2
//
3
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4
//
5
// By downloading, copying, installing or using the software you agree to this license.
6
// If you do not agree to this license, do not download, install,
7
// copy or use the software.
8
//
9
//
10
// Intel License Agreement
11
//
12
// Copyright (C) 2000, Intel Corporation, all rights reserved.
13
// Third party copyrights are property of their respective owners.
14
//
15
// Redistribution and use in source and binary forms, with or without modification,
16
// are permitted provided that the following conditions are met:
17
//
18
// * Redistribution's of source code must retain the above copyright notice,
19
// this list of conditions and the following disclaimer.
20
//
21
// * Redistribution's in binary form must reproduce the above copyright notice,
22
// this list of conditions and the following disclaimer in the documentation
23
// and/or other materials provided with the distribution.
24
//
25
// * The name of Intel Corporation may not be used to endorse or promote products
26
// derived from this software without specific prior written permission.
27
//
28
// This software is provided by the copyright holders and contributors "as is" and
29
// any express or implied warranties, including, but not limited to, the implied
30
// warranties of merchantability and fitness for a particular purpose are disclaimed.
31
// In no event shall the Intel Corporation or contributors be liable for any direct,
32
// indirect, incidental, special, exemplary, or consequential damages
33
// (including, but not limited to, procurement of substitute goods or services;
34
// loss of use, data, or profits; or business interruption) however caused
35
// and on any theory of liability, whether in contract, strict liability,
36
// or tort (including negligence or otherwise) arising in any way out of
37
// the use of this software, even if advised of the possibility of such damage.
38
//
39
//M*/
40
41
#include "precomp.hpp"
42
43
namespace cv {
44
namespace ml {
45
46
47
class NormalBayesClassifierImpl : public NormalBayesClassifier
48
{
49
public:
50
NormalBayesClassifierImpl()
51
{
52
nallvars = 0;
53
}
54
55
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
56
{
57
const float min_variation = FLT_EPSILON;
58
Mat responses = trainData->getNormCatResponses();
59
Mat __cls_labels = trainData->getClassLabels();
60
Mat __var_idx = trainData->getVarIdx();
61
Mat samples = trainData->getTrainSamples();
62
int nclasses = (int)__cls_labels.total();
63
64
int nvars = trainData->getNVars();
65
int s, c1, c2, cls;
66
67
int __nallvars = trainData->getNAllVars();
68
bool update = (flags & UPDATE_MODEL) != 0;
69
70
if( !update )
71
{
72
nallvars = __nallvars;
73
count.resize(nclasses);
74
sum.resize(nclasses);
75
productsum.resize(nclasses);
76
avg.resize(nclasses);
77
inv_eigen_values.resize(nclasses);
78
cov_rotate_mats.resize(nclasses);
79
80
for( cls = 0; cls < nclasses; cls++ )
81
{
82
count[cls] = Mat::zeros( 1, nvars, CV_32SC1 );
83
sum[cls] = Mat::zeros( 1, nvars, CV_64FC1 );
84
productsum[cls] = Mat::zeros( nvars, nvars, CV_64FC1 );
85
avg[cls] = Mat::zeros( 1, nvars, CV_64FC1 );
86
inv_eigen_values[cls] = Mat::zeros( 1, nvars, CV_64FC1 );
87
cov_rotate_mats[cls] = Mat::zeros( nvars, nvars, CV_64FC1 );
88
}
89
90
var_idx = __var_idx;
91
cls_labels = __cls_labels;
92
93
c.create(1, nclasses, CV_64FC1);
94
}
95
else
96
{
97
// check that the new training data has the same dimensionality etc.
98
if( nallvars != __nallvars ||
99
var_idx.size() != __var_idx.size() ||
100
norm(var_idx, __var_idx, NORM_INF) != 0 ||
101
cls_labels.size() != __cls_labels.size() ||
102
norm(cls_labels, __cls_labels, NORM_INF) != 0 )
103
CV_Error( CV_StsBadArg,
104
"The new training data is inconsistent with the original training data; varIdx and the class labels should be the same" );
105
}
106
107
Mat cov( nvars, nvars, CV_64FC1 );
108
int nsamples = samples.rows;
109
110
// process train data (count, sum , productsum)
111
for( s = 0; s < nsamples; s++ )
112
{
113
cls = responses.at<int>(s);
114
int* count_data = count[cls].ptr<int>();
115
double* sum_data = sum[cls].ptr<double>();
116
double* prod_data = productsum[cls].ptr<double>();
117
const float* train_vec = samples.ptr<float>(s);
118
119
for( c1 = 0; c1 < nvars; c1++, prod_data += nvars )
120
{
121
double val1 = train_vec[c1];
122
sum_data[c1] += val1;
123
count_data[c1]++;
124
for( c2 = c1; c2 < nvars; c2++ )
125
prod_data[c2] += train_vec[c2]*val1;
126
}
127
}
128
129
Mat vt;
130
131
// calculate avg, covariance matrix, c
132
for( cls = 0; cls < nclasses; cls++ )
133
{
134
double det = 1;
135
int i, j;
136
Mat& w = inv_eigen_values[cls];
137
int* count_data = count[cls].ptr<int>();
138
double* avg_data = avg[cls].ptr<double>();
139
double* sum1 = sum[cls].ptr<double>();
140
141
completeSymm(productsum[cls], 0);
142
143
for( j = 0; j < nvars; j++ )
144
{
145
int n = count_data[j];
146
avg_data[j] = n ? sum1[j] / n : 0.;
147
}
148
149
count_data = count[cls].ptr<int>();
150
avg_data = avg[cls].ptr<double>();
151
sum1 = sum[cls].ptr<double>();
152
153
for( i = 0; i < nvars; i++ )
154
{
155
double* avg2_data = avg[cls].ptr<double>();
156
double* sum2 = sum[cls].ptr<double>();
157
double* prod_data = productsum[cls].ptr<double>(i);
158
double* cov_data = cov.ptr<double>(i);
159
double s1val = sum1[i];
160
double avg1 = avg_data[i];
161
int _count = count_data[i];
162
163
for( j = 0; j <= i; j++ )
164
{
165
double avg2 = avg2_data[j];
166
double cov_val = prod_data[j] - avg1 * sum2[j] - avg2 * s1val + avg1 * avg2 * _count;
167
cov_val = (_count > 1) ? cov_val / (_count - 1) : cov_val;
168
cov_data[j] = cov_val;
169
}
170
}
171
172
completeSymm( cov, 1 );
173
174
SVD::compute(cov, w, cov_rotate_mats[cls], noArray());
175
transpose(cov_rotate_mats[cls], cov_rotate_mats[cls]);
176
cv::max(w, min_variation, w);
177
for( j = 0; j < nvars; j++ )
178
det *= w.at<double>(j);
179
180
divide(1., w, w);
181
c.at<double>(cls) = det > 0 ? log(det) : -700;
182
}
183
184
return true;
185
}
186
187
class NBPredictBody : public ParallelLoopBody
188
{
189
public:
190
NBPredictBody( const Mat& _c, const vector<Mat>& _cov_rotate_mats,
191
const vector<Mat>& _inv_eigen_values,
192
const vector<Mat>& _avg,
193
const Mat& _samples, const Mat& _vidx, const Mat& _cls_labels,
194
Mat& _results, Mat& _results_prob, bool _rawOutput )
195
{
196
c = &_c;
197
cov_rotate_mats = &_cov_rotate_mats;
198
inv_eigen_values = &_inv_eigen_values;
199
avg = &_avg;
200
samples = &_samples;
201
vidx = &_vidx;
202
cls_labels = &_cls_labels;
203
results = &_results;
204
results_prob = !_results_prob.empty() ? &_results_prob : 0;
205
rawOutput = _rawOutput;
206
value = 0;
207
}
208
209
const Mat* c;
210
const vector<Mat>* cov_rotate_mats;
211
const vector<Mat>* inv_eigen_values;
212
const vector<Mat>* avg;
213
const Mat* samples;
214
const Mat* vidx;
215
const Mat* cls_labels;
216
217
Mat* results_prob;
218
Mat* results;
219
float* value;
220
bool rawOutput;
221
222
void operator()(const Range& range) const CV_OVERRIDE
223
{
224
int cls = -1;
225
int rtype = 0, rptype = 0;
226
size_t rstep = 0, rpstep = 0;
227
int nclasses = (int)cls_labels->total();
228
int nvars = avg->at(0).cols;
229
double probability = 0;
230
const int* vptr = vidx && !vidx->empty() ? vidx->ptr<int>() : 0;
231
232
if (results)
233
{
234
rtype = results->type();
235
rstep = results->isContinuous() ? 1 : results->step/results->elemSize();
236
}
237
if (results_prob)
238
{
239
rptype = results_prob->type();
240
rpstep = results_prob->isContinuous() ? results_prob->cols : results_prob->step/results_prob->elemSize();
241
}
242
// allocate memory and initializing headers for calculating
243
cv::AutoBuffer<double> _buffer(nvars*2);
244
double* _diffin = _buffer.data();
245
double* _diffout = _buffer.data() + nvars;
246
Mat diffin( 1, nvars, CV_64FC1, _diffin );
247
Mat diffout( 1, nvars, CV_64FC1, _diffout );
248
249
for(int k = range.start; k < range.end; k++ )
250
{
251
double opt = FLT_MAX;
252
253
for(int i = 0; i < nclasses; i++ )
254
{
255
double cur = c->at<double>(i);
256
const Mat& u = cov_rotate_mats->at(i);
257
const Mat& w = inv_eigen_values->at(i);
258
259
const double* avg_data = avg->at(i).ptr<double>();
260
const float* x = samples->ptr<float>(k);
261
262
// cov = u w u' --> cov^(-1) = u w^(-1) u'
263
for(int j = 0; j < nvars; j++ )
264
_diffin[j] = avg_data[j] - x[vptr ? vptr[j] : j];
265
266
gemm( diffin, u, 1, noArray(), 0, diffout, GEMM_2_T );
267
for(int j = 0; j < nvars; j++ )
268
{
269
double d = _diffout[j];
270
cur += d*d*w.ptr<double>()[j];
271
}
272
273
if( cur < opt )
274
{
275
cls = i;
276
opt = cur;
277
}
278
probability = exp( -0.5 * cur );
279
280
if( results_prob )
281
{
282
if ( rptype == CV_32FC1 )
283
results_prob->ptr<float>()[k*rpstep + i] = (float)probability;
284
else
285
results_prob->ptr<double>()[k*rpstep + i] = probability;
286
}
287
}
288
289
int ival = rawOutput ? cls : cls_labels->at<int>(cls);
290
if( results )
291
{
292
if( rtype == CV_32SC1 )
293
results->ptr<int>()[k*rstep] = ival;
294
else
295
results->ptr<float>()[k*rstep] = (float)ival;
296
}
297
}
298
}
299
};
300
301
float predict( InputArray _samples, OutputArray _results, int flags ) const CV_OVERRIDE
302
{
303
return predictProb(_samples, _results, noArray(), flags);
304
}
305
306
float predictProb( InputArray _samples, OutputArray _results, OutputArray _resultsProb, int flags ) const CV_OVERRIDE
307
{
308
int value=0;
309
Mat samples = _samples.getMat(), results, resultsProb;
310
int nsamples = samples.rows, nclasses = (int)cls_labels.total();
311
bool rawOutput = (flags & RAW_OUTPUT) != 0;
312
313
if( samples.type() != CV_32F || samples.cols != nallvars )
314
CV_Error( CV_StsBadArg,
315
"The input samples must be 32f matrix with the number of columns = nallvars" );
316
317
if( (samples.rows > 1) && (! _results.needed()) )
318
CV_Error( CV_StsNullPtr,
319
"When the number of input samples is >1, the output vector of results must be passed" );
320
321
if( _results.needed() )
322
{
323
_results.create(nsamples, 1, CV_32S);
324
results = _results.getMat();
325
}
326
else
327
results = Mat(1, 1, CV_32S, &value);
328
329
if( _resultsProb.needed() )
330
{
331
_resultsProb.create(nsamples, nclasses, CV_32F);
332
resultsProb = _resultsProb.getMat();
333
}
334
335
cv::parallel_for_(cv::Range(0, nsamples),
336
NBPredictBody(c, cov_rotate_mats, inv_eigen_values, avg, samples,
337
var_idx, cls_labels, results, resultsProb, rawOutput));
338
339
return (float)value;
340
}
341
342
void write( FileStorage& fs ) const CV_OVERRIDE
343
{
344
int nclasses = (int)cls_labels.total(), i;
345
346
writeFormat(fs);
347
fs << "var_count" << (var_idx.empty() ? nallvars : (int)var_idx.total());
348
fs << "var_all" << nallvars;
349
350
if( !var_idx.empty() )
351
fs << "var_idx" << var_idx;
352
fs << "cls_labels" << cls_labels;
353
354
fs << "count" << "[";
355
for( i = 0; i < nclasses; i++ )
356
fs << count[i];
357
358
fs << "]" << "sum" << "[";
359
for( i = 0; i < nclasses; i++ )
360
fs << sum[i];
361
362
fs << "]" << "productsum" << "[";
363
for( i = 0; i < nclasses; i++ )
364
fs << productsum[i];
365
366
fs << "]" << "avg" << "[";
367
for( i = 0; i < nclasses; i++ )
368
fs << avg[i];
369
370
fs << "]" << "inv_eigen_values" << "[";
371
for( i = 0; i < nclasses; i++ )
372
fs << inv_eigen_values[i];
373
374
fs << "]" << "cov_rotate_mats" << "[";
375
for( i = 0; i < nclasses; i++ )
376
fs << cov_rotate_mats[i];
377
378
fs << "]";
379
380
fs << "c" << c;
381
}
382
383
void read( const FileNode& fn ) CV_OVERRIDE
384
{
385
clear();
386
387
fn["var_all"] >> nallvars;
388
389
if( nallvars <= 0 )
390
CV_Error( CV_StsParseError,
391
"The field \"var_count\" of NBayes classifier is missing or non-positive" );
392
393
fn["var_idx"] >> var_idx;
394
fn["cls_labels"] >> cls_labels;
395
396
int nclasses = (int)cls_labels.total(), i;
397
398
if( cls_labels.empty() || nclasses < 1 )
399
CV_Error( CV_StsParseError, "No or invalid \"cls_labels\" in NBayes classifier" );
400
401
FileNodeIterator
402
count_it = fn["count"].begin(),
403
sum_it = fn["sum"].begin(),
404
productsum_it = fn["productsum"].begin(),
405
avg_it = fn["avg"].begin(),
406
inv_eigen_values_it = fn["inv_eigen_values"].begin(),
407
cov_rotate_mats_it = fn["cov_rotate_mats"].begin();
408
409
count.resize(nclasses);
410
sum.resize(nclasses);
411
productsum.resize(nclasses);
412
avg.resize(nclasses);
413
inv_eigen_values.resize(nclasses);
414
cov_rotate_mats.resize(nclasses);
415
416
for( i = 0; i < nclasses; i++, ++count_it, ++sum_it, ++productsum_it, ++avg_it,
417
++inv_eigen_values_it, ++cov_rotate_mats_it )
418
{
419
*count_it >> count[i];
420
*sum_it >> sum[i];
421
*productsum_it >> productsum[i];
422
*avg_it >> avg[i];
423
*inv_eigen_values_it >> inv_eigen_values[i];
424
*cov_rotate_mats_it >> cov_rotate_mats[i];
425
}
426
427
fn["c"] >> c;
428
}
429
430
void clear() CV_OVERRIDE
431
{
432
count.clear();
433
sum.clear();
434
productsum.clear();
435
avg.clear();
436
inv_eigen_values.clear();
437
cov_rotate_mats.clear();
438
439
var_idx.release();
440
cls_labels.release();
441
c.release();
442
nallvars = 0;
443
}
444
445
bool isTrained() const CV_OVERRIDE { return !avg.empty(); }
446
bool isClassifier() const CV_OVERRIDE { return true; }
447
int getVarCount() const CV_OVERRIDE { return nallvars; }
448
String getDefaultName() const CV_OVERRIDE { return "opencv_ml_nbayes"; }
449
450
int nallvars;
451
Mat var_idx, cls_labels, c;
452
vector<Mat> count, sum, productsum, avg, inv_eigen_values, cov_rotate_mats;
453
};
454
455
456
Ptr<NormalBayesClassifier> NormalBayesClassifier::create()
457
{
458
Ptr<NormalBayesClassifierImpl> p = makePtr<NormalBayesClassifierImpl>();
459
return p;
460
}
461
462
Ptr<NormalBayesClassifier> NormalBayesClassifier::load(const String& filepath, const String& nodeName)
463
{
464
return Algorithm::load<NormalBayesClassifier>(filepath, nodeName);
465
}
466
467
}
468
}
469
470
/* End of file. */
471
472