Path: blob/master/apps/traincascade/cascadeclassifier.cpp
16337 views
#include "opencv2/core.hpp"12#include "cascadeclassifier.h"3#include <queue>45using namespace std;6using namespace cv;78static const char* stageTypes[] = { CC_BOOST };9static const char* featureTypes[] = { CC_HAAR, CC_LBP, CC_HOG };1011CvCascadeParams::CvCascadeParams() : stageType( defaultStageType ),12featureType( defaultFeatureType ), winSize( cvSize(24, 24) )13{14name = CC_CASCADE_PARAMS;15}16CvCascadeParams::CvCascadeParams( int _stageType, int _featureType ) : stageType( _stageType ),17featureType( _featureType ), winSize( cvSize(24, 24) )18{19name = CC_CASCADE_PARAMS;20}2122//---------------------------- CascadeParams --------------------------------------2324void CvCascadeParams::write( FileStorage &fs ) const25{26string stageTypeStr = stageType == BOOST ? CC_BOOST : string();27CV_Assert( !stageTypeStr.empty() );28fs << CC_STAGE_TYPE << stageTypeStr;29string featureTypeStr = featureType == CvFeatureParams::HAAR ? CC_HAAR :30featureType == CvFeatureParams::LBP ? CC_LBP :31featureType == CvFeatureParams::HOG ? CC_HOG :320;33CV_Assert( !stageTypeStr.empty() );34fs << CC_FEATURE_TYPE << featureTypeStr;35fs << CC_HEIGHT << winSize.height;36fs << CC_WIDTH << winSize.width;37}3839bool CvCascadeParams::read( const FileNode &node )40{41if ( node.empty() )42return false;43string stageTypeStr, featureTypeStr;44FileNode rnode = node[CC_STAGE_TYPE];45if ( !rnode.isString() )46return false;47rnode >> stageTypeStr;48stageType = !stageTypeStr.compare( CC_BOOST ) ? BOOST : -1;49if (stageType == -1)50return false;51rnode = node[CC_FEATURE_TYPE];52if ( !rnode.isString() )53return false;54rnode >> featureTypeStr;55featureType = !featureTypeStr.compare( CC_HAAR ) ? CvFeatureParams::HAAR :56!featureTypeStr.compare( CC_LBP ) ? CvFeatureParams::LBP :57!featureTypeStr.compare( CC_HOG ) ? CvFeatureParams::HOG :58-1;59if (featureType == -1)60return false;61node[CC_HEIGHT] >> winSize.height;62node[CC_WIDTH] >> winSize.width;63return winSize.height > 0 && winSize.width > 0;64}6566void CvCascadeParams::printDefaults() const67{68CvParams::printDefaults();69cout << " [-stageType <";70for( int i = 0; i < (int)(sizeof(stageTypes)/sizeof(stageTypes[0])); i++ )71{72cout << (i ? " | " : "") << stageTypes[i];73if ( i == defaultStageType )74cout << "(default)";75}76cout << ">]" << endl;7778cout << " [-featureType <{";79for( int i = 0; i < (int)(sizeof(featureTypes)/sizeof(featureTypes[0])); i++ )80{81cout << (i ? ", " : "") << featureTypes[i];82if ( i == defaultStageType )83cout << "(default)";84}85cout << "}>]" << endl;86cout << " [-w <sampleWidth = " << winSize.width << ">]" << endl;87cout << " [-h <sampleHeight = " << winSize.height << ">]" << endl;88}8990void CvCascadeParams::printAttrs() const91{92cout << "stageType: " << stageTypes[stageType] << endl;93cout << "featureType: " << featureTypes[featureType] << endl;94cout << "sampleWidth: " << winSize.width << endl;95cout << "sampleHeight: " << winSize.height << endl;96}9798bool CvCascadeParams::scanAttr( const string prmName, const string val )99{100bool res = true;101if( !prmName.compare( "-stageType" ) )102{103for( int i = 0; i < (int)(sizeof(stageTypes)/sizeof(stageTypes[0])); i++ )104if( !val.compare( stageTypes[i] ) )105stageType = i;106}107else if( !prmName.compare( "-featureType" ) )108{109for( int i = 0; i < (int)(sizeof(featureTypes)/sizeof(featureTypes[0])); i++ )110if( !val.compare( featureTypes[i] ) )111featureType = i;112}113else if( !prmName.compare( "-w" ) )114{115winSize.width = atoi( val.c_str() );116}117else if( !prmName.compare( "-h" ) )118{119winSize.height = atoi( val.c_str() );120}121else122res = false;123return res;124}125126//---------------------------- CascadeClassifier --------------------------------------127128bool CvCascadeClassifier::train( const string _cascadeDirName,129const string _posFilename,130const string _negFilename,131int _numPos, int _numNeg,132int _precalcValBufSize, int _precalcIdxBufSize,133int _numStages,134const CvCascadeParams& _cascadeParams,135const CvFeatureParams& _featureParams,136const CvCascadeBoostParams& _stageParams,137bool baseFormatSave,138double acceptanceRatioBreakValue )139{140// Start recording clock ticks for training time output141double time = (double)getTickCount();142143if( _cascadeDirName.empty() || _posFilename.empty() || _negFilename.empty() )144CV_Error( CV_StsBadArg, "_cascadeDirName or _bgfileName or _vecFileName is NULL" );145146string dirName;147if (_cascadeDirName.find_last_of("/\\") == (_cascadeDirName.length() - 1) )148dirName = _cascadeDirName;149else150dirName = _cascadeDirName + '/';151152numPos = _numPos;153numNeg = _numNeg;154numStages = _numStages;155if ( !imgReader.create( _posFilename, _negFilename, _cascadeParams.winSize ) )156{157cout << "Image reader can not be created from -vec " << _posFilename158<< " and -bg " << _negFilename << "." << endl;159return false;160}161if ( !load( dirName ) )162{163cascadeParams = _cascadeParams;164featureParams = CvFeatureParams::create(cascadeParams.featureType);165featureParams->init(_featureParams);166stageParams = makePtr<CvCascadeBoostParams>();167*stageParams = _stageParams;168featureEvaluator = CvFeatureEvaluator::create(cascadeParams.featureType);169featureEvaluator->init( featureParams, numPos + numNeg, cascadeParams.winSize );170stageClassifiers.reserve( numStages );171}else{172// Make sure that if model parameters are preloaded, that people are aware of this,173// even when passing other parameters to the training command174cout << "---------------------------------------------------------------------------------" << endl;175cout << "Training parameters are pre-loaded from the parameter file in data folder!" << endl;176cout << "Please empty this folder if you want to use a NEW set of training parameters." << endl;177cout << "---------------------------------------------------------------------------------" << endl;178}179cout << "PARAMETERS:" << endl;180cout << "cascadeDirName: " << _cascadeDirName << endl;181cout << "vecFileName: " << _posFilename << endl;182cout << "bgFileName: " << _negFilename << endl;183cout << "numPos: " << _numPos << endl;184cout << "numNeg: " << _numNeg << endl;185cout << "numStages: " << numStages << endl;186cout << "precalcValBufSize[Mb] : " << _precalcValBufSize << endl;187cout << "precalcIdxBufSize[Mb] : " << _precalcIdxBufSize << endl;188cout << "acceptanceRatioBreakValue : " << acceptanceRatioBreakValue << endl;189cascadeParams.printAttrs();190stageParams->printAttrs();191featureParams->printAttrs();192cout << "Number of unique features given windowSize [" << _cascadeParams.winSize.width << "," << _cascadeParams.winSize.height << "] : " << featureEvaluator->getNumFeatures() << "" << endl;193194int startNumStages = (int)stageClassifiers.size();195if ( startNumStages > 1 )196cout << endl << "Stages 0-" << startNumStages-1 << " are loaded" << endl;197else if ( startNumStages == 1)198cout << endl << "Stage 0 is loaded" << endl;199200double requiredLeafFARate = pow( (double) stageParams->maxFalseAlarm, (double) numStages ) /201(double)stageParams->max_depth;202double tempLeafFARate;203204for( int i = startNumStages; i < numStages; i++ )205{206cout << endl << "===== TRAINING " << i << "-stage =====" << endl;207cout << "<BEGIN" << endl;208209if ( !updateTrainingSet( requiredLeafFARate, tempLeafFARate ) )210{211cout << "Train dataset for temp stage can not be filled. "212"Branch training terminated." << endl;213break;214}215if( tempLeafFARate <= requiredLeafFARate )216{217cout << "Required leaf false alarm rate achieved. "218"Branch training terminated." << endl;219break;220}221if( (tempLeafFARate <= acceptanceRatioBreakValue) && (acceptanceRatioBreakValue >= 0) ){222cout << "The required acceptanceRatio for the model has been reached to avoid overfitting of trainingdata. "223"Branch training terminated." << endl;224break;225}226227Ptr<CvCascadeBoost> tempStage = makePtr<CvCascadeBoost>();228bool isStageTrained = tempStage->train( featureEvaluator,229curNumSamples, _precalcValBufSize, _precalcIdxBufSize,230*stageParams );231cout << "END>" << endl;232233if(!isStageTrained)234break;235236stageClassifiers.push_back( tempStage );237238// save params239if( i == 0)240{241std::string paramsFilename = dirName + CC_PARAMS_FILENAME;242FileStorage fs( paramsFilename, FileStorage::WRITE);243if ( !fs.isOpened() )244{245cout << "Parameters can not be written, because file " << paramsFilename246<< " can not be opened." << endl;247return false;248}249fs << FileStorage::getDefaultObjectName(paramsFilename) << "{";250writeParams( fs );251fs << "}";252}253// save current stage254char buf[10];255sprintf(buf, "%s%d", "stage", i );256string stageFilename = dirName + buf + ".xml";257FileStorage fs( stageFilename, FileStorage::WRITE );258if ( !fs.isOpened() )259{260cout << "Current stage can not be written, because file " << stageFilename261<< " can not be opened." << endl;262return false;263}264fs << FileStorage::getDefaultObjectName(stageFilename) << "{";265tempStage->write( fs, Mat() );266fs << "}";267268// Output training time up till now269double seconds = ( (double)getTickCount() - time)/ getTickFrequency();270int days = int(seconds) / 60 / 60 / 24;271int hours = (int(seconds) / 60 / 60) % 24;272int minutes = (int(seconds) / 60) % 60;273int seconds_left = int(seconds) % 60;274cout << "Training until now has taken " << days << " days " << hours << " hours " << minutes << " minutes " << seconds_left <<" seconds." << endl;275}276277if(stageClassifiers.size() == 0)278{279cout << "Cascade classifier can't be trained. Check the used training parameters." << endl;280return false;281}282283save( dirName + CC_CASCADE_FILENAME, baseFormatSave );284285return true;286}287288int CvCascadeClassifier::predict( int sampleIdx )289{290CV_DbgAssert( sampleIdx < numPos + numNeg );291for (vector< Ptr<CvCascadeBoost> >::iterator it = stageClassifiers.begin();292it != stageClassifiers.end();++it )293{294if ( (*it)->predict( sampleIdx ) == 0.f )295return 0;296}297return 1;298}299300bool CvCascadeClassifier::updateTrainingSet( double minimumAcceptanceRatio, double& acceptanceRatio)301{302int64 posConsumed = 0, negConsumed = 0;303imgReader.restart();304int posCount = fillPassedSamples( 0, numPos, true, 0, posConsumed );305if( !posCount )306return false;307cout << "POS count : consumed " << posCount << " : " << (int)posConsumed << endl;308309int proNumNeg = cvRound( ( ((double)numNeg) * ((double)posCount) ) / numPos ); // apply only a fraction of negative samples. double is required since overflow is possible310int negCount = fillPassedSamples( posCount, proNumNeg, false, minimumAcceptanceRatio, negConsumed );311if ( !negCount )312if ( !(negConsumed > 0 && ((double)negCount+1)/(double)negConsumed <= minimumAcceptanceRatio) )313return false;314315curNumSamples = posCount + negCount;316acceptanceRatio = negConsumed == 0 ? 0 : ( (double)negCount/(double)(int64)negConsumed );317cout << "NEG count : acceptanceRatio " << negCount << " : " << acceptanceRatio << endl;318return true;319}320321int CvCascadeClassifier::fillPassedSamples( int first, int count, bool isPositive, double minimumAcceptanceRatio, int64& consumed )322{323int getcount = 0;324Mat img(cascadeParams.winSize, CV_8UC1);325for( int i = first; i < first + count; i++ )326{327for( ; ; )328{329if( consumed != 0 && ((double)getcount+1)/(double)(int64)consumed <= minimumAcceptanceRatio )330return getcount;331332bool isGetImg = isPositive ? imgReader.getPos( img ) :333imgReader.getNeg( img );334if( !isGetImg )335return getcount;336consumed++;337338featureEvaluator->setImage( img, isPositive ? 1 : 0, i );339if( predict( i ) == 1 )340{341getcount++;342printf("%s current samples: %d\r", isPositive ? "POS":"NEG", getcount);343break;344}345}346}347return getcount;348}349350void CvCascadeClassifier::writeParams( FileStorage &fs ) const351{352cascadeParams.write( fs );353fs << CC_STAGE_PARAMS << "{"; stageParams->write( fs ); fs << "}";354fs << CC_FEATURE_PARAMS << "{"; featureParams->write( fs ); fs << "}";355}356357void CvCascadeClassifier::writeFeatures( FileStorage &fs, const Mat& featureMap ) const358{359featureEvaluator->writeFeatures( fs, featureMap );360}361362void CvCascadeClassifier::writeStages( FileStorage &fs, const Mat& featureMap ) const363{364char cmnt[30];365int i = 0;366fs << CC_STAGES << "[";367for( vector< Ptr<CvCascadeBoost> >::const_iterator it = stageClassifiers.begin();368it != stageClassifiers.end();++it, ++i )369{370sprintf( cmnt, "stage %d", i );371cvWriteComment( fs.fs, cmnt, 0 );372fs << "{";373(*it)->write( fs, featureMap );374fs << "}";375}376fs << "]";377}378379bool CvCascadeClassifier::readParams( const FileNode &node )380{381if ( !node.isMap() || !cascadeParams.read( node ) )382return false;383384stageParams = makePtr<CvCascadeBoostParams>();385FileNode rnode = node[CC_STAGE_PARAMS];386if ( !stageParams->read( rnode ) )387return false;388389featureParams = CvFeatureParams::create(cascadeParams.featureType);390rnode = node[CC_FEATURE_PARAMS];391if ( !featureParams->read( rnode ) )392return false;393return true;394}395396bool CvCascadeClassifier::readStages( const FileNode &node)397{398FileNode rnode = node[CC_STAGES];399if (!rnode.empty() || !rnode.isSeq())400return false;401stageClassifiers.reserve(numStages);402FileNodeIterator it = rnode.begin();403for( int i = 0; i < min( (int)rnode.size(), numStages ); i++, it++ )404{405Ptr<CvCascadeBoost> tempStage = makePtr<CvCascadeBoost>();406if ( !tempStage->read( *it, featureEvaluator, *stageParams) )407return false;408stageClassifiers.push_back(tempStage);409}410return true;411}412413// For old Haar Classifier file saving414#define ICV_HAAR_TYPE_ID "opencv-haar-classifier"415#define ICV_HAAR_SIZE_NAME "size"416#define ICV_HAAR_STAGES_NAME "stages"417#define ICV_HAAR_TREES_NAME "trees"418#define ICV_HAAR_FEATURE_NAME "feature"419#define ICV_HAAR_RECTS_NAME "rects"420#define ICV_HAAR_TILTED_NAME "tilted"421#define ICV_HAAR_THRESHOLD_NAME "threshold"422#define ICV_HAAR_LEFT_NODE_NAME "left_node"423#define ICV_HAAR_LEFT_VAL_NAME "left_val"424#define ICV_HAAR_RIGHT_NODE_NAME "right_node"425#define ICV_HAAR_RIGHT_VAL_NAME "right_val"426#define ICV_HAAR_STAGE_THRESHOLD_NAME "stage_threshold"427#define ICV_HAAR_PARENT_NAME "parent"428#define ICV_HAAR_NEXT_NAME "next"429430void CvCascadeClassifier::save( const string filename, bool baseFormat )431{432FileStorage fs( filename, FileStorage::WRITE );433434if ( !fs.isOpened() )435return;436437fs << FileStorage::getDefaultObjectName(filename);438if ( !baseFormat )439{440Mat featureMap;441getUsedFeaturesIdxMap( featureMap );442fs << "{";443writeParams( fs );444fs << CC_STAGE_NUM << (int)stageClassifiers.size();445writeStages( fs, featureMap );446writeFeatures( fs, featureMap );447}448else449{450//char buf[256];451CvSeq* weak;452if ( cascadeParams.featureType != CvFeatureParams::HAAR )453CV_Error( CV_StsBadFunc, "old file format is used for Haar-like features only");454fs << "{:" ICV_HAAR_TYPE_ID;455fs << ICV_HAAR_SIZE_NAME << "[:" << cascadeParams.winSize.width <<456cascadeParams.winSize.height << "]";457fs << ICV_HAAR_STAGES_NAME << "[";458for( size_t si = 0; si < stageClassifiers.size(); si++ )459{460fs << "{"; //stage461/*sprintf( buf, "stage %d", si );462CV_CALL( cvWriteComment( fs, buf, 1 ) );*/463weak = stageClassifiers[si]->get_weak_predictors();464fs << ICV_HAAR_TREES_NAME << "[";465for( int wi = 0; wi < weak->total; wi++ )466{467int inner_node_idx = -1, total_inner_node_idx = -1;468queue<const CvDTreeNode*> inner_nodes_queue;469CvCascadeBoostTree* tree = *((CvCascadeBoostTree**) cvGetSeqElem( weak, wi ));470471fs << "[";472/*sprintf( buf, "tree %d", wi );473CV_CALL( cvWriteComment( fs, buf, 1 ) );*/474475const CvDTreeNode* tempNode;476477inner_nodes_queue.push( tree->get_root() );478total_inner_node_idx++;479480while (!inner_nodes_queue.empty())481{482tempNode = inner_nodes_queue.front();483inner_node_idx++;484485fs << "{";486fs << ICV_HAAR_FEATURE_NAME << "{";487((CvHaarEvaluator*)featureEvaluator.get())->writeFeature( fs, tempNode->split->var_idx );488fs << "}";489490fs << ICV_HAAR_THRESHOLD_NAME << tempNode->split->ord.c;491492if( tempNode->left->left || tempNode->left->right )493{494inner_nodes_queue.push( tempNode->left );495total_inner_node_idx++;496fs << ICV_HAAR_LEFT_NODE_NAME << total_inner_node_idx;497}498else499fs << ICV_HAAR_LEFT_VAL_NAME << tempNode->left->value;500501if( tempNode->right->left || tempNode->right->right )502{503inner_nodes_queue.push( tempNode->right );504total_inner_node_idx++;505fs << ICV_HAAR_RIGHT_NODE_NAME << total_inner_node_idx;506}507else508fs << ICV_HAAR_RIGHT_VAL_NAME << tempNode->right->value;509fs << "}"; // ICV_HAAR_FEATURE_NAME510inner_nodes_queue.pop();511}512fs << "]";513}514fs << "]"; //ICV_HAAR_TREES_NAME515fs << ICV_HAAR_STAGE_THRESHOLD_NAME << stageClassifiers[si]->getThreshold();516fs << ICV_HAAR_PARENT_NAME << (int)si-1 << ICV_HAAR_NEXT_NAME << -1;517fs << "}"; //stage518} /* for each stage */519fs << "]"; //ICV_HAAR_STAGES_NAME520}521fs << "}";522}523524bool CvCascadeClassifier::load( const string cascadeDirName )525{526FileStorage fs( cascadeDirName + CC_PARAMS_FILENAME, FileStorage::READ );527if ( !fs.isOpened() )528return false;529FileNode node = fs.getFirstTopLevelNode();530if ( !readParams( node ) )531return false;532featureEvaluator = CvFeatureEvaluator::create(cascadeParams.featureType);533featureEvaluator->init( featureParams, numPos + numNeg, cascadeParams.winSize );534fs.release();535536char buf[16] = {0};537for ( int si = 0; si < numStages; si++ )538{539sprintf( buf, "%s%d", "stage", si);540fs.open( cascadeDirName + buf + ".xml", FileStorage::READ );541node = fs.getFirstTopLevelNode();542if ( !fs.isOpened() )543break;544Ptr<CvCascadeBoost> tempStage = makePtr<CvCascadeBoost>();545546if ( !tempStage->read( node, featureEvaluator, *stageParams ))547{548fs.release();549break;550}551stageClassifiers.push_back(tempStage);552}553return true;554}555556void CvCascadeClassifier::getUsedFeaturesIdxMap( Mat& featureMap )557{558int varCount = featureEvaluator->getNumFeatures() * featureEvaluator->getFeatureSize();559featureMap.create( 1, varCount, CV_32SC1 );560featureMap.setTo(Scalar(-1));561562for( vector< Ptr<CvCascadeBoost> >::const_iterator it = stageClassifiers.begin();563it != stageClassifiers.end();++it )564(*it)->markUsedFeaturesInMap( featureMap );565566for( int fi = 0, idx = 0; fi < varCount; fi++ )567if ( featureMap.at<int>(0, fi) >= 0 )568featureMap.ptr<int>(0)[fi] = idx++;569}570571572