Path: blob/master/utils.py
2909 views
#!/usr/bin/env python1# -*- coding: utf-8 -*-2__author__ = 'Stefan Jansen'34import numpy as np56np.random.seed(42)789def format_time(t):10"""Return a formatted time string 'HH:MM:SS11based on a numeric time() value"""12m, s = divmod(t, 60)13h, m = divmod(m, 60)14return f'{h:0>2.0f}:{m:0>2.0f}:{s:0>2.0f}'151617class MultipleTimeSeriesCV:18"""Generates tuples of train_idx, test_idx pairs19Assumes the MultiIndex contains levels 'symbol' and 'date'20purges overlapping outcomes"""2122def __init__(self,23n_splits=3,24train_period_length=126,25test_period_length=21,26lookahead=None,27date_idx='date',28shuffle=False):29self.n_splits = n_splits30self.lookahead = lookahead31self.test_length = test_period_length32self.train_length = train_period_length33self.shuffle = shuffle34self.date_idx = date_idx3536def split(self, X, y=None, groups=None):37unique_dates = X.index.get_level_values(self.date_idx).unique()38days = sorted(unique_dates, reverse=True)39split_idx = []40for i in range(self.n_splits):41test_end_idx = i * self.test_length42test_start_idx = test_end_idx + self.test_length43train_end_idx = test_start_idx + self.lookahead - 144train_start_idx = train_end_idx + self.train_length + self.lookahead - 145split_idx.append([train_start_idx, train_end_idx,46test_start_idx, test_end_idx])4748dates = X.reset_index()[[self.date_idx]]49for train_start, train_end, test_start, test_end in split_idx:5051train_idx = dates[(dates[self.date_idx] > days[train_start])52& (dates[self.date_idx] <= days[train_end])].index53test_idx = dates[(dates[self.date_idx] > days[test_start])54& (dates[self.date_idx] <= days[test_end])].index55if self.shuffle:56np.random.shuffle(list(train_idx))57yield train_idx.to_numpy(), test_idx.to_numpy()5859def get_n_splits(self, X, y, groups=None):60return self.n_splits616263