Path: blob/master/samples/cpp/logistic_regression.cpp
16337 views
/*//////////////////////////////////////////////////////////////////////////////////////1// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.23// By downloading, copying, installing or using the software you agree to this license.4// If you do not agree to this license, do not download, install,5// copy or use the software.67// This is a implementation of the Logistic Regression algorithm in C++ in OpenCV.89// AUTHOR:10// Rahul Kavi rahulkavi[at]live[at]com11//1213// contains a subset of data from the popular Iris Dataset (taken from14// "http://archive.ics.uci.edu/ml/datasets/Iris")1516// # You are free to use, change, or redistribute the code in any way you wish for17// # non-commercial purposes, but please maintain the name of the original author.18// # This code comes with no warranty of any kind.1920// #21// # You are free to use, change, or redistribute the code in any way you wish for22// # non-commercial purposes, but please maintain the name of the original author.23// # This code comes with no warranty of any kind.2425// # Logistic Regression ALGORITHM2627// License Agreement28// For Open Source Computer Vision Library2930// Copyright (C) 2000-2008, Intel Corporation, all rights reserved.31// Copyright (C) 2008-2011, Willow Garage Inc., all rights reserved.32// Third party copyrights are property of their respective owners.3334// Redistribution and use in source and binary forms, with or without modification,35// are permitted provided that the following conditions are met:3637// * Redistributions of source code must retain the above copyright notice,38// this list of conditions and the following disclaimer.3940// * Redistributions in binary form must reproduce the above copyright notice,41// this list of conditions and the following disclaimer in the documentation42// and/or other materials provided with the distribution.4344// * The name of the copyright holders may not be used to endorse or promote products45// derived from this software without specific prior written permission.4647// This software is provided by the copyright holders and contributors "as is" and48// any express or implied warranties, including, but not limited to, the implied49// warranties of merchantability and fitness for a particular purpose are disclaimed.50// In no event shall the Intel Corporation or contributors be liable for any direct,51// indirect, incidental, special, exemplary, or consequential damages52// (including, but not limited to, procurement of substitute goods or services;53// loss of use, data, or profits; or business interruption) however caused54// and on any theory of liability, whether in contract, strict liability,55// or tort (including negligence or otherwise) arising in any way out of56// the use of this software, even if advised of the possibility of such damage.*/5758#include <iostream>5960#include <opencv2/core.hpp>61#include <opencv2/ml.hpp>62#include <opencv2/highgui.hpp>6364using namespace std;65using namespace cv;66using namespace cv::ml;6768static void showImage(const Mat &data, int columns, const String &name)69{70Mat bigImage;71for(int i = 0; i < data.rows; ++i)72{73bigImage.push_back(data.row(i).reshape(0, columns));74}75imshow(name, bigImage.t());76}7778static float calculateAccuracyPercent(const Mat &original, const Mat &predicted)79{80return 100 * (float)countNonZero(original == predicted) / predicted.rows;81}8283int main()84{85const String filename = "../data/data01.xml";86cout << "**********************************************************************" << endl;87cout << filename88<< " contains digits 0 and 1 of 20 samples each, collected on an Android device" << endl;89cout << "Each of the collected images are of size 28 x 28 re-arranged to 1 x 784 matrix"90<< endl;91cout << "**********************************************************************" << endl;9293Mat data, labels;94{95cout << "loading the dataset...";96FileStorage f;97if(f.open(filename, FileStorage::READ))98{99f["datamat"] >> data;100f["labelsmat"] >> labels;101f.release();102}103else104{105cerr << "file can not be opened: " << filename << endl;106return 1;107}108data.convertTo(data, CV_32F);109labels.convertTo(labels, CV_32F);110cout << "read " << data.rows << " rows of data" << endl;111}112113Mat data_train, data_test;114Mat labels_train, labels_test;115for(int i = 0; i < data.rows; i++)116{117if(i % 2 == 0)118{119data_train.push_back(data.row(i));120labels_train.push_back(labels.row(i));121}122else123{124data_test.push_back(data.row(i));125labels_test.push_back(labels.row(i));126}127}128cout << "training/testing samples count: " << data_train.rows << "/" << data_test.rows << endl;129130// display sample image131showImage(data_train, 28, "train data");132showImage(data_test, 28, "test data");133134// simple case with batch gradient135cout << "training...";136//! [init]137Ptr<LogisticRegression> lr1 = LogisticRegression::create();138lr1->setLearningRate(0.001);139lr1->setIterations(10);140lr1->setRegularization(LogisticRegression::REG_L2);141lr1->setTrainMethod(LogisticRegression::BATCH);142lr1->setMiniBatchSize(1);143//! [init]144lr1->train(data_train, ROW_SAMPLE, labels_train);145cout << "done!" << endl;146147cout << "predicting...";148Mat responses;149lr1->predict(data_test, responses);150cout << "done!" << endl;151152// show prediction report153cout << "original vs predicted:" << endl;154labels_test.convertTo(labels_test, CV_32S);155cout << labels_test.t() << endl;156cout << responses.t() << endl;157cout << "accuracy: " << calculateAccuracyPercent(labels_test, responses) << "%" << endl;158159// save the classfier160const String saveFilename = "NewLR_Trained.xml";161cout << "saving the classifier to " << saveFilename << endl;162lr1->save(saveFilename);163164// load the classifier onto new object165cout << "loading a new classifier from " << saveFilename << endl;166Ptr<LogisticRegression> lr2 = StatModel::load<LogisticRegression>(saveFilename);167168// predict using loaded classifier169cout << "predicting the dataset using the loaded classfier...";170Mat responses2;171lr2->predict(data_test, responses2);172cout << "done!" << endl;173174// calculate accuracy175cout << labels_test.t() << endl;176cout << responses2.t() << endl;177cout << "accuracy: " << calculateAccuracyPercent(labels_test, responses2) << "%" << endl;178179waitKey(0);180return 0;181}182183184