Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/unit_tests/test_core.py
726 views
1
from __future__ import annotations
2
3
import os
4
from typing import List
5
6
import pandas as pd
7
import pytest
8
9
from finrl import config
10
from finrl import config_tickers
11
from finrl.config import DATA_SAVE_DIR
12
from finrl.config import RESULTS_DIR
13
from finrl.config import TENSORBOARD_LOG_DIR
14
from finrl.config import TRAINED_MODEL_DIR
15
from finrl.main import check_and_make_directories
16
from finrl.meta.preprocessor.preprocessors import FeatureEngineer
17
from finrl.meta.preprocessor.yahoodownloader import YahooDownloader
18
19
20
@pytest.fixture(scope="session")
21
def DIRS():
22
return [DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR]
23
24
25
@pytest.fixture(scope="session")
26
def ticker_list():
27
return config_tickers.DOW_30_TICKER
28
29
30
@pytest.fixture(scope="session")
31
def ticker_list_small():
32
return ["AAPL", "GOOG"]
33
34
35
@pytest.fixture(scope="session")
36
def indicators():
37
return config.INDICATORS
38
39
40
@pytest.fixture(scope="session")
41
def old_start_date():
42
return "2009-01-01"
43
44
45
@pytest.fixture(scope="session")
46
def start_date():
47
return "2021-01-01"
48
49
50
@pytest.fixture(scope="session")
51
def end_date():
52
return "2021-10-31"
53
54
55
def test_check_and_make_directories(DIRS: list[str]) -> None:
56
"""
57
Tests the creation of directories
58
parameters:
59
----------
60
DIRS : a List of str, which indicate the name of the folders to create
61
"""
62
assert isinstance(DIRS, list)
63
check_and_make_directories(DIRS)
64
for dir in DIRS:
65
assert os.path.exists(dir)
66
67
68
def test_download_large(ticker_list: list[str], start_date: str, end_date: str) -> None:
69
"""
70
Tests the Yahoo Downloader and the returned data shape
71
"""
72
assert isinstance(ticker_list, list)
73
assert len(ticker_list) > 0
74
assert isinstance(ticker_list[0], str)
75
assert isinstance(start_date, str)
76
assert isinstance(end_date, str)
77
df = YahooDownloader(
78
start_date=start_date, end_date=end_date, ticker_list=ticker_list
79
).fetch_data()
80
assert isinstance(df, pd.DataFrame)
81
assert df.shape == (6300, 8) or df.shape == (6270, 8)
82
83
84
def test_feature_engineer_no_turbulence(
85
ticker_list: list[str],
86
indicators: list[str],
87
start_date: str,
88
end_date: str,
89
) -> None:
90
"""
91
Tests the feature_engineer function - WIP
92
"""
93
assert isinstance(ticker_list, list)
94
assert len(ticker_list) > 0
95
assert isinstance(ticker_list[0], str)
96
assert isinstance(start_date, str)
97
assert isinstance(end_date, str)
98
assert isinstance(indicators, list)
99
assert isinstance(indicators[0], str)
100
101
df = YahooDownloader(
102
start_date=start_date, end_date=end_date, ticker_list=ticker_list
103
).fetch_data()
104
fe = FeatureEngineer(
105
use_technical_indicator=True,
106
tech_indicator_list=indicators,
107
use_vix=True,
108
use_turbulence=False,
109
user_defined_feature=False,
110
)
111
assert isinstance(fe.preprocess_data(df), pd.DataFrame)
112
113
114
def test_feature_engineer_turbulence_less_than_a_year(
115
ticker_list: list[str],
116
indicators: list[str],
117
start_date: str,
118
end_date: str,
119
) -> None:
120
"""
121
Tests the feature_engineer function - with turbulence, start and end date
122
are less than 1 year apart.
123
the code should raise an error
124
"""
125
assert isinstance(ticker_list, list)
126
assert len(ticker_list) > 0
127
assert isinstance(ticker_list[0], str)
128
assert isinstance(start_date, str)
129
assert isinstance(end_date, str)
130
assert isinstance(indicators, list)
131
assert isinstance(indicators[0], str)
132
133
df = YahooDownloader(
134
start_date=start_date, end_date=end_date, ticker_list=ticker_list
135
).fetch_data()
136
137
fe = FeatureEngineer(
138
use_technical_indicator=True,
139
tech_indicator_list=indicators,
140
use_vix=True,
141
use_turbulence=True,
142
user_defined_feature=False,
143
)
144
with pytest.raises(Exception):
145
fe.preprocess_data(df)
146
147
148
def test_feature_engineer_turbulence_more_than_a_year(
149
ticker_list: list[str],
150
indicators: list[str],
151
old_start_date: str,
152
end_date: str,
153
) -> None:
154
"""
155
Tests the feature_engineer function - with turbulence, start and end date
156
are less than 1 year apart.
157
the code should raise an error
158
"""
159
assert isinstance(ticker_list, list)
160
assert len(ticker_list) > 0
161
assert isinstance(ticker_list[0], str)
162
assert isinstance(end_date, str)
163
assert isinstance(indicators, list)
164
assert isinstance(indicators[0], str)
165
166
df = YahooDownloader(
167
start_date=old_start_date, end_date=end_date, ticker_list=ticker_list
168
).fetch_data()
169
fe = FeatureEngineer(
170
use_technical_indicator=True,
171
tech_indicator_list=indicators,
172
use_vix=True,
173
use_turbulence=True,
174
user_defined_feature=False,
175
)
176
assert isinstance(fe.preprocess_data(df), pd.DataFrame)
177
178