Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/ml/src/em.cpp
16337 views
1
/*M///////////////////////////////////////////////////////////////////////////////////////
2
//
3
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4
//
5
// By downloading, copying, installing or using the software you agree to this license.
6
// If you do not agree to this license, do not download, install,
7
// copy or use the software.
8
//
9
//
10
// Intel License Agreement
11
// For Open Source Computer Vision Library
12
//
13
// Copyright( C) 2000, Intel Corporation, all rights reserved.
14
// Third party copyrights are property of their respective owners.
15
//
16
// Redistribution and use in source and binary forms, with or without modification,
17
// are permitted provided that the following conditions are met:
18
//
19
// * Redistribution's of source code must retain the above copyright notice,
20
// this list of conditions and the following disclaimer.
21
//
22
// * Redistribution's in binary form must reproduce the above copyright notice,
23
// this list of conditions and the following disclaimer in the documentation
24
// and/or other materials provided with the distribution.
25
//
26
// * The name of Intel Corporation may not be used to endorse or promote products
27
// derived from this software without specific prior written permission.
28
//
29
// This software is provided by the copyright holders and contributors "as is" and
30
// any express or implied warranties, including, but not limited to, the implied
31
// warranties of merchantability and fitness for a particular purpose are disclaimed.
32
// In no event shall the Intel Corporation or contributors be liable for any direct,
33
// indirect, incidental, special, exemplary, or consequential damages
34
//(including, but not limited to, procurement of substitute goods or services;
35
// loss of use, data, or profits; or business interruption) however caused
36
// and on any theory of liability, whether in contract, strict liability,
37
// or tort(including negligence or otherwise) arising in any way out of
38
// the use of this software, even ifadvised of the possibility of such damage.
39
//
40
//M*/
41
42
#include "precomp.hpp"
43
44
namespace cv
45
{
46
namespace ml
47
{
48
49
const double minEigenValue = DBL_EPSILON;
50
51
class CV_EXPORTS EMImpl CV_FINAL : public EM
52
{
53
public:
54
55
int nclusters;
56
int covMatType;
57
TermCriteria termCrit;
58
59
inline TermCriteria getTermCriteria() const CV_OVERRIDE { return termCrit; }
60
inline void setTermCriteria(const TermCriteria& val) CV_OVERRIDE { termCrit = val; }
61
62
void setClustersNumber(int val) CV_OVERRIDE
63
{
64
nclusters = val;
65
CV_Assert(nclusters >= 1);
66
}
67
68
int getClustersNumber() const CV_OVERRIDE
69
{
70
return nclusters;
71
}
72
73
void setCovarianceMatrixType(int val) CV_OVERRIDE
74
{
75
covMatType = val;
76
CV_Assert(covMatType == COV_MAT_SPHERICAL ||
77
covMatType == COV_MAT_DIAGONAL ||
78
covMatType == COV_MAT_GENERIC);
79
}
80
81
int getCovarianceMatrixType() const CV_OVERRIDE
82
{
83
return covMatType;
84
}
85
86
EMImpl()
87
{
88
nclusters = DEFAULT_NCLUSTERS;
89
covMatType=EM::COV_MAT_DIAGONAL;
90
termCrit = TermCriteria(TermCriteria::COUNT+TermCriteria::EPS, EM::DEFAULT_MAX_ITERS, 1e-6);
91
}
92
93
virtual ~EMImpl() {}
94
95
void clear() CV_OVERRIDE
96
{
97
trainSamples.release();
98
trainProbs.release();
99
trainLogLikelihoods.release();
100
trainLabels.release();
101
102
weights.release();
103
means.release();
104
covs.clear();
105
106
covsEigenValues.clear();
107
invCovsEigenValues.clear();
108
covsRotateMats.clear();
109
110
logWeightDivDet.release();
111
}
112
113
bool train(const Ptr<TrainData>& data, int) CV_OVERRIDE
114
{
115
Mat samples = data->getTrainSamples(), labels;
116
return trainEM(samples, labels, noArray(), noArray());
117
}
118
119
bool trainEM(InputArray samples,
120
OutputArray logLikelihoods,
121
OutputArray labels,
122
OutputArray probs) CV_OVERRIDE
123
{
124
Mat samplesMat = samples.getMat();
125
setTrainData(START_AUTO_STEP, samplesMat, 0, 0, 0, 0);
126
return doTrain(START_AUTO_STEP, logLikelihoods, labels, probs);
127
}
128
129
bool trainE(InputArray samples,
130
InputArray _means0,
131
InputArray _covs0,
132
InputArray _weights0,
133
OutputArray logLikelihoods,
134
OutputArray labels,
135
OutputArray probs) CV_OVERRIDE
136
{
137
Mat samplesMat = samples.getMat();
138
std::vector<Mat> covs0;
139
_covs0.getMatVector(covs0);
140
141
Mat means0 = _means0.getMat(), weights0 = _weights0.getMat();
142
143
setTrainData(START_E_STEP, samplesMat, 0, !_means0.empty() ? &means0 : 0,
144
!_covs0.empty() ? &covs0 : 0, !_weights0.empty() ? &weights0 : 0);
145
return doTrain(START_E_STEP, logLikelihoods, labels, probs);
146
}
147
148
bool trainM(InputArray samples,
149
InputArray _probs0,
150
OutputArray logLikelihoods,
151
OutputArray labels,
152
OutputArray probs) CV_OVERRIDE
153
{
154
Mat samplesMat = samples.getMat();
155
Mat probs0 = _probs0.getMat();
156
157
setTrainData(START_M_STEP, samplesMat, !_probs0.empty() ? &probs0 : 0, 0, 0, 0);
158
return doTrain(START_M_STEP, logLikelihoods, labels, probs);
159
}
160
161
float predict(InputArray _inputs, OutputArray _outputs, int) const CV_OVERRIDE
162
{
163
bool needprobs = _outputs.needed();
164
Mat samples = _inputs.getMat(), probs, probsrow;
165
int ptype = CV_64F;
166
float firstres = 0.f;
167
int i, nsamples = samples.rows;
168
169
if( needprobs )
170
{
171
if( _outputs.fixedType() )
172
ptype = _outputs.type();
173
_outputs.create(samples.rows, nclusters, ptype);
174
probs = _outputs.getMat();
175
}
176
else
177
nsamples = std::min(nsamples, 1);
178
179
for( i = 0; i < nsamples; i++ )
180
{
181
if( needprobs )
182
probsrow = probs.row(i);
183
Vec2d res = computeProbabilities(samples.row(i), needprobs ? &probsrow : 0, ptype);
184
if( i == 0 )
185
firstres = (float)res[1];
186
}
187
return firstres;
188
}
189
190
Vec2d predict2(InputArray _sample, OutputArray _probs) const CV_OVERRIDE
191
{
192
int ptype = CV_64F;
193
Mat sample = _sample.getMat();
194
CV_Assert(isTrained());
195
196
CV_Assert(!sample.empty());
197
if(sample.type() != CV_64FC1)
198
{
199
Mat tmp;
200
sample.convertTo(tmp, CV_64FC1);
201
sample = tmp;
202
}
203
sample = sample.reshape(1, 1);
204
205
Mat probs;
206
if( _probs.needed() )
207
{
208
if( _probs.fixedType() )
209
ptype = _probs.type();
210
_probs.create(1, nclusters, ptype);
211
probs = _probs.getMat();
212
}
213
214
return computeProbabilities(sample, !probs.empty() ? &probs : 0, ptype);
215
}
216
217
bool isTrained() const CV_OVERRIDE
218
{
219
return !means.empty();
220
}
221
222
bool isClassifier() const CV_OVERRIDE
223
{
224
return true;
225
}
226
227
int getVarCount() const CV_OVERRIDE
228
{
229
return means.cols;
230
}
231
232
String getDefaultName() const CV_OVERRIDE
233
{
234
return "opencv_ml_em";
235
}
236
237
static void checkTrainData(int startStep, const Mat& samples,
238
int nclusters, int covMatType, const Mat* probs, const Mat* means,
239
const std::vector<Mat>* covs, const Mat* weights)
240
{
241
// Check samples.
242
CV_Assert(!samples.empty());
243
CV_Assert(samples.channels() == 1);
244
245
int nsamples = samples.rows;
246
int dim = samples.cols;
247
248
// Check training params.
249
CV_Assert(nclusters > 0);
250
CV_Assert(nclusters <= nsamples);
251
CV_Assert(startStep == START_AUTO_STEP ||
252
startStep == START_E_STEP ||
253
startStep == START_M_STEP);
254
CV_Assert(covMatType == COV_MAT_GENERIC ||
255
covMatType == COV_MAT_DIAGONAL ||
256
covMatType == COV_MAT_SPHERICAL);
257
258
CV_Assert(!probs ||
259
(!probs->empty() &&
260
probs->rows == nsamples && probs->cols == nclusters &&
261
(probs->type() == CV_32FC1 || probs->type() == CV_64FC1)));
262
263
CV_Assert(!weights ||
264
(!weights->empty() &&
265
(weights->cols == 1 || weights->rows == 1) && static_cast<int>(weights->total()) == nclusters &&
266
(weights->type() == CV_32FC1 || weights->type() == CV_64FC1)));
267
268
CV_Assert(!means ||
269
(!means->empty() &&
270
means->rows == nclusters && means->cols == dim &&
271
means->channels() == 1));
272
273
CV_Assert(!covs ||
274
(!covs->empty() &&
275
static_cast<int>(covs->size()) == nclusters));
276
if(covs)
277
{
278
const Size covSize(dim, dim);
279
for(size_t i = 0; i < covs->size(); i++)
280
{
281
const Mat& m = (*covs)[i];
282
CV_Assert(!m.empty() && m.size() == covSize && (m.channels() == 1));
283
}
284
}
285
286
if(startStep == START_E_STEP)
287
{
288
CV_Assert(means);
289
}
290
else if(startStep == START_M_STEP)
291
{
292
CV_Assert(probs);
293
}
294
}
295
296
static void preprocessSampleData(const Mat& src, Mat& dst, int dstType, bool isAlwaysClone)
297
{
298
if(src.type() == dstType && !isAlwaysClone)
299
dst = src;
300
else
301
src.convertTo(dst, dstType);
302
}
303
304
static void preprocessProbability(Mat& probs)
305
{
306
max(probs, 0., probs);
307
308
const double uniformProbability = (double)(1./probs.cols);
309
for(int y = 0; y < probs.rows; y++)
310
{
311
Mat sampleProbs = probs.row(y);
312
313
double maxVal = 0;
314
minMaxLoc(sampleProbs, 0, &maxVal);
315
if(maxVal < FLT_EPSILON)
316
sampleProbs.setTo(uniformProbability);
317
else
318
normalize(sampleProbs, sampleProbs, 1, 0, NORM_L1);
319
}
320
}
321
322
void setTrainData(int startStep, const Mat& samples,
323
const Mat* probs0,
324
const Mat* means0,
325
const std::vector<Mat>* covs0,
326
const Mat* weights0)
327
{
328
clear();
329
330
checkTrainData(startStep, samples, nclusters, covMatType, probs0, means0, covs0, weights0);
331
332
bool isKMeansInit = (startStep == START_AUTO_STEP) || (startStep == START_E_STEP && (covs0 == 0 || weights0 == 0));
333
// Set checked data
334
preprocessSampleData(samples, trainSamples, isKMeansInit ? CV_32FC1 : CV_64FC1, false);
335
336
// set probs
337
if(probs0 && startStep == START_M_STEP)
338
{
339
preprocessSampleData(*probs0, trainProbs, CV_64FC1, true);
340
preprocessProbability(trainProbs);
341
}
342
343
// set weights
344
if(weights0 && (startStep == START_E_STEP && covs0))
345
{
346
weights0->convertTo(weights, CV_64FC1);
347
weights = weights.reshape(1,1);
348
preprocessProbability(weights);
349
}
350
351
// set means
352
if(means0 && (startStep == START_E_STEP/* || startStep == START_AUTO_STEP*/))
353
means0->convertTo(means, isKMeansInit ? CV_32FC1 : CV_64FC1);
354
355
// set covs
356
if(covs0 && (startStep == START_E_STEP && weights0))
357
{
358
covs.resize(nclusters);
359
for(size_t i = 0; i < covs0->size(); i++)
360
(*covs0)[i].convertTo(covs[i], CV_64FC1);
361
}
362
}
363
364
void decomposeCovs()
365
{
366
CV_Assert(!covs.empty());
367
covsEigenValues.resize(nclusters);
368
if(covMatType == COV_MAT_GENERIC)
369
covsRotateMats.resize(nclusters);
370
invCovsEigenValues.resize(nclusters);
371
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
372
{
373
CV_Assert(!covs[clusterIndex].empty());
374
375
SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
376
377
if(covMatType == COV_MAT_SPHERICAL)
378
{
379
double maxSingularVal = svd.w.at<double>(0);
380
covsEigenValues[clusterIndex] = Mat(1, 1, CV_64FC1, Scalar(maxSingularVal));
381
}
382
else if(covMatType == COV_MAT_DIAGONAL)
383
{
384
covsEigenValues[clusterIndex] = covs[clusterIndex].diag().clone(); //Preserve the original order of eigen values.
385
}
386
else //COV_MAT_GENERIC
387
{
388
covsEigenValues[clusterIndex] = svd.w;
389
covsRotateMats[clusterIndex] = svd.u;
390
}
391
max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
392
invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
393
}
394
}
395
396
void clusterTrainSamples()
397
{
398
int nsamples = trainSamples.rows;
399
400
// Cluster samples, compute/update means
401
402
// Convert samples and means to 32F, because kmeans requires this type.
403
Mat trainSamplesFlt, meansFlt;
404
if(trainSamples.type() != CV_32FC1)
405
trainSamples.convertTo(trainSamplesFlt, CV_32FC1);
406
else
407
trainSamplesFlt = trainSamples;
408
if(!means.empty())
409
{
410
if(means.type() != CV_32FC1)
411
means.convertTo(meansFlt, CV_32FC1);
412
else
413
meansFlt = means;
414
}
415
416
Mat labels;
417
kmeans(trainSamplesFlt, nclusters, labels,
418
TermCriteria(TermCriteria::COUNT, means.empty() ? 10 : 1, 0.5),
419
10, KMEANS_PP_CENTERS, meansFlt);
420
421
// Convert samples and means back to 64F.
422
CV_Assert(meansFlt.type() == CV_32FC1);
423
if(trainSamples.type() != CV_64FC1)
424
{
425
Mat trainSamplesBuffer;
426
trainSamplesFlt.convertTo(trainSamplesBuffer, CV_64FC1);
427
trainSamples = trainSamplesBuffer;
428
}
429
meansFlt.convertTo(means, CV_64FC1);
430
431
// Compute weights and covs
432
weights = Mat(1, nclusters, CV_64FC1, Scalar(0));
433
covs.resize(nclusters);
434
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
435
{
436
Mat clusterSamples;
437
for(int sampleIndex = 0; sampleIndex < nsamples; sampleIndex++)
438
{
439
if(labels.at<int>(sampleIndex) == clusterIndex)
440
{
441
const Mat sample = trainSamples.row(sampleIndex);
442
clusterSamples.push_back(sample);
443
}
444
}
445
CV_Assert(!clusterSamples.empty());
446
447
calcCovarMatrix(clusterSamples, covs[clusterIndex], means.row(clusterIndex),
448
CV_COVAR_NORMAL + CV_COVAR_ROWS + CV_COVAR_USE_AVG + CV_COVAR_SCALE, CV_64FC1);
449
weights.at<double>(clusterIndex) = static_cast<double>(clusterSamples.rows)/static_cast<double>(nsamples);
450
}
451
452
decomposeCovs();
453
}
454
455
void computeLogWeightDivDet()
456
{
457
CV_Assert(!covsEigenValues.empty());
458
459
Mat logWeights;
460
cv::max(weights, DBL_MIN, weights);
461
log(weights, logWeights);
462
463
logWeightDivDet.create(1, nclusters, CV_64FC1);
464
// note: logWeightDivDet = log(weight_k) - 0.5 * log(|det(cov_k)|)
465
466
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
467
{
468
double logDetCov = 0.;
469
const int evalCount = static_cast<int>(covsEigenValues[clusterIndex].total());
470
for(int di = 0; di < evalCount; di++)
471
logDetCov += std::log(covsEigenValues[clusterIndex].at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0));
472
473
logWeightDivDet.at<double>(clusterIndex) = logWeights.at<double>(clusterIndex) - 0.5 * logDetCov;
474
}
475
}
476
477
bool doTrain(int startStep, OutputArray logLikelihoods, OutputArray labels, OutputArray probs)
478
{
479
int dim = trainSamples.cols;
480
// Precompute the empty initial train data in the cases of START_E_STEP and START_AUTO_STEP
481
if(startStep != START_M_STEP)
482
{
483
if(covs.empty())
484
{
485
CV_Assert(weights.empty());
486
clusterTrainSamples();
487
}
488
}
489
490
if(!covs.empty() && covsEigenValues.empty() )
491
{
492
CV_Assert(invCovsEigenValues.empty());
493
decomposeCovs();
494
}
495
496
if(startStep == START_M_STEP)
497
mStep();
498
499
double trainLogLikelihood, prevTrainLogLikelihood = 0.;
500
int maxIters = (termCrit.type & TermCriteria::MAX_ITER) ?
501
termCrit.maxCount : DEFAULT_MAX_ITERS;
502
double epsilon = (termCrit.type & TermCriteria::EPS) ? termCrit.epsilon : 0.;
503
504
for(int iter = 0; ; iter++)
505
{
506
eStep();
507
trainLogLikelihood = sum(trainLogLikelihoods)[0];
508
509
if(iter >= maxIters - 1)
510
break;
511
512
double trainLogLikelihoodDelta = trainLogLikelihood - prevTrainLogLikelihood;
513
if( iter != 0 &&
514
(trainLogLikelihoodDelta < -DBL_EPSILON ||
515
trainLogLikelihoodDelta < epsilon * std::fabs(trainLogLikelihood)))
516
break;
517
518
mStep();
519
520
prevTrainLogLikelihood = trainLogLikelihood;
521
}
522
523
if( trainLogLikelihood <= -DBL_MAX/10000. )
524
{
525
clear();
526
return false;
527
}
528
529
// postprocess covs
530
covs.resize(nclusters);
531
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
532
{
533
if(covMatType == COV_MAT_SPHERICAL)
534
{
535
covs[clusterIndex].create(dim, dim, CV_64FC1);
536
setIdentity(covs[clusterIndex], Scalar(covsEigenValues[clusterIndex].at<double>(0)));
537
}
538
else if(covMatType == COV_MAT_DIAGONAL)
539
{
540
covs[clusterIndex] = Mat::diag(covsEigenValues[clusterIndex]);
541
}
542
}
543
544
if(labels.needed())
545
trainLabels.copyTo(labels);
546
if(probs.needed())
547
trainProbs.copyTo(probs);
548
if(logLikelihoods.needed())
549
trainLogLikelihoods.copyTo(logLikelihoods);
550
551
trainSamples.release();
552
trainProbs.release();
553
trainLabels.release();
554
trainLogLikelihoods.release();
555
556
return true;
557
}
558
559
Vec2d computeProbabilities(const Mat& sample, Mat* probs, int ptype) const
560
{
561
// L_ik = log(weight_k) - 0.5 * log(|det(cov_k)|) - 0.5 *(x_i - mean_k)' cov_k^(-1) (x_i - mean_k)]
562
// q = arg(max_k(L_ik))
563
// probs_ik = exp(L_ik - L_iq) / (1 + sum_j!=q (exp(L_ij - L_iq))
564
// see Alex Smola's blog http://blog.smola.org/page/2 for
565
// details on the log-sum-exp trick
566
567
int stype = sample.type();
568
CV_Assert(!means.empty());
569
CV_Assert((stype == CV_32F || stype == CV_64F) && (ptype == CV_32F || ptype == CV_64F));
570
CV_Assert(sample.size() == Size(means.cols, 1));
571
572
int dim = sample.cols;
573
574
Mat L(1, nclusters, CV_64FC1), centeredSample(1, dim, CV_64F);
575
int i, label = 0;
576
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
577
{
578
const double* mptr = means.ptr<double>(clusterIndex);
579
double* dptr = centeredSample.ptr<double>();
580
if( stype == CV_32F )
581
{
582
const float* sptr = sample.ptr<float>();
583
for( i = 0; i < dim; i++ )
584
dptr[i] = sptr[i] - mptr[i];
585
}
586
else
587
{
588
const double* sptr = sample.ptr<double>();
589
for( i = 0; i < dim; i++ )
590
dptr[i] = sptr[i] - mptr[i];
591
}
592
593
Mat rotatedCenteredSample = covMatType != COV_MAT_GENERIC ?
594
centeredSample : centeredSample * covsRotateMats[clusterIndex];
595
596
double Lval = 0;
597
for(int di = 0; di < dim; di++)
598
{
599
double w = invCovsEigenValues[clusterIndex].at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0);
600
double val = rotatedCenteredSample.at<double>(di);
601
Lval += w * val * val;
602
}
603
CV_DbgAssert(!logWeightDivDet.empty());
604
L.at<double>(clusterIndex) = logWeightDivDet.at<double>(clusterIndex) - 0.5 * Lval;
605
606
if(L.at<double>(clusterIndex) > L.at<double>(label))
607
label = clusterIndex;
608
}
609
610
double maxLVal = L.at<double>(label);
611
double expDiffSum = 0;
612
for( i = 0; i < L.cols; i++ )
613
{
614
double v = std::exp(L.at<double>(i) - maxLVal);
615
L.at<double>(i) = v;
616
expDiffSum += v; // sum_j(exp(L_ij - L_iq))
617
}
618
619
CV_Assert(expDiffSum > 0);
620
if(probs)
621
L.convertTo(*probs, ptype, 1./expDiffSum);
622
623
Vec2d res;
624
res[0] = std::log(expDiffSum) + maxLVal - 0.5 * dim * CV_LOG2PI;
625
res[1] = label;
626
627
return res;
628
}
629
630
void eStep()
631
{
632
// Compute probs_ik from means_k, covs_k and weights_k.
633
trainProbs.create(trainSamples.rows, nclusters, CV_64FC1);
634
trainLabels.create(trainSamples.rows, 1, CV_32SC1);
635
trainLogLikelihoods.create(trainSamples.rows, 1, CV_64FC1);
636
637
computeLogWeightDivDet();
638
639
CV_DbgAssert(trainSamples.type() == CV_64FC1);
640
CV_DbgAssert(means.type() == CV_64FC1);
641
642
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
643
{
644
Mat sampleProbs = trainProbs.row(sampleIndex);
645
Vec2d res = computeProbabilities(trainSamples.row(sampleIndex), &sampleProbs, CV_64F);
646
trainLogLikelihoods.at<double>(sampleIndex) = res[0];
647
trainLabels.at<int>(sampleIndex) = static_cast<int>(res[1]);
648
}
649
}
650
651
void mStep()
652
{
653
// Update means_k, covs_k and weights_k from probs_ik
654
int dim = trainSamples.cols;
655
656
// Update weights
657
// not normalized first
658
reduce(trainProbs, weights, 0, CV_REDUCE_SUM);
659
660
// Update means
661
means.create(nclusters, dim, CV_64FC1);
662
means = Scalar(0);
663
664
const double minPosWeight = trainSamples.rows * DBL_EPSILON;
665
double minWeight = DBL_MAX;
666
int minWeightClusterIndex = -1;
667
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
668
{
669
if(weights.at<double>(clusterIndex) <= minPosWeight)
670
continue;
671
672
if(weights.at<double>(clusterIndex) < minWeight)
673
{
674
minWeight = weights.at<double>(clusterIndex);
675
minWeightClusterIndex = clusterIndex;
676
}
677
678
Mat clusterMean = means.row(clusterIndex);
679
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
680
clusterMean += trainProbs.at<double>(sampleIndex, clusterIndex) * trainSamples.row(sampleIndex);
681
clusterMean /= weights.at<double>(clusterIndex);
682
}
683
684
// Update covsEigenValues and invCovsEigenValues
685
covs.resize(nclusters);
686
covsEigenValues.resize(nclusters);
687
if(covMatType == COV_MAT_GENERIC)
688
covsRotateMats.resize(nclusters);
689
invCovsEigenValues.resize(nclusters);
690
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
691
{
692
if(weights.at<double>(clusterIndex) <= minPosWeight)
693
continue;
694
695
if(covMatType != COV_MAT_SPHERICAL)
696
covsEigenValues[clusterIndex].create(1, dim, CV_64FC1);
697
else
698
covsEigenValues[clusterIndex].create(1, 1, CV_64FC1);
699
700
if(covMatType == COV_MAT_GENERIC)
701
covs[clusterIndex].create(dim, dim, CV_64FC1);
702
703
Mat clusterCov = covMatType != COV_MAT_GENERIC ?
704
covsEigenValues[clusterIndex] : covs[clusterIndex];
705
706
clusterCov = Scalar(0);
707
708
Mat centeredSample;
709
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
710
{
711
centeredSample = trainSamples.row(sampleIndex) - means.row(clusterIndex);
712
713
if(covMatType == COV_MAT_GENERIC)
714
clusterCov += trainProbs.at<double>(sampleIndex, clusterIndex) * centeredSample.t() * centeredSample;
715
else
716
{
717
double p = trainProbs.at<double>(sampleIndex, clusterIndex);
718
for(int di = 0; di < dim; di++ )
719
{
720
double val = centeredSample.at<double>(di);
721
clusterCov.at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0) += p*val*val;
722
}
723
}
724
}
725
726
if(covMatType == COV_MAT_SPHERICAL)
727
clusterCov /= dim;
728
729
clusterCov /= weights.at<double>(clusterIndex);
730
731
// Update covsRotateMats for COV_MAT_GENERIC only
732
if(covMatType == COV_MAT_GENERIC)
733
{
734
SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
735
covsEigenValues[clusterIndex] = svd.w;
736
covsRotateMats[clusterIndex] = svd.u;
737
}
738
739
max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
740
741
// update invCovsEigenValues
742
invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
743
}
744
745
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
746
{
747
if(weights.at<double>(clusterIndex) <= minPosWeight)
748
{
749
Mat clusterMean = means.row(clusterIndex);
750
means.row(minWeightClusterIndex).copyTo(clusterMean);
751
covs[minWeightClusterIndex].copyTo(covs[clusterIndex]);
752
covsEigenValues[minWeightClusterIndex].copyTo(covsEigenValues[clusterIndex]);
753
if(covMatType == COV_MAT_GENERIC)
754
covsRotateMats[minWeightClusterIndex].copyTo(covsRotateMats[clusterIndex]);
755
invCovsEigenValues[minWeightClusterIndex].copyTo(invCovsEigenValues[clusterIndex]);
756
}
757
}
758
759
// Normalize weights
760
weights /= trainSamples.rows;
761
}
762
763
void write_params(FileStorage& fs) const
764
{
765
fs << "nclusters" << nclusters;
766
fs << "cov_mat_type" << (covMatType == COV_MAT_SPHERICAL ? String("spherical") :
767
covMatType == COV_MAT_DIAGONAL ? String("diagonal") :
768
covMatType == COV_MAT_GENERIC ? String("generic") :
769
format("unknown_%d", covMatType));
770
writeTermCrit(fs, termCrit);
771
}
772
773
void write(FileStorage& fs) const CV_OVERRIDE
774
{
775
writeFormat(fs);
776
fs << "training_params" << "{";
777
write_params(fs);
778
fs << "}";
779
fs << "weights" << weights;
780
fs << "means" << means;
781
782
size_t i, n = covs.size();
783
784
fs << "covs" << "[";
785
for( i = 0; i < n; i++ )
786
fs << covs[i];
787
fs << "]";
788
}
789
790
void read_params(const FileNode& fn)
791
{
792
nclusters = (int)fn["nclusters"];
793
String s = (String)fn["cov_mat_type"];
794
covMatType = s == "spherical" ? COV_MAT_SPHERICAL :
795
s == "diagonal" ? COV_MAT_DIAGONAL :
796
s == "generic" ? COV_MAT_GENERIC : -1;
797
CV_Assert(covMatType >= 0);
798
termCrit = readTermCrit(fn);
799
}
800
801
void read(const FileNode& fn) CV_OVERRIDE
802
{
803
clear();
804
read_params(fn["training_params"]);
805
806
fn["weights"] >> weights;
807
fn["means"] >> means;
808
809
FileNode cfn = fn["covs"];
810
FileNodeIterator cfn_it = cfn.begin();
811
int i, n = (int)cfn.size();
812
covs.resize(n);
813
814
for( i = 0; i < n; i++, ++cfn_it )
815
(*cfn_it) >> covs[i];
816
817
decomposeCovs();
818
computeLogWeightDivDet();
819
}
820
821
Mat getWeights() const CV_OVERRIDE { return weights; }
822
Mat getMeans() const CV_OVERRIDE { return means; }
823
void getCovs(std::vector<Mat>& _covs) const CV_OVERRIDE
824
{
825
_covs.resize(covs.size());
826
std::copy(covs.begin(), covs.end(), _covs.begin());
827
}
828
829
// all inner matrices have type CV_64FC1
830
Mat trainSamples;
831
Mat trainProbs;
832
Mat trainLogLikelihoods;
833
Mat trainLabels;
834
835
Mat weights;
836
Mat means;
837
std::vector<Mat> covs;
838
839
std::vector<Mat> covsEigenValues;
840
std::vector<Mat> covsRotateMats;
841
std::vector<Mat> invCovsEigenValues;
842
Mat logWeightDivDet;
843
};
844
845
Ptr<EM> EM::create()
846
{
847
return makePtr<EMImpl>();
848
}
849
850
Ptr<EM> EM::load(const String& filepath, const String& nodeName)
851
{
852
return Algorithm::load<EM>(filepath, nodeName);
853
}
854
855
}
856
} // namespace cv
857
858
/* End of file. */
859
860