Path: blob/master/apps/traincascade/old_ml_boost.cpp
16337 views
/*M///////////////////////////////////////////////////////////////////////////////////////1//2// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.3//4// By downloading, copying, installing or using the software you agree to this license.5// If you do not agree to this license, do not download, install,6// copy or use the software.7//8//9// Intel License Agreement10//11// Copyright (C) 2000, Intel Corporation, all rights reserved.12// Third party copyrights are property of their respective owners.13//14// Redistribution and use in source and binary forms, with or without modification,15// are permitted provided that the following conditions are met:16//17// * Redistribution's of source code must retain the above copyright notice,18// this list of conditions and the following disclaimer.19//20// * Redistribution's in binary form must reproduce the above copyright notice,21// this list of conditions and the following disclaimer in the documentation22// and/or other materials provided with the distribution.23//24// * The name of Intel Corporation may not be used to endorse or promote products25// derived from this software without specific prior written permission.26//27// This software is provided by the copyright holders and contributors "as is" and28// any express or implied warranties, including, but not limited to, the implied29// warranties of merchantability and fitness for a particular purpose are disclaimed.30// In no event shall the Intel Corporation or contributors be liable for any direct,31// indirect, incidental, special, exemplary, or consequential damages32// (including, but not limited to, procurement of substitute goods or services;33// loss of use, data, or profits; or business interruption) however caused34// and on any theory of liability, whether in contract, strict liability,35// or tort (including negligence or otherwise) arising in any way out of36// the use of this software, even if advised of the possibility of such damage.37//38//M*/3940#include "old_ml_precomp.hpp"4142static inline double43log_ratio( double val )44{45const double eps = 1e-5;4647val = MAX( val, eps );48val = MIN( val, 1. - eps );49return log( val/(1. - val) );50}515253CvBoostParams::CvBoostParams()54{55boost_type = CvBoost::REAL;56weak_count = 100;57weight_trim_rate = 0.95;58cv_folds = 0;59max_depth = 1;60}616263CvBoostParams::CvBoostParams( int _boost_type, int _weak_count,64double _weight_trim_rate, int _max_depth,65bool _use_surrogates, const float* _priors )66{67boost_type = _boost_type;68weak_count = _weak_count;69weight_trim_rate = _weight_trim_rate;70split_criteria = CvBoost::DEFAULT;71cv_folds = 0;72max_depth = _max_depth;73use_surrogates = _use_surrogates;74priors = _priors;75}76777879///////////////////////////////// CvBoostTree ///////////////////////////////////8081CvBoostTree::CvBoostTree()82{83ensemble = 0;84}858687CvBoostTree::~CvBoostTree()88{89clear();90}919293void94CvBoostTree::clear()95{96CvDTree::clear();97ensemble = 0;98}99100101bool102CvBoostTree::train( CvDTreeTrainData* _train_data,103const CvMat* _subsample_idx, CvBoost* _ensemble )104{105clear();106ensemble = _ensemble;107data = _train_data;108data->shared = true;109return do_train( _subsample_idx );110}111112113bool114CvBoostTree::train( const CvMat*, int, const CvMat*, const CvMat*,115const CvMat*, const CvMat*, const CvMat*, CvDTreeParams )116{117assert(0);118return false;119}120121122bool123CvBoostTree::train( CvDTreeTrainData*, const CvMat* )124{125assert(0);126return false;127}128129130void131CvBoostTree::scale( double _scale )132{133CvDTreeNode* node = root;134135// traverse the tree and scale all the node values136for(;;)137{138CvDTreeNode* parent;139for(;;)140{141node->value *= _scale;142if( !node->left )143break;144node = node->left;145}146147for( parent = node->parent; parent && parent->right == node;148node = parent, parent = parent->parent )149;150151if( !parent )152break;153154node = parent->right;155}156}157158159void160CvBoostTree::try_split_node( CvDTreeNode* node )161{162CvDTree::try_split_node( node );163164if( !node->left )165{166// if the node has not been split,167// store the responses for the corresponding training samples168double* weak_eval = ensemble->get_weak_response()->data.db;169cv::AutoBuffer<int> inn_buf(node->sample_count);170const int* labels = data->get_cv_labels(node, inn_buf.data());171int i, count = node->sample_count;172double value = node->value;173174for( i = 0; i < count; i++ )175weak_eval[labels[i]] = value;176}177}178179180double181CvBoostTree::calc_node_dir( CvDTreeNode* node )182{183char* dir = (char*)data->direction->data.ptr;184const double* weights = ensemble->get_subtree_weights()->data.db;185int i, n = node->sample_count, vi = node->split->var_idx;186double L, R;187188assert( !node->split->inversed );189190if( data->get_var_type(vi) >= 0 ) // split on categorical var191{192cv::AutoBuffer<int> inn_buf(n);193const int* cat_labels = data->get_cat_var_data(node, vi, inn_buf.data());194const int* subset = node->split->subset;195double sum = 0, sum_abs = 0;196197for( i = 0; i < n; i++ )198{199int idx = ((cat_labels[i] == 65535) && data->is_buf_16u) ? -1 : cat_labels[i];200double w = weights[i];201int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;202sum += d*w; sum_abs += (d & 1)*w;203dir[i] = (char)d;204}205206R = (sum_abs + sum) * 0.5;207L = (sum_abs - sum) * 0.5;208}209else // split on ordered var210{211cv::AutoBuffer<uchar> inn_buf(2*n*sizeof(int)+n*sizeof(float));212float* values_buf = (float*)inn_buf.data();213int* sorted_indices_buf = (int*)(values_buf + n);214int* sample_indices_buf = sorted_indices_buf + n;215const float* values = 0;216const int* sorted_indices = 0;217data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );218int split_point = node->split->ord.split_point;219int n1 = node->get_num_valid(vi);220221assert( 0 <= split_point && split_point < n1-1 );222L = R = 0;223224for( i = 0; i <= split_point; i++ )225{226int idx = sorted_indices[i];227double w = weights[idx];228dir[idx] = (char)-1;229L += w;230}231232for( ; i < n1; i++ )233{234int idx = sorted_indices[i];235double w = weights[idx];236dir[idx] = (char)1;237R += w;238}239240for( ; i < n; i++ )241dir[sorted_indices[i]] = (char)0;242}243244node->maxlr = MAX( L, R );245return node->split->quality/(L + R);246}247248249CvDTreeSplit*250CvBoostTree::find_split_ord_class( CvDTreeNode* node, int vi, float init_quality,251CvDTreeSplit* _split, uchar* _ext_buf )252{253const float epsilon = FLT_EPSILON*2;254255const double* weights = ensemble->get_subtree_weights()->data.db;256int n = node->sample_count;257int n1 = node->get_num_valid(vi);258259cv::AutoBuffer<uchar> inn_buf;260if( !_ext_buf )261inn_buf.allocate(n*(3*sizeof(int)+sizeof(float)));262uchar* ext_buf = _ext_buf ? _ext_buf : inn_buf.data();263float* values_buf = (float*)ext_buf;264int* sorted_indices_buf = (int*)(values_buf + n);265int* sample_indices_buf = sorted_indices_buf + n;266const float* values = 0;267const int* sorted_indices = 0;268data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );269int* responses_buf = sorted_indices_buf + n;270const int* responses = data->get_class_labels( node, responses_buf );271const double* rcw0 = weights + n;272double lcw[2] = {0,0}, rcw[2];273int i, best_i = -1;274double best_val = init_quality;275int boost_type = ensemble->get_params().boost_type;276int split_criteria = ensemble->get_params().split_criteria;277278rcw[0] = rcw0[0]; rcw[1] = rcw0[1];279for( i = n1; i < n; i++ )280{281int idx = sorted_indices[i];282double w = weights[idx];283rcw[responses[idx]] -= w;284}285286if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS )287split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI;288289if( split_criteria == CvBoost::GINI )290{291double L = 0, R = rcw[0] + rcw[1];292double lsum2 = 0, rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1];293294for( i = 0; i < n1 - 1; i++ )295{296int idx = sorted_indices[i];297double w = weights[idx], w2 = w*w;298double lv, rv;299idx = responses[idx];300L += w; R -= w;301lv = lcw[idx]; rv = rcw[idx];302lsum2 += 2*lv*w + w2;303rsum2 -= 2*rv*w - w2;304lcw[idx] = lv + w; rcw[idx] = rv - w;305306if( values[i] + epsilon < values[i+1] )307{308double val = (lsum2*R + rsum2*L)/(L*R);309if( best_val < val )310{311best_val = val;312best_i = i;313}314}315}316}317else318{319for( i = 0; i < n1 - 1; i++ )320{321int idx = sorted_indices[i];322double w = weights[idx];323idx = responses[idx];324lcw[idx] += w;325rcw[idx] -= w;326327if( values[i] + epsilon < values[i+1] )328{329double val = lcw[0] + rcw[1], val2 = lcw[1] + rcw[0];330val = MAX(val, val2);331if( best_val < val )332{333best_val = val;334best_i = i;335}336}337}338}339340CvDTreeSplit* split = 0;341if( best_i >= 0 )342{343split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );344split->var_idx = vi;345split->ord.c = (values[best_i] + values[best_i+1])*0.5f;346split->ord.split_point = best_i;347split->inversed = 0;348split->quality = (float)best_val;349}350return split;351}352353template<typename T>354class LessThanPtr355{356public:357bool operator()(T* a, T* b) const { return *a < *b; }358};359360CvDTreeSplit*361CvBoostTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )362{363int ci = data->get_var_type(vi);364int n = node->sample_count;365int mi = data->cat_count->data.i[ci];366367int base_size = (2*mi+3)*sizeof(double) + mi*sizeof(double*);368cv::AutoBuffer<uchar> inn_buf((2*mi+3)*sizeof(double) + mi*sizeof(double*));369if( !_ext_buf)370inn_buf.allocate( base_size + 2*n*sizeof(int) );371uchar* base_buf = inn_buf.data();372uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;373374int* cat_labels_buf = (int*)ext_buf;375const int* cat_labels = data->get_cat_var_data(node, vi, cat_labels_buf);376int* responses_buf = cat_labels_buf + n;377const int* responses = data->get_class_labels(node, responses_buf);378double lcw[2]={0,0}, rcw[2]={0,0};379380double* cjk = (double*)cv::alignPtr(base_buf,sizeof(double))+2;381const double* weights = ensemble->get_subtree_weights()->data.db;382double** dbl_ptr = (double**)(cjk + 2*mi);383int i, j, k, idx;384double L = 0, R;385double best_val = init_quality;386int best_subset = -1, subset_i;387int boost_type = ensemble->get_params().boost_type;388int split_criteria = ensemble->get_params().split_criteria;389390// init array of counters:391// c_{jk} - number of samples that have vi-th input variable = j and response = k.392for( j = -1; j < mi; j++ )393cjk[j*2] = cjk[j*2+1] = 0;394395for( i = 0; i < n; i++ )396{397double w = weights[i];398j = ((cat_labels[i] == 65535) && data->is_buf_16u) ? -1 : cat_labels[i];399k = responses[i];400cjk[j*2 + k] += w;401}402403for( j = 0; j < mi; j++ )404{405rcw[0] += cjk[j*2];406rcw[1] += cjk[j*2+1];407dbl_ptr[j] = cjk + j*2 + 1;408}409410R = rcw[0] + rcw[1];411412if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS )413split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI;414415// sort rows of c_jk by increasing c_j,1416// (i.e. by the weight of samples in j-th category that belong to class 1)417std::sort(dbl_ptr, dbl_ptr + mi, LessThanPtr<double>());418419for( subset_i = 0; subset_i < mi-1; subset_i++ )420{421idx = (int)(dbl_ptr[subset_i] - cjk)/2;422const double* crow = cjk + idx*2;423double w0 = crow[0], w1 = crow[1];424double weight = w0 + w1;425426if( weight < FLT_EPSILON )427continue;428429lcw[0] += w0; rcw[0] -= w0;430lcw[1] += w1; rcw[1] -= w1;431432if( split_criteria == CvBoost::GINI )433{434double lsum2 = lcw[0]*lcw[0] + lcw[1]*lcw[1];435double rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1];436437L += weight;438R -= weight;439440if( L > FLT_EPSILON && R > FLT_EPSILON )441{442double val = (lsum2*R + rsum2*L)/(L*R);443if( best_val < val )444{445best_val = val;446best_subset = subset_i;447}448}449}450else451{452double val = lcw[0] + rcw[1];453double val2 = lcw[1] + rcw[0];454455val = MAX(val, val2);456if( best_val < val )457{458best_val = val;459best_subset = subset_i;460}461}462}463464CvDTreeSplit* split = 0;465if( best_subset >= 0 )466{467split = _split ? _split : data->new_split_cat( 0, -1.0f);468split->var_idx = vi;469split->quality = (float)best_val;470memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));471for( i = 0; i <= best_subset; i++ )472{473idx = (int)(dbl_ptr[i] - cjk) >> 1;474split->subset[idx >> 5] |= 1 << (idx & 31);475}476}477return split;478}479480481CvDTreeSplit*482CvBoostTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )483{484const float epsilon = FLT_EPSILON*2;485const double* weights = ensemble->get_subtree_weights()->data.db;486int n = node->sample_count;487int n1 = node->get_num_valid(vi);488489cv::AutoBuffer<uchar> inn_buf;490if( !_ext_buf )491inn_buf.allocate(2*n*(sizeof(int)+sizeof(float)));492uchar* ext_buf = _ext_buf ? _ext_buf : inn_buf.data();493494float* values_buf = (float*)ext_buf;495int* indices_buf = (int*)(values_buf + n);496int* sample_indices_buf = indices_buf + n;497const float* values = 0;498const int* indices = 0;499data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices, sample_indices_buf );500float* responses_buf = (float*)(indices_buf + n);501const float* responses = data->get_ord_responses( node, responses_buf, sample_indices_buf );502503int i, best_i = -1;504double L = 0, R = weights[n];505double best_val = init_quality, lsum = 0, rsum = node->value*R;506507// compensate for missing values508for( i = n1; i < n; i++ )509{510int idx = indices[i];511double w = weights[idx];512rsum -= responses[idx]*w;513R -= w;514}515516// find the optimal split517for( i = 0; i < n1 - 1; i++ )518{519int idx = indices[i];520double w = weights[idx];521double t = responses[idx]*w;522L += w; R -= w;523lsum += t; rsum -= t;524525if( values[i] + epsilon < values[i+1] )526{527double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);528if( best_val < val )529{530best_val = val;531best_i = i;532}533}534}535536CvDTreeSplit* split = 0;537if( best_i >= 0 )538{539split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );540split->var_idx = vi;541split->ord.c = (values[best_i] + values[best_i+1])*0.5f;542split->ord.split_point = best_i;543split->inversed = 0;544split->quality = (float)best_val;545}546return split;547}548549550CvDTreeSplit*551CvBoostTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )552{553const double* weights = ensemble->get_subtree_weights()->data.db;554int ci = data->get_var_type(vi);555int n = node->sample_count;556int mi = data->cat_count->data.i[ci];557int base_size = (2*mi+3)*sizeof(double) + mi*sizeof(double*);558cv::AutoBuffer<uchar> inn_buf(base_size);559if( !_ext_buf )560inn_buf.allocate(base_size + n*(2*sizeof(int) + sizeof(float)));561uchar* base_buf = inn_buf.data();562uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;563564int* cat_labels_buf = (int*)ext_buf;565const int* cat_labels = data->get_cat_var_data(node, vi, cat_labels_buf);566float* responses_buf = (float*)(cat_labels_buf + n);567int* sample_indices_buf = (int*)(responses_buf + n);568const float* responses = data->get_ord_responses(node, responses_buf, sample_indices_buf);569570double* sum = (double*)cv::alignPtr(base_buf,sizeof(double)) + 1;571double* counts = sum + mi + 1;572double** sum_ptr = (double**)(counts + mi);573double L = 0, R = 0, best_val = init_quality, lsum = 0, rsum = 0;574int i, best_subset = -1, subset_i;575576for( i = -1; i < mi; i++ )577sum[i] = counts[i] = 0;578579// calculate sum response and weight of each category of the input var580for( i = 0; i < n; i++ )581{582int idx = ((cat_labels[i] == 65535) && data->is_buf_16u) ? -1 : cat_labels[i];583double w = weights[i];584double s = sum[idx] + responses[i]*w;585double nc = counts[idx] + w;586sum[idx] = s;587counts[idx] = nc;588}589590// calculate average response in each category591for( i = 0; i < mi; i++ )592{593R += counts[i];594rsum += sum[i];595sum[i] = fabs(counts[i]) > DBL_EPSILON ? sum[i]/counts[i] : 0;596sum_ptr[i] = sum + i;597}598599std::sort(sum_ptr, sum_ptr + mi, LessThanPtr<double>());600601// revert back to unnormalized sums602// (there should be a very little loss in accuracy)603for( i = 0; i < mi; i++ )604sum[i] *= counts[i];605606for( subset_i = 0; subset_i < mi-1; subset_i++ )607{608int idx = (int)(sum_ptr[subset_i] - sum);609double ni = counts[idx];610611if( ni > FLT_EPSILON )612{613double s = sum[idx];614lsum += s; L += ni;615rsum -= s; R -= ni;616617if( L > FLT_EPSILON && R > FLT_EPSILON )618{619double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);620if( best_val < val )621{622best_val = val;623best_subset = subset_i;624}625}626}627}628629CvDTreeSplit* split = 0;630if( best_subset >= 0 )631{632split = _split ? _split : data->new_split_cat( 0, -1.0f);633split->var_idx = vi;634split->quality = (float)best_val;635memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));636for( i = 0; i <= best_subset; i++ )637{638int idx = (int)(sum_ptr[i] - sum);639split->subset[idx >> 5] |= 1 << (idx & 31);640}641}642return split;643}644645646CvDTreeSplit*647CvBoostTree::find_surrogate_split_ord( CvDTreeNode* node, int vi, uchar* _ext_buf )648{649const float epsilon = FLT_EPSILON*2;650int n = node->sample_count;651cv::AutoBuffer<uchar> inn_buf;652if( !_ext_buf )653inn_buf.allocate(n*(2*sizeof(int)+sizeof(float)));654uchar* ext_buf = _ext_buf ? _ext_buf : inn_buf.data();655float* values_buf = (float*)ext_buf;656int* indices_buf = (int*)(values_buf + n);657int* sample_indices_buf = indices_buf + n;658const float* values = 0;659const int* indices = 0;660data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices, sample_indices_buf );661662const double* weights = ensemble->get_subtree_weights()->data.db;663const char* dir = (char*)data->direction->data.ptr;664int n1 = node->get_num_valid(vi);665// LL - number of samples that both the primary and the surrogate splits send to the left666// LR - ... primary split sends to the left and the surrogate split sends to the right667// RL - ... primary split sends to the right and the surrogate split sends to the left668// RR - ... both send to the right669int i, best_i = -1, best_inversed = 0;670double best_val;671double LL = 0, RL = 0, LR, RR;672double worst_val = node->maxlr;673double sum = 0, sum_abs = 0;674best_val = worst_val;675676for( i = 0; i < n1; i++ )677{678int idx = indices[i];679double w = weights[idx];680int d = dir[idx];681sum += d*w; sum_abs += (d & 1)*w;682}683684// sum_abs = R + L; sum = R - L685RR = (sum_abs + sum)*0.5;686LR = (sum_abs - sum)*0.5;687688// initially all the samples are sent to the right by the surrogate split,689// LR of them are sent to the left by primary split, and RR - to the right.690// now iteratively compute LL, LR, RL and RR for every possible surrogate split value.691for( i = 0; i < n1 - 1; i++ )692{693int idx = indices[i];694double w = weights[idx];695int d = dir[idx];696697if( d < 0 )698{699LL += w; LR -= w;700if( LL + RR > best_val && values[i] + epsilon < values[i+1] )701{702best_val = LL + RR;703best_i = i; best_inversed = 0;704}705}706else if( d > 0 )707{708RL += w; RR -= w;709if( RL + LR > best_val && values[i] + epsilon < values[i+1] )710{711best_val = RL + LR;712best_i = i; best_inversed = 1;713}714}715}716717return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,718(values[best_i] + values[best_i+1])*0.5f, best_i,719best_inversed, (float)best_val ) : 0;720}721722723CvDTreeSplit*724CvBoostTree::find_surrogate_split_cat( CvDTreeNode* node, int vi, uchar* _ext_buf )725{726const char* dir = (char*)data->direction->data.ptr;727const double* weights = ensemble->get_subtree_weights()->data.db;728int n = node->sample_count;729int i, mi = data->cat_count->data.i[data->get_var_type(vi)];730731int base_size = (2*mi+3)*sizeof(double);732cv::AutoBuffer<uchar> inn_buf(base_size);733if( !_ext_buf )734inn_buf.allocate(base_size + n*sizeof(int));735uchar* ext_buf = _ext_buf ? _ext_buf : inn_buf.data();736int* cat_labels_buf = (int*)ext_buf;737const int* cat_labels = data->get_cat_var_data(node, vi, cat_labels_buf);738739// LL - number of samples that both the primary and the surrogate splits send to the left740// LR - ... primary split sends to the left and the surrogate split sends to the right741// RL - ... primary split sends to the right and the surrogate split sends to the left742// RR - ... both send to the right743CvDTreeSplit* split = data->new_split_cat( vi, 0 );744double best_val = 0;745double* lc = (double*)cv::alignPtr(cat_labels_buf + n, sizeof(double)) + 1;746double* rc = lc + mi + 1;747748for( i = -1; i < mi; i++ )749lc[i] = rc[i] = 0;750751// 1. for each category calculate the weight of samples752// sent to the left (lc) and to the right (rc) by the primary split753for( i = 0; i < n; i++ )754{755int idx = ((cat_labels[i] == 65535) && data->is_buf_16u) ? -1 : cat_labels[i];756double w = weights[i];757int d = dir[i];758double sum = lc[idx] + d*w;759double sum_abs = rc[idx] + (d & 1)*w;760lc[idx] = sum; rc[idx] = sum_abs;761}762763for( i = 0; i < mi; i++ )764{765double sum = lc[i];766double sum_abs = rc[i];767lc[i] = (sum_abs - sum) * 0.5;768rc[i] = (sum_abs + sum) * 0.5;769}770771// 2. now form the split.772// in each category send all the samples to the same direction as majority773for( i = 0; i < mi; i++ )774{775double lval = lc[i], rval = rc[i];776if( lval > rval )777{778split->subset[i >> 5] |= 1 << (i & 31);779best_val += lval;780}781else782best_val += rval;783}784785split->quality = (float)best_val;786if( split->quality <= node->maxlr )787cvSetRemoveByPtr( data->split_heap, split ), split = 0;788789return split;790}791792793void794CvBoostTree::calc_node_value( CvDTreeNode* node )795{796int i, n = node->sample_count;797const double* weights = ensemble->get_weights()->data.db;798cv::AutoBuffer<uchar> inn_buf(n*(sizeof(int) + ( data->is_classifier ? sizeof(int) : sizeof(int) + sizeof(float))));799int* labels_buf = (int*)inn_buf.data();800const int* labels = data->get_cv_labels(node, labels_buf);801double* subtree_weights = ensemble->get_subtree_weights()->data.db;802double rcw[2] = {0,0};803int boost_type = ensemble->get_params().boost_type;804805if( data->is_classifier )806{807int* _responses_buf = labels_buf + n;808const int* _responses = data->get_class_labels(node, _responses_buf);809int m = data->get_num_classes();810int* cls_count = data->counts->data.i;811for( int k = 0; k < m; k++ )812cls_count[k] = 0;813814for( i = 0; i < n; i++ )815{816int idx = labels[i];817double w = weights[idx];818int r = _responses[i];819rcw[r] += w;820cls_count[r]++;821subtree_weights[i] = w;822}823824node->class_idx = rcw[1] > rcw[0];825826if( boost_type == CvBoost::DISCRETE )827{828// ignore cat_map for responses, and use {-1,1},829// as the whole ensemble response is computes as sign(sum_i(weak_response_i)830node->value = node->class_idx*2 - 1;831}832else833{834double p = rcw[1]/(rcw[0] + rcw[1]);835assert( boost_type == CvBoost::REAL );836837// store log-ratio of the probability838node->value = 0.5*log_ratio(p);839}840}841else842{843// in case of regression tree:844// * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,845// n is the number of samples in the node.846// * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)847double sum = 0, sum2 = 0, iw;848float* values_buf = (float*)(labels_buf + n);849int* sample_indices_buf = (int*)(values_buf + n);850const float* values = data->get_ord_responses(node, values_buf, sample_indices_buf);851852for( i = 0; i < n; i++ )853{854int idx = labels[i];855double w = weights[idx]/*priors[values[i] > 0]*/;856double t = values[i];857rcw[0] += w;858subtree_weights[i] = w;859sum += t*w;860sum2 += t*t*w;861}862863iw = 1./rcw[0];864node->value = sum*iw;865node->node_risk = sum2 - (sum*iw)*sum;866867// renormalize the risk, as in try_split_node the unweighted formula868// sqrt(risk)/n is used, rather than sqrt(risk)/sum(weights_i)869node->node_risk *= n*iw*n*iw;870}871872// store summary weights873subtree_weights[n] = rcw[0];874subtree_weights[n+1] = rcw[1];875}876877878void CvBoostTree::read( CvFileStorage* fs, CvFileNode* fnode, CvBoost* _ensemble, CvDTreeTrainData* _data )879{880CvDTree::read( fs, fnode, _data );881ensemble = _ensemble;882}883884void CvBoostTree::read( CvFileStorage*, CvFileNode* )885{886assert(0);887}888889void CvBoostTree::read( CvFileStorage* _fs, CvFileNode* _node,890CvDTreeTrainData* _data )891{892CvDTree::read( _fs, _node, _data );893}894895896/////////////////////////////////// CvBoost /////////////////////////////////////897898CvBoost::CvBoost()899{900data = 0;901weak = 0;902default_model_name = "my_boost_tree";903904active_vars = active_vars_abs = orig_response = sum_response = weak_eval =905subsample_mask = weights = subtree_weights = 0;906have_active_cat_vars = have_subsample = false;907908clear();909}910911912void CvBoost::prune( CvSlice slice )913{914if( weak && weak->total > 0 )915{916CvSeqReader reader;917int i, count = cvSliceLength( slice, weak );918919cvStartReadSeq( weak, &reader );920cvSetSeqReaderPos( &reader, slice.start_index );921922for( i = 0; i < count; i++ )923{924CvBoostTree* w;925CV_READ_SEQ_ELEM( w, reader );926delete w;927}928929cvSeqRemoveSlice( weak, slice );930}931}932933934void CvBoost::clear()935{936if( weak )937{938prune( CV_WHOLE_SEQ );939cvReleaseMemStorage( &weak->storage );940}941if( data )942delete data;943weak = 0;944data = 0;945cvReleaseMat( &active_vars );946cvReleaseMat( &active_vars_abs );947cvReleaseMat( &orig_response );948cvReleaseMat( &sum_response );949cvReleaseMat( &weak_eval );950cvReleaseMat( &subsample_mask );951cvReleaseMat( &weights );952cvReleaseMat( &subtree_weights );953954have_subsample = false;955}956957958CvBoost::~CvBoost()959{960clear();961}962963964CvBoost::CvBoost( const CvMat* _train_data, int _tflag,965const CvMat* _responses, const CvMat* _var_idx,966const CvMat* _sample_idx, const CvMat* _var_type,967const CvMat* _missing_mask, CvBoostParams _params )968{969weak = 0;970data = 0;971default_model_name = "my_boost_tree";972973active_vars = active_vars_abs = orig_response = sum_response = weak_eval =974subsample_mask = weights = subtree_weights = 0;975976train( _train_data, _tflag, _responses, _var_idx, _sample_idx,977_var_type, _missing_mask, _params );978}979980981bool982CvBoost::set_params( const CvBoostParams& _params )983{984bool ok = false;985986CV_FUNCNAME( "CvBoost::set_params" );987988__BEGIN__;989990params = _params;991if( params.boost_type != DISCRETE && params.boost_type != REAL &&992params.boost_type != LOGIT && params.boost_type != GENTLE )993CV_ERROR( CV_StsBadArg, "Unknown/unsupported boosting type" );994995params.weak_count = MAX( params.weak_count, 1 );996params.weight_trim_rate = MAX( params.weight_trim_rate, 0. );997params.weight_trim_rate = MIN( params.weight_trim_rate, 1. );998if( params.weight_trim_rate < FLT_EPSILON )999params.weight_trim_rate = 1.f;10001001if( params.boost_type == DISCRETE &&1002params.split_criteria != GINI && params.split_criteria != MISCLASS )1003params.split_criteria = MISCLASS;1004if( params.boost_type == REAL &&1005params.split_criteria != GINI && params.split_criteria != MISCLASS )1006params.split_criteria = GINI;1007if( (params.boost_type == LOGIT || params.boost_type == GENTLE) &&1008params.split_criteria != SQERR )1009params.split_criteria = SQERR;10101011ok = true;10121013__END__;10141015return ok;1016}101710181019bool1020CvBoost::train( const CvMat* _train_data, int _tflag,1021const CvMat* _responses, const CvMat* _var_idx,1022const CvMat* _sample_idx, const CvMat* _var_type,1023const CvMat* _missing_mask,1024CvBoostParams _params, bool _update )1025{1026bool ok = false;1027CvMemStorage* storage = 0;10281029CV_FUNCNAME( "CvBoost::train" );10301031__BEGIN__;10321033int i;10341035set_params( _params );10361037cvReleaseMat( &active_vars );1038cvReleaseMat( &active_vars_abs );10391040if( !_update || !data )1041{1042clear();1043data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,1044_sample_idx, _var_type, _missing_mask, _params, true, true );10451046if( data->get_num_classes() != 2 )1047CV_ERROR( CV_StsNotImplemented,1048"Boosted trees can only be used for 2-class classification." );1049CV_CALL( storage = cvCreateMemStorage() );1050weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );1051storage = 0;1052}1053else1054{1055data->set_data( _train_data, _tflag, _responses, _var_idx,1056_sample_idx, _var_type, _missing_mask, _params, true, true, true );1057}10581059if ( (_params.boost_type == LOGIT) || (_params.boost_type == GENTLE) )1060data->do_responses_copy();10611062update_weights( 0 );10631064for( i = 0; i < params.weak_count; i++ )1065{1066CvBoostTree* tree = new CvBoostTree;1067if( !tree->train( data, subsample_mask, this ) )1068{1069delete tree;1070break;1071}1072//cvCheckArr( get_weak_response());1073cvSeqPush( weak, &tree );1074update_weights( tree );1075trim_weights();1076if( cvCountNonZero(subsample_mask) == 0 )1077break;1078}10791080if(weak->total > 0)1081{1082get_active_vars(); // recompute active_vars* maps and condensed_idx's in the splits.1083data->is_classifier = true;1084data->free_train_data();1085ok = true;1086}1087else1088clear();10891090__END__;10911092return ok;1093}10941095bool CvBoost::train( CvMLData* _data,1096CvBoostParams _params,1097bool update )1098{1099bool result = false;11001101CV_FUNCNAME( "CvBoost::train" );11021103__BEGIN__;11041105const CvMat* values = _data->get_values();1106const CvMat* response = _data->get_responses();1107const CvMat* missing = _data->get_missing();1108const CvMat* var_types = _data->get_var_types();1109const CvMat* train_sidx = _data->get_train_sample_idx();1110const CvMat* var_idx = _data->get_var_idx();11111112CV_CALL( result = train( values, CV_ROW_SAMPLE, response, var_idx,1113train_sidx, var_types, missing, _params, update ) );11141115__END__;11161117return result;1118}11191120void CvBoost::initialize_weights(double (&p)[2])1121{1122p[0] = 1.;1123p[1] = 1.;1124}11251126void1127CvBoost::update_weights( CvBoostTree* tree )1128{1129CV_FUNCNAME( "CvBoost::update_weights" );11301131__BEGIN__;11321133int i, n = data->sample_count;1134double sumw = 0.;1135int step = 0;1136float* fdata = 0;1137int *sample_idx_buf;1138const int* sample_idx = 0;1139cv::AutoBuffer<uchar> inn_buf;1140size_t _buf_size = (params.boost_type == LOGIT) || (params.boost_type == GENTLE) ? (size_t)(data->sample_count)*sizeof(int) : 0;1141if( !tree )1142_buf_size += n*sizeof(int);1143else1144{1145if( have_subsample )1146_buf_size += data->get_length_subbuf()*(sizeof(float)+sizeof(uchar));1147}1148inn_buf.allocate(_buf_size);1149uchar* cur_buf_pos = inn_buf.data();11501151if ( (params.boost_type == LOGIT) || (params.boost_type == GENTLE) )1152{1153step = CV_IS_MAT_CONT(data->responses_copy->type) ?11541 : data->responses_copy->step / CV_ELEM_SIZE(data->responses_copy->type);1155fdata = data->responses_copy->data.fl;1156sample_idx_buf = (int*)cur_buf_pos;1157cur_buf_pos = (uchar*)(sample_idx_buf + data->sample_count);1158sample_idx = data->get_sample_indices( data->data_root, sample_idx_buf );1159}1160CvMat* dtree_data_buf = data->buf;1161size_t length_buf_row = data->get_length_subbuf();1162if( !tree ) // before training the first tree, initialize weights and other parameters1163{1164int* class_labels_buf = (int*)cur_buf_pos;1165cur_buf_pos = (uchar*)(class_labels_buf + n);1166const int* class_labels = data->get_class_labels(data->data_root, class_labels_buf);1167// in case of logitboost and gentle adaboost each weak tree is a regression tree,1168// so we need to convert class labels to floating-point values11691170double w0 = 1./ n;1171double p[2] = { 1., 1. };1172initialize_weights(p);11731174cvReleaseMat( &orig_response );1175cvReleaseMat( &sum_response );1176cvReleaseMat( &weak_eval );1177cvReleaseMat( &subsample_mask );1178cvReleaseMat( &weights );1179cvReleaseMat( &subtree_weights );11801181CV_CALL( orig_response = cvCreateMat( 1, n, CV_32S ));1182CV_CALL( weak_eval = cvCreateMat( 1, n, CV_64F ));1183CV_CALL( subsample_mask = cvCreateMat( 1, n, CV_8U ));1184CV_CALL( weights = cvCreateMat( 1, n, CV_64F ));1185CV_CALL( subtree_weights = cvCreateMat( 1, n + 2, CV_64F ));11861187if( data->have_priors )1188{1189// compute weight scale for each class from their prior probabilities1190int c1 = 0;1191for( i = 0; i < n; i++ )1192c1 += class_labels[i];1193p[0] = data->priors->data.db[0]*(c1 < n ? 1./(n - c1) : 0.);1194p[1] = data->priors->data.db[1]*(c1 > 0 ? 1./c1 : 0.);1195p[0] /= p[0] + p[1];1196p[1] = 1. - p[0];1197}11981199if (data->is_buf_16u)1200{1201unsigned short* labels = (unsigned short*)(dtree_data_buf->data.s + data->data_root->buf_idx*length_buf_row +1202data->data_root->offset + (size_t)(data->work_var_count-1)*data->sample_count);1203for( i = 0; i < n; i++ )1204{1205// save original categorical responses {0,1}, convert them to {-1,1}1206orig_response->data.i[i] = class_labels[i]*2 - 1;1207// make all the samples active at start.1208// later, in trim_weights() deactivate/reactive again some, if need1209subsample_mask->data.ptr[i] = (uchar)1;1210// make all the initial weights the same.1211weights->data.db[i] = w0*p[class_labels[i]];1212// set the labels to find (from within weak tree learning proc)1213// the particular sample weight, and where to store the response.1214labels[i] = (unsigned short)i;1215}1216}1217else1218{1219int* labels = dtree_data_buf->data.i + data->data_root->buf_idx*length_buf_row +1220data->data_root->offset + (size_t)(data->work_var_count-1)*data->sample_count;12211222for( i = 0; i < n; i++ )1223{1224// save original categorical responses {0,1}, convert them to {-1,1}1225orig_response->data.i[i] = class_labels[i]*2 - 1;1226// make all the samples active at start.1227// later, in trim_weights() deactivate/reactive again some, if need1228subsample_mask->data.ptr[i] = (uchar)1;1229// make all the initial weights the same.1230weights->data.db[i] = w0*p[class_labels[i]];1231// set the labels to find (from within weak tree learning proc)1232// the particular sample weight, and where to store the response.1233labels[i] = i;1234}1235}12361237if( params.boost_type == LOGIT )1238{1239CV_CALL( sum_response = cvCreateMat( 1, n, CV_64F ));12401241for( i = 0; i < n; i++ )1242{1243sum_response->data.db[i] = 0;1244fdata[sample_idx[i]*step] = orig_response->data.i[i] > 0 ? 2.f : -2.f;1245}12461247// in case of logitboost each weak tree is a regression tree.1248// the target function values are recalculated for each of the trees1249data->is_classifier = false;1250}1251else if( params.boost_type == GENTLE )1252{1253for( i = 0; i < n; i++ )1254fdata[sample_idx[i]*step] = (float)orig_response->data.i[i];12551256data->is_classifier = false;1257}1258}1259else1260{1261// at this moment, for all the samples that participated in the training of the most1262// recent weak classifier we know the responses. For other samples we need to compute them1263if( have_subsample )1264{1265float* values = (float*)cur_buf_pos;1266cur_buf_pos = (uchar*)(values + data->get_length_subbuf());1267uchar* missing = cur_buf_pos;1268cur_buf_pos = missing + data->get_length_subbuf() * (size_t)CV_ELEM_SIZE(data->buf->type);12691270CvMat _sample, _mask;12711272// invert the subsample mask1273cvXorS( subsample_mask, cvScalar(1.), subsample_mask );1274data->get_vectors( subsample_mask, values, missing, 0 );12751276_sample = cvMat( 1, data->var_count, CV_32F );1277_mask = cvMat( 1, data->var_count, CV_8U );12781279// run tree through all the non-processed samples1280for( i = 0; i < n; i++ )1281if( subsample_mask->data.ptr[i] )1282{1283_sample.data.fl = values;1284_mask.data.ptr = missing;1285values += _sample.cols;1286missing += _mask.cols;1287weak_eval->data.db[i] = tree->predict( &_sample, &_mask, true )->value;1288}1289}12901291// now update weights and other parameters for each type of boosting1292if( params.boost_type == DISCRETE )1293{1294// Discrete AdaBoost:1295// weak_eval[i] (=f(x_i)) is in {-1,1}1296// err = sum(w_i*(f(x_i) != y_i))/sum(w_i)1297// C = log((1-err)/err)1298// w_i *= exp(C*(f(x_i) != y_i))12991300double C, err = 0.;1301double scale[] = { 1., 0. };13021303for( i = 0; i < n; i++ )1304{1305double w = weights->data.db[i];1306sumw += w;1307err += w*(weak_eval->data.db[i] != orig_response->data.i[i]);1308}13091310if( sumw != 0 )1311err /= sumw;1312C = err = -log_ratio( err );1313scale[1] = exp(err);13141315sumw = 0;1316for( i = 0; i < n; i++ )1317{1318double w = weights->data.db[i]*1319scale[weak_eval->data.db[i] != orig_response->data.i[i]];1320sumw += w;1321weights->data.db[i] = w;1322}13231324tree->scale( C );1325}1326else if( params.boost_type == REAL )1327{1328// Real AdaBoost:1329// weak_eval[i] = f(x_i) = 0.5*log(p(x_i)/(1-p(x_i))), p(x_i)=P(y=1|x_i)1330// w_i *= exp(-y_i*f(x_i))13311332for( i = 0; i < n; i++ )1333weak_eval->data.db[i] *= -orig_response->data.i[i];13341335cvExp( weak_eval, weak_eval );13361337for( i = 0; i < n; i++ )1338{1339double w = weights->data.db[i]*weak_eval->data.db[i];1340sumw += w;1341weights->data.db[i] = w;1342}1343}1344else if( params.boost_type == LOGIT )1345{1346// LogitBoost:1347// weak_eval[i] = f(x_i) in [-z_max,z_max]1348// sum_response = F(x_i).1349// F(x_i) += 0.5*f(x_i)1350// p(x_i) = exp(F(x_i))/(exp(F(x_i)) + exp(-F(x_i))=1/(1+exp(-2*F(x_i)))1351// reuse weak_eval: weak_eval[i] <- p(x_i)1352// w_i = p(x_i)*1(1 - p(x_i))1353// z_i = ((y_i+1)/2 - p(x_i))/(p(x_i)*(1 - p(x_i)))1354// store z_i to the data->data_root as the new target responses13551356const double lb_weight_thresh = FLT_EPSILON;1357const double lb_z_max = 10.;1358/*float* responses_buf = data->get_resp_float_buf();1359const float* responses = 0;1360data->get_ord_responses(data->data_root, responses_buf, &responses);*/13611362/*if( weak->total == 7 )1363putchar('*');*/13641365for( i = 0; i < n; i++ )1366{1367double s = sum_response->data.db[i] + 0.5*weak_eval->data.db[i];1368sum_response->data.db[i] = s;1369weak_eval->data.db[i] = -2*s;1370}13711372cvExp( weak_eval, weak_eval );13731374for( i = 0; i < n; i++ )1375{1376double p = 1./(1. + weak_eval->data.db[i]);1377double w = p*(1 - p), z;1378w = MAX( w, lb_weight_thresh );1379weights->data.db[i] = w;1380sumw += w;1381if( orig_response->data.i[i] > 0 )1382{1383z = 1./p;1384fdata[sample_idx[i]*step] = (float)MIN(z, lb_z_max);1385}1386else1387{1388z = 1./(1-p);1389fdata[sample_idx[i]*step] = (float)-MIN(z, lb_z_max);1390}1391}1392}1393else1394{1395// Gentle AdaBoost:1396// weak_eval[i] = f(x_i) in [-1,1]1397// w_i *= exp(-y_i*f(x_i))1398assert( params.boost_type == GENTLE );13991400for( i = 0; i < n; i++ )1401weak_eval->data.db[i] *= -orig_response->data.i[i];14021403cvExp( weak_eval, weak_eval );14041405for( i = 0; i < n; i++ )1406{1407double w = weights->data.db[i] * weak_eval->data.db[i];1408weights->data.db[i] = w;1409sumw += w;1410}1411}1412}14131414// renormalize weights1415if( sumw > FLT_EPSILON )1416{1417sumw = 1./sumw;1418for( i = 0; i < n; ++i )1419weights->data.db[i] *= sumw;1420}14211422__END__;1423}142414251426void1427CvBoost::trim_weights()1428{1429//CV_FUNCNAME( "CvBoost::trim_weights" );14301431__BEGIN__;14321433int i, count = data->sample_count, nz_count = 0;1434double sum, threshold;14351436if( params.weight_trim_rate <= 0. || params.weight_trim_rate >= 1. )1437EXIT;14381439// use weak_eval as temporary buffer for sorted weights1440cvCopy( weights, weak_eval );14411442std::sort(weak_eval->data.db, weak_eval->data.db + count);14431444// as weight trimming occurs immediately after updating the weights,1445// where they are renormalized, we assume that the weight sum = 1.1446sum = 1. - params.weight_trim_rate;14471448for( i = 0; i < count; i++ )1449{1450double w = weak_eval->data.db[i];1451if( sum <= 0 )1452break;1453sum -= w;1454}14551456threshold = i < count ? weak_eval->data.db[i] : DBL_MAX;14571458for( i = 0; i < count; i++ )1459{1460double w = weights->data.db[i];1461int f = w >= threshold;1462subsample_mask->data.ptr[i] = (uchar)f;1463nz_count += f;1464}14651466have_subsample = nz_count < count;14671468__END__;1469}147014711472const CvMat*1473CvBoost::get_active_vars( bool absolute_idx )1474{1475CvMat* mask = 0;1476CvMat* inv_map = 0;1477CvMat* result = 0;14781479CV_FUNCNAME( "CvBoost::get_active_vars" );14801481__BEGIN__;14821483if( !weak )1484CV_ERROR( CV_StsError, "The boosted tree ensemble has not been trained yet" );14851486if( !active_vars || !active_vars_abs )1487{1488CvSeqReader reader;1489int i, j, nactive_vars;1490CvBoostTree* wtree;1491const CvDTreeNode* node;14921493assert(!active_vars && !active_vars_abs);1494mask = cvCreateMat( 1, data->var_count, CV_8U );1495inv_map = cvCreateMat( 1, data->var_count, CV_32S );1496cvZero( mask );1497cvSet( inv_map, cvScalar(-1) );14981499// first pass: compute the mask of used variables1500cvStartReadSeq( weak, &reader );1501for( i = 0; i < weak->total; i++ )1502{1503CV_READ_SEQ_ELEM(wtree, reader);15041505node = wtree->get_root();1506assert( node != 0 );1507for(;;)1508{1509const CvDTreeNode* parent;1510for(;;)1511{1512CvDTreeSplit* split = node->split;1513for( ; split != 0; split = split->next )1514mask->data.ptr[split->var_idx] = 1;1515if( !node->left )1516break;1517node = node->left;1518}15191520for( parent = node->parent; parent && parent->right == node;1521node = parent, parent = parent->parent )1522;15231524if( !parent )1525break;15261527node = parent->right;1528}1529}15301531nactive_vars = cvCountNonZero(mask);15321533//if ( nactive_vars > 0 )1534{1535active_vars = cvCreateMat( 1, nactive_vars, CV_32S );1536active_vars_abs = cvCreateMat( 1, nactive_vars, CV_32S );15371538have_active_cat_vars = false;15391540for( i = j = 0; i < data->var_count; i++ )1541{1542if( mask->data.ptr[i] )1543{1544active_vars->data.i[j] = i;1545active_vars_abs->data.i[j] = data->var_idx ? data->var_idx->data.i[i] : i;1546inv_map->data.i[i] = j;1547if( data->var_type->data.i[i] >= 0 )1548have_active_cat_vars = true;1549j++;1550}1551}155215531554// second pass: now compute the condensed indices1555cvStartReadSeq( weak, &reader );1556for( i = 0; i < weak->total; i++ )1557{1558CV_READ_SEQ_ELEM(wtree, reader);1559node = wtree->get_root();1560for(;;)1561{1562const CvDTreeNode* parent;1563for(;;)1564{1565CvDTreeSplit* split = node->split;1566for( ; split != 0; split = split->next )1567{1568split->condensed_idx = inv_map->data.i[split->var_idx];1569assert( split->condensed_idx >= 0 );1570}15711572if( !node->left )1573break;1574node = node->left;1575}15761577for( parent = node->parent; parent && parent->right == node;1578node = parent, parent = parent->parent )1579;15801581if( !parent )1582break;15831584node = parent->right;1585}1586}1587}1588}15891590result = absolute_idx ? active_vars_abs : active_vars;15911592__END__;15931594cvReleaseMat( &mask );1595cvReleaseMat( &inv_map );15961597return result;1598}159916001601float1602CvBoost::predict( const CvMat* _sample, const CvMat* _missing,1603CvMat* weak_responses, CvSlice slice,1604bool raw_mode, bool return_sum ) const1605{1606float value = -FLT_MAX;16071608CvSeqReader reader;1609double sum = 0;1610int wstep = 0;1611const float* sample_data;16121613if( !weak )1614CV_Error( CV_StsError, "The boosted tree ensemble has not been trained yet" );16151616if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||1617(_sample->cols != 1 && _sample->rows != 1) ||1618(_sample->cols + _sample->rows - 1 != data->var_all && !raw_mode) ||1619(active_vars && _sample->cols + _sample->rows - 1 != active_vars->cols && raw_mode) )1620CV_Error( CV_StsBadArg,1621"the input sample must be 1d floating-point vector with the same "1622"number of elements as the total number of variables or "1623"as the number of variables used for training" );16241625if( _missing )1626{1627if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||1628!CV_ARE_SIZES_EQ(_missing, _sample) )1629CV_Error( CV_StsBadArg,1630"the missing data mask must be 8-bit vector of the same size as input sample" );1631}16321633int i, weak_count = cvSliceLength( slice, weak );1634if( weak_count >= weak->total )1635{1636weak_count = weak->total;1637slice.start_index = 0;1638}16391640if( weak_responses )1641{1642if( !CV_IS_MAT(weak_responses) ||1643CV_MAT_TYPE(weak_responses->type) != CV_32FC1 ||1644(weak_responses->cols != 1 && weak_responses->rows != 1) ||1645weak_responses->cols + weak_responses->rows - 1 != weak_count )1646CV_Error( CV_StsBadArg,1647"The output matrix of weak classifier responses must be valid "1648"floating-point vector of the same number of components as the length of input slice" );1649wstep = CV_IS_MAT_CONT(weak_responses->type) ? 1 : weak_responses->step/sizeof(float);1650}16511652int var_count = active_vars->cols;1653const int* vtype = data->var_type->data.i;1654const int* cmap = data->cat_map->data.i;1655const int* cofs = data->cat_ofs->data.i;16561657cv::Mat sample = cv::cvarrToMat(_sample);1658cv::Mat missing;1659if(!_missing)1660missing = cv::cvarrToMat(_missing);16611662// if need, preprocess the input vector1663if( !raw_mode )1664{1665int sstep, mstep = 0;1666const float* src_sample;1667const uchar* src_mask = 0;1668float* dst_sample;1669uchar* dst_mask;1670const int* vidx = active_vars->data.i;1671const int* vidx_abs = active_vars_abs->data.i;1672bool have_mask = _missing != 0;16731674sample = cv::Mat(1, var_count, CV_32FC1);1675missing = cv::Mat(1, var_count, CV_8UC1);16761677dst_sample = sample.ptr<float>();1678dst_mask = missing.ptr<uchar>();16791680src_sample = _sample->data.fl;1681sstep = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(src_sample[0]);16821683if( _missing )1684{1685src_mask = _missing->data.ptr;1686mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step;1687}16881689for( i = 0; i < var_count; i++ )1690{1691int idx = vidx[i], idx_abs = vidx_abs[i];1692float val = src_sample[idx_abs*sstep];1693int ci = vtype[idx];1694uchar m = src_mask ? src_mask[idx_abs*mstep] : (uchar)0;16951696if( ci >= 0 )1697{1698int a = cofs[ci], b = (ci+1 >= data->cat_ofs->cols) ? data->cat_map->cols : cofs[ci+1],1699c = a;1700int ival = cvRound(val);1701if ( (ival != val) && (!m) )1702CV_Error( CV_StsBadArg,1703"one of input categorical variable is not an integer" );17041705while( a < b )1706{1707c = (a + b) >> 1;1708if( ival < cmap[c] )1709b = c;1710else if( ival > cmap[c] )1711a = c+1;1712else1713break;1714}17151716if( c < 0 || ival != cmap[c] )1717{1718m = 1;1719have_mask = true;1720}1721else1722{1723val = (float)(c - cofs[ci]);1724}1725}17261727dst_sample[i] = val;1728dst_mask[i] = m;1729}17301731if( !have_mask )1732missing.release();1733}1734else1735{1736if( !CV_IS_MAT_CONT(_sample->type & (_missing ? _missing->type : -1)) )1737CV_Error( CV_StsBadArg, "In raw mode the input vectors must be continuous" );1738}17391740cvStartReadSeq( weak, &reader );1741cvSetSeqReaderPos( &reader, slice.start_index );17421743sample_data = sample.ptr<float>();17441745if( !have_active_cat_vars && missing.empty() && !weak_responses )1746{1747for( i = 0; i < weak_count; i++ )1748{1749CvBoostTree* wtree;1750const CvDTreeNode* node;1751CV_READ_SEQ_ELEM( wtree, reader );17521753node = wtree->get_root();1754while( node->left )1755{1756CvDTreeSplit* split = node->split;1757int vi = split->condensed_idx;1758float val = sample_data[vi];1759int dir = val <= split->ord.c ? -1 : 1;1760if( split->inversed )1761dir = -dir;1762node = dir < 0 ? node->left : node->right;1763}1764sum += node->value;1765}1766}1767else1768{1769const int* avars = active_vars->data.i;1770const uchar* m = !missing.empty() ? missing.ptr<uchar>() : 0;17711772// full-featured version1773for( i = 0; i < weak_count; i++ )1774{1775CvBoostTree* wtree;1776const CvDTreeNode* node;1777CV_READ_SEQ_ELEM( wtree, reader );17781779node = wtree->get_root();1780while( node->left )1781{1782const CvDTreeSplit* split = node->split;1783int dir = 0;1784for( ; !dir && split != 0; split = split->next )1785{1786int vi = split->condensed_idx;1787int ci = vtype[avars[vi]];1788float val = sample_data[vi];1789if( m && m[vi] )1790continue;1791if( ci < 0 ) // ordered1792dir = val <= split->ord.c ? -1 : 1;1793else // categorical1794{1795int c = cvRound(val);1796dir = CV_DTREE_CAT_DIR(c, split->subset);1797}1798if( split->inversed )1799dir = -dir;1800}18011802if( !dir )1803{1804int diff = node->right->sample_count - node->left->sample_count;1805dir = diff < 0 ? -1 : 1;1806}1807node = dir < 0 ? node->left : node->right;1808}1809if( weak_responses )1810weak_responses->data.fl[i*wstep] = (float)node->value;1811sum += node->value;1812}1813}18141815if( return_sum )1816value = (float)sum;1817else1818{1819int cls_idx = sum >= 0;1820if( raw_mode )1821value = (float)cls_idx;1822else1823value = (float)cmap[cofs[vtype[data->var_count]] + cls_idx];1824}18251826return value;1827}18281829float CvBoost::calc_error( CvMLData* _data, int type, std::vector<float> *resp )1830{1831float err = 0;1832const CvMat* values = _data->get_values();1833const CvMat* response = _data->get_responses();1834const CvMat* missing = _data->get_missing();1835const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();1836const CvMat* var_types = _data->get_var_types();1837int* sidx = sample_idx ? sample_idx->data.i : 0;1838int r_step = CV_IS_MAT_CONT(response->type) ?18391 : response->step / CV_ELEM_SIZE(response->type);1840bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;1841int sample_count = sample_idx ? sample_idx->cols : 0;1842sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;1843float* pred_resp = 0;1844if( resp && (sample_count > 0) )1845{1846resp->resize( sample_count );1847pred_resp = &((*resp)[0]);1848}1849if ( is_classifier )1850{1851for( int i = 0; i < sample_count; i++ )1852{1853CvMat sample, miss;1854int si = sidx ? sidx[i] : i;1855cvGetRow( values, &sample, si );1856if( missing )1857cvGetRow( missing, &miss, si );1858float r = (float)predict( &sample, missing ? &miss : 0 );1859if( pred_resp )1860pred_resp[i] = r;1861int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;1862err += d;1863}1864err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;1865}1866else1867{1868for( int i = 0; i < sample_count; i++ )1869{1870CvMat sample, miss;1871int si = sidx ? sidx[i] : i;1872cvGetRow( values, &sample, si );1873if( missing )1874cvGetRow( missing, &miss, si );1875float r = (float)predict( &sample, missing ? &miss : 0 );1876if( pred_resp )1877pred_resp[i] = r;1878float d = r - response->data.fl[si*r_step];1879err += d*d;1880}1881err = sample_count ? err / (float)sample_count : -FLT_MAX;1882}1883return err;1884}18851886void CvBoost::write_params( CvFileStorage* fs ) const1887{1888const char* boost_type_str =1889params.boost_type == DISCRETE ? "DiscreteAdaboost" :1890params.boost_type == REAL ? "RealAdaboost" :1891params.boost_type == LOGIT ? "LogitBoost" :1892params.boost_type == GENTLE ? "GentleAdaboost" : 0;18931894const char* split_crit_str =1895params.split_criteria == DEFAULT ? "Default" :1896params.split_criteria == GINI ? "Gini" :1897params.boost_type == MISCLASS ? "Misclassification" :1898params.boost_type == SQERR ? "SquaredErr" : 0;18991900if( boost_type_str )1901cvWriteString( fs, "boosting_type", boost_type_str );1902else1903cvWriteInt( fs, "boosting_type", params.boost_type );19041905if( split_crit_str )1906cvWriteString( fs, "splitting_criteria", split_crit_str );1907else1908cvWriteInt( fs, "splitting_criteria", params.split_criteria );19091910cvWriteInt( fs, "ntrees", weak->total );1911cvWriteReal( fs, "weight_trimming_rate", params.weight_trim_rate );19121913data->write_params( fs );1914}191519161917void CvBoost::read_params( CvFileStorage* fs, CvFileNode* fnode )1918{1919CV_FUNCNAME( "CvBoost::read_params" );19201921__BEGIN__;19221923CvFileNode* temp;19241925if( !fnode || !CV_NODE_IS_MAP(fnode->tag) )1926return;19271928data = new CvDTreeTrainData();1929CV_CALL( data->read_params(fs, fnode));1930data->shared = true;19311932params.max_depth = data->params.max_depth;1933params.min_sample_count = data->params.min_sample_count;1934params.max_categories = data->params.max_categories;1935params.priors = data->params.priors;1936params.regression_accuracy = data->params.regression_accuracy;1937params.use_surrogates = data->params.use_surrogates;19381939temp = cvGetFileNodeByName( fs, fnode, "boosting_type" );1940if( !temp )1941return;19421943if( temp && CV_NODE_IS_STRING(temp->tag) )1944{1945const char* boost_type_str = cvReadString( temp, "" );1946params.boost_type = strcmp( boost_type_str, "DiscreteAdaboost" ) == 0 ? DISCRETE :1947strcmp( boost_type_str, "RealAdaboost" ) == 0 ? REAL :1948strcmp( boost_type_str, "LogitBoost" ) == 0 ? LOGIT :1949strcmp( boost_type_str, "GentleAdaboost" ) == 0 ? GENTLE : -1;1950}1951else1952params.boost_type = cvReadInt( temp, -1 );19531954if( params.boost_type < DISCRETE || params.boost_type > GENTLE )1955CV_ERROR( CV_StsBadArg, "Unknown boosting type" );19561957temp = cvGetFileNodeByName( fs, fnode, "splitting_criteria" );1958if( temp && CV_NODE_IS_STRING(temp->tag) )1959{1960const char* split_crit_str = cvReadString( temp, "" );1961params.split_criteria = strcmp( split_crit_str, "Default" ) == 0 ? DEFAULT :1962strcmp( split_crit_str, "Gini" ) == 0 ? GINI :1963strcmp( split_crit_str, "Misclassification" ) == 0 ? MISCLASS :1964strcmp( split_crit_str, "SquaredErr" ) == 0 ? SQERR : -1;1965}1966else1967params.split_criteria = cvReadInt( temp, -1 );19681969if( params.split_criteria < DEFAULT || params.boost_type > SQERR )1970CV_ERROR( CV_StsBadArg, "Unknown boosting type" );19711972params.weak_count = cvReadIntByName( fs, fnode, "ntrees" );1973params.weight_trim_rate = cvReadRealByName( fs, fnode, "weight_trimming_rate", 0. );19741975__END__;1976}1977197819791980void1981CvBoost::read( CvFileStorage* fs, CvFileNode* node )1982{1983CV_FUNCNAME( "CvBoost::read" );19841985__BEGIN__;19861987CvSeqReader reader;1988CvFileNode* trees_fnode;1989CvMemStorage* storage;1990int i, ntrees;19911992clear();1993read_params( fs, node );19941995if( !data )1996EXIT;19971998trees_fnode = cvGetFileNodeByName( fs, node, "trees" );1999if( !trees_fnode || !CV_NODE_IS_SEQ(trees_fnode->tag) )2000CV_ERROR( CV_StsParseError, "<trees> tag is missing" );20012002cvStartReadSeq( trees_fnode->data.seq, &reader );2003ntrees = trees_fnode->data.seq->total;20042005if( ntrees != params.weak_count )2006CV_ERROR( CV_StsUnmatchedSizes,2007"The number of trees stored does not match <ntrees> tag value" );20082009CV_CALL( storage = cvCreateMemStorage() );2010weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );20112012for( i = 0; i < ntrees; i++ )2013{2014CvBoostTree* tree = new CvBoostTree();2015CV_CALL(tree->read( fs, (CvFileNode*)reader.ptr, this, data ));2016CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );2017cvSeqPush( weak, &tree );2018}2019get_active_vars();20202021__END__;2022}202320242025void2026CvBoost::write( CvFileStorage* fs, const char* name ) const2027{2028CV_FUNCNAME( "CvBoost::write" );20292030__BEGIN__;20312032CvSeqReader reader;2033int i;20342035cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_BOOSTING );20362037if( !weak )2038CV_ERROR( CV_StsBadArg, "The classifier has not been trained yet" );20392040write_params( fs );2041cvStartWriteStruct( fs, "trees", CV_NODE_SEQ );20422043cvStartReadSeq( weak, &reader );20442045for( i = 0; i < weak->total; i++ )2046{2047CvBoostTree* tree;2048CV_READ_SEQ_ELEM( tree, reader );2049cvStartWriteStruct( fs, 0, CV_NODE_MAP );2050tree->write( fs );2051cvEndWriteStruct( fs );2052}20532054cvEndWriteStruct( fs );2055cvEndWriteStruct( fs );20562057__END__;2058}205920602061CvMat*2062CvBoost::get_weights()2063{2064return weights;2065}206620672068CvMat*2069CvBoost::get_subtree_weights()2070{2071return subtree_weights;2072}207320742075CvMat*2076CvBoost::get_weak_response()2077{2078return weak_eval;2079}208020812082const CvBoostParams&2083CvBoost::get_params() const2084{2085return params;2086}20872088CvSeq* CvBoost::get_weak_predictors()2089{2090return weak;2091}20922093const CvDTreeTrainData* CvBoost::get_data() const2094{2095return data;2096}20972098using namespace cv;20992100CvBoost::CvBoost( const Mat& _train_data, int _tflag,2101const Mat& _responses, const Mat& _var_idx,2102const Mat& _sample_idx, const Mat& _var_type,2103const Mat& _missing_mask,2104CvBoostParams _params )2105{2106weak = 0;2107data = 0;2108default_model_name = "my_boost_tree";2109active_vars = active_vars_abs = orig_response = sum_response = weak_eval =2110subsample_mask = weights = subtree_weights = 0;21112112train( _train_data, _tflag, _responses, _var_idx, _sample_idx,2113_var_type, _missing_mask, _params );2114}211521162117bool2118CvBoost::train( const Mat& _train_data, int _tflag,2119const Mat& _responses, const Mat& _var_idx,2120const Mat& _sample_idx, const Mat& _var_type,2121const Mat& _missing_mask,2122CvBoostParams _params, bool _update )2123{2124train_data_hdr = cvMat(_train_data);2125train_data_mat = _train_data;2126responses_hdr = cvMat(_responses);2127responses_mat = _responses;21282129CvMat vidx = cvMat(_var_idx), sidx = cvMat(_sample_idx), vtype = cvMat(_var_type), mmask = cvMat(_missing_mask);21302131return train(&train_data_hdr, _tflag, &responses_hdr, vidx.data.ptr ? &vidx : 0,2132sidx.data.ptr ? &sidx : 0, vtype.data.ptr ? &vtype : 0,2133mmask.data.ptr ? &mmask : 0, _params, _update);2134}21352136float2137CvBoost::predict( const Mat& _sample, const Mat& _missing,2138const Range& slice, bool raw_mode, bool return_sum ) const2139{2140CvMat sample = cvMat(_sample), mmask = cvMat(_missing);2141/*if( weak_responses )2142{2143int weak_count = cvSliceLength( slice, weak );2144if( weak_count >= weak->total )2145{2146weak_count = weak->total;2147slice.start_index = 0;2148}21492150if( !(weak_responses->data && weak_responses->type() == CV_32FC1 &&2151(weak_responses->cols == 1 || weak_responses->rows == 1) &&2152weak_responses->cols + weak_responses->rows - 1 == weak_count) )2153weak_responses->create(weak_count, 1, CV_32FC1);2154pwr = &(wr = *weak_responses);2155}*/2156return predict(&sample, _missing.empty() ? 0 : &mmask, 0,2157slice == Range::all() ? CV_WHOLE_SEQ : cvSlice(slice.start, slice.end),2158raw_mode, return_sum);2159}21602161/* End of file. */216221632164