Path: blob/master/Model-3/ocr/mlhelpers.py
426 views
# -*- coding: utf-8 -*-1"""2Classes for controling machine learning processes3"""4import numpy as np5import math6import matplotlib.pyplot as plt7import csv8910class TrainingPlot:11"""12Creating live plot during training13REUIRES notebook backend: %matplotlib notebook14@TODO Migrate to Tensorboard15"""16trainLoss = []17trainAcc = []18validAcc = []19testInterval = 020lossInterval = 021interval = 022ax1 = None23ax2 = None24fig = None2526def __init__(self, steps, testItr, lossItr):27self.testInterval = testItr28self.lossInterval = lossItr29self.interval = steps3031self.fig, self.ax1 = plt.subplots()32self.ax2 = self.ax1.twinx()33self.ax1.set_autoscaley_on(True)34plt.ion()3536self.updatePlot()3738# Description39self.ax1.set_xlabel('Iteration')40self.ax1.set_ylabel('Train Loss')41self.ax2.set_ylabel('Valid. Accuracy')4243# Axes limits44self.ax1.set_ylim([0,10])454647def updatePlot(self):48self.fig.canvas.draw()4950def updateCost(self, lossTrain, index):51self.trainLoss.append(lossTrain)52if len(self.trainLoss) == 1:53self.ax1.set_ylim([0, min(10, math.ceil(lossTrain))])54self.ax1.plot(self.lossInterval * np.arange(len(self.trainLoss)),55self.trainLoss, 'b', linewidth=1.0)5657self.updatePlot()5859def updateAcc(self, accVal, accTrain, index):60self.validAcc.append(accVal)61self.trainAcc.append(accTrain)6263self.ax2.plot(self.testInterval * np.arange(len(self.validAcc)),64self.validAcc, 'r', linewidth=1.0)65self.ax2.plot(self.testInterval * np.arange(len(self.trainAcc)),66self.trainAcc, 'g',linewidth=1.0)6768self.ax2.set_title('Valid. Accuracy: {:.4f}'.format(self.validAcc[-1]))6970self.updatePlot()717273class DataSet:74""" Class for training data and feeding train function """75images = None76labels = None77length = 078index = 07980def __init__(self, img, lbl):81""" Crate the dataset """82self.images = img83self.labels = lbl84self.length = len(img)85self.index = 08687def next_batch(self, batchSize):88"""Return the next batch from the data set."""89start = self.index90self.index += batchSize9192if self.index > self.length:93# Shuffle the data94perm = np.arange(self.length)95np.random.shuffle(perm)96self.images = self.images[perm]97self.labels = self.labels[perm]98# Start next epoch99start = 0100self.index = batchSize101102103end = self.index104return self.images[start:end], self.labels[start:end]105106107