Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/ml/src/data.cpp
16337 views
1
/*M///////////////////////////////////////////////////////////////////////////////////////
2
//
3
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4
//
5
// By downloading, copying, installing or using the software you agree to this license.
6
// If you do not agree to this license, do not download, install,
7
// copy or use the software.
8
//
9
//
10
// Intel License Agreement
11
//
12
// Copyright (C) 2000, Intel Corporation, all rights reserved.
13
// Third party copyrights are property of their respective owners.
14
//
15
// Redistribution and use in source and binary forms, with or without modification,
16
// are permitted provided that the following conditions are met:
17
//
18
// * Redistribution's of source code must retain the above copyright notice,
19
// this list of conditions and the following disclaimer.
20
//
21
// * Redistribution's in binary form must reproduce the above copyright notice,
22
// this list of conditions and the following disclaimer in the documentation
23
// and/or other materials provided with the distribution.
24
//
25
// * The name of Intel Corporation may not be used to endorse or promote products
26
// derived from this software without specific prior written permission.
27
//
28
// This software is provided by the copyright holders and contributors "as is" and
29
// any express or implied warranties, including, but not limited to, the implied
30
// warranties of merchantability and fitness for a particular purpose are disclaimed.
31
// In no event shall the Intel Corporation or contributors be liable for any direct,
32
// indirect, incidental, special, exemplary, or consequential damages
33
// (including, but not limited to, procurement of substitute goods or services;
34
// loss of use, data, or profits; or business interruption) however caused
35
// and on any theory of liability, whether in contract, strict liability,
36
// or tort (including negligence or otherwise) arising in any way out of
37
// the use of this software, even if advised of the possibility of such damage.
38
//
39
//M*/
40
41
#include "precomp.hpp"
42
#include <ctype.h>
43
#include <algorithm>
44
#include <iterator>
45
46
#include <opencv2/core/utils/logger.hpp>
47
48
namespace cv { namespace ml {
49
50
static const float MISSED_VAL = TrainData::missingValue();
51
static const int VAR_MISSED = VAR_ORDERED;
52
53
TrainData::~TrainData() {}
54
55
Mat TrainData::getSubVector(const Mat& vec, const Mat& idx)
56
{
57
if (!(vec.cols == 1 || vec.rows == 1))
58
CV_LOG_WARNING(NULL, "'getSubVector(const Mat& vec, const Mat& idx)' call with non-1D input is deprecated. It is not designed to work with 2D matrixes (especially with 'cv::ml::COL_SAMPLE' layout).");
59
return getSubMatrix(vec, idx, vec.rows == 1 ? cv::ml::COL_SAMPLE : cv::ml::ROW_SAMPLE);
60
}
61
62
template<typename T>
63
Mat getSubMatrixImpl(const Mat& m, const Mat& idx, int layout)
64
{
65
int nidx = idx.checkVector(1, CV_32S);
66
int dims = m.cols, nsamples = m.rows;
67
68
Mat subm;
69
if (layout == COL_SAMPLE)
70
{
71
std::swap(dims, nsamples);
72
subm.create(dims, nidx, m.type());
73
}
74
else
75
{
76
subm.create(nidx, dims, m.type());
77
}
78
79
for (int i = 0; i < nidx; i++)
80
{
81
int k = idx.at<int>(i); CV_CheckGE(k, 0, "Bad idx"); CV_CheckLT(k, nsamples, "Bad idx or layout");
82
if (dims == 1)
83
{
84
subm.at<T>(i) = m.at<T>(k); // at() has "transparent" access for 1D col-based / row-based vectors.
85
}
86
else if (layout == COL_SAMPLE)
87
{
88
for (int j = 0; j < dims; j++)
89
subm.at<T>(j, i) = m.at<T>(j, k);
90
}
91
else
92
{
93
for (int j = 0; j < dims; j++)
94
subm.at<T>(i, j) = m.at<T>(k, j);
95
}
96
}
97
return subm;
98
}
99
100
Mat TrainData::getSubMatrix(const Mat& m, const Mat& idx, int layout)
101
{
102
if (idx.empty())
103
return m;
104
int type = m.type();
105
CV_CheckType(type, type == CV_32S || type == CV_32F || type == CV_64F, "");
106
if (type == CV_32S || type == CV_32F) // 32-bit
107
return getSubMatrixImpl<int>(m, idx, layout);
108
if (type == CV_64F) // 64-bit
109
return getSubMatrixImpl<double>(m, idx, layout);
110
CV_Error(Error::StsInternal, "");
111
}
112
113
114
class TrainDataImpl CV_FINAL : public TrainData
115
{
116
public:
117
typedef std::map<String, int> MapType;
118
119
TrainDataImpl()
120
{
121
file = 0;
122
clear();
123
}
124
125
virtual ~TrainDataImpl() { closeFile(); }
126
127
int getLayout() const CV_OVERRIDE { return layout; }
128
int getNSamples() const CV_OVERRIDE
129
{
130
return !sampleIdx.empty() ? (int)sampleIdx.total() :
131
layout == ROW_SAMPLE ? samples.rows : samples.cols;
132
}
133
int getNTrainSamples() const CV_OVERRIDE
134
{
135
return !trainSampleIdx.empty() ? (int)trainSampleIdx.total() : getNSamples();
136
}
137
int getNTestSamples() const CV_OVERRIDE
138
{
139
return !testSampleIdx.empty() ? (int)testSampleIdx.total() : 0;
140
}
141
int getNVars() const CV_OVERRIDE
142
{
143
return !varIdx.empty() ? (int)varIdx.total() : getNAllVars();
144
}
145
int getNAllVars() const CV_OVERRIDE
146
{
147
return layout == ROW_SAMPLE ? samples.cols : samples.rows;
148
}
149
150
Mat getTestSamples() const CV_OVERRIDE
151
{
152
Mat idx = getTestSampleIdx();
153
return idx.empty() ? Mat() : getSubMatrix(samples, idx, getLayout());
154
}
155
156
Mat getSamples() const CV_OVERRIDE { return samples; }
157
Mat getResponses() const CV_OVERRIDE { return responses; }
158
Mat getMissing() const CV_OVERRIDE { return missing; }
159
Mat getVarIdx() const CV_OVERRIDE { return varIdx; }
160
Mat getVarType() const CV_OVERRIDE { return varType; }
161
int getResponseType() const CV_OVERRIDE
162
{
163
return classLabels.empty() ? VAR_ORDERED : VAR_CATEGORICAL;
164
}
165
Mat getTrainSampleIdx() const CV_OVERRIDE { return !trainSampleIdx.empty() ? trainSampleIdx : sampleIdx; }
166
Mat getTestSampleIdx() const CV_OVERRIDE { return testSampleIdx; }
167
Mat getSampleWeights() const CV_OVERRIDE
168
{
169
return sampleWeights;
170
}
171
Mat getTrainSampleWeights() const CV_OVERRIDE
172
{
173
return getSubVector(sampleWeights, getTrainSampleIdx()); // 1D-vector
174
}
175
Mat getTestSampleWeights() const CV_OVERRIDE
176
{
177
Mat idx = getTestSampleIdx();
178
return idx.empty() ? Mat() : getSubVector(sampleWeights, idx); // 1D-vector
179
}
180
Mat getTrainResponses() const CV_OVERRIDE
181
{
182
return getSubMatrix(responses, getTrainSampleIdx(), cv::ml::ROW_SAMPLE); // col-based responses are transposed in setData()
183
}
184
Mat getTrainNormCatResponses() const CV_OVERRIDE
185
{
186
return getSubMatrix(normCatResponses, getTrainSampleIdx(), cv::ml::ROW_SAMPLE); // like 'responses'
187
}
188
Mat getTestResponses() const CV_OVERRIDE
189
{
190
Mat idx = getTestSampleIdx();
191
return idx.empty() ? Mat() : getSubMatrix(responses, idx, cv::ml::ROW_SAMPLE); // col-based responses are transposed in setData()
192
}
193
Mat getTestNormCatResponses() const CV_OVERRIDE
194
{
195
Mat idx = getTestSampleIdx();
196
return idx.empty() ? Mat() : getSubMatrix(normCatResponses, idx, cv::ml::ROW_SAMPLE); // like 'responses'
197
}
198
Mat getNormCatResponses() const CV_OVERRIDE { return normCatResponses; }
199
Mat getClassLabels() const CV_OVERRIDE { return classLabels; }
200
Mat getClassCounters() const { return classCounters; }
201
int getCatCount(int vi) const CV_OVERRIDE
202
{
203
int n = (int)catOfs.total();
204
CV_Assert( 0 <= vi && vi < n );
205
Vec2i ofs = catOfs.at<Vec2i>(vi);
206
return ofs[1] - ofs[0];
207
}
208
209
Mat getCatOfs() const CV_OVERRIDE { return catOfs; }
210
Mat getCatMap() const CV_OVERRIDE { return catMap; }
211
212
Mat getDefaultSubstValues() const CV_OVERRIDE { return missingSubst; }
213
214
void closeFile() { if(file) fclose(file); file=0; }
215
void clear()
216
{
217
closeFile();
218
samples.release();
219
missing.release();
220
varType.release();
221
varSymbolFlags.release();
222
responses.release();
223
sampleIdx.release();
224
trainSampleIdx.release();
225
testSampleIdx.release();
226
normCatResponses.release();
227
classLabels.release();
228
classCounters.release();
229
catMap.release();
230
catOfs.release();
231
nameMap = MapType();
232
layout = ROW_SAMPLE;
233
}
234
235
typedef std::map<int, int> CatMapHash;
236
237
void setData(InputArray _samples, int _layout, InputArray _responses,
238
InputArray _varIdx, InputArray _sampleIdx, InputArray _sampleWeights,
239
InputArray _varType, InputArray _missing)
240
{
241
clear();
242
243
CV_Assert(_layout == ROW_SAMPLE || _layout == COL_SAMPLE );
244
samples = _samples.getMat();
245
layout = _layout;
246
responses = _responses.getMat();
247
varIdx = _varIdx.getMat();
248
sampleIdx = _sampleIdx.getMat();
249
sampleWeights = _sampleWeights.getMat();
250
varType = _varType.getMat();
251
missing = _missing.getMat();
252
253
int nsamples = layout == ROW_SAMPLE ? samples.rows : samples.cols;
254
int ninputvars = layout == ROW_SAMPLE ? samples.cols : samples.rows;
255
int i, noutputvars = 0;
256
257
CV_Assert( samples.type() == CV_32F || samples.type() == CV_32S );
258
259
if( !sampleIdx.empty() )
260
{
261
CV_Assert( (sampleIdx.checkVector(1, CV_32S, true) > 0 &&
262
checkRange(sampleIdx, true, 0, 0, nsamples)) ||
263
sampleIdx.checkVector(1, CV_8U, true) == nsamples );
264
if( sampleIdx.type() == CV_8U )
265
sampleIdx = convertMaskToIdx(sampleIdx);
266
}
267
268
if( !sampleWeights.empty() )
269
{
270
CV_Assert( sampleWeights.checkVector(1, CV_32F, true) == nsamples );
271
}
272
else
273
{
274
sampleWeights = Mat::ones(nsamples, 1, CV_32F);
275
}
276
277
if( !varIdx.empty() )
278
{
279
CV_Assert( (varIdx.checkVector(1, CV_32S, true) > 0 &&
280
checkRange(varIdx, true, 0, 0, ninputvars)) ||
281
varIdx.checkVector(1, CV_8U, true) == ninputvars );
282
if( varIdx.type() == CV_8U )
283
varIdx = convertMaskToIdx(varIdx);
284
varIdx = varIdx.clone();
285
std::sort(varIdx.ptr<int>(), varIdx.ptr<int>() + varIdx.total());
286
}
287
288
if( !responses.empty() )
289
{
290
CV_Assert( responses.type() == CV_32F || responses.type() == CV_32S );
291
if( (responses.cols == 1 || responses.rows == 1) && (int)responses.total() == nsamples )
292
noutputvars = 1;
293
else
294
{
295
CV_Assert( (layout == ROW_SAMPLE && responses.rows == nsamples) ||
296
(layout == COL_SAMPLE && responses.cols == nsamples) );
297
noutputvars = layout == ROW_SAMPLE ? responses.cols : responses.rows;
298
}
299
if( !responses.isContinuous() || (layout == COL_SAMPLE && noutputvars > 1) )
300
{
301
Mat temp;
302
transpose(responses, temp);
303
responses = temp;
304
}
305
}
306
307
int nvars = ninputvars + noutputvars;
308
309
if( !varType.empty() )
310
{
311
CV_Assert( varType.checkVector(1, CV_8U, true) == nvars &&
312
checkRange(varType, true, 0, VAR_ORDERED, VAR_CATEGORICAL+1) );
313
}
314
else
315
{
316
varType.create(1, nvars, CV_8U);
317
varType = Scalar::all(VAR_ORDERED);
318
if( noutputvars == 1 )
319
varType.at<uchar>(ninputvars) = (uchar)(responses.type() < CV_32F ? VAR_CATEGORICAL : VAR_ORDERED);
320
}
321
322
if( noutputvars > 1 )
323
{
324
for( i = 0; i < noutputvars; i++ )
325
CV_Assert( varType.at<uchar>(ninputvars + i) == VAR_ORDERED );
326
}
327
328
catOfs = Mat::zeros(1, nvars, CV_32SC2);
329
missingSubst = Mat::zeros(1, nvars, CV_32F);
330
331
vector<int> labels, counters, sortbuf, tempCatMap;
332
vector<Vec2i> tempCatOfs;
333
CatMapHash ofshash;
334
335
AutoBuffer<uchar> buf(nsamples);
336
Mat non_missing(layout == ROW_SAMPLE ? Size(1, nsamples) : Size(nsamples, 1), CV_8U, buf.data());
337
bool haveMissing = !missing.empty();
338
if( haveMissing )
339
{
340
CV_Assert( missing.size() == samples.size() && missing.type() == CV_8U );
341
}
342
343
// we iterate through all the variables. For each categorical variable we build a map
344
// in order to convert input values of the variable into normalized values (0..catcount_vi-1)
345
// often many categorical variables are similar, so we compress the map - try to re-use
346
// maps for different variables if they are identical
347
for( i = 0; i < ninputvars; i++ )
348
{
349
Mat values_i = layout == ROW_SAMPLE ? samples.col(i) : samples.row(i);
350
351
if( varType.at<uchar>(i) == VAR_CATEGORICAL )
352
{
353
preprocessCategorical(values_i, 0, labels, 0, sortbuf);
354
missingSubst.at<float>(i) = -1.f;
355
int j, m = (int)labels.size();
356
CV_Assert( m > 0 );
357
int a = labels.front(), b = labels.back();
358
const int* currmap = &labels[0];
359
int hashval = ((unsigned)a*127 + (unsigned)b)*127 + m;
360
CatMapHash::iterator it = ofshash.find(hashval);
361
if( it != ofshash.end() )
362
{
363
int vi = it->second;
364
Vec2i ofs0 = tempCatOfs[vi];
365
int m0 = ofs0[1] - ofs0[0];
366
const int* map0 = &tempCatMap[ofs0[0]];
367
if( m0 == m && map0[0] == a && map0[m0-1] == b )
368
{
369
for( j = 0; j < m; j++ )
370
if( map0[j] != currmap[j] )
371
break;
372
if( j == m )
373
{
374
// re-use the map
375
tempCatOfs.push_back(ofs0);
376
continue;
377
}
378
}
379
}
380
else
381
ofshash[hashval] = i;
382
Vec2i ofs;
383
ofs[0] = (int)tempCatMap.size();
384
ofs[1] = ofs[0] + m;
385
tempCatOfs.push_back(ofs);
386
std::copy(labels.begin(), labels.end(), std::back_inserter(tempCatMap));
387
}
388
else
389
{
390
tempCatOfs.push_back(Vec2i(0, 0));
391
/*Mat missing_i = layout == ROW_SAMPLE ? missing.col(i) : missing.row(i);
392
compare(missing_i, Scalar::all(0), non_missing, CMP_EQ);
393
missingSubst.at<float>(i) = (float)(mean(values_i, non_missing)[0]);*/
394
missingSubst.at<float>(i) = 0.f;
395
}
396
}
397
398
if( !tempCatOfs.empty() )
399
{
400
Mat(tempCatOfs).copyTo(catOfs);
401
Mat(tempCatMap).copyTo(catMap);
402
}
403
404
if( noutputvars > 0 && varType.at<uchar>(ninputvars) == VAR_CATEGORICAL )
405
{
406
preprocessCategorical(responses, &normCatResponses, labels, &counters, sortbuf);
407
Mat(labels).copyTo(classLabels);
408
Mat(counters).copyTo(classCounters);
409
}
410
}
411
412
Mat convertMaskToIdx(const Mat& mask)
413
{
414
int i, j, nz = countNonZero(mask), n = mask.cols + mask.rows - 1;
415
Mat idx(1, nz, CV_32S);
416
for( i = j = 0; i < n; i++ )
417
if( mask.at<uchar>(i) )
418
idx.at<int>(j++) = i;
419
return idx;
420
}
421
422
struct CmpByIdx
423
{
424
CmpByIdx(const int* _data, int _step) : data(_data), step(_step) {}
425
bool operator ()(int i, int j) const { return data[i*step] < data[j*step]; }
426
const int* data;
427
int step;
428
};
429
430
void preprocessCategorical(const Mat& data, Mat* normdata, vector<int>& labels,
431
vector<int>* counters, vector<int>& sortbuf)
432
{
433
CV_Assert((data.cols == 1 || data.rows == 1) && (data.type() == CV_32S || data.type() == CV_32F));
434
int* odata = 0;
435
int ostep = 0;
436
437
if(normdata)
438
{
439
normdata->create(data.size(), CV_32S);
440
odata = normdata->ptr<int>();
441
ostep = normdata->isContinuous() ? 1 : (int)normdata->step1();
442
}
443
444
int i, n = data.cols + data.rows - 1;
445
sortbuf.resize(n*2);
446
int* idx = &sortbuf[0];
447
int* idata = (int*)data.ptr<int>();
448
int istep = data.isContinuous() ? 1 : (int)data.step1();
449
450
if( data.type() == CV_32F )
451
{
452
idata = idx + n;
453
const float* fdata = data.ptr<float>();
454
for( i = 0; i < n; i++ )
455
{
456
if( fdata[i*istep] == MISSED_VAL )
457
idata[i] = -1;
458
else
459
{
460
idata[i] = cvRound(fdata[i*istep]);
461
CV_Assert( (float)idata[i] == fdata[i*istep] );
462
}
463
}
464
istep = 1;
465
}
466
467
for( i = 0; i < n; i++ )
468
idx[i] = i;
469
470
std::sort(idx, idx + n, CmpByIdx(idata, istep));
471
472
int clscount = 1;
473
for( i = 1; i < n; i++ )
474
clscount += idata[idx[i]*istep] != idata[idx[i-1]*istep];
475
476
int clslabel = -1;
477
int prev = ~idata[idx[0]*istep];
478
int previdx = 0;
479
480
labels.resize(clscount);
481
if(counters)
482
counters->resize(clscount);
483
484
for( i = 0; i < n; i++ )
485
{
486
int l = idata[idx[i]*istep];
487
if( l != prev )
488
{
489
clslabel++;
490
labels[clslabel] = l;
491
int k = i - previdx;
492
if( clslabel > 0 && counters )
493
counters->at(clslabel-1) = k;
494
prev = l;
495
previdx = i;
496
}
497
if(odata)
498
odata[idx[i]*ostep] = clslabel;
499
}
500
if(counters)
501
counters->at(clslabel) = i - previdx;
502
}
503
504
bool loadCSV(const String& filename, int headerLines,
505
int responseStartIdx, int responseEndIdx,
506
const String& varTypeSpec, char delimiter, char missch)
507
{
508
const int M = 1000000;
509
const char delimiters[3] = { ' ', delimiter, '\0' };
510
int nvars = 0;
511
bool varTypesSet = false;
512
513
clear();
514
515
file = fopen( filename.c_str(), "rt" );
516
517
if( !file )
518
return false;
519
520
std::vector<char> _buf(M);
521
std::vector<float> allresponses;
522
std::vector<float> rowvals;
523
std::vector<uchar> vtypes, rowtypes;
524
std::vector<uchar> vsymbolflags;
525
bool haveMissed = false;
526
char* buf = &_buf[0];
527
528
int i, ridx0 = responseStartIdx, ridx1 = responseEndIdx;
529
int ninputvars = 0, noutputvars = 0;
530
531
Mat tempSamples, tempMissing, tempResponses;
532
MapType tempNameMap;
533
int catCounter = 1;
534
535
// skip header lines
536
int lineno = 0;
537
for(;;lineno++)
538
{
539
if( !fgets(buf, M, file) )
540
break;
541
if(lineno < headerLines )
542
continue;
543
// trim trailing spaces
544
int idx = (int)strlen(buf)-1;
545
while( idx >= 0 && isspace(buf[idx]) )
546
buf[idx--] = '\0';
547
// skip spaces in the beginning
548
char* ptr = buf;
549
while( *ptr != '\0' && isspace(*ptr) )
550
ptr++;
551
// skip commented off lines
552
if(*ptr == '#')
553
continue;
554
rowvals.clear();
555
rowtypes.clear();
556
557
char* token = strtok(buf, delimiters);
558
if (!token)
559
break;
560
561
for(;;)
562
{
563
float val=0.f; int tp = 0;
564
decodeElem( token, val, tp, missch, tempNameMap, catCounter );
565
if( tp == VAR_MISSED )
566
haveMissed = true;
567
rowvals.push_back(val);
568
rowtypes.push_back((uchar)tp);
569
token = strtok(NULL, delimiters);
570
if (!token)
571
break;
572
}
573
574
if( nvars == 0 )
575
{
576
if( rowvals.empty() )
577
CV_Error(CV_StsBadArg, "invalid CSV format; no data found");
578
nvars = (int)rowvals.size();
579
if( !varTypeSpec.empty() && varTypeSpec.size() > 0 )
580
{
581
setVarTypes(varTypeSpec, nvars, vtypes);
582
varTypesSet = true;
583
}
584
else
585
vtypes = rowtypes;
586
vsymbolflags.resize(nvars);
587
for( i = 0; i < nvars; i++ )
588
vsymbolflags[i] = (uchar)(rowtypes[i] == VAR_CATEGORICAL);
589
590
ridx0 = ridx0 >= 0 ? ridx0 : ridx0 == -1 ? nvars - 1 : -1;
591
ridx1 = ridx1 >= 0 ? ridx1 : ridx0 >= 0 ? ridx0+1 : -1;
592
CV_Assert(ridx1 > ridx0);
593
noutputvars = ridx0 >= 0 ? ridx1 - ridx0 : 0;
594
ninputvars = nvars - noutputvars;
595
}
596
else
597
CV_Assert( nvars == (int)rowvals.size() );
598
599
// check var types
600
for( i = 0; i < nvars; i++ )
601
{
602
CV_Assert( (!varTypesSet && vtypes[i] == rowtypes[i]) ||
603
(varTypesSet && (vtypes[i] == rowtypes[i] || rowtypes[i] == VAR_ORDERED)) );
604
uchar sflag = (uchar)(rowtypes[i] == VAR_CATEGORICAL);
605
if( vsymbolflags[i] == VAR_MISSED )
606
vsymbolflags[i] = sflag;
607
else
608
CV_Assert(vsymbolflags[i] == sflag || rowtypes[i] == VAR_MISSED);
609
}
610
611
if( ridx0 >= 0 )
612
{
613
for( i = ridx1; i < nvars; i++ )
614
std::swap(rowvals[i], rowvals[i-noutputvars]);
615
for( i = ninputvars; i < nvars; i++ )
616
allresponses.push_back(rowvals[i]);
617
rowvals.pop_back();
618
}
619
Mat rmat(1, ninputvars, CV_32F, &rowvals[0]);
620
tempSamples.push_back(rmat);
621
}
622
623
closeFile();
624
625
int nsamples = tempSamples.rows;
626
if( nsamples == 0 )
627
return false;
628
629
if( haveMissed )
630
compare(tempSamples, MISSED_VAL, tempMissing, CMP_EQ);
631
632
if( ridx0 >= 0 )
633
{
634
for( i = ridx1; i < nvars; i++ )
635
std::swap(vtypes[i], vtypes[i-noutputvars]);
636
if( noutputvars > 1 )
637
{
638
for( i = ninputvars; i < nvars; i++ )
639
if( vtypes[i] == VAR_CATEGORICAL )
640
CV_Error(CV_StsBadArg,
641
"If responses are vector values, not scalars, they must be marked as ordered responses");
642
}
643
}
644
645
if( !varTypesSet && noutputvars == 1 && vtypes[ninputvars] == VAR_ORDERED )
646
{
647
for( i = 0; i < nsamples; i++ )
648
if( allresponses[i] != cvRound(allresponses[i]) )
649
break;
650
if( i == nsamples )
651
vtypes[ninputvars] = VAR_CATEGORICAL;
652
}
653
654
//If there are responses in the csv file, save them. If not, responses matrix will contain just zeros
655
if (noutputvars != 0){
656
Mat(nsamples, noutputvars, CV_32F, &allresponses[0]).copyTo(tempResponses);
657
setData(tempSamples, ROW_SAMPLE, tempResponses, noArray(), noArray(),
658
noArray(), Mat(vtypes).clone(), tempMissing);
659
}
660
else{
661
Mat zero_mat(nsamples, 1, CV_32F, Scalar(0));
662
zero_mat.copyTo(tempResponses);
663
setData(tempSamples, ROW_SAMPLE, tempResponses, noArray(), noArray(),
664
noArray(), noArray(), tempMissing);
665
}
666
bool ok = !samples.empty();
667
if(ok)
668
{
669
std::swap(tempNameMap, nameMap);
670
Mat(vsymbolflags).copyTo(varSymbolFlags);
671
}
672
return ok;
673
}
674
675
void decodeElem( const char* token, float& elem, int& type,
676
char missch, MapType& namemap, int& counter ) const
677
{
678
char* stopstring = NULL;
679
elem = (float)strtod( token, &stopstring );
680
if( *stopstring == missch && strlen(stopstring) == 1 ) // missed value
681
{
682
elem = MISSED_VAL;
683
type = VAR_MISSED;
684
}
685
else if( *stopstring != '\0' )
686
{
687
MapType::iterator it = namemap.find(token);
688
if( it == namemap.end() )
689
{
690
elem = (float)counter;
691
namemap[token] = counter++;
692
}
693
else
694
elem = (float)it->second;
695
type = VAR_CATEGORICAL;
696
}
697
else
698
type = VAR_ORDERED;
699
}
700
701
void setVarTypes( const String& s, int nvars, std::vector<uchar>& vtypes ) const
702
{
703
const char* errmsg = "type spec is not correct; it should have format \"cat\", \"ord\" or "
704
"\"ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\", where n's and m's are 0-based variable indices";
705
const char* str = s.c_str();
706
int specCounter = 0;
707
708
vtypes.resize(nvars);
709
710
for( int k = 0; k < 2; k++ )
711
{
712
const char* ptr = strstr(str, k == 0 ? "ord" : "cat");
713
int tp = k == 0 ? VAR_ORDERED : VAR_CATEGORICAL;
714
if( ptr ) // parse ord/cat str
715
{
716
char* stopstring = NULL;
717
718
if( ptr[3] == '\0' )
719
{
720
for( int i = 0; i < nvars; i++ )
721
vtypes[i] = (uchar)tp;
722
specCounter = nvars;
723
break;
724
}
725
726
if ( ptr[3] != '[')
727
CV_Error( CV_StsBadArg, errmsg );
728
729
ptr += 4; // pass "ord["
730
do
731
{
732
int b1 = (int)strtod( ptr, &stopstring );
733
if( *stopstring == 0 || (*stopstring != ',' && *stopstring != ']' && *stopstring != '-') )
734
CV_Error( CV_StsBadArg, errmsg );
735
ptr = stopstring + 1;
736
if( (stopstring[0] == ',') || (stopstring[0] == ']'))
737
{
738
CV_Assert( 0 <= b1 && b1 < nvars );
739
vtypes[b1] = (uchar)tp;
740
specCounter++;
741
}
742
else
743
{
744
if( stopstring[0] == '-')
745
{
746
int b2 = (int)strtod( ptr, &stopstring);
747
if ( (*stopstring == 0) || (*stopstring != ',' && *stopstring != ']') )
748
CV_Error( CV_StsBadArg, errmsg );
749
ptr = stopstring + 1;
750
CV_Assert( 0 <= b1 && b1 <= b2 && b2 < nvars );
751
for (int i = b1; i <= b2; i++)
752
vtypes[i] = (uchar)tp;
753
specCounter += b2 - b1 + 1;
754
}
755
else
756
CV_Error( CV_StsBadArg, errmsg );
757
758
}
759
}
760
while(*stopstring != ']');
761
}
762
}
763
764
if( specCounter != nvars )
765
CV_Error( CV_StsBadArg, "type of some variables is not specified" );
766
}
767
768
void setTrainTestSplitRatio(double ratio, bool shuffle) CV_OVERRIDE
769
{
770
CV_Assert( 0. <= ratio && ratio <= 1. );
771
setTrainTestSplit(cvRound(getNSamples()*ratio), shuffle);
772
}
773
774
void setTrainTestSplit(int count, bool shuffle) CV_OVERRIDE
775
{
776
int i, nsamples = getNSamples();
777
CV_Assert( 0 <= count && count < nsamples );
778
779
trainSampleIdx.release();
780
testSampleIdx.release();
781
782
if( count == 0 )
783
trainSampleIdx = sampleIdx;
784
else if( count == nsamples )
785
testSampleIdx = sampleIdx;
786
else
787
{
788
Mat mask(1, nsamples, CV_8U);
789
uchar* mptr = mask.ptr();
790
for( i = 0; i < nsamples; i++ )
791
mptr[i] = (uchar)(i < count);
792
trainSampleIdx.create(1, count, CV_32S);
793
testSampleIdx.create(1, nsamples - count, CV_32S);
794
int j0 = 0, j1 = 0;
795
const int* sptr = !sampleIdx.empty() ? sampleIdx.ptr<int>() : 0;
796
int* trainptr = trainSampleIdx.ptr<int>();
797
int* testptr = testSampleIdx.ptr<int>();
798
for( i = 0; i < nsamples; i++ )
799
{
800
int idx = sptr ? sptr[i] : i;
801
if( mptr[i] )
802
trainptr[j0++] = idx;
803
else
804
testptr[j1++] = idx;
805
}
806
if( shuffle )
807
shuffleTrainTest();
808
}
809
}
810
811
void shuffleTrainTest() CV_OVERRIDE
812
{
813
if( !trainSampleIdx.empty() && !testSampleIdx.empty() )
814
{
815
int i, nsamples = getNSamples(), ntrain = getNTrainSamples(), ntest = getNTestSamples();
816
int* trainIdx = trainSampleIdx.ptr<int>();
817
int* testIdx = testSampleIdx.ptr<int>();
818
RNG& rng = theRNG();
819
820
for( i = 0; i < nsamples; i++)
821
{
822
int a = rng.uniform(0, nsamples);
823
int b = rng.uniform(0, nsamples);
824
int* ptra = trainIdx;
825
int* ptrb = trainIdx;
826
if( a >= ntrain )
827
{
828
ptra = testIdx;
829
a -= ntrain;
830
CV_Assert( a < ntest );
831
}
832
if( b >= ntrain )
833
{
834
ptrb = testIdx;
835
b -= ntrain;
836
CV_Assert( b < ntest );
837
}
838
std::swap(ptra[a], ptrb[b]);
839
}
840
}
841
}
842
843
Mat getTrainSamples(int _layout,
844
bool compressSamples,
845
bool compressVars) const CV_OVERRIDE
846
{
847
if( samples.empty() )
848
return samples;
849
850
if( (!compressSamples || (trainSampleIdx.empty() && sampleIdx.empty())) &&
851
(!compressVars || varIdx.empty()) &&
852
layout == _layout )
853
return samples;
854
855
int drows = getNTrainSamples(), dcols = getNVars();
856
Mat sidx = getTrainSampleIdx(), vidx = getVarIdx();
857
const float* src0 = samples.ptr<float>();
858
const int* sptr = !sidx.empty() ? sidx.ptr<int>() : 0;
859
const int* vptr = !vidx.empty() ? vidx.ptr<int>() : 0;
860
size_t sstep0 = samples.step/samples.elemSize();
861
size_t sstep = layout == ROW_SAMPLE ? sstep0 : 1;
862
size_t vstep = layout == ROW_SAMPLE ? 1 : sstep0;
863
864
if( _layout == COL_SAMPLE )
865
{
866
std::swap(drows, dcols);
867
std::swap(sptr, vptr);
868
std::swap(sstep, vstep);
869
}
870
871
Mat dsamples(drows, dcols, CV_32F);
872
873
for( int i = 0; i < drows; i++ )
874
{
875
const float* src = src0 + (sptr ? sptr[i] : i)*sstep;
876
float* dst = dsamples.ptr<float>(i);
877
878
for( int j = 0; j < dcols; j++ )
879
dst[j] = src[(vptr ? vptr[j] : j)*vstep];
880
}
881
882
return dsamples;
883
}
884
885
void getValues( int vi, InputArray _sidx, float* values ) const CV_OVERRIDE
886
{
887
Mat sidx = _sidx.getMat();
888
int i, n = sidx.checkVector(1, CV_32S), nsamples = getNSamples();
889
CV_Assert( 0 <= vi && vi < getNAllVars() );
890
CV_Assert( n >= 0 );
891
const int* s = n > 0 ? sidx.ptr<int>() : 0;
892
if( n == 0 )
893
n = nsamples;
894
895
size_t step = samples.step/samples.elemSize();
896
size_t sstep = layout == ROW_SAMPLE ? step : 1;
897
size_t vstep = layout == ROW_SAMPLE ? 1 : step;
898
899
const float* src = samples.ptr<float>() + vi*vstep;
900
float subst = missingSubst.at<float>(vi);
901
for( i = 0; i < n; i++ )
902
{
903
int j = i;
904
if( s )
905
{
906
j = s[i];
907
CV_Assert( 0 <= j && j < nsamples );
908
}
909
values[i] = src[j*sstep];
910
if( values[i] == MISSED_VAL )
911
values[i] = subst;
912
}
913
}
914
915
void getNormCatValues( int vi, InputArray _sidx, int* values ) const CV_OVERRIDE
916
{
917
float* fvalues = (float*)values;
918
getValues(vi, _sidx, fvalues);
919
int i, n = (int)_sidx.total();
920
Vec2i ofs = catOfs.at<Vec2i>(vi);
921
int m = ofs[1] - ofs[0];
922
923
CV_Assert( m > 0 ); // if m==0, vi is an ordered variable
924
const int* cmap = &catMap.at<int>(ofs[0]);
925
bool fastMap = (m == cmap[m - 1] - cmap[0] + 1);
926
927
if( fastMap )
928
{
929
for( i = 0; i < n; i++ )
930
{
931
int val = cvRound(fvalues[i]);
932
int idx = val - cmap[0];
933
CV_Assert(cmap[idx] == val);
934
values[i] = idx;
935
}
936
}
937
else
938
{
939
for( i = 0; i < n; i++ )
940
{
941
int val = cvRound(fvalues[i]);
942
int a = 0, b = m, c = -1;
943
944
while( a < b )
945
{
946
c = (a + b) >> 1;
947
if( val < cmap[c] )
948
b = c;
949
else if( val > cmap[c] )
950
a = c+1;
951
else
952
break;
953
}
954
955
CV_DbgAssert( c >= 0 && val == cmap[c] );
956
values[i] = c;
957
}
958
}
959
}
960
961
void getSample(InputArray _vidx, int sidx, float* buf) const CV_OVERRIDE
962
{
963
CV_Assert(buf != 0 && 0 <= sidx && sidx < getNSamples());
964
Mat vidx = _vidx.getMat();
965
int i, n = vidx.checkVector(1, CV_32S), nvars = getNAllVars();
966
CV_Assert( n >= 0 );
967
const int* vptr = n > 0 ? vidx.ptr<int>() : 0;
968
if( n == 0 )
969
n = nvars;
970
971
size_t step = samples.step/samples.elemSize();
972
size_t sstep = layout == ROW_SAMPLE ? step : 1;
973
size_t vstep = layout == ROW_SAMPLE ? 1 : step;
974
975
const float* src = samples.ptr<float>() + sidx*sstep;
976
for( i = 0; i < n; i++ )
977
{
978
int j = i;
979
if( vptr )
980
{
981
j = vptr[i];
982
CV_Assert( 0 <= j && j < nvars );
983
}
984
buf[i] = src[j*vstep];
985
}
986
}
987
988
void getNames(std::vector<String>& names) const CV_OVERRIDE
989
{
990
size_t n = nameMap.size();
991
TrainDataImpl::MapType::const_iterator it = nameMap.begin(),
992
it_end = nameMap.end();
993
names.resize(n+1);
994
names[0] = "?";
995
for( ; it != it_end; ++it )
996
{
997
String s = it->first;
998
int label = it->second;
999
CV_Assert( label > 0 && label <= (int)n );
1000
names[label] = s;
1001
}
1002
}
1003
1004
Mat getVarSymbolFlags() const CV_OVERRIDE
1005
{
1006
return varSymbolFlags;
1007
}
1008
1009
FILE* file;
1010
int layout;
1011
Mat samples, missing, varType, varIdx, varSymbolFlags, responses, missingSubst;
1012
Mat sampleIdx, trainSampleIdx, testSampleIdx;
1013
Mat sampleWeights, catMap, catOfs;
1014
Mat normCatResponses, classLabels, classCounters;
1015
MapType nameMap;
1016
};
1017
1018
1019
Ptr<TrainData> TrainData::loadFromCSV(const String& filename,
1020
int headerLines,
1021
int responseStartIdx,
1022
int responseEndIdx,
1023
const String& varTypeSpec,
1024
char delimiter, char missch)
1025
{
1026
CV_TRACE_FUNCTION_SKIP_NESTED();
1027
Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
1028
if(!td->loadCSV(filename, headerLines, responseStartIdx, responseEndIdx, varTypeSpec, delimiter, missch))
1029
td.release();
1030
return td;
1031
}
1032
1033
Ptr<TrainData> TrainData::create(InputArray samples, int layout, InputArray responses,
1034
InputArray varIdx, InputArray sampleIdx, InputArray sampleWeights,
1035
InputArray varType)
1036
{
1037
CV_TRACE_FUNCTION_SKIP_NESTED();
1038
Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
1039
td->setData(samples, layout, responses, varIdx, sampleIdx, sampleWeights, varType, noArray());
1040
return td;
1041
}
1042
1043
}}
1044
1045
/* End of file. */
1046
1047