Path: blob/master/modules/python/test/test_digits.py
16337 views
#!/usr/bin/env python12'''3SVM and KNearest digit recognition.45Sample loads a dataset of handwritten digits from '../data/digits.png'.6Then it trains a SVM and KNearest classifiers on it and evaluates7their accuracy.89Following preprocessing is applied to the dataset:10- Moment-based image deskew (see deskew())11- Digit images are split into 4 10x10 cells and 16-bin12histogram of oriented gradients is computed for each13cell14- Transform histograms to space with Hellinger metric (see [1] (RootSIFT))151617[1] R. Arandjelovic, A. Zisserman18"Three things everyone should know to improve object retrieval"19http://www.robots.ox.ac.uk/~vgg/publications/2012/Arandjelovic12/arandjelovic12.pdf2021'''222324# Python 2/3 compatibility25from __future__ import print_function2627# built-in modules28from multiprocessing.pool import ThreadPool2930import cv2 as cv3132import numpy as np33from numpy.linalg import norm343536SZ = 20 # size of each digit is SZ x SZ37CLASS_N = 1038DIGITS_FN = 'samples/data/digits.png'3940def split2d(img, cell_size, flatten=True):41h, w = img.shape[:2]42sx, sy = cell_size43cells = [np.hsplit(row, w//sx) for row in np.vsplit(img, h//sy)]44cells = np.array(cells)45if flatten:46cells = cells.reshape(-1, sy, sx)47return cells4849def deskew(img):50m = cv.moments(img)51if abs(m['mu02']) < 1e-2:52return img.copy()53skew = m['mu11']/m['mu02']54M = np.float32([[1, skew, -0.5*SZ*skew], [0, 1, 0]])55img = cv.warpAffine(img, M, (SZ, SZ), flags=cv.WARP_INVERSE_MAP | cv.INTER_LINEAR)56return img5758class StatModel(object):59def load(self, fn):60self.model.load(fn) # Known bug: https://github.com/opencv/opencv/issues/496961def save(self, fn):62self.model.save(fn)6364class KNearest(StatModel):65def __init__(self, k = 3):66self.k = k67self.model = cv.ml.KNearest_create()6869def train(self, samples, responses):70self.model.train(samples, cv.ml.ROW_SAMPLE, responses)7172def predict(self, samples):73_retval, results, _neigh_resp, _dists = self.model.findNearest(samples, self.k)74return results.ravel()7576class SVM(StatModel):77def __init__(self, C = 1, gamma = 0.5):78self.model = cv.ml.SVM_create()79self.model.setGamma(gamma)80self.model.setC(C)81self.model.setKernel(cv.ml.SVM_RBF)82self.model.setType(cv.ml.SVM_C_SVC)8384def train(self, samples, responses):85self.model.train(samples, cv.ml.ROW_SAMPLE, responses)8687def predict(self, samples):88return self.model.predict(samples)[1].ravel()899091def evaluate_model(model, digits, samples, labels):92resp = model.predict(samples)93err = (labels != resp).mean()9495confusion = np.zeros((10, 10), np.int32)96for i, j in zip(labels, resp):97confusion[int(i), int(j)] += 19899return err, confusion100101def preprocess_simple(digits):102return np.float32(digits).reshape(-1, SZ*SZ) / 255.0103104def preprocess_hog(digits):105samples = []106for img in digits:107gx = cv.Sobel(img, cv.CV_32F, 1, 0)108gy = cv.Sobel(img, cv.CV_32F, 0, 1)109mag, ang = cv.cartToPolar(gx, gy)110bin_n = 16111bin = np.int32(bin_n*ang/(2*np.pi))112bin_cells = bin[:10,:10], bin[10:,:10], bin[:10,10:], bin[10:,10:]113mag_cells = mag[:10,:10], mag[10:,:10], mag[:10,10:], mag[10:,10:]114hists = [np.bincount(b.ravel(), m.ravel(), bin_n) for b, m in zip(bin_cells, mag_cells)]115hist = np.hstack(hists)116117# transform to Hellinger kernel118eps = 1e-7119hist /= hist.sum() + eps120hist = np.sqrt(hist)121hist /= norm(hist) + eps122123samples.append(hist)124return np.float32(samples)125126from tests_common import NewOpenCVTests127128class digits_test(NewOpenCVTests):129130def load_digits(self, fn):131digits_img = self.get_sample(fn, 0)132digits = split2d(digits_img, (SZ, SZ))133labels = np.repeat(np.arange(CLASS_N), len(digits)/CLASS_N)134return digits, labels135136def test_digits(self):137138digits, labels = self.load_digits(DIGITS_FN)139140# shuffle digits141rand = np.random.RandomState(321)142shuffle = rand.permutation(len(digits))143digits, labels = digits[shuffle], labels[shuffle]144145digits2 = list(map(deskew, digits))146samples = preprocess_hog(digits2)147148train_n = int(0.9*len(samples))149_digits_train, digits_test = np.split(digits2, [train_n])150samples_train, samples_test = np.split(samples, [train_n])151labels_train, labels_test = np.split(labels, [train_n])152errors = list()153confusionMatrixes = list()154155model = KNearest(k=4)156model.train(samples_train, labels_train)157error, confusion = evaluate_model(model, digits_test, samples_test, labels_test)158errors.append(error)159confusionMatrixes.append(confusion)160161model = SVM(C=2.67, gamma=5.383)162model.train(samples_train, labels_train)163error, confusion = evaluate_model(model, digits_test, samples_test, labels_test)164errors.append(error)165confusionMatrixes.append(confusion)166167eps = 0.001168normEps = len(samples_test) * 0.02169170confusionKNN = [[45, 0, 0, 0, 0, 0, 0, 0, 0, 0],171[ 0, 57, 0, 0, 0, 0, 0, 0, 0, 0],172[ 0, 0, 59, 1, 0, 0, 0, 0, 1, 0],173[ 0, 0, 0, 43, 0, 0, 0, 1, 0, 0],174[ 0, 0, 0, 0, 38, 0, 2, 0, 0, 0],175[ 0, 0, 0, 2, 0, 48, 0, 0, 1, 0],176[ 0, 1, 0, 0, 0, 0, 51, 0, 0, 0],177[ 0, 0, 1, 0, 0, 0, 0, 54, 0, 0],178[ 0, 0, 0, 0, 0, 1, 0, 0, 46, 0],179[ 1, 1, 0, 1, 1, 0, 0, 0, 2, 42]]180181confusionSVM = [[45, 0, 0, 0, 0, 0, 0, 0, 0, 0],182[ 0, 57, 0, 0, 0, 0, 0, 0, 0, 0],183[ 0, 0, 59, 2, 0, 0, 0, 0, 0, 0],184[ 0, 0, 0, 43, 0, 0, 0, 1, 0, 0],185[ 0, 0, 0, 0, 40, 0, 0, 0, 0, 0],186[ 0, 0, 0, 1, 0, 50, 0, 0, 0, 0],187[ 0, 0, 0, 0, 1, 0, 51, 0, 0, 0],188[ 0, 0, 1, 0, 0, 0, 0, 54, 0, 0],189[ 0, 0, 0, 0, 0, 0, 0, 0, 47, 0],190[ 0, 1, 0, 1, 0, 0, 0, 0, 1, 45]]191192self.assertLess(cv.norm(confusionMatrixes[0] - confusionKNN, cv.NORM_L1), normEps)193self.assertLess(cv.norm(confusionMatrixes[1] - confusionSVM, cv.NORM_L1), normEps)194195self.assertLess(errors[0] - 0.034, eps)196self.assertLess(errors[1] - 0.018, eps)197198199if __name__ == '__main__':200NewOpenCVTests.bootstrap()201202203