Path: blob/master/modules/core/src/batch_distance.cpp
16337 views
// This file is part of OpenCV project.1// It is subject to the license terms in the LICENSE file found in the top-level directory2// of this distribution and at http://opencv.org/license.html345#include "precomp.hpp"6#include "stat.hpp"7#include <opencv2/core/hal/hal.hpp>89namespace cv10{1112template<typename _Tp, typename _Rt>13void batchDistL1_(const _Tp* src1, const _Tp* src2, size_t step2,14int nvecs, int len, _Rt* dist, const uchar* mask)15{16step2 /= sizeof(src2[0]);17if( !mask )18{19for( int i = 0; i < nvecs; i++ )20dist[i] = normL1<_Tp, _Rt>(src1, src2 + step2*i, len);21}22else23{24_Rt val0 = std::numeric_limits<_Rt>::max();25for( int i = 0; i < nvecs; i++ )26dist[i] = mask[i] ? normL1<_Tp, _Rt>(src1, src2 + step2*i, len) : val0;27}28}2930template<typename _Tp, typename _Rt>31void batchDistL2Sqr_(const _Tp* src1, const _Tp* src2, size_t step2,32int nvecs, int len, _Rt* dist, const uchar* mask)33{34step2 /= sizeof(src2[0]);35if( !mask )36{37for( int i = 0; i < nvecs; i++ )38dist[i] = normL2Sqr<_Tp, _Rt>(src1, src2 + step2*i, len);39}40else41{42_Rt val0 = std::numeric_limits<_Rt>::max();43for( int i = 0; i < nvecs; i++ )44dist[i] = mask[i] ? normL2Sqr<_Tp, _Rt>(src1, src2 + step2*i, len) : val0;45}46}4748template<>49void batchDistL2Sqr_(const float* src1, const float* src2, size_t step2,50int nvecs, int len, float* dist, const uchar* mask)51{52step2 /= sizeof(src2[0]);53if( !mask )54{55for( int i = 0; i < nvecs; i++ )56dist[i] = hal::normL2Sqr_(src1, src2 + step2*i, len);57}58else59{60float val0 = std::numeric_limits<float>::max();61for( int i = 0; i < nvecs; i++ )62dist[i] = mask[i] ? hal::normL2Sqr_(src1, src2 + step2*i, len) : val0;63}64}6566template<typename _Tp, typename _Rt>67void batchDistL2_(const _Tp* src1, const _Tp* src2, size_t step2,68int nvecs, int len, _Rt* dist, const uchar* mask)69{70step2 /= sizeof(src2[0]);71if( !mask )72{73for( int i = 0; i < nvecs; i++ )74dist[i] = std::sqrt(normL2Sqr<_Tp, _Rt>(src1, src2 + step2*i, len));75}76else77{78_Rt val0 = std::numeric_limits<_Rt>::max();79for( int i = 0; i < nvecs; i++ )80dist[i] = mask[i] ? std::sqrt(normL2Sqr<_Tp, _Rt>(src1, src2 + step2*i, len)) : val0;81}82}8384template<>85void batchDistL2_(const float* src1, const float* src2, size_t step2,86int nvecs, int len, float* dist, const uchar* mask)87{88step2 /= sizeof(src2[0]);89if( !mask )90{91for( int i = 0; i < nvecs; i++ )92dist[i] = std::sqrt(hal::normL2Sqr_(src1, src2 + step2*i, len));93}94else95{96float val0 = std::numeric_limits<float>::max();97for( int i = 0; i < nvecs; i++ )98dist[i] = mask[i] ? std::sqrt(hal::normL2Sqr_(src1, src2 + step2*i, len)) : val0;99}100}101102static void batchDistHamming(const uchar* src1, const uchar* src2, size_t step2,103int nvecs, int len, int* dist, const uchar* mask)104{105step2 /= sizeof(src2[0]);106if( !mask )107{108for( int i = 0; i < nvecs; i++ )109dist[i] = hal::normHamming(src1, src2 + step2*i, len);110}111else112{113int val0 = INT_MAX;114for( int i = 0; i < nvecs; i++ )115{116if (mask[i])117dist[i] = hal::normHamming(src1, src2 + step2*i, len);118else119dist[i] = val0;120}121}122}123124static void batchDistHamming2(const uchar* src1, const uchar* src2, size_t step2,125int nvecs, int len, int* dist, const uchar* mask)126{127step2 /= sizeof(src2[0]);128if( !mask )129{130for( int i = 0; i < nvecs; i++ )131dist[i] = hal::normHamming(src1, src2 + step2*i, len, 2);132}133else134{135int val0 = INT_MAX;136for( int i = 0; i < nvecs; i++ )137{138if (mask[i])139dist[i] = hal::normHamming(src1, src2 + step2*i, len, 2);140else141dist[i] = val0;142}143}144}145146static void batchDistL1_8u32s(const uchar* src1, const uchar* src2, size_t step2,147int nvecs, int len, int* dist, const uchar* mask)148{149batchDistL1_<uchar, int>(src1, src2, step2, nvecs, len, dist, mask);150}151152static void batchDistL1_8u32f(const uchar* src1, const uchar* src2, size_t step2,153int nvecs, int len, float* dist, const uchar* mask)154{155batchDistL1_<uchar, float>(src1, src2, step2, nvecs, len, dist, mask);156}157158static void batchDistL2Sqr_8u32s(const uchar* src1, const uchar* src2, size_t step2,159int nvecs, int len, int* dist, const uchar* mask)160{161batchDistL2Sqr_<uchar, int>(src1, src2, step2, nvecs, len, dist, mask);162}163164static void batchDistL2Sqr_8u32f(const uchar* src1, const uchar* src2, size_t step2,165int nvecs, int len, float* dist, const uchar* mask)166{167batchDistL2Sqr_<uchar, float>(src1, src2, step2, nvecs, len, dist, mask);168}169170static void batchDistL2_8u32f(const uchar* src1, const uchar* src2, size_t step2,171int nvecs, int len, float* dist, const uchar* mask)172{173batchDistL2_<uchar, float>(src1, src2, step2, nvecs, len, dist, mask);174}175176static void batchDistL1_32f(const float* src1, const float* src2, size_t step2,177int nvecs, int len, float* dist, const uchar* mask)178{179batchDistL1_<float, float>(src1, src2, step2, nvecs, len, dist, mask);180}181182static void batchDistL2Sqr_32f(const float* src1, const float* src2, size_t step2,183int nvecs, int len, float* dist, const uchar* mask)184{185batchDistL2Sqr_<float, float>(src1, src2, step2, nvecs, len, dist, mask);186}187188static void batchDistL2_32f(const float* src1, const float* src2, size_t step2,189int nvecs, int len, float* dist, const uchar* mask)190{191batchDistL2_<float, float>(src1, src2, step2, nvecs, len, dist, mask);192}193194typedef void (*BatchDistFunc)(const uchar* src1, const uchar* src2, size_t step2,195int nvecs, int len, uchar* dist, const uchar* mask);196197198struct BatchDistInvoker : public ParallelLoopBody199{200BatchDistInvoker( const Mat& _src1, const Mat& _src2,201Mat& _dist, Mat& _nidx, int _K,202const Mat& _mask, int _update,203BatchDistFunc _func)204{205src1 = &_src1;206src2 = &_src2;207dist = &_dist;208nidx = &_nidx;209K = _K;210mask = &_mask;211update = _update;212func = _func;213}214215void operator()(const Range& range) const CV_OVERRIDE216{217AutoBuffer<int> buf(src2->rows);218int* bufptr = buf.data();219220for( int i = range.start; i < range.end; i++ )221{222func(src1->ptr(i), src2->ptr(), src2->step, src2->rows, src2->cols,223K > 0 ? (uchar*)bufptr : dist->ptr(i), mask->data ? mask->ptr(i) : 0);224225if( K > 0 )226{227int* nidxptr = nidx->ptr<int>(i);228// since positive float's can be compared just like int's,229// we handle both CV_32S and CV_32F cases with a single branch230int* distptr = (int*)dist->ptr(i);231232int j, k;233234for( j = 0; j < src2->rows; j++ )235{236int d = bufptr[j];237if( d < distptr[K-1] )238{239for( k = K-2; k >= 0 && distptr[k] > d; k-- )240{241nidxptr[k+1] = nidxptr[k];242distptr[k+1] = distptr[k];243}244nidxptr[k+1] = j + update;245distptr[k+1] = d;246}247}248}249}250}251252const Mat *src1;253const Mat *src2;254Mat *dist;255Mat *nidx;256const Mat *mask;257int K;258int update;259BatchDistFunc func;260};261262}263264void cv::batchDistance( InputArray _src1, InputArray _src2,265OutputArray _dist, int dtype, OutputArray _nidx,266int normType, int K, InputArray _mask,267int update, bool crosscheck )268{269CV_INSTRUMENT_REGION();270271Mat src1 = _src1.getMat(), src2 = _src2.getMat(), mask = _mask.getMat();272int type = src1.type();273CV_Assert( type == src2.type() && src1.cols == src2.cols &&274(type == CV_32F || type == CV_8U));275CV_Assert( _nidx.needed() == (K > 0) );276277if( dtype == -1 )278{279dtype = normType == NORM_HAMMING || normType == NORM_HAMMING2 ? CV_32S : CV_32F;280}281CV_Assert( (type == CV_8U && dtype == CV_32S) || dtype == CV_32F);282283K = std::min(K, src2.rows);284285_dist.create(src1.rows, (K > 0 ? K : src2.rows), dtype);286Mat dist = _dist.getMat(), nidx;287if( _nidx.needed() )288{289_nidx.create(dist.size(), CV_32S);290nidx = _nidx.getMat();291}292293if( update == 0 && K > 0 )294{295dist = Scalar::all(dtype == CV_32S ? (double)INT_MAX : (double)FLT_MAX);296nidx = Scalar::all(-1);297}298299if( crosscheck )300{301CV_Assert( K == 1 && update == 0 && mask.empty() );302CV_Assert(!nidx.empty());303Mat tdist, tidx;304batchDistance(src2, src1, tdist, dtype, tidx, normType, K, mask, 0, false);305306// if an idx-th element from src1 appeared to be the nearest to i-th element of src2,307// we update the minimum mutual distance between idx-th element of src1 and the whole src2 set.308// As a result, if nidx[idx] = i*, it means that idx-th element of src1 is the nearest309// to i*-th element of src2 and i*-th element of src2 is the closest to idx-th element of src1.310// If nidx[idx] = -1, it means that there is no such ideal couple for it in src2.311// This O(N) procedure is called cross-check and it helps to eliminate some false matches.312if( dtype == CV_32S )313{314for( int i = 0; i < tdist.rows; i++ )315{316int idx = tidx.at<int>(i);317int d = tdist.at<int>(i), d0 = dist.at<int>(idx);318if( d < d0 )319{320dist.at<int>(idx) = d;321nidx.at<int>(idx) = i + update;322}323}324}325else326{327for( int i = 0; i < tdist.rows; i++ )328{329int idx = tidx.at<int>(i);330float d = tdist.at<float>(i), d0 = dist.at<float>(idx);331if( d < d0 )332{333dist.at<float>(idx) = d;334nidx.at<int>(idx) = i + update;335}336}337}338return;339}340341BatchDistFunc func = 0;342if( type == CV_8U )343{344if( normType == NORM_L1 && dtype == CV_32S )345func = (BatchDistFunc)batchDistL1_8u32s;346else if( normType == NORM_L1 && dtype == CV_32F )347func = (BatchDistFunc)batchDistL1_8u32f;348else if( normType == NORM_L2SQR && dtype == CV_32S )349func = (BatchDistFunc)batchDistL2Sqr_8u32s;350else if( normType == NORM_L2SQR && dtype == CV_32F )351func = (BatchDistFunc)batchDistL2Sqr_8u32f;352else if( normType == NORM_L2 && dtype == CV_32F )353func = (BatchDistFunc)batchDistL2_8u32f;354else if( normType == NORM_HAMMING && dtype == CV_32S )355func = (BatchDistFunc)batchDistHamming;356else if( normType == NORM_HAMMING2 && dtype == CV_32S )357func = (BatchDistFunc)batchDistHamming2;358}359else if( type == CV_32F && dtype == CV_32F )360{361if( normType == NORM_L1 )362func = (BatchDistFunc)batchDistL1_32f;363else if( normType == NORM_L2SQR )364func = (BatchDistFunc)batchDistL2Sqr_32f;365else if( normType == NORM_L2 )366func = (BatchDistFunc)batchDistL2_32f;367}368369if( func == 0 )370CV_Error_(CV_StsUnsupportedFormat,371("The combination of type=%d, dtype=%d and normType=%d is not supported",372type, dtype, normType));373374parallel_for_(Range(0, src1.rows),375BatchDistInvoker(src1, src2, dist, nidx, K, mask, update, func));376}377378379