Path: blob/master/finrl/meta/preprocessor/tusharedownloader.py
732 views
"""Contains methods and classes to collect data from1tushare API2"""34from __future__ import annotations56import pandas as pd7import tushare as ts8from tqdm import tqdm91011class TushareDownloader:12"""Provides methods for retrieving daily stock data from13tushare API14Attributes15----------16start_date : str17start date of the data (modified from config.py)18end_date : str19end date of the data (modified from config.py)20ticker_list : list21a list of stock tickers (modified from config.py)22Methods23-------24fetch_data()25Fetches data from tushare API26date: date27Open: opening price28High: the highest price29Close: closing price30Low: lowest price31Volume: volume32Price_change: price change33P_change: fluctuation34ma5: 5-day average price35Ma10: 10 average daily price36Ma20:20 average daily price37V_ma5:5 daily average38V_ma10:10 daily average39V_ma20:20 daily average40"""4142def __init__(self, start_date: str, end_date: str, ticker_list: list):43self.start_date = start_date44self.end_date = end_date45self.ticker_list = ticker_list4647def fetch_data(self) -> pd.DataFrame:48"""Fetches data from Alpaca49Parameters50----------51Returns52-------53`pd.DataFrame`547 columns: A date, open, high, low, close, volume and tick symbol55for the specified stock ticker56"""57# Download and save the data in a pandas DataFrame:58data_df = pd.DataFrame()59for tic in tqdm(self.ticker_list, total=len(self.ticker_list)):60temp_df = ts.get_hist_data(61tic[0:6], start=self.start_date, end=self.end_date62)63temp_df["tic"] = tic[0:6]64# data_df = data_df.append(temp_df)65data_df = pd.concat([data_df, temp_df], axis=0, ignore_index=True)6667data_df = data_df.reset_index(level="date")6869# create day of the week column (monday = 0)70data_df = data_df.drop(71[72"price_change",73"p_change",74"ma5",75"ma10",76"ma20",77"v_ma5",78"v_ma10",79"v_ma20",80],811,82)83data_df["day"] = pd.to_datetime(data_df["date"]).dt.dayofweek84# rank desc85data_df = data_df.sort_index(axis=0, ascending=False)86# convert date to standard string format, easy to filter87data_df["date"] = pd.to_datetime(data_df["date"])88data_df["date"] = data_df.date.apply(lambda x: x.strftime("%Y-%m-%d"))89# drop missing data90data_df = data_df.dropna()91data_df = data_df.reset_index(drop=True)92print("Shape of DataFrame: ", data_df.shape)93# print("Display DataFrame: ", data_df.head())94print(data_df)95data_df = data_df.sort_values(by=["date", "tic"]).reset_index(drop=True)96return data_df9798def select_equal_rows_stock(self, df):99df_check = df.tic.value_counts()100df_check = pd.DataFrame(df_check).reset_index()101df_check.columns = ["tic", "counts"]102mean_df = df_check.counts.mean()103equal_list = list(df.tic.value_counts() >= mean_df)104names = df.tic.value_counts().index105select_stocks_list = list(names[equal_list])106df = df[df.tic.isin(select_stocks_list)]107return df108109110