Path: blob/master/finrl/meta/preprocessor/shioajidownloader.py
732 views
"""Contains methods and classes to collect data from1shioaji Finance API2"""34from __future__ import annotations56import pandas as pd7import shioaji as sj8910class SinopacDownloader:1112def __init__(13self,14start_date: str,15end_date: str,16ticker_list: list = [],17api: sj.Shioaji = None,18):19if api is None:20self.api = sj.Shioaji()21self.api.login(22api_key="3Tn2BbtCzbaU1KSy8yyqLa4m7LEJJyhkRCDrK2nknbcu",23secret_key="Epakqh1Nt4inC3hsqowE2XjwQicPNzswkuLjtzj2WKpR",24contracts_cb=lambda security_type: print(25f"{repr(security_type)} fetch done."26),27)28else:29self.api = api30self.start_date = start_date31self.end_date = end_date32self.ticker_list = ticker_list3334def fetch_data(self) -> pd.DataFrame:35"""Fetches data from Shioaji API3637Returns38-------39pd.DataFrame40DataFrame with columns: timestamp, open, high, low, close, volume, amount, ticker41"""42data_df = pd.DataFrame()43num_failures = 044for tic in self.ticker_list:45try:46kbars = self.api.kbars(47self.api.Contracts.Stocks[tic],48start=self.start_date,49end=self.end_date,50)51temp_df = pd.DataFrame({**kbars})52temp_df.ts = pd.to_datetime(temp_df.ts)53temp_df["tic"] = tic54data_df = pd.concat([data_df, temp_df], axis=0)55except Exception as e:56num_failures += 157print(f"Failed to fetch data for ticker {tic}: {e}")5859if num_failures == len(self.ticker_list):60raise ValueError("No data is fetched.")6162data_df = data_df.reset_index(drop=True)63print("Original columns:", data_df.columns)64try:65data_df.columns = [66"timestamp",67"open",68"high",69"low",70"close",71"volume",72"amount",73"tic",74]75except ValueError as e:76print(f"Error renaming columns: {e}")7778data_df["day"] = data_df["timestamp"].dt.dayofweek79data_df["date"] = data_df.timestamp.apply(lambda x: x.strftime("%Y-%m-%d"))80data_df = data_df.dropna().reset_index(drop=True)81data_df = data_df.sort_values(by=["timestamp", "tic"]).reset_index(drop=True)8283print("Shape of DataFrame: ", data_df.shape)84print("Display DataFrame: ", data_df.head())8586return data_df8788def select_equal_rows_stock(self, df: pd.DataFrame) -> pd.DataFrame:89df_check = df.ticker.value_counts().reset_index()90df_check.columns = ["tic", "counts"]91mean_df = df_check.counts.mean()92select_stocks_list = df_check[df_check.counts >= mean_df]["tic"].tolist()93df = df[df.ticker.isin(select_stocks_list)]94return df959697if __name__ == "__main__":98start_date = "2023-04-13"99end_date = "2024-04-13"100ticker_list = ["2330", "2317", "2454", "2303", "2412"]101102# 测试 api 为 None 的情况103downloader = SinopacDownloader(104start_date=start_date, end_date=end_date, ticker_list=ticker_list, api=None105)106df = downloader.fetch_data()107print(df)108print(df.ticker.value_counts())109df = downloader.select_equal_rows_stock(df)110print(df.ticker.value_counts())111print(df)112print(df.shape)113114115