Path: blob/master/ invest-robot-contest_NeuroInvest-main/NeuroInvest/training.py
5931 views
from config import predictPeriod, scanInterval, predictionIntervals1from kalman import Kalman23from pandas import DataFrame4from sklearn import neural_network5from sklearn.preprocessing import MinMaxScaler6from joblib import dump, load7from os import path89class Training:10__model = None11__scalerInput = MinMaxScaler()12__scalerOutput = MinMaxScaler()1314def __init__(self, filename):15self.__loadModel(filename)1617def train(self, history):18times, tradeY = zip(*history.items())19smoothData = Kalman(history).getSmoothData()2021inputData = []22outputData = []23countTimes = len(times)24countPredicts = len(predictionIntervals)25countDatum = predictPeriod / scanInterval2627for timeIndex, timestamp in enumerate(times):28datum = [smoothData[timeIndex]]29dataIndex = timeIndex + 130while dataIndex < countTimes and len(datum) <= countDatum:31datum.append(smoothData[dataIndex])32dataIndex += 13334predictCounter = 03536if dataIndex < countTimes:37startPredictTimestamp = times[dataIndex]38predict = []3940while dataIndex < countTimes and predictCounter < countPredicts:41if times[dataIndex] >= startPredictTimestamp + predictionIntervals[predictCounter]:42predict.append(smoothData[dataIndex])43predictCounter += 14445dataIndex += 14647if predictCounter != countPredicts:48if predictCounter > 0:49for index in range(predictCounter, countPredicts):50predict.append(predict[-1])5152else:53for index in range(0, countPredicts):54predict.append(datum[-1])5556if len(datum) == 0:57continue5859if len(datum) < countDatum:60for index in range(len(datum), countDatum):61datum.append(datum[-1])6263inputData.append(datum)64outputData.append(predict)6566if len(inputData) == 0 or len(inputData) != len(outputData):67print("Have not data for training")6869else:70self.__scalerInput.fit(inputData)71self.__scalerOutput.fit(outputData)7273inputData = DataFrame(self.__scalerInput.transform(inputData), dtype = float)74outputData = DataFrame(self.__scalerOutput.transform(outputData), dtype = float)7576try:77self.__model = neural_network.MLPRegressor(solver = "lbfgs", activation = 'logistic', max_iter = 1000000)7879self.__model.fit(inputData, outputData)8081print("Learn score = ", self.__model.score(inputData, outputData))8283#print(outputData[:3], self.__model.predict(inputData[:3]))84except Exception as message:85self.__model = None86print("ERROR WHEN CREATE TRADE MODEL: " + str(message))8788def tradePredict(self, dayValues, nowTime):89countInputValues = len(dayValues)9091try:92countExpectedValues = self.__scalerInput.n_features_in_93if self.__model is not None and countInputValues > 0:94if countInputValues > countExpectedValues:95dayValues = dayValues[countInputValues - countExpectedValues:]96elif countInputValues < countExpectedValues:97for i in range(0, countExpectedValues - countInputValues):98dayValues.insert(0, dayValues[0])99100inputData = DataFrame(self.__scalerInput.transform([dayValues]), dtype = float)101102outputData = self.__model.predict(inputData)103outputData = self.__scalerOutput.inverse_transform(outputData)104105result = {}106107if len(outputData) > 0:108for index, time in enumerate(predictionIntervals):109result[nowTime + time] = outputData[0][index]110111return result112except:113pass114115return None116117118def saveModel(self, filename):119if self.__model is not None:120with open(filename + ".model", 'wb') as saveFile:121dump(self.__model, saveFile)122123with open(filename + ".input", 'wb') as saveFile:124dump(self.__scalerInput, saveFile)125126with open(filename + ".output", 'wb') as saveFile:127dump(self.__scalerOutput, saveFile)128129130def __loadModel(self, filename):131if path.isfile(filename + ".model"):132with open(filename + ".model", 'rb') as loadFile:133self.__model = load(loadFile)134135if path.isfile(filename + ".input"):136with open(filename + ".input", 'rb') as loadFile:137self.__scalerInput = load(loadFile)138139if path.isfile(filename + ".output"):140with open(filename + ".output", 'rb') as loadFile:141self.__scalerOutput = load(loadFile)142143144145146