Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/imgproc/src/emd.cpp
16354 views
1
/*M///////////////////////////////////////////////////////////////////////////////////////
2
//
3
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4
//
5
// By downloading, copying, installing or using the software you agree to this license.
6
// If you do not agree to this license, do not download, install,
7
// copy or use the software.
8
//
9
//
10
// Intel License Agreement
11
// For Open Source Computer Vision Library
12
//
13
// Copyright (C) 2000, Intel Corporation, all rights reserved.
14
// Third party copyrights are property of their respective owners.
15
//
16
// Redistribution and use in source and binary forms, with or without modification,
17
// are permitted provided that the following conditions are met:
18
//
19
// * Redistribution's of source code must retain the above copyright notice,
20
// this list of conditions and the following disclaimer.
21
//
22
// * Redistribution's in binary form must reproduce the above copyright notice,
23
// this list of conditions and the following disclaimer in the documentation
24
// and/or other materials provided with the distribution.
25
//
26
// * The name of Intel Corporation may not be used to endorse or promote products
27
// derived from this software without specific prior written permission.
28
//
29
// This software is provided by the copyright holders and contributors "as is" and
30
// any express or implied warranties, including, but not limited to, the implied
31
// warranties of merchantability and fitness for a particular purpose are disclaimed.
32
// In no event shall the Intel Corporation or contributors be liable for any direct,
33
// indirect, incidental, special, exemplary, or consequential damages
34
// (including, but not limited to, procurement of substitute goods or services;
35
// loss of use, data, or profits; or business interruption) however caused
36
// and on any theory of liability, whether in contract, strict liability,
37
// or tort (including negligence or otherwise) arising in any way out of
38
// the use of this software, even if advised of the possibility of such damage.
39
//
40
//M*/
41
42
/*
43
Partially based on Yossi Rubner code:
44
=========================================================================
45
emd.c
46
47
Last update: 3/14/98
48
49
An implementation of the Earth Movers Distance.
50
Based of the solution for the Transportation problem as described in
51
"Introduction to Mathematical Programming" by F. S. Hillier and
52
G. J. Lieberman, McGraw-Hill, 1990.
53
54
Copyright (C) 1998 Yossi Rubner
55
Computer Science Department, Stanford University
56
E-Mail: [email protected] URL: http://vision.stanford.edu/~rubner
57
==========================================================================
58
*/
59
#include "precomp.hpp"
60
61
#define MAX_ITERATIONS 500
62
#define CV_EMD_INF ((float)1e20)
63
#define CV_EMD_EPS ((float)1e-5)
64
65
/* CvNode1D is used for lists, representing 1D sparse array */
66
typedef struct CvNode1D
67
{
68
float val;
69
struct CvNode1D *next;
70
}
71
CvNode1D;
72
73
/* CvNode2D is used for lists, representing 2D sparse matrix */
74
typedef struct CvNode2D
75
{
76
float val;
77
struct CvNode2D *next[2]; /* next row & next column */
78
int i, j;
79
}
80
CvNode2D;
81
82
83
typedef struct CvEMDState
84
{
85
int ssize, dsize;
86
87
float **cost;
88
CvNode2D *_x;
89
CvNode2D *end_x;
90
CvNode2D *enter_x;
91
char **is_x;
92
93
CvNode2D **rows_x;
94
CvNode2D **cols_x;
95
96
CvNode1D *u;
97
CvNode1D *v;
98
99
int* idx1;
100
int* idx2;
101
102
/* find_loop buffers */
103
CvNode2D **loop;
104
char *is_used;
105
106
/* russel buffers */
107
float *s;
108
float *d;
109
float **delta;
110
111
float weight, max_cost;
112
char *buffer;
113
}
114
CvEMDState;
115
116
/* static function declaration */
117
static int icvInitEMD( const float *signature1, int size1,
118
const float *signature2, int size2,
119
int dims, CvDistanceFunction dist_func, void *user_param,
120
const float* cost, int cost_step,
121
CvEMDState * state, float *lower_bound,
122
cv::AutoBuffer<char>& _buffer );
123
124
static int icvFindBasicVariables( float **cost, char **is_x,
125
CvNode1D * u, CvNode1D * v, int ssize, int dsize );
126
127
static float icvIsOptimal( float **cost, char **is_x,
128
CvNode1D * u, CvNode1D * v,
129
int ssize, int dsize, CvNode2D * enter_x );
130
131
static void icvRussel( CvEMDState * state );
132
133
134
static bool icvNewSolution( CvEMDState * state );
135
static int icvFindLoop( CvEMDState * state );
136
137
static void icvAddBasicVariable( CvEMDState * state,
138
int min_i, int min_j,
139
CvNode1D * prev_u_min_i,
140
CvNode1D * prev_v_min_j,
141
CvNode1D * u_head );
142
143
static float icvDistL2( const float *x, const float *y, void *user_param );
144
static float icvDistL1( const float *x, const float *y, void *user_param );
145
static float icvDistC( const float *x, const float *y, void *user_param );
146
147
/* The main function */
148
CV_IMPL float cvCalcEMD2( const CvArr* signature_arr1,
149
const CvArr* signature_arr2,
150
int dist_type,
151
CvDistanceFunction dist_func,
152
const CvArr* cost_matrix,
153
CvArr* flow_matrix,
154
float *lower_bound,
155
void *user_param )
156
{
157
cv::AutoBuffer<char> local_buf;
158
CvEMDState state;
159
float emd = 0;
160
161
memset( &state, 0, sizeof(state));
162
163
double total_cost = 0;
164
int result = 0;
165
float eps, min_delta;
166
CvNode2D *xp = 0;
167
CvMat sign_stub1, *signature1 = (CvMat*)signature_arr1;
168
CvMat sign_stub2, *signature2 = (CvMat*)signature_arr2;
169
CvMat cost_stub, *cost = &cost_stub;
170
CvMat flow_stub, *flow = (CvMat*)flow_matrix;
171
int dims, size1, size2;
172
173
signature1 = cvGetMat( signature1, &sign_stub1 );
174
signature2 = cvGetMat( signature2, &sign_stub2 );
175
176
if( signature1->cols != signature2->cols )
177
CV_Error( CV_StsUnmatchedSizes, "The arrays must have equal number of columns (which is number of dimensions but 1)" );
178
179
dims = signature1->cols - 1;
180
size1 = signature1->rows;
181
size2 = signature2->rows;
182
183
if( !CV_ARE_TYPES_EQ( signature1, signature2 ))
184
CV_Error( CV_StsUnmatchedFormats, "The array must have equal types" );
185
186
if( CV_MAT_TYPE( signature1->type ) != CV_32FC1 )
187
CV_Error( CV_StsUnsupportedFormat, "The signatures must be 32fC1" );
188
189
if( flow )
190
{
191
flow = cvGetMat( flow, &flow_stub );
192
193
if( flow->rows != size1 || flow->cols != size2 )
194
CV_Error( CV_StsUnmatchedSizes,
195
"The flow matrix size does not match to the signatures' sizes" );
196
197
if( CV_MAT_TYPE( flow->type ) != CV_32FC1 )
198
CV_Error( CV_StsUnsupportedFormat, "The flow matrix must be 32fC1" );
199
}
200
201
cost->data.fl = 0;
202
cost->step = 0;
203
204
if( dist_type < 0 )
205
{
206
if( cost_matrix )
207
{
208
if( dist_func )
209
CV_Error( CV_StsBadArg,
210
"Only one of cost matrix or distance function should be non-NULL in case of user-defined distance" );
211
212
if( lower_bound )
213
CV_Error( CV_StsBadArg,
214
"The lower boundary can not be calculated if the cost matrix is used" );
215
216
cost = cvGetMat( cost_matrix, &cost_stub );
217
if( cost->rows != size1 || cost->cols != size2 )
218
CV_Error( CV_StsUnmatchedSizes,
219
"The cost matrix size does not match to the signatures' sizes" );
220
221
if( CV_MAT_TYPE( cost->type ) != CV_32FC1 )
222
CV_Error( CV_StsUnsupportedFormat, "The cost matrix must be 32fC1" );
223
}
224
else if( !dist_func )
225
CV_Error( CV_StsNullPtr, "In case of user-defined distance Distance function is undefined" );
226
}
227
else
228
{
229
if( dims == 0 )
230
CV_Error( CV_StsBadSize,
231
"Number of dimensions can be 0 only if a user-defined metric is used" );
232
user_param = (void *) (size_t)dims;
233
switch (dist_type)
234
{
235
case CV_DIST_L1:
236
dist_func = icvDistL1;
237
break;
238
case CV_DIST_L2:
239
dist_func = icvDistL2;
240
break;
241
case CV_DIST_C:
242
dist_func = icvDistC;
243
break;
244
default:
245
CV_Error( CV_StsBadFlag, "Bad or unsupported metric type" );
246
}
247
}
248
249
result = icvInitEMD( signature1->data.fl, size1,
250
signature2->data.fl, size2,
251
dims, dist_func, user_param,
252
cost->data.fl, cost->step,
253
&state, lower_bound, local_buf );
254
255
if( result > 0 && lower_bound )
256
{
257
emd = *lower_bound;
258
return emd;
259
}
260
261
eps = CV_EMD_EPS * state.max_cost;
262
263
/* if ssize = 1 or dsize = 1 then we are done, else ... */
264
if( state.ssize > 1 && state.dsize > 1 )
265
{
266
int itr;
267
268
for( itr = 1; itr < MAX_ITERATIONS; itr++ )
269
{
270
/* find basic variables */
271
result = icvFindBasicVariables( state.cost, state.is_x,
272
state.u, state.v, state.ssize, state.dsize );
273
if( result < 0 )
274
break;
275
276
/* check for optimality */
277
min_delta = icvIsOptimal( state.cost, state.is_x,
278
state.u, state.v,
279
state.ssize, state.dsize, state.enter_x );
280
281
if( min_delta == CV_EMD_INF )
282
CV_Error( CV_StsNoConv, "" );
283
284
/* if no negative deltamin, we found the optimal solution */
285
if( min_delta >= -eps )
286
break;
287
288
/* improve solution */
289
if(!icvNewSolution( &state ))
290
CV_Error( CV_StsNoConv, "" );
291
}
292
}
293
294
/* compute the total flow */
295
for( xp = state._x; xp < state.end_x; xp++ )
296
{
297
float val = xp->val;
298
int i = xp->i;
299
int j = xp->j;
300
301
if( xp == state.enter_x )
302
continue;
303
304
int ci = state.idx1[i];
305
int cj = state.idx2[j];
306
307
if( ci >= 0 && cj >= 0 )
308
{
309
total_cost += (double)val * state.cost[i][j];
310
if( flow )
311
((float*)(flow->data.ptr + flow->step*ci))[cj] = val;
312
}
313
}
314
315
emd = (float) (total_cost / state.weight);
316
return emd;
317
}
318
319
320
/************************************************************************************\
321
* initialize structure, allocate buffers and generate initial golution *
322
\************************************************************************************/
323
static int icvInitEMD( const float* signature1, int size1,
324
const float* signature2, int size2,
325
int dims, CvDistanceFunction dist_func, void* user_param,
326
const float* cost, int cost_step,
327
CvEMDState* state, float* lower_bound,
328
cv::AutoBuffer<char>& _buffer )
329
{
330
float s_sum = 0, d_sum = 0, diff;
331
int i, j;
332
int ssize = 0, dsize = 0;
333
int equal_sums = 1;
334
int buffer_size;
335
float max_cost = 0;
336
char *buffer, *buffer_end;
337
338
memset( state, 0, sizeof( *state ));
339
assert( cost_step % sizeof(float) == 0 );
340
cost_step /= sizeof(float);
341
342
/* calculate buffer size */
343
buffer_size = (size1+1) * (size2+1) * (sizeof( float ) + /* cost */
344
sizeof( char ) + /* is_x */
345
sizeof( float )) + /* delta matrix */
346
(size1 + size2 + 2) * (sizeof( CvNode2D ) + /* _x */
347
sizeof( CvNode2D * ) + /* cols_x & rows_x */
348
sizeof( CvNode1D ) + /* u & v */
349
sizeof( float ) + /* s & d */
350
sizeof( int ) + sizeof(CvNode2D*)) + /* idx1 & idx2 */
351
(size1+1) * (sizeof( float * ) + sizeof( char * ) + /* rows pointers for */
352
sizeof( float * )) + 256; /* cost, is_x and delta */
353
354
if( buffer_size < (int) (dims * 2 * sizeof( float )))
355
{
356
buffer_size = dims * 2 * sizeof( float );
357
}
358
359
/* allocate buffers */
360
_buffer.allocate(buffer_size);
361
362
state->buffer = buffer = _buffer.data();
363
buffer_end = buffer + buffer_size;
364
365
state->idx1 = (int*) buffer;
366
buffer += (size1 + 1) * sizeof( int );
367
368
state->idx2 = (int*) buffer;
369
buffer += (size2 + 1) * sizeof( int );
370
371
state->s = (float *) buffer;
372
buffer += (size1 + 1) * sizeof( float );
373
374
state->d = (float *) buffer;
375
buffer += (size2 + 1) * sizeof( float );
376
377
/* sum up the supply and demand */
378
for( i = 0; i < size1; i++ )
379
{
380
float weight = signature1[i * (dims + 1)];
381
382
if( weight > 0 )
383
{
384
s_sum += weight;
385
state->s[ssize] = weight;
386
state->idx1[ssize++] = i;
387
388
}
389
else if( weight < 0 )
390
CV_Error(CV_StsBadArg, "signature1 must not contain negative weights");
391
}
392
393
for( i = 0; i < size2; i++ )
394
{
395
float weight = signature2[i * (dims + 1)];
396
397
if( weight > 0 )
398
{
399
d_sum += weight;
400
state->d[dsize] = weight;
401
state->idx2[dsize++] = i;
402
}
403
else if( weight < 0 )
404
CV_Error(CV_StsBadArg, "signature2 must not contain negative weights");
405
}
406
407
if( ssize == 0 )
408
CV_Error(CV_StsBadArg, "signature1 must contain at least one non-zero value");
409
if( dsize == 0 )
410
CV_Error(CV_StsBadArg, "signature2 must contain at least one non-zero value");
411
412
/* if supply different than the demand, add a zero-cost dummy cluster */
413
diff = s_sum - d_sum;
414
if( fabs( diff ) >= CV_EMD_EPS * s_sum )
415
{
416
equal_sums = 0;
417
if( diff < 0 )
418
{
419
state->s[ssize] = -diff;
420
state->idx1[ssize++] = -1;
421
}
422
else
423
{
424
state->d[dsize] = diff;
425
state->idx2[dsize++] = -1;
426
}
427
}
428
429
state->ssize = ssize;
430
state->dsize = dsize;
431
state->weight = s_sum > d_sum ? s_sum : d_sum;
432
433
if( lower_bound && equal_sums ) /* check lower bound */
434
{
435
int sz1 = size1 * (dims + 1), sz2 = size2 * (dims + 1);
436
float lb = 0;
437
438
float* xs = (float *) buffer;
439
float* xd = xs + dims;
440
441
memset( xs, 0, dims*sizeof(xs[0]));
442
memset( xd, 0, dims*sizeof(xd[0]));
443
444
for( j = 0; j < sz1; j += dims + 1 )
445
{
446
float weight = signature1[j];
447
for( i = 0; i < dims; i++ )
448
xs[i] += signature1[j + i + 1] * weight;
449
}
450
451
for( j = 0; j < sz2; j += dims + 1 )
452
{
453
float weight = signature2[j];
454
for( i = 0; i < dims; i++ )
455
xd[i] += signature2[j + i + 1] * weight;
456
}
457
458
lb = dist_func( xs, xd, user_param ) / state->weight;
459
i = *lower_bound <= lb;
460
*lower_bound = lb;
461
if( i )
462
return 1;
463
}
464
465
/* assign pointers */
466
state->is_used = (char *) buffer;
467
/* init delta matrix */
468
state->delta = (float **) buffer;
469
buffer += ssize * sizeof( float * );
470
471
for( i = 0; i < ssize; i++ )
472
{
473
state->delta[i] = (float *) buffer;
474
buffer += dsize * sizeof( float );
475
}
476
477
state->loop = (CvNode2D **) buffer;
478
buffer += (ssize + dsize + 1) * sizeof(CvNode2D*);
479
480
state->_x = state->end_x = (CvNode2D *) buffer;
481
buffer += (ssize + dsize) * sizeof( CvNode2D );
482
483
/* init cost matrix */
484
state->cost = (float **) buffer;
485
buffer += ssize * sizeof( float * );
486
487
/* compute the distance matrix */
488
for( i = 0; i < ssize; i++ )
489
{
490
int ci = state->idx1[i];
491
492
state->cost[i] = (float *) buffer;
493
buffer += dsize * sizeof( float );
494
495
if( ci >= 0 )
496
{
497
for( j = 0; j < dsize; j++ )
498
{
499
int cj = state->idx2[j];
500
if( cj < 0 )
501
state->cost[i][j] = 0;
502
else
503
{
504
float val;
505
if( dist_func )
506
{
507
val = dist_func( signature1 + ci * (dims + 1) + 1,
508
signature2 + cj * (dims + 1) + 1,
509
user_param );
510
}
511
else
512
{
513
assert( cost );
514
val = cost[cost_step*ci + cj];
515
}
516
state->cost[i][j] = val;
517
if( max_cost < val )
518
max_cost = val;
519
}
520
}
521
}
522
else
523
{
524
for( j = 0; j < dsize; j++ )
525
state->cost[i][j] = 0;
526
}
527
}
528
529
state->max_cost = max_cost;
530
531
memset( buffer, 0, buffer_end - buffer );
532
533
state->rows_x = (CvNode2D **) buffer;
534
buffer += ssize * sizeof( CvNode2D * );
535
536
state->cols_x = (CvNode2D **) buffer;
537
buffer += dsize * sizeof( CvNode2D * );
538
539
state->u = (CvNode1D *) buffer;
540
buffer += ssize * sizeof( CvNode1D );
541
542
state->v = (CvNode1D *) buffer;
543
buffer += dsize * sizeof( CvNode1D );
544
545
/* init is_x matrix */
546
state->is_x = (char **) buffer;
547
buffer += ssize * sizeof( char * );
548
549
for( i = 0; i < ssize; i++ )
550
{
551
state->is_x[i] = buffer;
552
buffer += dsize;
553
}
554
555
assert( buffer <= buffer_end );
556
557
icvRussel( state );
558
559
state->enter_x = (state->end_x)++;
560
return 0;
561
}
562
563
564
/****************************************************************************************\
565
* icvFindBasicVariables *
566
\****************************************************************************************/
567
static int icvFindBasicVariables( float **cost, char **is_x,
568
CvNode1D * u, CvNode1D * v, int ssize, int dsize )
569
{
570
int i, j;
571
int u_cfound, v_cfound;
572
CvNode1D u0_head, u1_head, *cur_u, *prev_u;
573
CvNode1D v0_head, v1_head, *cur_v, *prev_v;
574
bool found;
575
576
CV_Assert(u != 0 && v != 0);
577
578
/* initialize the rows list (u) and the columns list (v) */
579
u0_head.next = u;
580
for( i = 0; i < ssize; i++ )
581
{
582
u[i].next = u + i + 1;
583
}
584
u[ssize - 1].next = 0;
585
u1_head.next = 0;
586
587
v0_head.next = ssize > 1 ? v + 1 : 0;
588
for( i = 1; i < dsize; i++ )
589
{
590
v[i].next = v + i + 1;
591
}
592
v[dsize - 1].next = 0;
593
v1_head.next = 0;
594
595
/* there are ssize+dsize variables but only ssize+dsize-1 independent equations,
596
so set v[0]=0 */
597
v[0].val = 0;
598
v1_head.next = v;
599
v1_head.next->next = 0;
600
601
/* loop until all variables are found */
602
u_cfound = v_cfound = 0;
603
while( u_cfound < ssize || v_cfound < dsize )
604
{
605
found = false;
606
if( v_cfound < dsize )
607
{
608
/* loop over all marked columns */
609
prev_v = &v1_head;
610
cur_v = v1_head.next;
611
found = found || (cur_v != 0);
612
for( ; cur_v != 0; cur_v = cur_v->next )
613
{
614
float cur_v_val = cur_v->val;
615
616
j = (int)(cur_v - v);
617
/* find the variables in column j */
618
prev_u = &u0_head;
619
for( cur_u = u0_head.next; cur_u != 0; )
620
{
621
i = (int)(cur_u - u);
622
if( is_x[i][j] )
623
{
624
/* compute u[i] */
625
cur_u->val = cost[i][j] - cur_v_val;
626
/* ...and add it to the marked list */
627
prev_u->next = cur_u->next;
628
cur_u->next = u1_head.next;
629
u1_head.next = cur_u;
630
cur_u = prev_u->next;
631
}
632
else
633
{
634
prev_u = cur_u;
635
cur_u = cur_u->next;
636
}
637
}
638
prev_v->next = cur_v->next;
639
v_cfound++;
640
}
641
}
642
643
if( u_cfound < ssize )
644
{
645
/* loop over all marked rows */
646
prev_u = &u1_head;
647
cur_u = u1_head.next;
648
found = found || (cur_u != 0);
649
for( ; cur_u != 0; cur_u = cur_u->next )
650
{
651
float cur_u_val = cur_u->val;
652
float *_cost;
653
char *_is_x;
654
655
i = (int)(cur_u - u);
656
_cost = cost[i];
657
_is_x = is_x[i];
658
/* find the variables in rows i */
659
prev_v = &v0_head;
660
for( cur_v = v0_head.next; cur_v != 0; )
661
{
662
j = (int)(cur_v - v);
663
if( _is_x[j] )
664
{
665
/* compute v[j] */
666
cur_v->val = _cost[j] - cur_u_val;
667
/* ...and add it to the marked list */
668
prev_v->next = cur_v->next;
669
cur_v->next = v1_head.next;
670
v1_head.next = cur_v;
671
cur_v = prev_v->next;
672
}
673
else
674
{
675
prev_v = cur_v;
676
cur_v = cur_v->next;
677
}
678
}
679
prev_u->next = cur_u->next;
680
u_cfound++;
681
}
682
}
683
684
if( !found )
685
return -1;
686
}
687
688
return 0;
689
}
690
691
692
/****************************************************************************************\
693
* icvIsOptimal *
694
\****************************************************************************************/
695
static float
696
icvIsOptimal( float **cost, char **is_x,
697
CvNode1D * u, CvNode1D * v, int ssize, int dsize, CvNode2D * enter_x )
698
{
699
float delta, min_delta = CV_EMD_INF;
700
int i, j, min_i = 0, min_j = 0;
701
702
/* find the minimal cij-ui-vj over all i,j */
703
for( i = 0; i < ssize; i++ )
704
{
705
float u_val = u[i].val;
706
float *_cost = cost[i];
707
char *_is_x = is_x[i];
708
709
for( j = 0; j < dsize; j++ )
710
{
711
if( !_is_x[j] )
712
{
713
delta = _cost[j] - u_val - v[j].val;
714
if( min_delta > delta )
715
{
716
min_delta = delta;
717
min_i = i;
718
min_j = j;
719
}
720
}
721
}
722
}
723
724
enter_x->i = min_i;
725
enter_x->j = min_j;
726
727
return min_delta;
728
}
729
730
/****************************************************************************************\
731
* icvNewSolution *
732
\****************************************************************************************/
733
static bool
734
icvNewSolution( CvEMDState * state )
735
{
736
int i, j;
737
float min_val = CV_EMD_INF;
738
int steps;
739
CvNode2D head = {0, {0}, 0, 0}, *cur_x, *next_x, *leave_x = 0;
740
CvNode2D *enter_x = state->enter_x;
741
CvNode2D **loop = state->loop;
742
743
/* enter the new basic variable */
744
i = enter_x->i;
745
j = enter_x->j;
746
state->is_x[i][j] = 1;
747
enter_x->next[0] = state->rows_x[i];
748
enter_x->next[1] = state->cols_x[j];
749
enter_x->val = 0;
750
state->rows_x[i] = enter_x;
751
state->cols_x[j] = enter_x;
752
753
/* find a chain reaction */
754
steps = icvFindLoop( state );
755
756
if( steps == 0 )
757
return false;
758
759
/* find the largest value in the loop */
760
for( i = 1; i < steps; i += 2 )
761
{
762
float temp = loop[i]->val;
763
764
if( min_val > temp )
765
{
766
leave_x = loop[i];
767
min_val = temp;
768
}
769
}
770
771
/* update the loop */
772
for( i = 0; i < steps; i += 2 )
773
{
774
float temp0 = loop[i]->val + min_val;
775
float temp1 = loop[i + 1]->val - min_val;
776
777
loop[i]->val = temp0;
778
loop[i + 1]->val = temp1;
779
}
780
781
/* remove the leaving basic variable */
782
CV_Assert(leave_x != NULL);
783
i = leave_x->i;
784
j = leave_x->j;
785
state->is_x[i][j] = 0;
786
787
head.next[0] = state->rows_x[i];
788
cur_x = &head;
789
while( (next_x = cur_x->next[0]) != leave_x )
790
{
791
cur_x = next_x;
792
CV_Assert( cur_x );
793
}
794
cur_x->next[0] = next_x->next[0];
795
state->rows_x[i] = head.next[0];
796
797
head.next[1] = state->cols_x[j];
798
cur_x = &head;
799
while( (next_x = cur_x->next[1]) != leave_x )
800
{
801
cur_x = next_x;
802
CV_Assert( cur_x );
803
}
804
cur_x->next[1] = next_x->next[1];
805
state->cols_x[j] = head.next[1];
806
807
/* set enter_x to be the new empty slot */
808
state->enter_x = leave_x;
809
810
return true;
811
}
812
813
814
815
/****************************************************************************************\
816
* icvFindLoop *
817
\****************************************************************************************/
818
static int
819
icvFindLoop( CvEMDState * state )
820
{
821
int i, steps = 1;
822
CvNode2D *new_x;
823
CvNode2D **loop = state->loop;
824
CvNode2D *enter_x = state->enter_x, *_x = state->_x;
825
char *is_used = state->is_used;
826
827
memset( is_used, 0, state->ssize + state->dsize );
828
829
new_x = loop[0] = enter_x;
830
is_used[enter_x - _x] = 1;
831
steps = 1;
832
833
do
834
{
835
if( (steps & 1) == 1 )
836
{
837
/* find an unused x in the row */
838
new_x = state->rows_x[new_x->i];
839
while( new_x != 0 && is_used[new_x - _x] )
840
new_x = new_x->next[0];
841
}
842
else
843
{
844
/* find an unused x in the column, or the entering x */
845
new_x = state->cols_x[new_x->j];
846
while( new_x != 0 && is_used[new_x - _x] && new_x != enter_x )
847
new_x = new_x->next[1];
848
if( new_x == enter_x )
849
break;
850
}
851
852
if( new_x != 0 ) /* found the next x */
853
{
854
/* add x to the loop */
855
loop[steps++] = new_x;
856
is_used[new_x - _x] = 1;
857
}
858
else /* didn't find the next x */
859
{
860
/* backtrack */
861
do
862
{
863
i = steps & 1;
864
new_x = loop[steps - 1];
865
do
866
{
867
new_x = new_x->next[i];
868
}
869
while( new_x != 0 && is_used[new_x - _x] );
870
871
if( new_x == 0 )
872
{
873
is_used[loop[--steps] - _x] = 0;
874
}
875
}
876
while( new_x == 0 && steps > 0 );
877
878
is_used[loop[steps - 1] - _x] = 0;
879
loop[steps - 1] = new_x;
880
is_used[new_x - _x] = 1;
881
}
882
}
883
while( steps > 0 );
884
885
return steps;
886
}
887
888
889
890
/****************************************************************************************\
891
* icvRussel *
892
\****************************************************************************************/
893
static void
894
icvRussel( CvEMDState * state )
895
{
896
int i, j, min_i = -1, min_j = -1;
897
float min_delta, diff;
898
CvNode1D u_head, *cur_u, *prev_u;
899
CvNode1D v_head, *cur_v, *prev_v;
900
CvNode1D *prev_u_min_i = 0, *prev_v_min_j = 0, *remember;
901
CvNode1D *u = state->u, *v = state->v;
902
int ssize = state->ssize, dsize = state->dsize;
903
float eps = CV_EMD_EPS * state->max_cost;
904
float **cost = state->cost;
905
float **delta = state->delta;
906
907
/* initialize the rows list (ur), and the columns list (vr) */
908
u_head.next = u;
909
for( i = 0; i < ssize; i++ )
910
{
911
u[i].next = u + i + 1;
912
}
913
u[ssize - 1].next = 0;
914
915
v_head.next = v;
916
for( i = 0; i < dsize; i++ )
917
{
918
v[i].val = -CV_EMD_INF;
919
v[i].next = v + i + 1;
920
}
921
v[dsize - 1].next = 0;
922
923
/* find the maximum row and column values (ur[i] and vr[j]) */
924
for( i = 0; i < ssize; i++ )
925
{
926
float u_val = -CV_EMD_INF;
927
float *cost_row = cost[i];
928
929
for( j = 0; j < dsize; j++ )
930
{
931
float temp = cost_row[j];
932
933
if( u_val < temp )
934
u_val = temp;
935
if( v[j].val < temp )
936
v[j].val = temp;
937
}
938
u[i].val = u_val;
939
}
940
941
/* compute the delta matrix */
942
for( i = 0; i < ssize; i++ )
943
{
944
float u_val = u[i].val;
945
float *delta_row = delta[i];
946
float *cost_row = cost[i];
947
948
for( j = 0; j < dsize; j++ )
949
{
950
delta_row[j] = cost_row[j] - u_val - v[j].val;
951
}
952
}
953
954
/* find the basic variables */
955
do
956
{
957
/* find the smallest delta[i][j] */
958
min_i = -1;
959
min_delta = CV_EMD_INF;
960
prev_u = &u_head;
961
for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
962
{
963
i = (int)(cur_u - u);
964
float *delta_row = delta[i];
965
966
prev_v = &v_head;
967
for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
968
{
969
j = (int)(cur_v - v);
970
if( min_delta > delta_row[j] )
971
{
972
min_delta = delta_row[j];
973
min_i = i;
974
min_j = j;
975
prev_u_min_i = prev_u;
976
prev_v_min_j = prev_v;
977
}
978
prev_v = cur_v;
979
}
980
prev_u = cur_u;
981
}
982
983
if( min_i < 0 )
984
break;
985
986
/* add x[min_i][min_j] to the basis, and adjust supplies and cost */
987
remember = prev_u_min_i->next;
988
icvAddBasicVariable( state, min_i, min_j, prev_u_min_i, prev_v_min_j, &u_head );
989
990
/* update the necessary delta[][] */
991
if( remember == prev_u_min_i->next ) /* line min_i was deleted */
992
{
993
for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
994
{
995
j = (int)(cur_v - v);
996
if( cur_v->val == cost[min_i][j] ) /* column j needs updating */
997
{
998
float max_val = -CV_EMD_INF;
999
1000
/* find the new maximum value in the column */
1001
for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
1002
{
1003
float temp = cost[cur_u - u][j];
1004
1005
if( max_val < temp )
1006
max_val = temp;
1007
}
1008
1009
/* if needed, adjust the relevant delta[*][j] */
1010
diff = max_val - cur_v->val;
1011
cur_v->val = max_val;
1012
if( fabs( diff ) < eps )
1013
{
1014
for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
1015
delta[cur_u - u][j] += diff;
1016
}
1017
}
1018
}
1019
}
1020
else /* column min_j was deleted */
1021
{
1022
for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
1023
{
1024
i = (int)(cur_u - u);
1025
if( cur_u->val == cost[i][min_j] ) /* row i needs updating */
1026
{
1027
float max_val = -CV_EMD_INF;
1028
1029
/* find the new maximum value in the row */
1030
for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
1031
{
1032
float temp = cost[i][cur_v - v];
1033
1034
if( max_val < temp )
1035
max_val = temp;
1036
}
1037
1038
/* if needed, adjust the relevant delta[i][*] */
1039
diff = max_val - cur_u->val;
1040
cur_u->val = max_val;
1041
1042
if( fabs( diff ) < eps )
1043
{
1044
for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
1045
delta[i][cur_v - v] += diff;
1046
}
1047
}
1048
}
1049
}
1050
}
1051
while( u_head.next != 0 || v_head.next != 0 );
1052
}
1053
1054
1055
1056
/****************************************************************************************\
1057
* icvAddBasicVariable *
1058
\****************************************************************************************/
1059
static void
1060
icvAddBasicVariable( CvEMDState * state,
1061
int min_i, int min_j,
1062
CvNode1D * prev_u_min_i, CvNode1D * prev_v_min_j, CvNode1D * u_head )
1063
{
1064
float temp;
1065
CvNode2D *end_x = state->end_x;
1066
1067
if( state->s[min_i] < state->d[min_j] + state->weight * CV_EMD_EPS )
1068
{ /* supply exhausted */
1069
temp = state->s[min_i];
1070
state->s[min_i] = 0;
1071
state->d[min_j] -= temp;
1072
}
1073
else /* demand exhausted */
1074
{
1075
temp = state->d[min_j];
1076
state->d[min_j] = 0;
1077
state->s[min_i] -= temp;
1078
}
1079
1080
/* x(min_i,min_j) is a basic variable */
1081
state->is_x[min_i][min_j] = 1;
1082
1083
end_x->val = temp;
1084
end_x->i = min_i;
1085
end_x->j = min_j;
1086
end_x->next[0] = state->rows_x[min_i];
1087
end_x->next[1] = state->cols_x[min_j];
1088
state->rows_x[min_i] = end_x;
1089
state->cols_x[min_j] = end_x;
1090
state->end_x = end_x + 1;
1091
1092
/* delete supply row only if the empty, and if not last row */
1093
if( state->s[min_i] == 0 && u_head->next->next != 0 )
1094
prev_u_min_i->next = prev_u_min_i->next->next; /* remove row from list */
1095
else
1096
prev_v_min_j->next = prev_v_min_j->next->next; /* remove column from list */
1097
}
1098
1099
1100
/****************************************************************************************\
1101
* standard metrics *
1102
\****************************************************************************************/
1103
static float
1104
icvDistL1( const float *x, const float *y, void *user_param )
1105
{
1106
int i, dims = (int)(size_t)user_param;
1107
double s = 0;
1108
1109
for( i = 0; i < dims; i++ )
1110
{
1111
double t = x[i] - y[i];
1112
1113
s += fabs( t );
1114
}
1115
return (float)s;
1116
}
1117
1118
static float
1119
icvDistL2( const float *x, const float *y, void *user_param )
1120
{
1121
int i, dims = (int)(size_t)user_param;
1122
double s = 0;
1123
1124
for( i = 0; i < dims; i++ )
1125
{
1126
double t = x[i] - y[i];
1127
1128
s += t * t;
1129
}
1130
return cvSqrt( (float)s );
1131
}
1132
1133
static float
1134
icvDistC( const float *x, const float *y, void *user_param )
1135
{
1136
int i, dims = (int)(size_t)user_param;
1137
double s = 0;
1138
1139
for( i = 0; i < dims; i++ )
1140
{
1141
double t = fabs( x[i] - y[i] );
1142
1143
if( s < t )
1144
s = t;
1145
}
1146
return (float)s;
1147
}
1148
1149
1150
float cv::EMD( InputArray _signature1, InputArray _signature2,
1151
int distType, InputArray _cost,
1152
float* lowerBound, OutputArray _flow )
1153
{
1154
CV_INSTRUMENT_REGION();
1155
1156
Mat signature1 = _signature1.getMat(), signature2 = _signature2.getMat();
1157
Mat cost = _cost.getMat(), flow;
1158
1159
CvMat _csignature1 = cvMat(signature1);
1160
CvMat _csignature2 = cvMat(signature2);
1161
CvMat _ccost = cvMat(cost), _cflow;
1162
if( _flow.needed() )
1163
{
1164
_flow.create(signature1.rows, signature2.rows, CV_32F);
1165
flow = _flow.getMat();
1166
flow = Scalar::all(0);
1167
_cflow = cvMat(flow);
1168
}
1169
1170
return cvCalcEMD2( &_csignature1, &_csignature2, distType, 0, cost.empty() ? 0 : &_ccost,
1171
_flow.needed() ? &_cflow : 0, lowerBound, 0 );
1172
}
1173
1174
float cv::wrapperEMD(InputArray _signature1, InputArray _signature2,
1175
int distType, InputArray _cost,
1176
Ptr<float> lowerBound, OutputArray _flow)
1177
{
1178
return EMD(_signature1, _signature2, distType, _cost, lowerBound.get(), _flow);
1179
}
1180
1181
/* End of file. */
1182
1183