Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/finrl/applications/stock_trading/stock_trading.py
732 views
1
from __future__ import annotations
2
3
import itertools
4
import sys
5
6
import pandas as pd
7
from stable_baselines3.common.logger import configure
8
9
from finrl.agents.stablebaselines3.models import DRLAgent
10
from finrl.config import DATA_SAVE_DIR
11
from finrl.config import INDICATORS
12
from finrl.config import RESULTS_DIR
13
from finrl.config import TENSORBOARD_LOG_DIR
14
from finrl.config import TRAINED_MODEL_DIR
15
from finrl.config_tickers import DOW_30_TICKER
16
from finrl.main import check_and_make_directories
17
from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv
18
from finrl.meta.preprocessor.preprocessors import data_split
19
from finrl.meta.preprocessor.preprocessors import FeatureEngineer
20
from finrl.meta.preprocessor.yahoodownloader import YahooDownloader
21
from finrl.plot import backtest_stats
22
from finrl.plot import get_baseline
23
from finrl.plot import plot_return
24
25
# matplotlib.use('Agg')
26
27
28
def stock_trading(
29
train_start_date: str,
30
train_end_date: str,
31
trade_start_date: str,
32
trade_end_date: str,
33
if_store_actions: bool = True,
34
if_store_result: bool = True,
35
if_using_a2c: bool = True,
36
if_using_ddpg: bool = True,
37
if_using_ppo: bool = True,
38
if_using_sac: bool = True,
39
if_using_td3: bool = True,
40
):
41
sys.path.append("../FinRL")
42
check_and_make_directories(
43
[DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR]
44
)
45
date_col = "date"
46
tic_col = "tic"
47
df = YahooDownloader(
48
start_date=train_start_date, end_date=trade_end_date, ticker_list=DOW_30_TICKER
49
).fetch_data()
50
fe = FeatureEngineer(
51
use_technical_indicator=True,
52
tech_indicator_list=INDICATORS,
53
use_vix=True,
54
use_turbulence=True,
55
user_defined_feature=False,
56
)
57
58
processed = fe.preprocess_data(df)
59
list_ticker = processed[tic_col].unique().tolist()
60
list_date = list(
61
pd.date_range(processed[date_col].min(), processed[date_col].max()).astype(str)
62
)
63
combination = list(itertools.product(list_date, list_ticker))
64
65
init_train_trade_data = pd.DataFrame(
66
combination, columns=[date_col, tic_col]
67
).merge(processed, on=[date_col, tic_col], how="left")
68
init_train_trade_data = init_train_trade_data[
69
init_train_trade_data[date_col].isin(processed[date_col])
70
]
71
init_train_trade_data = init_train_trade_data.sort_values([date_col, tic_col])
72
73
init_train_trade_data = init_train_trade_data.fillna(0)
74
75
init_train_data = data_split(
76
init_train_trade_data, train_start_date, train_end_date
77
)
78
init_trade_data = data_split(
79
init_train_trade_data, trade_start_date, trade_end_date
80
)
81
82
stock_dimension = len(init_train_data.tic.unique())
83
state_space = 1 + 2 * stock_dimension + len(INDICATORS) * stock_dimension
84
print(f"Stock Dimension: {stock_dimension}, State Space: {state_space}")
85
buy_cost_list = sell_cost_list = [0.001] * stock_dimension
86
num_stock_shares = [0] * stock_dimension
87
88
initial_amount = 1000000
89
env_kwargs = {
90
"hmax": 100,
91
"initial_amount": initial_amount,
92
"num_stock_shares": num_stock_shares,
93
"buy_cost_pct": buy_cost_list,
94
"sell_cost_pct": sell_cost_list,
95
"state_space": state_space,
96
"stock_dim": stock_dimension,
97
"tech_indicator_list": INDICATORS,
98
"action_space": stock_dimension,
99
"reward_scaling": 1e-4,
100
}
101
102
e_train_gym = StockTradingEnv(df=init_train_data, **env_kwargs)
103
104
env_train, _ = e_train_gym.get_sb_env()
105
print(type(env_train))
106
107
if if_using_a2c:
108
agent = DRLAgent(env=env_train)
109
model_a2c = agent.get_model("a2c")
110
# set up logger
111
tmp_path = RESULTS_DIR + "/a2c"
112
new_logger_a2c = configure(tmp_path, ["stdout", "csv", "tensorboard"])
113
# Set new logger
114
model_a2c.set_logger(new_logger_a2c)
115
trained_a2c = agent.train_model(
116
model=model_a2c, tb_log_name="a2c", total_timesteps=50000
117
)
118
119
if if_using_ddpg:
120
agent = DRLAgent(env=env_train)
121
model_ddpg = agent.get_model("ddpg")
122
# set up logger
123
tmp_path = RESULTS_DIR + "/ddpg"
124
new_logger_ddpg = configure(tmp_path, ["stdout", "csv", "tensorboard"])
125
# Set new logger
126
model_ddpg.set_logger(new_logger_ddpg)
127
trained_ddpg = agent.train_model(
128
model=model_ddpg, tb_log_name="ddpg", total_timesteps=50000
129
)
130
131
if if_using_ppo:
132
agent = DRLAgent(env=env_train)
133
PPO_PARAMS = {
134
"n_steps": 2048,
135
"ent_coef": 0.01,
136
"learning_rate": 0.00025,
137
"batch_size": 128,
138
}
139
model_ppo = agent.get_model("ppo", model_kwargs=PPO_PARAMS)
140
# set up logger
141
tmp_path = RESULTS_DIR + "/ppo"
142
new_logger_ppo = configure(tmp_path, ["stdout", "csv", "tensorboard"])
143
# Set new logger
144
model_ppo.set_logger(new_logger_ppo)
145
trained_ppo = agent.train_model(
146
model=model_ppo, tb_log_name="ppo", total_timesteps=50000
147
)
148
149
if if_using_sac:
150
agent = DRLAgent(env=env_train)
151
SAC_PARAMS = {
152
"batch_size": 128,
153
"buffer_size": 100000,
154
"learning_rate": 0.0001,
155
"learning_starts": 100,
156
"ent_coef": "auto_0.1",
157
}
158
model_sac = agent.get_model("sac", model_kwargs=SAC_PARAMS)
159
# set up logger
160
tmp_path = RESULTS_DIR + "/sac"
161
new_logger_sac = configure(tmp_path, ["stdout", "csv", "tensorboard"])
162
# Set new logger
163
model_sac.set_logger(new_logger_sac)
164
trained_sac = agent.train_model(
165
model=model_sac, tb_log_name="sac", total_timesteps=50000
166
)
167
168
if if_using_td3:
169
agent = DRLAgent(env=env_train)
170
TD3_PARAMS = {"batch_size": 100, "buffer_size": 1000000, "learning_rate": 0.001}
171
model_td3 = agent.get_model("td3", model_kwargs=TD3_PARAMS)
172
# set up logger
173
tmp_path = RESULTS_DIR + "/td3"
174
new_logger_td3 = configure(tmp_path, ["stdout", "csv", "tensorboard"])
175
# Set new logger
176
model_td3.set_logger(new_logger_td3)
177
trained_td3 = agent.train_model(
178
model=model_td3, tb_log_name="td3", total_timesteps=50000
179
)
180
181
# trade
182
e_trade_gym = StockTradingEnv(
183
df=init_trade_data,
184
turbulence_threshold=70,
185
risk_indicator_col="vix",
186
**env_kwargs,
187
)
188
# env_trade, obs_trade = e_trade_gym.get_sb_env()
189
190
if if_using_a2c:
191
result_a2c, actions_a2c = DRLAgent.DRL_prediction(
192
model=trained_a2c, environment=e_trade_gym
193
)
194
195
if if_using_ddpg:
196
result_ddpg, actions_ddpg = DRLAgent.DRL_prediction(
197
model=trained_ddpg, environment=e_trade_gym
198
)
199
200
if if_using_ppo:
201
result_ppo, actions_ppo = DRLAgent.DRL_prediction(
202
model=trained_ppo, environment=e_trade_gym
203
)
204
205
if if_using_sac:
206
result_sac, actions_sac = DRLAgent.DRL_prediction(
207
model=trained_sac, environment=e_trade_gym
208
)
209
210
if if_using_td3:
211
result_td3, actions_td3 = DRLAgent.DRL_prediction(
212
model=trained_td3, environment=e_trade_gym
213
)
214
215
# in python version, we should check isinstance, but in notebook version, it is not necessary
216
if if_using_a2c and isinstance(result_a2c, tuple):
217
actions_a2c = result_a2c[1]
218
result_a2c = result_a2c[0]
219
if if_using_ddpg and isinstance(result_ddpg, tuple):
220
actions_ddpg = result_ddpg[1]
221
result_ddpg = result_ddpg[0]
222
if if_using_ppo and isinstance(result_ppo, tuple):
223
actions_ppo = result_ppo[1]
224
result_ppo = result_ppo[0]
225
if if_using_sac and isinstance(result_sac, tuple):
226
actions_sac = result_sac[1]
227
result_sac = result_sac[0]
228
if if_using_td3 and isinstance(result_td3, tuple):
229
actions_td3 = result_td3[1]
230
result_td3 = result_td3[0]
231
232
# store actions
233
if if_store_actions:
234
actions_a2c.to_csv("actions_a2c.csv") if if_using_a2c else None
235
actions_ddpg.to_csv("actions_ddpg.csv") if if_using_ddpg else None
236
actions_td3.to_csv("actions_td3.csv") if if_using_td3 else None
237
actions_ppo.to_csv("actions_ppo.csv") if if_using_ppo else None
238
actions_sac.to_csv("actions_sac.csv") if if_using_sac else None
239
240
# dji
241
dji_ = get_baseline(ticker="^DJI", start=trade_start_date, end=trade_end_date)
242
dji = pd.DataFrame()
243
dji[date_col] = dji_[date_col]
244
dji["DJI"] = dji_["close"]
245
# select the rows between trade_start and trade_end (not included), since some values may not in this region
246
dji = dji.loc[
247
(dji[date_col] >= trade_start_date) & (dji[date_col] < trade_end_date)
248
]
249
250
result = dji
251
252
if if_using_a2c:
253
result_a2c.rename(columns={"account_value": "A2C"}, inplace=True)
254
result = pd.merge(result, result_a2c, how="left")
255
if if_using_ddpg:
256
result_ddpg.rename(columns={"account_value": "DDPG"}, inplace=True)
257
result = pd.merge(result, result_ddpg, how="left")
258
if if_using_td3:
259
result_td3.rename(columns={"account_value": "TD3"}, inplace=True)
260
result = pd.merge(result, result_td3, how="left")
261
if if_using_ppo:
262
result_ppo.rename(columns={"account_value": "PPO"}, inplace=True)
263
result = pd.merge(result, result_ppo, how="left")
264
if if_using_sac:
265
result_sac.rename(columns={"account_value": "SAC"}, inplace=True)
266
result = pd.merge(result, result_sac, how="left")
267
268
# remove the rows with nan
269
result = result.dropna(axis=0, how="any")
270
271
# calc the column name of strategies, including DJI
272
col_strategies = []
273
for col in result.columns:
274
if col != date_col and col != "" and "Unnamed" not in col:
275
col_strategies.append(col)
276
277
# make sure that the first row of DJI is initial_amount
278
col = "DJI"
279
result[col] = result[col] / result[col].iloc[0] * initial_amount
280
result = result.reset_index(drop=True)
281
282
# stats
283
for col in col_strategies:
284
stats = backtest_stats(result, value_col_name=col)
285
print("\nstats of " + col + ": \n", stats)
286
287
# print and save result
288
print("result: ", result)
289
if if_store_result:
290
result.to_csv("result.csv")
291
292
# plot fig
293
plot_return(
294
result=result,
295
column_as_x=date_col,
296
if_need_calc_return=True,
297
savefig_filename="stock_trading.png",
298
xlabel="Date",
299
ylabel="Return",
300
if_transfer_date=True,
301
num_days_xticks=20,
302
)
303
304
305
if __name__ == "__main__":
306
train_start_date = "2009-01-01"
307
train_end_date = "2022-09-01"
308
trade_start_date = "2022-09-01"
309
trade_end_date = "2023-11-01"
310
if_store_actions = True
311
if_store_result = True
312
if_using_a2c = True
313
if_using_ddpg = True
314
if_using_ppo = True
315
if_using_sac = True
316
if_using_td3 = True
317
318
stock_trading(
319
train_start_date=train_start_date,
320
train_end_date=train_end_date,
321
trade_start_date=trade_start_date,
322
trade_end_date=trade_end_date,
323
if_store_actions=if_store_actions,
324
if_store_result=if_store_result,
325
if_using_a2c=if_using_a2c,
326
if_using_ddpg=if_using_ddpg,
327
if_using_ppo=if_using_ppo,
328
if_using_sac=if_using_sac,
329
if_using_td3=if_using_td3,
330
)
331
332