Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/finrl/meta/preprocessor/shioajidownloader.py
732 views
1
"""Contains methods and classes to collect data from
2
shioaji Finance API
3
"""
4
5
from __future__ import annotations
6
7
import pandas as pd
8
import shioaji as sj
9
10
11
class SinopacDownloader:
12
13
def __init__(
14
self,
15
start_date: str,
16
end_date: str,
17
ticker_list: list = [],
18
api: sj.Shioaji = None,
19
):
20
if api is None:
21
self.api = sj.Shioaji()
22
self.api.login(
23
api_key="3Tn2BbtCzbaU1KSy8yyqLa4m7LEJJyhkRCDrK2nknbcu",
24
secret_key="Epakqh1Nt4inC3hsqowE2XjwQicPNzswkuLjtzj2WKpR",
25
contracts_cb=lambda security_type: print(
26
f"{repr(security_type)} fetch done."
27
),
28
)
29
else:
30
self.api = api
31
self.start_date = start_date
32
self.end_date = end_date
33
self.ticker_list = ticker_list
34
35
def fetch_data(self) -> pd.DataFrame:
36
"""Fetches data from Shioaji API
37
38
Returns
39
-------
40
pd.DataFrame
41
DataFrame with columns: timestamp, open, high, low, close, volume, amount, ticker
42
"""
43
data_df = pd.DataFrame()
44
num_failures = 0
45
for tic in self.ticker_list:
46
try:
47
kbars = self.api.kbars(
48
self.api.Contracts.Stocks[tic],
49
start=self.start_date,
50
end=self.end_date,
51
)
52
temp_df = pd.DataFrame({**kbars})
53
temp_df.ts = pd.to_datetime(temp_df.ts)
54
temp_df["tic"] = tic
55
data_df = pd.concat([data_df, temp_df], axis=0)
56
except Exception as e:
57
num_failures += 1
58
print(f"Failed to fetch data for ticker {tic}: {e}")
59
60
if num_failures == len(self.ticker_list):
61
raise ValueError("No data is fetched.")
62
63
data_df = data_df.reset_index(drop=True)
64
print("Original columns:", data_df.columns)
65
try:
66
data_df.columns = [
67
"timestamp",
68
"open",
69
"high",
70
"low",
71
"close",
72
"volume",
73
"amount",
74
"tic",
75
]
76
except ValueError as e:
77
print(f"Error renaming columns: {e}")
78
79
data_df["day"] = data_df["timestamp"].dt.dayofweek
80
data_df["date"] = data_df.timestamp.apply(lambda x: x.strftime("%Y-%m-%d"))
81
data_df = data_df.dropna().reset_index(drop=True)
82
data_df = data_df.sort_values(by=["timestamp", "tic"]).reset_index(drop=True)
83
84
print("Shape of DataFrame: ", data_df.shape)
85
print("Display DataFrame: ", data_df.head())
86
87
return data_df
88
89
def select_equal_rows_stock(self, df: pd.DataFrame) -> pd.DataFrame:
90
df_check = df.ticker.value_counts().reset_index()
91
df_check.columns = ["tic", "counts"]
92
mean_df = df_check.counts.mean()
93
select_stocks_list = df_check[df_check.counts >= mean_df]["tic"].tolist()
94
df = df[df.ticker.isin(select_stocks_list)]
95
return df
96
97
98
if __name__ == "__main__":
99
start_date = "2023-04-13"
100
end_date = "2024-04-13"
101
ticker_list = ["2330", "2317", "2454", "2303", "2412"]
102
103
# 测试 api 为 None 的情况
104
downloader = SinopacDownloader(
105
start_date=start_date, end_date=end_date, ticker_list=ticker_list, api=None
106
)
107
df = downloader.fetch_data()
108
print(df)
109
print(df.ticker.value_counts())
110
df = downloader.select_equal_rows_stock(df)
111
print(df.ticker.value_counts())
112
print(df)
113
print(df.shape)
114
115