Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/ml/src/knearest.cpp
16337 views
1
/*M///////////////////////////////////////////////////////////////////////////////////////
2
//
3
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4
//
5
// By downloading, copying, installing or using the software you agree to this license.
6
// If you do not agree to this license, do not download, install,
7
// copy or use the software.
8
//
9
//
10
// License Agreement
11
// For Open Source Computer Vision Library
12
//
13
// Copyright (C) 2000, Intel Corporation, all rights reserved.
14
// Copyright (C) 2014, Itseez Inc, all rights reserved.
15
// Third party copyrights are property of their respective owners.
16
//
17
// Redistribution and use in source and binary forms, with or without modification,
18
// are permitted provided that the following conditions are met:
19
//
20
// * Redistribution's of source code must retain the above copyright notice,
21
// this list of conditions and the following disclaimer.
22
//
23
// * Redistribution's in binary form must reproduce the above copyright notice,
24
// this list of conditions and the following disclaimer in the documentation
25
// and/or other materials provided with the distribution.
26
//
27
// * The name of the copyright holders may not be used to endorse or promote products
28
// derived from this software without specific prior written permission.
29
//
30
// This software is provided by the copyright holders and contributors "as is" and
31
// any express or implied warranties, including, but not limited to, the implied
32
// warranties of merchantability and fitness for a particular purpose are disclaimed.
33
// In no event shall the Intel Corporation or contributors be liable for any direct,
34
// indirect, incidental, special, exemplary, or consequential damages
35
// (including, but not limited to, procurement of substitute goods or services;
36
// loss of use, data, or profits; or business interruption) however caused
37
// and on any theory of liability, whether in contract, strict liability,
38
// or tort (including negligence or otherwise) arising in any way out of
39
// the use of this software, even if advised of the possibility of such damage.
40
//
41
//M*/
42
43
#include "precomp.hpp"
44
#include "kdtree.hpp"
45
46
/****************************************************************************************\
47
* K-Nearest Neighbors Classifier *
48
\****************************************************************************************/
49
50
namespace cv {
51
namespace ml {
52
53
const String NAME_BRUTE_FORCE = "opencv_ml_knn";
54
const String NAME_KDTREE = "opencv_ml_knn_kd";
55
56
class Impl
57
{
58
public:
59
Impl()
60
{
61
defaultK = 10;
62
isclassifier = true;
63
Emax = INT_MAX;
64
}
65
66
virtual ~Impl() {}
67
virtual String getModelName() const = 0;
68
virtual int getType() const = 0;
69
virtual float findNearest( InputArray _samples, int k,
70
OutputArray _results,
71
OutputArray _neighborResponses,
72
OutputArray _dists ) const = 0;
73
74
bool train( const Ptr<TrainData>& data, int flags )
75
{
76
Mat new_samples = data->getTrainSamples(ROW_SAMPLE);
77
Mat new_responses;
78
data->getTrainResponses().convertTo(new_responses, CV_32F);
79
bool update = (flags & ml::KNearest::UPDATE_MODEL) != 0 && !samples.empty();
80
81
CV_Assert( new_samples.type() == CV_32F );
82
83
if( !update )
84
{
85
clear();
86
}
87
else
88
{
89
CV_Assert( new_samples.cols == samples.cols &&
90
new_responses.cols == responses.cols );
91
}
92
93
samples.push_back(new_samples);
94
responses.push_back(new_responses);
95
96
doTrain(samples);
97
98
return true;
99
}
100
101
virtual void doTrain(InputArray points) { CV_UNUSED(points); }
102
103
void clear()
104
{
105
samples.release();
106
responses.release();
107
}
108
109
void read( const FileNode& fn )
110
{
111
clear();
112
isclassifier = (int)fn["is_classifier"] != 0;
113
defaultK = (int)fn["default_k"];
114
115
fn["samples"] >> samples;
116
fn["responses"] >> responses;
117
}
118
119
void write( FileStorage& fs ) const
120
{
121
fs << "is_classifier" << (int)isclassifier;
122
fs << "default_k" << defaultK;
123
124
fs << "samples" << samples;
125
fs << "responses" << responses;
126
}
127
128
public:
129
int defaultK;
130
bool isclassifier;
131
int Emax;
132
133
Mat samples;
134
Mat responses;
135
};
136
137
class BruteForceImpl CV_FINAL : public Impl
138
{
139
public:
140
String getModelName() const CV_OVERRIDE { return NAME_BRUTE_FORCE; }
141
int getType() const CV_OVERRIDE { return ml::KNearest::BRUTE_FORCE; }
142
143
void findNearestCore( const Mat& _samples, int k, const Range& range,
144
Mat* results, Mat* neighbor_responses,
145
Mat* dists, float* presult ) const
146
{
147
int testidx, baseidx, i, j, d = samples.cols, nsamples = samples.rows;
148
int testcount = range.end - range.start;
149
150
AutoBuffer<float> buf(testcount*k*2);
151
float* dbuf = buf.data();
152
float* rbuf = dbuf + testcount*k;
153
154
const float* rptr = responses.ptr<float>();
155
156
for( testidx = 0; testidx < testcount; testidx++ )
157
{
158
for( i = 0; i < k; i++ )
159
{
160
dbuf[testidx*k + i] = FLT_MAX;
161
rbuf[testidx*k + i] = 0.f;
162
}
163
}
164
165
for( baseidx = 0; baseidx < nsamples; baseidx++ )
166
{
167
for( testidx = 0; testidx < testcount; testidx++ )
168
{
169
const float* v = samples.ptr<float>(baseidx);
170
const float* u = _samples.ptr<float>(testidx + range.start);
171
172
float s = 0;
173
for( i = 0; i <= d - 4; i += 4 )
174
{
175
float t0 = u[i] - v[i], t1 = u[i+1] - v[i+1];
176
float t2 = u[i+2] - v[i+2], t3 = u[i+3] - v[i+3];
177
s += t0*t0 + t1*t1 + t2*t2 + t3*t3;
178
}
179
180
for( ; i < d; i++ )
181
{
182
float t0 = u[i] - v[i];
183
s += t0*t0;
184
}
185
186
Cv32suf si;
187
si.f = (float)s;
188
Cv32suf* dd = (Cv32suf*)(&dbuf[testidx*k]);
189
float* nr = &rbuf[testidx*k];
190
191
for( i = k; i > 0; i-- )
192
if( si.i >= dd[i-1].i )
193
break;
194
if( i >= k )
195
continue;
196
197
for( j = k-2; j >= i; j-- )
198
{
199
dd[j+1].i = dd[j].i;
200
nr[j+1] = nr[j];
201
}
202
dd[i].i = si.i;
203
nr[i] = rptr[baseidx];
204
}
205
}
206
207
float result = 0.f;
208
float inv_scale = 1.f/k;
209
210
for( testidx = 0; testidx < testcount; testidx++ )
211
{
212
if( neighbor_responses )
213
{
214
float* nr = neighbor_responses->ptr<float>(testidx + range.start);
215
for( j = 0; j < k; j++ )
216
nr[j] = rbuf[testidx*k + j];
217
for( ; j < k; j++ )
218
nr[j] = 0.f;
219
}
220
221
if( dists )
222
{
223
float* dptr = dists->ptr<float>(testidx + range.start);
224
for( j = 0; j < k; j++ )
225
dptr[j] = dbuf[testidx*k + j];
226
for( ; j < k; j++ )
227
dptr[j] = 0.f;
228
}
229
230
if( results || testidx+range.start == 0 )
231
{
232
if( !isclassifier || k == 1 )
233
{
234
float s = 0.f;
235
for( j = 0; j < k; j++ )
236
s += rbuf[testidx*k + j];
237
result = (float)(s*inv_scale);
238
}
239
else
240
{
241
float* rp = rbuf + testidx*k;
242
std::sort(rp, rp+k);
243
244
result = rp[0];
245
int prev_start = 0;
246
int best_count = 0;
247
for( j = 1; j <= k; j++ )
248
{
249
if( j == k || rp[j] != rp[j-1] )
250
{
251
int count = j - prev_start;
252
if( best_count < count )
253
{
254
best_count = count;
255
result = rp[j-1];
256
}
257
prev_start = j;
258
}
259
}
260
}
261
if( results )
262
results->at<float>(testidx + range.start) = result;
263
if( presult && testidx+range.start == 0 )
264
*presult = result;
265
}
266
}
267
}
268
269
struct findKNearestInvoker : public ParallelLoopBody
270
{
271
findKNearestInvoker(const BruteForceImpl* _p, int _k, const Mat& __samples,
272
Mat* __results, Mat* __neighbor_responses, Mat* __dists, float* _presult)
273
{
274
p = _p;
275
k = _k;
276
_samples = &__samples;
277
_results = __results;
278
_neighbor_responses = __neighbor_responses;
279
_dists = __dists;
280
presult = _presult;
281
}
282
283
void operator()(const Range& range) const CV_OVERRIDE
284
{
285
int delta = std::min(range.end - range.start, 256);
286
for( int start = range.start; start < range.end; start += delta )
287
{
288
p->findNearestCore( *_samples, k, Range(start, std::min(start + delta, range.end)),
289
_results, _neighbor_responses, _dists, presult );
290
}
291
}
292
293
const BruteForceImpl* p;
294
int k;
295
const Mat* _samples;
296
Mat* _results;
297
Mat* _neighbor_responses;
298
Mat* _dists;
299
float* presult;
300
};
301
302
float findNearest( InputArray _samples, int k,
303
OutputArray _results,
304
OutputArray _neighborResponses,
305
OutputArray _dists ) const CV_OVERRIDE
306
{
307
float result = 0.f;
308
CV_Assert( 0 < k );
309
k = std::min(k, samples.rows);
310
311
Mat test_samples = _samples.getMat();
312
CV_Assert( test_samples.type() == CV_32F && test_samples.cols == samples.cols );
313
int testcount = test_samples.rows;
314
315
if( testcount == 0 )
316
{
317
_results.release();
318
_neighborResponses.release();
319
_dists.release();
320
return 0.f;
321
}
322
323
Mat res, nr, d, *pres = 0, *pnr = 0, *pd = 0;
324
if( _results.needed() )
325
{
326
_results.create(testcount, 1, CV_32F);
327
pres = &(res = _results.getMat());
328
}
329
if( _neighborResponses.needed() )
330
{
331
_neighborResponses.create(testcount, k, CV_32F);
332
pnr = &(nr = _neighborResponses.getMat());
333
}
334
if( _dists.needed() )
335
{
336
_dists.create(testcount, k, CV_32F);
337
pd = &(d = _dists.getMat());
338
}
339
340
findKNearestInvoker invoker(this, k, test_samples, pres, pnr, pd, &result);
341
parallel_for_(Range(0, testcount), invoker);
342
//invoker(Range(0, testcount));
343
return result;
344
}
345
};
346
347
348
class KDTreeImpl CV_FINAL : public Impl
349
{
350
public:
351
String getModelName() const CV_OVERRIDE { return NAME_KDTREE; }
352
int getType() const CV_OVERRIDE { return ml::KNearest::KDTREE; }
353
354
void doTrain(InputArray points) CV_OVERRIDE
355
{
356
tr.build(points);
357
}
358
359
float findNearest( InputArray _samples, int k,
360
OutputArray _results,
361
OutputArray _neighborResponses,
362
OutputArray _dists ) const CV_OVERRIDE
363
{
364
float result = 0.f;
365
CV_Assert( 0 < k );
366
k = std::min(k, samples.rows);
367
368
Mat test_samples = _samples.getMat();
369
CV_Assert( test_samples.type() == CV_32F && test_samples.cols == samples.cols );
370
int testcount = test_samples.rows;
371
372
if( testcount == 0 )
373
{
374
_results.release();
375
_neighborResponses.release();
376
_dists.release();
377
return 0.f;
378
}
379
380
Mat res, nr, d;
381
if( _results.needed() )
382
{
383
_results.create(testcount, 1, CV_32F);
384
res = _results.getMat();
385
}
386
if( _neighborResponses.needed() )
387
{
388
_neighborResponses.create(testcount, k, CV_32F);
389
nr = _neighborResponses.getMat();
390
}
391
if( _dists.needed() )
392
{
393
_dists.create(testcount, k, CV_32F);
394
d = _dists.getMat();
395
}
396
397
for (int i=0; i<test_samples.rows; ++i)
398
{
399
Mat _res, _nr, _d;
400
if (res.rows>i)
401
{
402
_res = res.row(i);
403
}
404
if (nr.rows>i)
405
{
406
_nr = nr.row(i);
407
}
408
if (d.rows>i)
409
{
410
_d = d.row(i);
411
}
412
tr.findNearest(test_samples.row(i), k, Emax, _res, _nr, _d, noArray());
413
}
414
415
return result; // currently always 0
416
}
417
418
KDTree tr;
419
};
420
421
//================================================================
422
423
class KNearestImpl CV_FINAL : public KNearest
424
{
425
inline int getDefaultK() const CV_OVERRIDE { return impl->defaultK; }
426
inline void setDefaultK(int val) CV_OVERRIDE { impl->defaultK = val; }
427
inline bool getIsClassifier() const CV_OVERRIDE { return impl->isclassifier; }
428
inline void setIsClassifier(bool val) CV_OVERRIDE { impl->isclassifier = val; }
429
inline int getEmax() const CV_OVERRIDE { return impl->Emax; }
430
inline void setEmax(int val) CV_OVERRIDE { impl->Emax = val; }
431
432
public:
433
int getAlgorithmType() const CV_OVERRIDE
434
{
435
return impl->getType();
436
}
437
void setAlgorithmType(int val) CV_OVERRIDE
438
{
439
if (val != BRUTE_FORCE && val != KDTREE)
440
val = BRUTE_FORCE;
441
442
int k = getDefaultK();
443
int e = getEmax();
444
bool c = getIsClassifier();
445
446
initImpl(val);
447
448
setDefaultK(k);
449
setEmax(e);
450
setIsClassifier(c);
451
}
452
453
public:
454
KNearestImpl()
455
{
456
initImpl(BRUTE_FORCE);
457
}
458
~KNearestImpl()
459
{
460
}
461
462
bool isClassifier() const CV_OVERRIDE { return impl->isclassifier; }
463
bool isTrained() const CV_OVERRIDE { return !impl->samples.empty(); }
464
465
int getVarCount() const CV_OVERRIDE { return impl->samples.cols; }
466
467
void write( FileStorage& fs ) const CV_OVERRIDE
468
{
469
writeFormat(fs);
470
impl->write(fs);
471
}
472
473
void read( const FileNode& fn ) CV_OVERRIDE
474
{
475
int algorithmType = BRUTE_FORCE;
476
if (fn.name() == NAME_KDTREE)
477
algorithmType = KDTREE;
478
initImpl(algorithmType);
479
impl->read(fn);
480
}
481
482
float findNearest( InputArray samples, int k,
483
OutputArray results,
484
OutputArray neighborResponses=noArray(),
485
OutputArray dist=noArray() ) const CV_OVERRIDE
486
{
487
return impl->findNearest(samples, k, results, neighborResponses, dist);
488
}
489
490
float predict(InputArray inputs, OutputArray outputs, int) const CV_OVERRIDE
491
{
492
return impl->findNearest( inputs, impl->defaultK, outputs, noArray(), noArray() );
493
}
494
495
bool train( const Ptr<TrainData>& data, int flags ) CV_OVERRIDE
496
{
497
return impl->train(data, flags);
498
}
499
500
String getDefaultName() const CV_OVERRIDE { return impl->getModelName(); }
501
502
protected:
503
void initImpl(int algorithmType)
504
{
505
if (algorithmType != KDTREE)
506
impl = makePtr<BruteForceImpl>();
507
else
508
impl = makePtr<KDTreeImpl>();
509
}
510
Ptr<Impl> impl;
511
};
512
513
Ptr<KNearest> KNearest::create()
514
{
515
return makePtr<KNearestImpl>();
516
}
517
518
}
519
}
520
521
/* End of file */
522
523