Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Aniket025
GitHub Repository: Aniket025/Medical-Prescription-OCR
Path: blob/master/Model-3/ocr/mlhelpers.py
426 views
1
# -*- coding: utf-8 -*-
2
"""
3
Classes for controling machine learning processes
4
"""
5
import numpy as np
6
import math
7
import matplotlib.pyplot as plt
8
import csv
9
10
11
class TrainingPlot:
12
"""
13
Creating live plot during training
14
REUIRES notebook backend: %matplotlib notebook
15
@TODO Migrate to Tensorboard
16
"""
17
trainLoss = []
18
trainAcc = []
19
validAcc = []
20
testInterval = 0
21
lossInterval = 0
22
interval = 0
23
ax1 = None
24
ax2 = None
25
fig = None
26
27
def __init__(self, steps, testItr, lossItr):
28
self.testInterval = testItr
29
self.lossInterval = lossItr
30
self.interval = steps
31
32
self.fig, self.ax1 = plt.subplots()
33
self.ax2 = self.ax1.twinx()
34
self.ax1.set_autoscaley_on(True)
35
plt.ion()
36
37
self.updatePlot()
38
39
# Description
40
self.ax1.set_xlabel('Iteration')
41
self.ax1.set_ylabel('Train Loss')
42
self.ax2.set_ylabel('Valid. Accuracy')
43
44
# Axes limits
45
self.ax1.set_ylim([0,10])
46
47
48
def updatePlot(self):
49
self.fig.canvas.draw()
50
51
def updateCost(self, lossTrain, index):
52
self.trainLoss.append(lossTrain)
53
if len(self.trainLoss) == 1:
54
self.ax1.set_ylim([0, min(10, math.ceil(lossTrain))])
55
self.ax1.plot(self.lossInterval * np.arange(len(self.trainLoss)),
56
self.trainLoss, 'b', linewidth=1.0)
57
58
self.updatePlot()
59
60
def updateAcc(self, accVal, accTrain, index):
61
self.validAcc.append(accVal)
62
self.trainAcc.append(accTrain)
63
64
self.ax2.plot(self.testInterval * np.arange(len(self.validAcc)),
65
self.validAcc, 'r', linewidth=1.0)
66
self.ax2.plot(self.testInterval * np.arange(len(self.trainAcc)),
67
self.trainAcc, 'g',linewidth=1.0)
68
69
self.ax2.set_title('Valid. Accuracy: {:.4f}'.format(self.validAcc[-1]))
70
71
self.updatePlot()
72
73
74
class DataSet:
75
""" Class for training data and feeding train function """
76
images = None
77
labels = None
78
length = 0
79
index = 0
80
81
def __init__(self, img, lbl):
82
""" Crate the dataset """
83
self.images = img
84
self.labels = lbl
85
self.length = len(img)
86
self.index = 0
87
88
def next_batch(self, batchSize):
89
"""Return the next batch from the data set."""
90
start = self.index
91
self.index += batchSize
92
93
if self.index > self.length:
94
# Shuffle the data
95
perm = np.arange(self.length)
96
np.random.shuffle(perm)
97
self.images = self.images[perm]
98
self.labels = self.labels[perm]
99
# Start next epoch
100
start = 0
101
self.index = batchSize
102
103
104
end = self.index
105
return self.images[start:end], self.labels[start:end]
106
107