Path: blob/master/modules/ml/test/test_emknearestkmeans.cpp
16339 views
/*M///////////////////////////////////////////////////////////////////////////////////////1//2// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.3//4// By downloading, copying, installing or using the software you agree to this license.5// If you do not agree to this license, do not download, install,6// copy or use the software.7//8//9// Intel License Agreement10// For Open Source Computer Vision Library11//12// Copyright (C) 2000, Intel Corporation, all rights reserved.13// Third party copyrights are property of their respective owners.14//15// Redistribution and use in source and binary forms, with or without modification,16// are permitted provided that the following conditions are met:17//18// * Redistribution's of source code must retain the above copyright notice,19// this list of conditions and the following disclaimer.20//21// * Redistribution's in binary form must reproduce the above copyright notice,22// this list of conditions and the following disclaimer in the documentation23// and/or other materials provided with the distribution.24//25// * The name of Intel Corporation may not be used to endorse or promote products26// derived from this software without specific prior written permission.27//28// This software is provided by the copyright holders and contributors "as is" and29// any express or implied warranties, including, but not limited to, the implied30// warranties of merchantability and fitness for a particular purpose are disclaimed.31// In no event shall the Intel Corporation or contributors be liable for any direct,32// indirect, incidental, special, exemplary, or consequential damages33// (including, but not limited to, procurement of substitute goods or services;34// loss of use, data, or profits; or business interruption) however caused35// and on any theory of liability, whether in contract, strict liability,36// or tort (including negligence or otherwise) arising in any way out of37// the use of this software, even if advised of the possibility of such damage.38//39//M*/4041#include "test_precomp.hpp"4243namespace opencv_test { namespace {4445using cv::ml::TrainData;46using cv::ml::EM;47using cv::ml::KNearest;4849void defaultDistribs( Mat& means, vector<Mat>& covs, int type=CV_32FC1 )50{51CV_TRACE_FUNCTION();52float mp0[] = {0.0f, 0.0f}, cp0[] = {0.67f, 0.0f, 0.0f, 0.67f};53float mp1[] = {5.0f, 0.0f}, cp1[] = {1.0f, 0.0f, 0.0f, 1.0f};54float mp2[] = {1.0f, 5.0f}, cp2[] = {1.0f, 0.0f, 0.0f, 1.0f};55means.create(3, 2, type);56Mat m0( 1, 2, CV_32FC1, mp0 ), c0( 2, 2, CV_32FC1, cp0 );57Mat m1( 1, 2, CV_32FC1, mp1 ), c1( 2, 2, CV_32FC1, cp1 );58Mat m2( 1, 2, CV_32FC1, mp2 ), c2( 2, 2, CV_32FC1, cp2 );59means.resize(3), covs.resize(3);6061Mat mr0 = means.row(0);62m0.convertTo(mr0, type);63c0.convertTo(covs[0], type);6465Mat mr1 = means.row(1);66m1.convertTo(mr1, type);67c1.convertTo(covs[1], type);6869Mat mr2 = means.row(2);70m2.convertTo(mr2, type);71c2.convertTo(covs[2], type);72}7374// generate points sets by normal distributions75void generateData( Mat& data, Mat& labels, const vector<int>& sizes, const Mat& _means, const vector<Mat>& covs, int dataType, int labelType )76{77CV_TRACE_FUNCTION();78vector<int>::const_iterator sit = sizes.begin();79int total = 0;80for( ; sit != sizes.end(); ++sit )81total += *sit;82CV_Assert( _means.rows == (int)sizes.size() && covs.size() == sizes.size() );83CV_Assert( !data.empty() && data.rows == total );84CV_Assert( data.type() == dataType );8586labels.create( data.rows, 1, labelType );8788randn( data, Scalar::all(-1.0), Scalar::all(1.0) );89vector<Mat> means(sizes.size());90for(int i = 0; i < _means.rows; i++)91means[i] = _means.row(i);92vector<Mat>::const_iterator mit = means.begin(), cit = covs.begin();93int bi, ei = 0;94sit = sizes.begin();95for( int p = 0, l = 0; sit != sizes.end(); ++sit, ++mit, ++cit, l++ )96{97bi = ei;98ei = bi + *sit;99assert( mit->rows == 1 && mit->cols == data.cols );100assert( cit->rows == data.cols && cit->cols == data.cols );101for( int i = bi; i < ei; i++, p++ )102{103Mat r = data.row(i);104r = r * (*cit) + *mit;105if( labelType == CV_32FC1 )106labels.at<float>(p, 0) = (float)l;107else if( labelType == CV_32SC1 )108labels.at<int>(p, 0) = l;109else110{111CV_DbgAssert(0);112}113}114}115}116117int maxIdx( const vector<int>& count )118{119int idx = -1;120int maxVal = -1;121vector<int>::const_iterator it = count.begin();122for( int i = 0; it != count.end(); ++it, i++ )123{124if( *it > maxVal)125{126maxVal = *it;127idx = i;128}129}130assert( idx >= 0);131return idx;132}133134bool getLabelsMap( const Mat& labels, const vector<int>& sizes, vector<int>& labelsMap, bool checkClusterUniq=true )135{136size_t total = 0, nclusters = sizes.size();137for(size_t i = 0; i < sizes.size(); i++)138total += sizes[i];139140assert( !labels.empty() );141assert( labels.total() == total && (labels.cols == 1 || labels.rows == 1));142assert( labels.type() == CV_32SC1 || labels.type() == CV_32FC1 );143144bool isFlt = labels.type() == CV_32FC1;145146labelsMap.resize(nclusters);147148vector<bool> buzy(nclusters, false);149int startIndex = 0;150for( size_t clusterIndex = 0; clusterIndex < sizes.size(); clusterIndex++ )151{152vector<int> count( nclusters, 0 );153for( int i = startIndex; i < startIndex + sizes[clusterIndex]; i++)154{155int lbl = isFlt ? (int)labels.at<float>(i) : labels.at<int>(i);156CV_Assert(lbl < (int)nclusters);157count[lbl]++;158CV_Assert(count[lbl] < (int)total);159}160startIndex += sizes[clusterIndex];161162int cls = maxIdx( count );163CV_Assert( !checkClusterUniq || !buzy[cls] );164165labelsMap[clusterIndex] = cls;166167buzy[cls] = true;168}169170if(checkClusterUniq)171{172for(size_t i = 0; i < buzy.size(); i++)173if(!buzy[i])174return false;175}176177return true;178}179180bool calcErr( const Mat& labels, const Mat& origLabels, const vector<int>& sizes, float& err, bool labelsEquivalent = true, bool checkClusterUniq=true )181{182err = 0;183CV_Assert( !labels.empty() && !origLabels.empty() );184CV_Assert( labels.rows == 1 || labels.cols == 1 );185CV_Assert( origLabels.rows == 1 || origLabels.cols == 1 );186CV_Assert( labels.total() == origLabels.total() );187CV_Assert( labels.type() == CV_32SC1 || labels.type() == CV_32FC1 );188CV_Assert( origLabels.type() == labels.type() );189190vector<int> labelsMap;191bool isFlt = labels.type() == CV_32FC1;192if( !labelsEquivalent )193{194if( !getLabelsMap( labels, sizes, labelsMap, checkClusterUniq ) )195return false;196197for( int i = 0; i < labels.rows; i++ )198if( isFlt )199err += labels.at<float>(i) != labelsMap[(int)origLabels.at<float>(i)] ? 1.f : 0.f;200else201err += labels.at<int>(i) != labelsMap[origLabels.at<int>(i)] ? 1.f : 0.f;202}203else204{205for( int i = 0; i < labels.rows; i++ )206if( isFlt )207err += labels.at<float>(i) != origLabels.at<float>(i) ? 1.f : 0.f;208else209err += labels.at<int>(i) != origLabels.at<int>(i) ? 1.f : 0.f;210}211err /= (float)labels.rows;212return true;213}214215//--------------------------------------------------------------------------------------------216class CV_KMeansTest : public cvtest::BaseTest {217public:218CV_KMeansTest() {}219protected:220virtual void run( int start_from );221};222223void CV_KMeansTest::run( int /*start_from*/ )224{225CV_TRACE_FUNCTION();226const int iters = 100;227int sizesArr[] = { 5000, 7000, 8000 };228int pointsCount = sizesArr[0]+ sizesArr[1] + sizesArr[2];229230Mat data( pointsCount, 2, CV_32FC1 ), labels;231vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) );232Mat means;233vector<Mat> covs;234defaultDistribs( means, covs );235generateData( data, labels, sizes, means, covs, CV_32FC1, CV_32SC1 );236237int code = cvtest::TS::OK;238float err;239Mat bestLabels;240// 1. flag==KMEANS_PP_CENTERS241kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_PP_CENTERS, noArray() );242if( !calcErr( bestLabels, labels, sizes, err , false ) )243{244ts->printf( cvtest::TS::LOG, "Bad output labels if flag==KMEANS_PP_CENTERS.\n" );245code = cvtest::TS::FAIL_INVALID_OUTPUT;246}247else if( err > 0.01f )248{249ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_PP_CENTERS.\n", err );250code = cvtest::TS::FAIL_BAD_ACCURACY;251}252253// 2. flag==KMEANS_RANDOM_CENTERS254kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_RANDOM_CENTERS, noArray() );255if( !calcErr( bestLabels, labels, sizes, err, false ) )256{257ts->printf( cvtest::TS::LOG, "Bad output labels if flag==KMEANS_RANDOM_CENTERS.\n" );258code = cvtest::TS::FAIL_INVALID_OUTPUT;259}260else if( err > 0.01f )261{262ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_RANDOM_CENTERS.\n", err );263code = cvtest::TS::FAIL_BAD_ACCURACY;264}265266// 3. flag==KMEANS_USE_INITIAL_LABELS267labels.copyTo( bestLabels );268RNG rng;269for( int i = 0; i < 0.5f * pointsCount; i++ )270bestLabels.at<int>( rng.next() % pointsCount, 0 ) = rng.next() % 3;271kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_USE_INITIAL_LABELS, noArray() );272if( !calcErr( bestLabels, labels, sizes, err, false ) )273{274ts->printf( cvtest::TS::LOG, "Bad output labels if flag==KMEANS_USE_INITIAL_LABELS.\n" );275code = cvtest::TS::FAIL_INVALID_OUTPUT;276}277else if( err > 0.01f )278{279ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_USE_INITIAL_LABELS.\n", err );280code = cvtest::TS::FAIL_BAD_ACCURACY;281}282283ts->set_failed_test_info( code );284}285286//--------------------------------------------------------------------------------------------287class CV_KNearestTest : public cvtest::BaseTest {288public:289CV_KNearestTest() {}290protected:291virtual void run( int start_from );292};293294void CV_KNearestTest::run( int /*start_from*/ )295{296int sizesArr[] = { 500, 700, 800 };297int pointsCount = sizesArr[0]+ sizesArr[1] + sizesArr[2];298299// train data300Mat trainData( pointsCount, 2, CV_32FC1 ), trainLabels;301vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) );302Mat means;303vector<Mat> covs;304defaultDistribs( means, covs );305generateData( trainData, trainLabels, sizes, means, covs, CV_32FC1, CV_32FC1 );306307// test data308Mat testData( pointsCount, 2, CV_32FC1 ), testLabels, bestLabels;309generateData( testData, testLabels, sizes, means, covs, CV_32FC1, CV_32FC1 );310311int code = cvtest::TS::OK;312313// KNearest default implementation314Ptr<KNearest> knearest = KNearest::create();315knearest->train(trainData, ml::ROW_SAMPLE, trainLabels);316knearest->findNearest(testData, 4, bestLabels);317float err;318if( !calcErr( bestLabels, testLabels, sizes, err, true ) )319{320ts->printf( cvtest::TS::LOG, "Bad output labels.\n" );321code = cvtest::TS::FAIL_INVALID_OUTPUT;322}323else if( err > 0.01f )324{325ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on test data.\n", err );326code = cvtest::TS::FAIL_BAD_ACCURACY;327}328329// KNearest KDTree implementation330Ptr<KNearest> knearestKdt = KNearest::create();331knearestKdt->setAlgorithmType(KNearest::KDTREE);332knearestKdt->train(trainData, ml::ROW_SAMPLE, trainLabels);333knearestKdt->findNearest(testData, 4, bestLabels);334if( !calcErr( bestLabels, testLabels, sizes, err, true ) )335{336ts->printf( cvtest::TS::LOG, "Bad output labels.\n" );337code = cvtest::TS::FAIL_INVALID_OUTPUT;338}339else if( err > 0.01f )340{341ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on test data.\n", err );342code = cvtest::TS::FAIL_BAD_ACCURACY;343}344345ts->set_failed_test_info( code );346}347348class EM_Params349{350public:351EM_Params(int _nclusters=10, int _covMatType=EM::COV_MAT_DIAGONAL, int _startStep=EM::START_AUTO_STEP,352const cv::TermCriteria& _termCrit=cv::TermCriteria(cv::TermCriteria::COUNT+cv::TermCriteria::EPS, 100, FLT_EPSILON),353const cv::Mat* _probs=0, const cv::Mat* _weights=0,354const cv::Mat* _means=0, const std::vector<cv::Mat>* _covs=0)355: nclusters(_nclusters), covMatType(_covMatType), startStep(_startStep),356probs(_probs), weights(_weights), means(_means), covs(_covs), termCrit(_termCrit)357{}358359int nclusters;360int covMatType;361int startStep;362363// all 4 following matrices should have type CV_32FC1364const cv::Mat* probs;365const cv::Mat* weights;366const cv::Mat* means;367const std::vector<cv::Mat>* covs;368369cv::TermCriteria termCrit;370};371372//--------------------------------------------------------------------------------------------373class CV_EMTest : public cvtest::BaseTest374{375public:376CV_EMTest() {}377protected:378virtual void run( int start_from );379int runCase( int caseIndex, const EM_Params& params,380const cv::Mat& trainData, const cv::Mat& trainLabels,381const cv::Mat& testData, const cv::Mat& testLabels,382const vector<int>& sizes);383};384385int CV_EMTest::runCase( int caseIndex, const EM_Params& params,386const cv::Mat& trainData, const cv::Mat& trainLabels,387const cv::Mat& testData, const cv::Mat& testLabels,388const vector<int>& sizes )389{390int code = cvtest::TS::OK;391392cv::Mat labels;393float err;394395Ptr<EM> em = EM::create();396em->setClustersNumber(params.nclusters);397em->setCovarianceMatrixType(params.covMatType);398em->setTermCriteria(params.termCrit);399if( params.startStep == EM::START_AUTO_STEP )400em->trainEM( trainData, noArray(), labels, noArray() );401else if( params.startStep == EM::START_E_STEP )402em->trainE( trainData, *params.means, *params.covs,403*params.weights, noArray(), labels, noArray() );404else if( params.startStep == EM::START_M_STEP )405em->trainM( trainData, *params.probs,406noArray(), labels, noArray() );407408// check train error409if( !calcErr( labels, trainLabels, sizes, err , false, false ) )410{411ts->printf( cvtest::TS::LOG, "Case index %i : Bad output labels.\n", caseIndex );412code = cvtest::TS::FAIL_INVALID_OUTPUT;413}414else if( err > 0.008f )415{416ts->printf( cvtest::TS::LOG, "Case index %i : Bad accuracy (%f) on train data.\n", caseIndex, err );417code = cvtest::TS::FAIL_BAD_ACCURACY;418}419420// check test error421labels.create( testData.rows, 1, CV_32SC1 );422for( int i = 0; i < testData.rows; i++ )423{424Mat sample = testData.row(i);425Mat probs;426labels.at<int>(i) = static_cast<int>(em->predict2( sample, probs )[1]);427}428if( !calcErr( labels, testLabels, sizes, err, false, false ) )429{430ts->printf( cvtest::TS::LOG, "Case index %i : Bad output labels.\n", caseIndex );431code = cvtest::TS::FAIL_INVALID_OUTPUT;432}433else if( err > 0.008f )434{435ts->printf( cvtest::TS::LOG, "Case index %i : Bad accuracy (%f) on test data.\n", caseIndex, err );436code = cvtest::TS::FAIL_BAD_ACCURACY;437}438439return code;440}441442void CV_EMTest::run( int /*start_from*/ )443{444int sizesArr[] = { 500, 700, 800 };445int pointsCount = sizesArr[0]+ sizesArr[1] + sizesArr[2];446447// Points distribution448Mat means;449vector<Mat> covs;450defaultDistribs( means, covs, CV_64FC1 );451452// train data453Mat trainData( pointsCount, 2, CV_64FC1 ), trainLabels;454vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) );455generateData( trainData, trainLabels, sizes, means, covs, CV_64FC1, CV_32SC1 );456457// test data458Mat testData( pointsCount, 2, CV_64FC1 ), testLabels;459generateData( testData, testLabels, sizes, means, covs, CV_64FC1, CV_32SC1 );460461EM_Params params;462params.nclusters = 3;463Mat probs(trainData.rows, params.nclusters, CV_64FC1, cv::Scalar(1));464params.probs = &probs;465Mat weights(1, params.nclusters, CV_64FC1, cv::Scalar(1));466params.weights = &weights;467params.means = &means;468params.covs = &covs;469470int code = cvtest::TS::OK;471int caseIndex = 0;472{473params.startStep = EM::START_AUTO_STEP;474params.covMatType = EM::COV_MAT_GENERIC;475int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);476code = currCode == cvtest::TS::OK ? code : currCode;477}478{479params.startStep = EM::START_AUTO_STEP;480params.covMatType = EM::COV_MAT_DIAGONAL;481int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);482code = currCode == cvtest::TS::OK ? code : currCode;483}484{485params.startStep = EM::START_AUTO_STEP;486params.covMatType = EM::COV_MAT_SPHERICAL;487int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);488code = currCode == cvtest::TS::OK ? code : currCode;489}490{491params.startStep = EM::START_M_STEP;492params.covMatType = EM::COV_MAT_GENERIC;493int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);494code = currCode == cvtest::TS::OK ? code : currCode;495}496{497params.startStep = EM::START_M_STEP;498params.covMatType = EM::COV_MAT_DIAGONAL;499int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);500code = currCode == cvtest::TS::OK ? code : currCode;501}502{503params.startStep = EM::START_M_STEP;504params.covMatType = EM::COV_MAT_SPHERICAL;505int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);506code = currCode == cvtest::TS::OK ? code : currCode;507}508{509params.startStep = EM::START_E_STEP;510params.covMatType = EM::COV_MAT_GENERIC;511int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);512code = currCode == cvtest::TS::OK ? code : currCode;513}514{515params.startStep = EM::START_E_STEP;516params.covMatType = EM::COV_MAT_DIAGONAL;517int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);518code = currCode == cvtest::TS::OK ? code : currCode;519}520{521params.startStep = EM::START_E_STEP;522params.covMatType = EM::COV_MAT_SPHERICAL;523int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);524code = currCode == cvtest::TS::OK ? code : currCode;525}526527ts->set_failed_test_info( code );528}529530class CV_EMTest_SaveLoad : public cvtest::BaseTest {531public:532CV_EMTest_SaveLoad() {}533protected:534virtual void run( int /*start_from*/ )535{536int code = cvtest::TS::OK;537const int nclusters = 2;538539Mat samples = Mat(3,1,CV_64FC1);540samples.at<double>(0,0) = 1;541samples.at<double>(1,0) = 2;542samples.at<double>(2,0) = 3;543544Mat labels;545546Ptr<EM> em = EM::create();547em->setClustersNumber(nclusters);548em->trainEM(samples, noArray(), labels, noArray());549550Mat firstResult(samples.rows, 1, CV_32SC1);551for( int i = 0; i < samples.rows; i++)552firstResult.at<int>(i) = static_cast<int>(em->predict2(samples.row(i), noArray())[1]);553554// Write out555string filename = cv::tempfile(".xml");556{557FileStorage fs = FileStorage(filename, FileStorage::WRITE);558try559{560fs << "em" << "{";561em->write(fs);562fs << "}";563}564catch(...)565{566ts->printf( cvtest::TS::LOG, "Crash in write method.\n" );567ts->set_failed_test_info( cvtest::TS::FAIL_EXCEPTION );568}569}570571em.release();572573// Read in574try575{576em = Algorithm::load<EM>(filename);577}578catch(...)579{580ts->printf( cvtest::TS::LOG, "Crash in read method.\n" );581ts->set_failed_test_info( cvtest::TS::FAIL_EXCEPTION );582}583584remove( filename.c_str() );585586int errCaseCount = 0;587for( int i = 0; i < samples.rows; i++)588errCaseCount = std::abs(em->predict2(samples.row(i), noArray())[1] - firstResult.at<int>(i)) < FLT_EPSILON ? 0 : 1;589590if( errCaseCount > 0 )591{592ts->printf( cvtest::TS::LOG, "Different prediction results before writing and after reading (errCaseCount=%d).\n", errCaseCount );593code = cvtest::TS::FAIL_BAD_ACCURACY;594}595596ts->set_failed_test_info( code );597}598};599600class CV_EMTest_Classification : public cvtest::BaseTest601{602public:603CV_EMTest_Classification() {}604protected:605virtual void run(int)606{607// This test classifies spam by the following way:608// 1. estimates distributions of "spam" / "not spam"609// 2. predict classID using Bayes classifier for estimated distributions.610611string dataFilename = string(ts->get_data_path()) + "spambase.data";612Ptr<TrainData> data = TrainData::loadFromCSV(dataFilename, 0);613614if( data.empty() )615{616ts->printf(cvtest::TS::LOG, "File with spambase dataset cann't be read.\n");617ts->set_failed_test_info(cvtest::TS::FAIL_INVALID_TEST_DATA);618return;619}620621Mat samples = data->getSamples();622CV_Assert(samples.cols == 57);623Mat responses = data->getResponses();624625vector<int> trainSamplesMask(samples.rows, 0);626int trainSamplesCount = (int)(0.5f * samples.rows);627for(int i = 0; i < trainSamplesCount; i++)628trainSamplesMask[i] = 1;629RNG rng(0);630for(size_t i = 0; i < trainSamplesMask.size(); i++)631{632int i1 = rng(static_cast<unsigned>(trainSamplesMask.size()));633int i2 = rng(static_cast<unsigned>(trainSamplesMask.size()));634std::swap(trainSamplesMask[i1], trainSamplesMask[i2]);635}636637Mat samples0, samples1;638for(int i = 0; i < samples.rows; i++)639{640if(trainSamplesMask[i])641{642Mat sample = samples.row(i);643int resp = (int)responses.at<float>(i);644if(resp == 0)645samples0.push_back(sample);646else647samples1.push_back(sample);648}649}650Ptr<EM> model0 = EM::create();651model0->setClustersNumber(3);652model0->trainEM(samples0, noArray(), noArray(), noArray());653654Ptr<EM> model1 = EM::create();655model1->setClustersNumber(3);656model1->trainEM(samples1, noArray(), noArray(), noArray());657658Mat trainConfusionMat(2, 2, CV_32SC1, Scalar(0)),659testConfusionMat(2, 2, CV_32SC1, Scalar(0));660const double lambda = 1.;661for(int i = 0; i < samples.rows; i++)662{663Mat sample = samples.row(i);664double sampleLogLikelihoods0 = model0->predict2(sample, noArray())[0];665double sampleLogLikelihoods1 = model1->predict2(sample, noArray())[0];666667int classID = sampleLogLikelihoods0 >= lambda * sampleLogLikelihoods1 ? 0 : 1;668669if(trainSamplesMask[i])670trainConfusionMat.at<int>((int)responses.at<float>(i), classID)++;671else672testConfusionMat.at<int>((int)responses.at<float>(i), classID)++;673}674// std::cout << trainConfusionMat << std::endl;675// std::cout << testConfusionMat << std::endl;676677double trainError = (double)(trainConfusionMat.at<int>(1,0) + trainConfusionMat.at<int>(0,1)) / trainSamplesCount;678double testError = (double)(testConfusionMat.at<int>(1,0) + testConfusionMat.at<int>(0,1)) / (samples.rows - trainSamplesCount);679const double maxTrainError = 0.23;680const double maxTestError = 0.26;681682int code = cvtest::TS::OK;683if(trainError > maxTrainError)684{685ts->printf(cvtest::TS::LOG, "Too large train classification error (calc = %f, valid=%f).\n", trainError, maxTrainError);686code = cvtest::TS::FAIL_INVALID_TEST_DATA;687}688if(testError > maxTestError)689{690ts->printf(cvtest::TS::LOG, "Too large test classification error (calc = %f, valid=%f).\n", testError, maxTestError);691code = cvtest::TS::FAIL_INVALID_TEST_DATA;692}693694ts->set_failed_test_info(code);695}696};697698TEST(ML_KMeans, accuracy) { CV_KMeansTest test; test.safe_run(); }699TEST(ML_KNearest, accuracy) { CV_KNearestTest test; test.safe_run(); }700TEST(ML_EM, accuracy) { CV_EMTest test; test.safe_run(); }701TEST(ML_EM, save_load) { CV_EMTest_SaveLoad test; test.safe_run(); }702TEST(ML_EM, classification) { CV_EMTest_Classification test; test.safe_run(); }703704TEST(ML_KNearest, regression_12347)705{706Mat xTrainData = (Mat_<float>(5,2) << 1, 1.1, 1.1, 1, 2, 2, 2.1, 2, 2.1, 2.1);707Mat yTrainLabels = (Mat_<float>(5,1) << 1, 1, 2, 2, 2);708Ptr<KNearest> knn = KNearest::create();709knn->train(xTrainData, ml::ROW_SAMPLE, yTrainLabels);710711Mat xTestData = (Mat_<float>(2,2) << 1.1, 1.1, 2, 2.2);712Mat zBestLabels, neighbours, dist;713// check output shapes:714int K = 16, Kexp = std::min(K, xTrainData.rows);715knn->findNearest(xTestData, K, zBestLabels, neighbours, dist);716EXPECT_EQ(xTestData.rows, zBestLabels.rows);717EXPECT_EQ(neighbours.cols, Kexp);718EXPECT_EQ(dist.cols, Kexp);719// see if the result is still correct:720K = 2;721knn->findNearest(xTestData, K, zBestLabels, neighbours, dist);722EXPECT_EQ(1, zBestLabels.at<float>(0,0));723EXPECT_EQ(2, zBestLabels.at<float>(1,0));724}725726}} // namespace727728729