Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/samples/python/digits_adjust.py
16337 views
1
#!/usr/bin/env python
2
3
'''
4
Digit recognition adjustment.
5
Grid search is used to find the best parameters for SVM and KNearest classifiers.
6
SVM adjustment follows the guidelines given in
7
http://www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf
8
9
Usage:
10
digits_adjust.py [--model {svm|knearest}]
11
12
--model {svm|knearest} - select the classifier (SVM is the default)
13
14
'''
15
16
# Python 2/3 compatibility
17
from __future__ import print_function
18
import sys
19
PY3 = sys.version_info[0] == 3
20
21
if PY3:
22
xrange = range
23
24
import numpy as np
25
import cv2 as cv
26
from multiprocessing.pool import ThreadPool
27
28
from digits import *
29
30
def cross_validate(model_class, params, samples, labels, kfold = 3, pool = None):
31
n = len(samples)
32
folds = np.array_split(np.arange(n), kfold)
33
def f(i):
34
model = model_class(**params)
35
test_idx = folds[i]
36
train_idx = list(folds)
37
train_idx.pop(i)
38
train_idx = np.hstack(train_idx)
39
train_samples, train_labels = samples[train_idx], labels[train_idx]
40
test_samples, test_labels = samples[test_idx], labels[test_idx]
41
model.train(train_samples, train_labels)
42
resp = model.predict(test_samples)
43
score = (resp != test_labels).mean()
44
print(".", end='')
45
return score
46
if pool is None:
47
scores = list(map(f, xrange(kfold)))
48
else:
49
scores = pool.map(f, xrange(kfold))
50
return np.mean(scores)
51
52
53
class App(object):
54
def __init__(self):
55
self._samples, self._labels = self.preprocess()
56
57
def preprocess(self):
58
digits, labels = load_digits(DIGITS_FN)
59
shuffle = np.random.permutation(len(digits))
60
digits, labels = digits[shuffle], labels[shuffle]
61
digits2 = list(map(deskew, digits))
62
samples = preprocess_hog(digits2)
63
return samples, labels
64
65
def get_dataset(self):
66
return self._samples, self._labels
67
68
def run_jobs(self, f, jobs):
69
pool = ThreadPool(processes=cv.getNumberOfCPUs())
70
ires = pool.imap_unordered(f, jobs)
71
return ires
72
73
def adjust_SVM(self):
74
Cs = np.logspace(0, 10, 15, base=2)
75
gammas = np.logspace(-7, 4, 15, base=2)
76
scores = np.zeros((len(Cs), len(gammas)))
77
scores[:] = np.nan
78
79
print('adjusting SVM (may take a long time) ...')
80
def f(job):
81
i, j = job
82
samples, labels = self.get_dataset()
83
params = dict(C = Cs[i], gamma=gammas[j])
84
score = cross_validate(SVM, params, samples, labels)
85
return i, j, score
86
87
ires = self.run_jobs(f, np.ndindex(*scores.shape))
88
for count, (i, j, score) in enumerate(ires):
89
scores[i, j] = score
90
print('%d / %d (best error: %.2f %%, last: %.2f %%)' %
91
(count+1, scores.size, np.nanmin(scores)*100, score*100))
92
print(scores)
93
94
print('writing score table to "svm_scores.npz"')
95
np.savez('svm_scores.npz', scores=scores, Cs=Cs, gammas=gammas)
96
97
i, j = np.unravel_index(scores.argmin(), scores.shape)
98
best_params = dict(C = Cs[i], gamma=gammas[j])
99
print('best params:', best_params)
100
print('best error: %.2f %%' % (scores.min()*100))
101
return best_params
102
103
def adjust_KNearest(self):
104
print('adjusting KNearest ...')
105
def f(k):
106
samples, labels = self.get_dataset()
107
err = cross_validate(KNearest, dict(k=k), samples, labels)
108
return k, err
109
best_err, best_k = np.inf, -1
110
for k, err in self.run_jobs(f, xrange(1, 9)):
111
if err < best_err:
112
best_err, best_k = err, k
113
print('k = %d, error: %.2f %%' % (k, err*100))
114
best_params = dict(k=best_k)
115
print('best params:', best_params, 'err: %.2f' % (best_err*100))
116
return best_params
117
118
119
if __name__ == '__main__':
120
import getopt
121
import sys
122
123
print(__doc__)
124
125
args, _ = getopt.getopt(sys.argv[1:], '', ['model='])
126
args = dict(args)
127
args.setdefault('--model', 'svm')
128
args.setdefault('--env', '')
129
if args['--model'] not in ['svm', 'knearest']:
130
print('unknown model "%s"' % args['--model'])
131
sys.exit(1)
132
133
t = clock()
134
app = App()
135
if args['--model'] == 'knearest':
136
app.adjust_KNearest()
137
else:
138
app.adjust_SVM()
139
print('work time: %f s' % (clock() - t))
140
141