Path: blob/master/unit_tests/preprocessors/test_yahoodownloader.py
728 views
from __future__ import annotations12import unittest34import pandas as pd5from pandas.testing import assert_frame_equal67from finrl.meta.preprocessor.yahoodownloader import YahooDownloader8910class TestYahooDownloaderAdjustPrices(unittest.TestCase):1112def setUp(self):13"""Set up a dummy YahooDownloader instance and test data."""14# These init params are not used by _adjust_prices but needed for instantiation15self.downloader = YahooDownloader(16start_date="2020-01-01",17end_date="2020-01-03",18ticker_list=["AAPL"],19)20# Create a sample DataFrame similar to what fetch_data might produce before adjustment21self.raw_data = pd.DataFrame(22{23"date": ["2020-01-01", "2020-01-02"],24"open": [100.0, 102.0],25"high": [105.0, 106.0],26"low": [98.0, 100.0],27"close": [104.0, 105.0],28"adjcp": [29100.0,30102.9,31], # Adjusted close price - crucial for the method32"volume": [10000, 11000],33"tic": ["AAPL", "AAPL"],34}35)3637def test_adjust_prices_calculates_correctly(self):38"""Test that prices are adjusted correctly based on adjcp/close ratio."""39# Explicitly ensure columns exist before passing to the method40self.assertIn("adjcp", self.raw_data.columns)41self.assertIn("close", self.raw_data.columns)4243adjusted_df = self.downloader._adjust_prices(self.raw_data.copy())4445# Calculate expected values46adj_ratio_1 = self.raw_data.loc[0, "adjcp"] / self.raw_data.loc[0, "close"]47adj_ratio_2 = self.raw_data.loc[1, "adjcp"] / self.raw_data.loc[1, "close"]4849expected_data = pd.DataFrame(50{51"date": ["2020-01-01", "2020-01-02"],52"open": [53self.raw_data.loc[0, "open"] * adj_ratio_1,54self.raw_data.loc[1, "open"] * adj_ratio_2,55],56"high": [57self.raw_data.loc[0, "high"] * adj_ratio_1,58self.raw_data.loc[1, "high"] * adj_ratio_2,59],60"low": [61self.raw_data.loc[0, "low"] * adj_ratio_1,62self.raw_data.loc[1, "low"] * adj_ratio_2,63],64"close": [65self.raw_data.loc[0, "adjcp"],66self.raw_data.loc[1, "adjcp"],67], # close becomes adjcp68"volume": [10000, 11000],69"tic": ["AAPL", "AAPL"],70}71)7273# Select only the columns present in the expected output for comparison74# and ensure the same column order and index75adjusted_df_compare = adjusted_df[expected_data.columns].reset_index(drop=True)76expected_data = expected_data.reset_index(drop=True)7778# Use pandas testing utility for robust DataFrame comparison79assert_frame_equal(adjusted_df_compare, expected_data, check_dtype=True)8081def test_adjust_prices_drops_columns(self):82"""Test that 'adjcp' and the temporary 'adj' columns are dropped."""83# Explicitly ensure columns exist before passing to the method84self.assertIn("adjcp", self.raw_data.columns)85self.assertIn("close", self.raw_data.columns)8687adjusted_df = self.downloader._adjust_prices(self.raw_data.copy())8889self.assertNotIn("adjcp", adjusted_df.columns)90self.assertNotIn("adj", adjusted_df.columns)91# Ensure other essential columns remain92self.assertIn("open", adjusted_df.columns)93self.assertIn(94"close", adjusted_df.columns95) # Note: This is the *new* adjusted close96self.assertIn("tic", adjusted_df.columns)97self.assertIn("date", adjusted_df.columns)98self.assertIn("volume", adjusted_df.columns)99100101if __name__ == "__main__":102unittest.main()103104105