Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/examples/FinRL_PaperTrading_Demo_refactored.py
726 views
1
# Disclaimer: Nothing herein is financial advice, and NOT a recommendation to trade real money. Many platforms exist for simulated trading (paper trading) which can be used for building and developing the methods discussed. Please use common sense and always first consult a professional before trading or investing.
2
# install finrl library
3
# %pip install --upgrade git+https://github.com/AI4Finance-Foundation/FinRL.git
4
# Alpaca keys
5
from __future__ import annotations
6
7
import argparse
8
9
parser = argparse.ArgumentParser()
10
parser.add_argument("data_key", help="data source api key")
11
parser.add_argument("data_secret", help="data source api secret")
12
parser.add_argument("data_url", help="data source api base url")
13
parser.add_argument("trading_key", help="trading api key")
14
parser.add_argument("trading_secret", help="trading api secret")
15
parser.add_argument("trading_url", help="trading api base url")
16
args = parser.parse_args()
17
DATA_API_KEY = args.data_key
18
DATA_API_SECRET = args.data_secret
19
DATA_API_BASE_URL = args.data_url
20
TRADING_API_KEY = args.trading_key
21
TRADING_API_SECRET = args.trading_secret
22
TRADING_API_BASE_URL = args.trading_url
23
24
print("DATA_API_KEY: ", DATA_API_KEY)
25
print("DATA_API_SECRET: ", DATA_API_SECRET)
26
print("DATA_API_BASE_URL: ", DATA_API_BASE_URL)
27
print("TRADING_API_KEY: ", TRADING_API_KEY)
28
print("TRADING_API_SECRET: ", TRADING_API_SECRET)
29
print("TRADING_API_BASE_URL: ", TRADING_API_BASE_URL)
30
31
from finrl.meta.env_stock_trading.env_stocktrading_np import StockTradingEnv
32
from finrl.meta.paper_trading.alpaca import PaperTradingAlpaca
33
from finrl.meta.paper_trading.common import train, test, alpaca_history, DIA_history
34
from finrl.config import INDICATORS
35
36
# Import Dow Jones 30 Symbols
37
from finrl.config_tickers import DOW_30_TICKER
38
39
ticker_list = DOW_30_TICKER
40
env = StockTradingEnv
41
# if you want to use larger datasets (change to longer period), and it raises error, please try to increase "target_step". It should be larger than the episode steps.
42
ERL_PARAMS = {
43
"learning_rate": 3e-6,
44
"batch_size": 2048,
45
"gamma": 0.985,
46
"seed": 312,
47
"net_dimension": [128, 64],
48
"target_step": 5000,
49
"eval_gap": 30,
50
"eval_times": 1,
51
}
52
53
# Set up sliding window of 6 days training and 2 days testing
54
import datetime
55
from pandas.tseries.offsets import BDay # BDay is business day, not birthday...
56
57
today = datetime.datetime.today()
58
59
TEST_END_DATE = (today - BDay(1)).to_pydatetime().date()
60
TEST_START_DATE = (TEST_END_DATE - BDay(1)).to_pydatetime().date()
61
TRAIN_END_DATE = (TEST_START_DATE - BDay(1)).to_pydatetime().date()
62
TRAIN_START_DATE = (TRAIN_END_DATE - BDay(5)).to_pydatetime().date()
63
TRAINFULL_START_DATE = TRAIN_START_DATE
64
TRAINFULL_END_DATE = TEST_END_DATE
65
66
TRAIN_START_DATE = str(TRAIN_START_DATE)
67
TRAIN_END_DATE = str(TRAIN_END_DATE)
68
TEST_START_DATE = str(TEST_START_DATE)
69
TEST_END_DATE = str(TEST_END_DATE)
70
TRAINFULL_START_DATE = str(TRAINFULL_START_DATE)
71
TRAINFULL_END_DATE = str(TRAINFULL_END_DATE)
72
73
print("TRAIN_START_DATE: ", TRAIN_START_DATE)
74
print("TRAIN_END_DATE: ", TRAIN_END_DATE)
75
print("TEST_START_DATE: ", TEST_START_DATE)
76
print("TEST_END_DATE: ", TEST_END_DATE)
77
print("TRAINFULL_START_DATE: ", TRAINFULL_START_DATE)
78
print("TRAINFULL_END_DATE: ", TRAINFULL_END_DATE)
79
80
train(
81
start_date=TRAIN_START_DATE,
82
end_date=TRAIN_END_DATE,
83
ticker_list=ticker_list,
84
data_source="alpaca",
85
time_interval="1Min",
86
technical_indicator_list=INDICATORS,
87
drl_lib="elegantrl",
88
env=env,
89
model_name="ppo",
90
if_vix=True,
91
API_KEY=DATA_API_KEY,
92
API_SECRET=DATA_API_SECRET,
93
API_BASE_URL=DATA_API_BASE_URL,
94
erl_params=ERL_PARAMS,
95
cwd="./papertrading_erl", # current_working_dir
96
break_step=1e5,
97
)
98
99
account_value_erl = test(
100
start_date=TEST_START_DATE,
101
end_date=TEST_END_DATE,
102
ticker_list=ticker_list,
103
data_source="alpaca",
104
time_interval="1Min",
105
technical_indicator_list=INDICATORS,
106
drl_lib="elegantrl",
107
env=env,
108
model_name="ppo",
109
if_vix=True,
110
API_KEY=DATA_API_KEY,
111
API_SECRET=DATA_API_SECRET,
112
API_BASE_URL=DATA_API_BASE_URL,
113
cwd="./papertrading_erl",
114
net_dimension=ERL_PARAMS["net_dimension"],
115
)
116
117
train(
118
start_date=TRAINFULL_START_DATE, # After tuning well, retrain on the training and testing sets
119
end_date=TRAINFULL_END_DATE,
120
ticker_list=ticker_list,
121
data_source="alpaca",
122
time_interval="1Min",
123
technical_indicator_list=INDICATORS,
124
drl_lib="elegantrl",
125
env=env,
126
model_name="ppo",
127
if_vix=True,
128
API_KEY=DATA_API_KEY,
129
API_SECRET=DATA_API_SECRET,
130
API_BASE_URL=DATA_API_BASE_URL,
131
erl_params=ERL_PARAMS,
132
cwd="./papertrading_erl_retrain",
133
break_step=2e5,
134
)
135
136
action_dim = len(DOW_30_TICKER)
137
state_dim = (
138
1 + 2 + 3 * action_dim + len(INDICATORS) * action_dim
139
) # Calculate the DRL state dimension manually for paper trading. amount + (turbulence, turbulence_bool) + (price, shares, cd (holding time)) * stock_dim + tech_dim
140
141
paper_trading_erl = PaperTradingAlpaca(
142
ticker_list=DOW_30_TICKER,
143
time_interval="1Min",
144
drl_lib="elegantrl",
145
agent="ppo",
146
cwd="./papertrading_erl_retrain",
147
net_dim=ERL_PARAMS["net_dimension"],
148
state_dim=state_dim,
149
action_dim=action_dim,
150
API_KEY=TRADING_API_KEY,
151
API_SECRET=TRADING_API_SECRET,
152
API_BASE_URL=TRADING_API_BASE_URL,
153
tech_indicator_list=INDICATORS,
154
turbulence_thresh=30,
155
max_stock=1e2,
156
)
157
158
paper_trading_erl.run()
159
160
# Check Portfolio Performance
161
# ## Get cumulative return
162
df_erl, cumu_erl = alpaca_history(
163
key=DATA_API_KEY,
164
secret=DATA_API_SECRET,
165
url=DATA_API_BASE_URL,
166
start="2022-09-01", # must be within 1 month
167
end="2022-09-12",
168
) # change the date if error occurs
169
170
df_djia, cumu_djia = DIA_history(start="2022-09-01")
171
returns_erl = cumu_erl - 1
172
returns_dia = cumu_djia - 1
173
returns_dia = returns_dia[: returns_erl.shape[0]]
174
175
# plot and save
176
import matplotlib.pyplot as plt
177
178
plt.figure(dpi=1000)
179
plt.grid()
180
plt.grid(which="minor", axis="y")
181
plt.title("Stock Trading (Paper trading)", fontsize=20)
182
plt.plot(returns_erl, label="ElegantRL Agent", color="red")
183
# plt.plot(returns_sb3, label = 'Stable-Baselines3 Agent', color = 'blue' )
184
# plt.plot(returns_rllib, label = 'RLlib Agent', color = 'green')
185
plt.plot(returns_dia, label="DJIA", color="grey")
186
plt.ylabel("Return", fontsize=16)
187
plt.xlabel("Year 2021", fontsize=16)
188
plt.xticks(size=14)
189
plt.yticks(size=14)
190
ax = plt.gca()
191
ax.xaxis.set_major_locator(ticker_list.MultipleLocator(78))
192
ax.xaxis.set_minor_locator(ticker_list.MultipleLocator(6))
193
ax.yaxis.set_minor_locator(ticker_list.MultipleLocator(0.005))
194
ax.yaxis.set_major_formatter(ticker_list.PercentFormatter(xmax=1, decimals=2))
195
ax.xaxis.set_major_formatter(
196
ticker_list.FixedFormatter(["", "10-19", "", "10-20", "", "10-21", "", "10-22"])
197
)
198
plt.legend(fontsize=10.5)
199
plt.savefig("papertrading_stock.png")
200
201