Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/core/src/batch_distance.cpp
16337 views
1
// This file is part of OpenCV project.
2
// It is subject to the license terms in the LICENSE file found in the top-level directory
3
// of this distribution and at http://opencv.org/license.html
4
5
6
#include "precomp.hpp"
7
#include "stat.hpp"
8
#include <opencv2/core/hal/hal.hpp>
9
10
namespace cv
11
{
12
13
template<typename _Tp, typename _Rt>
14
void batchDistL1_(const _Tp* src1, const _Tp* src2, size_t step2,
15
int nvecs, int len, _Rt* dist, const uchar* mask)
16
{
17
step2 /= sizeof(src2[0]);
18
if( !mask )
19
{
20
for( int i = 0; i < nvecs; i++ )
21
dist[i] = normL1<_Tp, _Rt>(src1, src2 + step2*i, len);
22
}
23
else
24
{
25
_Rt val0 = std::numeric_limits<_Rt>::max();
26
for( int i = 0; i < nvecs; i++ )
27
dist[i] = mask[i] ? normL1<_Tp, _Rt>(src1, src2 + step2*i, len) : val0;
28
}
29
}
30
31
template<typename _Tp, typename _Rt>
32
void batchDistL2Sqr_(const _Tp* src1, const _Tp* src2, size_t step2,
33
int nvecs, int len, _Rt* dist, const uchar* mask)
34
{
35
step2 /= sizeof(src2[0]);
36
if( !mask )
37
{
38
for( int i = 0; i < nvecs; i++ )
39
dist[i] = normL2Sqr<_Tp, _Rt>(src1, src2 + step2*i, len);
40
}
41
else
42
{
43
_Rt val0 = std::numeric_limits<_Rt>::max();
44
for( int i = 0; i < nvecs; i++ )
45
dist[i] = mask[i] ? normL2Sqr<_Tp, _Rt>(src1, src2 + step2*i, len) : val0;
46
}
47
}
48
49
template<>
50
void batchDistL2Sqr_(const float* src1, const float* src2, size_t step2,
51
int nvecs, int len, float* dist, const uchar* mask)
52
{
53
step2 /= sizeof(src2[0]);
54
if( !mask )
55
{
56
for( int i = 0; i < nvecs; i++ )
57
dist[i] = hal::normL2Sqr_(src1, src2 + step2*i, len);
58
}
59
else
60
{
61
float val0 = std::numeric_limits<float>::max();
62
for( int i = 0; i < nvecs; i++ )
63
dist[i] = mask[i] ? hal::normL2Sqr_(src1, src2 + step2*i, len) : val0;
64
}
65
}
66
67
template<typename _Tp, typename _Rt>
68
void batchDistL2_(const _Tp* src1, const _Tp* src2, size_t step2,
69
int nvecs, int len, _Rt* dist, const uchar* mask)
70
{
71
step2 /= sizeof(src2[0]);
72
if( !mask )
73
{
74
for( int i = 0; i < nvecs; i++ )
75
dist[i] = std::sqrt(normL2Sqr<_Tp, _Rt>(src1, src2 + step2*i, len));
76
}
77
else
78
{
79
_Rt val0 = std::numeric_limits<_Rt>::max();
80
for( int i = 0; i < nvecs; i++ )
81
dist[i] = mask[i] ? std::sqrt(normL2Sqr<_Tp, _Rt>(src1, src2 + step2*i, len)) : val0;
82
}
83
}
84
85
template<>
86
void batchDistL2_(const float* src1, const float* src2, size_t step2,
87
int nvecs, int len, float* dist, const uchar* mask)
88
{
89
step2 /= sizeof(src2[0]);
90
if( !mask )
91
{
92
for( int i = 0; i < nvecs; i++ )
93
dist[i] = std::sqrt(hal::normL2Sqr_(src1, src2 + step2*i, len));
94
}
95
else
96
{
97
float val0 = std::numeric_limits<float>::max();
98
for( int i = 0; i < nvecs; i++ )
99
dist[i] = mask[i] ? std::sqrt(hal::normL2Sqr_(src1, src2 + step2*i, len)) : val0;
100
}
101
}
102
103
static void batchDistHamming(const uchar* src1, const uchar* src2, size_t step2,
104
int nvecs, int len, int* dist, const uchar* mask)
105
{
106
step2 /= sizeof(src2[0]);
107
if( !mask )
108
{
109
for( int i = 0; i < nvecs; i++ )
110
dist[i] = hal::normHamming(src1, src2 + step2*i, len);
111
}
112
else
113
{
114
int val0 = INT_MAX;
115
for( int i = 0; i < nvecs; i++ )
116
{
117
if (mask[i])
118
dist[i] = hal::normHamming(src1, src2 + step2*i, len);
119
else
120
dist[i] = val0;
121
}
122
}
123
}
124
125
static void batchDistHamming2(const uchar* src1, const uchar* src2, size_t step2,
126
int nvecs, int len, int* dist, const uchar* mask)
127
{
128
step2 /= sizeof(src2[0]);
129
if( !mask )
130
{
131
for( int i = 0; i < nvecs; i++ )
132
dist[i] = hal::normHamming(src1, src2 + step2*i, len, 2);
133
}
134
else
135
{
136
int val0 = INT_MAX;
137
for( int i = 0; i < nvecs; i++ )
138
{
139
if (mask[i])
140
dist[i] = hal::normHamming(src1, src2 + step2*i, len, 2);
141
else
142
dist[i] = val0;
143
}
144
}
145
}
146
147
static void batchDistL1_8u32s(const uchar* src1, const uchar* src2, size_t step2,
148
int nvecs, int len, int* dist, const uchar* mask)
149
{
150
batchDistL1_<uchar, int>(src1, src2, step2, nvecs, len, dist, mask);
151
}
152
153
static void batchDistL1_8u32f(const uchar* src1, const uchar* src2, size_t step2,
154
int nvecs, int len, float* dist, const uchar* mask)
155
{
156
batchDistL1_<uchar, float>(src1, src2, step2, nvecs, len, dist, mask);
157
}
158
159
static void batchDistL2Sqr_8u32s(const uchar* src1, const uchar* src2, size_t step2,
160
int nvecs, int len, int* dist, const uchar* mask)
161
{
162
batchDistL2Sqr_<uchar, int>(src1, src2, step2, nvecs, len, dist, mask);
163
}
164
165
static void batchDistL2Sqr_8u32f(const uchar* src1, const uchar* src2, size_t step2,
166
int nvecs, int len, float* dist, const uchar* mask)
167
{
168
batchDistL2Sqr_<uchar, float>(src1, src2, step2, nvecs, len, dist, mask);
169
}
170
171
static void batchDistL2_8u32f(const uchar* src1, const uchar* src2, size_t step2,
172
int nvecs, int len, float* dist, const uchar* mask)
173
{
174
batchDistL2_<uchar, float>(src1, src2, step2, nvecs, len, dist, mask);
175
}
176
177
static void batchDistL1_32f(const float* src1, const float* src2, size_t step2,
178
int nvecs, int len, float* dist, const uchar* mask)
179
{
180
batchDistL1_<float, float>(src1, src2, step2, nvecs, len, dist, mask);
181
}
182
183
static void batchDistL2Sqr_32f(const float* src1, const float* src2, size_t step2,
184
int nvecs, int len, float* dist, const uchar* mask)
185
{
186
batchDistL2Sqr_<float, float>(src1, src2, step2, nvecs, len, dist, mask);
187
}
188
189
static void batchDistL2_32f(const float* src1, const float* src2, size_t step2,
190
int nvecs, int len, float* dist, const uchar* mask)
191
{
192
batchDistL2_<float, float>(src1, src2, step2, nvecs, len, dist, mask);
193
}
194
195
typedef void (*BatchDistFunc)(const uchar* src1, const uchar* src2, size_t step2,
196
int nvecs, int len, uchar* dist, const uchar* mask);
197
198
199
struct BatchDistInvoker : public ParallelLoopBody
200
{
201
BatchDistInvoker( const Mat& _src1, const Mat& _src2,
202
Mat& _dist, Mat& _nidx, int _K,
203
const Mat& _mask, int _update,
204
BatchDistFunc _func)
205
{
206
src1 = &_src1;
207
src2 = &_src2;
208
dist = &_dist;
209
nidx = &_nidx;
210
K = _K;
211
mask = &_mask;
212
update = _update;
213
func = _func;
214
}
215
216
void operator()(const Range& range) const CV_OVERRIDE
217
{
218
AutoBuffer<int> buf(src2->rows);
219
int* bufptr = buf.data();
220
221
for( int i = range.start; i < range.end; i++ )
222
{
223
func(src1->ptr(i), src2->ptr(), src2->step, src2->rows, src2->cols,
224
K > 0 ? (uchar*)bufptr : dist->ptr(i), mask->data ? mask->ptr(i) : 0);
225
226
if( K > 0 )
227
{
228
int* nidxptr = nidx->ptr<int>(i);
229
// since positive float's can be compared just like int's,
230
// we handle both CV_32S and CV_32F cases with a single branch
231
int* distptr = (int*)dist->ptr(i);
232
233
int j, k;
234
235
for( j = 0; j < src2->rows; j++ )
236
{
237
int d = bufptr[j];
238
if( d < distptr[K-1] )
239
{
240
for( k = K-2; k >= 0 && distptr[k] > d; k-- )
241
{
242
nidxptr[k+1] = nidxptr[k];
243
distptr[k+1] = distptr[k];
244
}
245
nidxptr[k+1] = j + update;
246
distptr[k+1] = d;
247
}
248
}
249
}
250
}
251
}
252
253
const Mat *src1;
254
const Mat *src2;
255
Mat *dist;
256
Mat *nidx;
257
const Mat *mask;
258
int K;
259
int update;
260
BatchDistFunc func;
261
};
262
263
}
264
265
void cv::batchDistance( InputArray _src1, InputArray _src2,
266
OutputArray _dist, int dtype, OutputArray _nidx,
267
int normType, int K, InputArray _mask,
268
int update, bool crosscheck )
269
{
270
CV_INSTRUMENT_REGION();
271
272
Mat src1 = _src1.getMat(), src2 = _src2.getMat(), mask = _mask.getMat();
273
int type = src1.type();
274
CV_Assert( type == src2.type() && src1.cols == src2.cols &&
275
(type == CV_32F || type == CV_8U));
276
CV_Assert( _nidx.needed() == (K > 0) );
277
278
if( dtype == -1 )
279
{
280
dtype = normType == NORM_HAMMING || normType == NORM_HAMMING2 ? CV_32S : CV_32F;
281
}
282
CV_Assert( (type == CV_8U && dtype == CV_32S) || dtype == CV_32F);
283
284
K = std::min(K, src2.rows);
285
286
_dist.create(src1.rows, (K > 0 ? K : src2.rows), dtype);
287
Mat dist = _dist.getMat(), nidx;
288
if( _nidx.needed() )
289
{
290
_nidx.create(dist.size(), CV_32S);
291
nidx = _nidx.getMat();
292
}
293
294
if( update == 0 && K > 0 )
295
{
296
dist = Scalar::all(dtype == CV_32S ? (double)INT_MAX : (double)FLT_MAX);
297
nidx = Scalar::all(-1);
298
}
299
300
if( crosscheck )
301
{
302
CV_Assert( K == 1 && update == 0 && mask.empty() );
303
CV_Assert(!nidx.empty());
304
Mat tdist, tidx;
305
batchDistance(src2, src1, tdist, dtype, tidx, normType, K, mask, 0, false);
306
307
// if an idx-th element from src1 appeared to be the nearest to i-th element of src2,
308
// we update the minimum mutual distance between idx-th element of src1 and the whole src2 set.
309
// As a result, if nidx[idx] = i*, it means that idx-th element of src1 is the nearest
310
// to i*-th element of src2 and i*-th element of src2 is the closest to idx-th element of src1.
311
// If nidx[idx] = -1, it means that there is no such ideal couple for it in src2.
312
// This O(N) procedure is called cross-check and it helps to eliminate some false matches.
313
if( dtype == CV_32S )
314
{
315
for( int i = 0; i < tdist.rows; i++ )
316
{
317
int idx = tidx.at<int>(i);
318
int d = tdist.at<int>(i), d0 = dist.at<int>(idx);
319
if( d < d0 )
320
{
321
dist.at<int>(idx) = d;
322
nidx.at<int>(idx) = i + update;
323
}
324
}
325
}
326
else
327
{
328
for( int i = 0; i < tdist.rows; i++ )
329
{
330
int idx = tidx.at<int>(i);
331
float d = tdist.at<float>(i), d0 = dist.at<float>(idx);
332
if( d < d0 )
333
{
334
dist.at<float>(idx) = d;
335
nidx.at<int>(idx) = i + update;
336
}
337
}
338
}
339
return;
340
}
341
342
BatchDistFunc func = 0;
343
if( type == CV_8U )
344
{
345
if( normType == NORM_L1 && dtype == CV_32S )
346
func = (BatchDistFunc)batchDistL1_8u32s;
347
else if( normType == NORM_L1 && dtype == CV_32F )
348
func = (BatchDistFunc)batchDistL1_8u32f;
349
else if( normType == NORM_L2SQR && dtype == CV_32S )
350
func = (BatchDistFunc)batchDistL2Sqr_8u32s;
351
else if( normType == NORM_L2SQR && dtype == CV_32F )
352
func = (BatchDistFunc)batchDistL2Sqr_8u32f;
353
else if( normType == NORM_L2 && dtype == CV_32F )
354
func = (BatchDistFunc)batchDistL2_8u32f;
355
else if( normType == NORM_HAMMING && dtype == CV_32S )
356
func = (BatchDistFunc)batchDistHamming;
357
else if( normType == NORM_HAMMING2 && dtype == CV_32S )
358
func = (BatchDistFunc)batchDistHamming2;
359
}
360
else if( type == CV_32F && dtype == CV_32F )
361
{
362
if( normType == NORM_L1 )
363
func = (BatchDistFunc)batchDistL1_32f;
364
else if( normType == NORM_L2SQR )
365
func = (BatchDistFunc)batchDistL2Sqr_32f;
366
else if( normType == NORM_L2 )
367
func = (BatchDistFunc)batchDistL2_32f;
368
}
369
370
if( func == 0 )
371
CV_Error_(CV_StsUnsupportedFormat,
372
("The combination of type=%d, dtype=%d and normType=%d is not supported",
373
type, dtype, normType));
374
375
parallel_for_(Range(0, src1.rows),
376
BatchDistInvoker(src1, src2, dist, nidx, K, mask, update, func));
377
}
378
379