Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/unit_tests/preprocessors/test_yahoodownloader.py
728 views
1
from __future__ import annotations
2
3
import unittest
4
5
import pandas as pd
6
from pandas.testing import assert_frame_equal
7
8
from finrl.meta.preprocessor.yahoodownloader import YahooDownloader
9
10
11
class TestYahooDownloaderAdjustPrices(unittest.TestCase):
12
13
def setUp(self):
14
"""Set up a dummy YahooDownloader instance and test data."""
15
# These init params are not used by _adjust_prices but needed for instantiation
16
self.downloader = YahooDownloader(
17
start_date="2020-01-01",
18
end_date="2020-01-03",
19
ticker_list=["AAPL"],
20
)
21
# Create a sample DataFrame similar to what fetch_data might produce before adjustment
22
self.raw_data = pd.DataFrame(
23
{
24
"date": ["2020-01-01", "2020-01-02"],
25
"open": [100.0, 102.0],
26
"high": [105.0, 106.0],
27
"low": [98.0, 100.0],
28
"close": [104.0, 105.0],
29
"adjcp": [
30
100.0,
31
102.9,
32
], # Adjusted close price - crucial for the method
33
"volume": [10000, 11000],
34
"tic": ["AAPL", "AAPL"],
35
}
36
)
37
38
def test_adjust_prices_calculates_correctly(self):
39
"""Test that prices are adjusted correctly based on adjcp/close ratio."""
40
# Explicitly ensure columns exist before passing to the method
41
self.assertIn("adjcp", self.raw_data.columns)
42
self.assertIn("close", self.raw_data.columns)
43
44
adjusted_df = self.downloader._adjust_prices(self.raw_data.copy())
45
46
# Calculate expected values
47
adj_ratio_1 = self.raw_data.loc[0, "adjcp"] / self.raw_data.loc[0, "close"]
48
adj_ratio_2 = self.raw_data.loc[1, "adjcp"] / self.raw_data.loc[1, "close"]
49
50
expected_data = pd.DataFrame(
51
{
52
"date": ["2020-01-01", "2020-01-02"],
53
"open": [
54
self.raw_data.loc[0, "open"] * adj_ratio_1,
55
self.raw_data.loc[1, "open"] * adj_ratio_2,
56
],
57
"high": [
58
self.raw_data.loc[0, "high"] * adj_ratio_1,
59
self.raw_data.loc[1, "high"] * adj_ratio_2,
60
],
61
"low": [
62
self.raw_data.loc[0, "low"] * adj_ratio_1,
63
self.raw_data.loc[1, "low"] * adj_ratio_2,
64
],
65
"close": [
66
self.raw_data.loc[0, "adjcp"],
67
self.raw_data.loc[1, "adjcp"],
68
], # close becomes adjcp
69
"volume": [10000, 11000],
70
"tic": ["AAPL", "AAPL"],
71
}
72
)
73
74
# Select only the columns present in the expected output for comparison
75
# and ensure the same column order and index
76
adjusted_df_compare = adjusted_df[expected_data.columns].reset_index(drop=True)
77
expected_data = expected_data.reset_index(drop=True)
78
79
# Use pandas testing utility for robust DataFrame comparison
80
assert_frame_equal(adjusted_df_compare, expected_data, check_dtype=True)
81
82
def test_adjust_prices_drops_columns(self):
83
"""Test that 'adjcp' and the temporary 'adj' columns are dropped."""
84
# Explicitly ensure columns exist before passing to the method
85
self.assertIn("adjcp", self.raw_data.columns)
86
self.assertIn("close", self.raw_data.columns)
87
88
adjusted_df = self.downloader._adjust_prices(self.raw_data.copy())
89
90
self.assertNotIn("adjcp", adjusted_df.columns)
91
self.assertNotIn("adj", adjusted_df.columns)
92
# Ensure other essential columns remain
93
self.assertIn("open", adjusted_df.columns)
94
self.assertIn(
95
"close", adjusted_df.columns
96
) # Note: This is the *new* adjusted close
97
self.assertIn("tic", adjusted_df.columns)
98
self.assertIn("date", adjusted_df.columns)
99
self.assertIn("volume", adjusted_df.columns)
100
101
102
if __name__ == "__main__":
103
unittest.main()
104
105