Path: blob/master/finrl/meta/preprocessor/yahoodownloader.py
732 views
"""Contains methods and classes to collect data from1Yahoo Finance API2"""34from __future__ import annotations56import pandas as pd7import yfinance as yf8910class YahooDownloader:11"""Provides methods for retrieving daily stock data from12Yahoo Finance API1314Attributes15----------16start_date : str17start date of the data (modified from neofinrl_config.py)18end_date : str19end date of the data (modified from neofinrl_config.py)20ticker_list : list21a list of stock tickers (modified from neofinrl_config.py)2223Methods24-------25fetch_data()26Fetches data from yahoo API2728"""2930def __init__(self, start_date: str, end_date: str, ticker_list: list):31self.start_date = start_date32self.end_date = end_date33self.ticker_list = ticker_list3435def fetch_data(self, proxy=None, auto_adjust=False) -> pd.DataFrame:36"""Fetches data from Yahoo API37Parameters38----------3940Returns41-------42`pd.DataFrame`437 columns: A date, open, high, low, close, volume and tick symbol44for the specified stock ticker45"""46# Download and save the data in a pandas DataFrame:47data_df = pd.DataFrame()48num_failures = 049for tic in self.ticker_list:50temp_df = yf.download(51tic,52start=self.start_date,53end=self.end_date,54proxy=proxy,55auto_adjust=auto_adjust,56)57if temp_df.columns.nlevels != 1:58temp_df.columns = temp_df.columns.droplevel(1)59temp_df["tic"] = tic60if len(temp_df) > 0:61# data_df = data_df.append(temp_df)62data_df = pd.concat([data_df, temp_df], axis=0)63else:64num_failures += 165if num_failures == len(self.ticker_list):66raise ValueError("no data is fetched.")67# reset the index, we want to use numbers as index instead of dates68data_df = data_df.reset_index()69try:70# convert the column names to standardized names71data_df.rename(72columns={73"Date": "date",74"Adj Close": "adjcp",75"Close": "close",76"High": "high",77"Low": "low",78"Volume": "volume",79"Open": "open",80"tic": "tic",81},82inplace=True,83)8485if not auto_adjust:86data_df = self._adjust_prices(data_df)87except NotImplementedError:88print("the features are not supported currently")89# create day of the week column (monday = 0)90data_df["day"] = data_df["date"].dt.dayofweek91# convert date to standard string format, easy to filter92data_df["date"] = data_df.date.apply(lambda x: x.strftime("%Y-%m-%d"))93# drop missing data94data_df = data_df.dropna()95data_df = data_df.reset_index(drop=True)96print("Shape of DataFrame: ", data_df.shape)97# print("Display DataFrame: ", data_df.head())9899data_df = data_df.sort_values(by=["date", "tic"]).reset_index(drop=True)100101return data_df102103def _adjust_prices(self, data_df: pd.DataFrame) -> pd.DataFrame:104# use adjusted close price instead of close price105data_df["adj"] = data_df["adjcp"] / data_df["close"]106for col in ["open", "high", "low", "close"]:107data_df[col] *= data_df["adj"]108109# drop the adjusted close price column110return data_df.drop(["adjcp", "adj"], axis=1)111112def select_equal_rows_stock(self, df):113df_check = df.tic.value_counts()114df_check = pd.DataFrame(df_check).reset_index()115df_check.columns = ["tic", "counts"]116mean_df = df_check.counts.mean()117equal_list = list(df.tic.value_counts() >= mean_df)118names = df.tic.value_counts().index119select_stocks_list = list(names[equal_list])120df = df[df.tic.isin(select_stocks_list)]121return df122123124