Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/finrl/applications/stock_trading/stock_trading_rolling_window.py
732 views
1
from __future__ import annotations
2
3
import itertools
4
5
import pandas as pd
6
from stable_baselines3.common.logger import configure
7
8
from finrl.agents.stablebaselines3.models import DRLAgent
9
from finrl.config import DATA_SAVE_DIR
10
from finrl.config import INDICATORS
11
from finrl.config import RESULTS_DIR
12
from finrl.config import TENSORBOARD_LOG_DIR
13
from finrl.config import TEST_END_DATE
14
from finrl.config import TEST_START_DATE
15
from finrl.config import TRAINED_MODEL_DIR
16
from finrl.config_tickers import DOW_30_TICKER
17
from finrl.main import check_and_make_directories
18
from finrl.meta.data_processors.func import calc_train_trade_data
19
from finrl.meta.data_processors.func import calc_train_trade_starts_ends_if_rolling
20
from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv
21
from finrl.meta.preprocessor.preprocessors import data_split
22
from finrl.meta.preprocessor.preprocessors import FeatureEngineer
23
from finrl.meta.preprocessor.yahoodownloader import YahooDownloader
24
from finrl.plot import backtest_stats
25
from finrl.plot import get_baseline
26
from finrl.plot import plot_return
27
28
# matplotlib.use('Agg')
29
30
31
def stock_trading_rolling_window(
32
train_start_date: str,
33
train_end_date: str,
34
trade_start_date: str,
35
trade_end_date: str,
36
rolling_window_length: int,
37
if_store_actions: bool = True,
38
if_store_result: bool = True,
39
if_using_a2c: bool = True,
40
if_using_ddpg: bool = True,
41
if_using_ppo: bool = True,
42
if_using_sac: bool = True,
43
if_using_td3: bool = True,
44
):
45
# sys.path.append("../FinRL")
46
check_and_make_directories(
47
[DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR]
48
)
49
date_col = "date"
50
tic_col = "tic"
51
df = YahooDownloader(
52
start_date=train_start_date, end_date=trade_end_date, ticker_list=DOW_30_TICKER
53
).fetch_data()
54
fe = FeatureEngineer(
55
use_technical_indicator=True,
56
tech_indicator_list=INDICATORS,
57
use_vix=True,
58
use_turbulence=True,
59
user_defined_feature=False,
60
)
61
62
processed = fe.preprocess_data(df)
63
list_ticker = processed[tic_col].unique().tolist()
64
list_date = list(
65
pd.date_range(processed[date_col].min(), processed[date_col].max()).astype(str)
66
)
67
combination = list(itertools.product(list_date, list_ticker))
68
69
init_train_trade_data = pd.DataFrame(
70
combination, columns=[date_col, tic_col]
71
).merge(processed, on=[date_col, tic_col], how="left")
72
init_train_trade_data = init_train_trade_data[
73
init_train_trade_data[date_col].isin(processed[date_col])
74
]
75
init_train_trade_data = init_train_trade_data.sort_values([date_col, tic_col])
76
77
init_train_trade_data = init_train_trade_data.fillna(0)
78
79
init_train_data = data_split(
80
init_train_trade_data, train_start_date, train_end_date
81
)
82
init_trade_data = data_split(
83
init_train_trade_data, trade_start_date, trade_end_date
84
)
85
86
stock_dimension = len(init_train_data.tic.unique())
87
state_space = 1 + 2 * stock_dimension + len(INDICATORS) * stock_dimension
88
print(f"Stock Dimension: {stock_dimension}, State Space: {state_space}")
89
buy_cost_list = sell_cost_list = [0.001] * stock_dimension
90
num_stock_shares = [0] * stock_dimension
91
92
initial_amount = 1000000
93
env_kwargs = {
94
"hmax": 100,
95
"initial_amount": initial_amount,
96
"num_stock_shares": num_stock_shares,
97
"buy_cost_pct": buy_cost_list,
98
"sell_cost_pct": sell_cost_list,
99
"state_space": state_space,
100
"stock_dim": stock_dimension,
101
"tech_indicator_list": INDICATORS,
102
"action_space": stock_dimension,
103
"reward_scaling": 1e-4,
104
}
105
106
# split the init_train_data and init_trade_data to subsets
107
init_train_dates = init_train_data[date_col].unique()
108
init_trade_dates = init_trade_data[date_col].unique()
109
110
(
111
train_starts,
112
train_ends,
113
trade_starts,
114
trade_ends,
115
) = calc_train_trade_starts_ends_if_rolling(
116
init_train_dates, init_trade_dates, rolling_window_length
117
)
118
119
result = pd.DataFrame()
120
actions_a2c = pd.DataFrame(columns=DOW_30_TICKER)
121
actions_ddpg = pd.DataFrame(columns=DOW_30_TICKER)
122
actions_ppo = pd.DataFrame(columns=DOW_30_TICKER)
123
actions_sac = pd.DataFrame(columns=DOW_30_TICKER)
124
actions_td3 = pd.DataFrame(columns=DOW_30_TICKER)
125
126
for i in range(len(train_starts)):
127
print("i: ", i)
128
train_data, trade_data = calc_train_trade_data(
129
i,
130
train_starts,
131
train_ends,
132
trade_starts,
133
trade_ends,
134
init_train_data,
135
init_trade_data,
136
date_col,
137
)
138
e_train_gym = StockTradingEnv(df=train_data, **env_kwargs)
139
env_train, _ = e_train_gym.get_sb_env()
140
141
# train
142
143
if if_using_a2c:
144
if len(result) >= 1:
145
env_kwargs["initial_amount"] = result["A2C"].iloc[-1]
146
e_train_gym = StockTradingEnv(df=train_data, **env_kwargs)
147
env_train, _ = e_train_gym.get_sb_env()
148
agent = DRLAgent(env=env_train)
149
model_a2c = agent.get_model("a2c")
150
# set up logger
151
tmp_path = RESULTS_DIR + "/a2c"
152
new_logger_a2c = configure(tmp_path, ["stdout", "csv", "tensorboard"])
153
# Set new logger
154
model_a2c.set_logger(new_logger_a2c)
155
trained_a2c = agent.train_model(
156
model=model_a2c, tb_log_name="a2c", total_timesteps=50000
157
)
158
159
if if_using_ddpg:
160
if len(result) >= 1:
161
env_kwargs["initial_amount"] = result["DDPG"].iloc[-1]
162
e_train_gym = StockTradingEnv(df=train_data, **env_kwargs)
163
env_train, _ = e_train_gym.get_sb_env()
164
agent = DRLAgent(env=env_train)
165
model_ddpg = agent.get_model("ddpg")
166
# set up logger
167
tmp_path = RESULTS_DIR + "/ddpg"
168
new_logger_ddpg = configure(tmp_path, ["stdout", "csv", "tensorboard"])
169
# Set new logger
170
model_ddpg.set_logger(new_logger_ddpg)
171
trained_ddpg = agent.train_model(
172
model=model_ddpg, tb_log_name="ddpg", total_timesteps=40000
173
)
174
175
if if_using_ppo:
176
if len(result) >= 1:
177
env_kwargs["initial_amount"] = result["PPO"].iloc[-1]
178
e_train_gym = StockTradingEnv(df=train_data, **env_kwargs)
179
env_train, _ = e_train_gym.get_sb_env()
180
agent = DRLAgent(env=env_train)
181
PPO_PARAMS = {
182
"n_steps": 2048,
183
"ent_coef": 0.005,
184
"learning_rate": 0.0001,
185
"batch_size": 64,
186
}
187
model_ppo = agent.get_model("ppo", model_kwargs=PPO_PARAMS)
188
# set up logger
189
tmp_path = RESULTS_DIR + "/ppo"
190
new_logger_ppo = configure(tmp_path, ["stdout", "csv", "tensorboard"])
191
# Set new logger
192
model_ppo.set_logger(new_logger_ppo)
193
trained_ppo = agent.train_model(
194
model=model_ppo, tb_log_name="ppo", total_timesteps=50000
195
)
196
197
if if_using_sac:
198
if len(result) >= 1:
199
env_kwargs["initial_amount"] = result["SAC"].iloc[-1]
200
e_train_gym = StockTradingEnv(df=train_data, **env_kwargs)
201
env_train, _ = e_train_gym.get_sb_env()
202
agent = DRLAgent(env=env_train)
203
SAC_PARAMS = {
204
"batch_size": 64,
205
"buffer_size": 100000,
206
"learning_rate": 0.00015,
207
"learning_starts": 100,
208
"ent_coef": "auto_0.1",
209
}
210
model_sac = agent.get_model("sac", model_kwargs=SAC_PARAMS)
211
# set up logger
212
tmp_path = RESULTS_DIR + "/sac"
213
new_logger_sac = configure(tmp_path, ["stdout", "csv", "tensorboard"])
214
# Set new logger
215
model_sac.set_logger(new_logger_sac)
216
trained_sac = agent.train_model(
217
model=model_sac, tb_log_name="sac", total_timesteps=50000
218
)
219
220
if if_using_td3:
221
if len(result) >= 1:
222
env_kwargs["initial_amount"] = result["TD3"].iloc[-1]
223
e_train_gym = StockTradingEnv(df=train_data, **env_kwargs)
224
env_train, _ = e_train_gym.get_sb_env()
225
agent = DRLAgent(env=env_train)
226
TD3_PARAMS = {
227
"batch_size": 64,
228
"buffer_size": 100000,
229
"learning_rate": 0.0008,
230
}
231
model_td3 = agent.get_model("td3", model_kwargs=TD3_PARAMS)
232
# set up logger
233
tmp_path = RESULTS_DIR + "/td3"
234
new_logger_td3 = configure(tmp_path, ["stdout", "csv", "tensorboard"])
235
# Set new logger
236
model_td3.set_logger(new_logger_td3)
237
trained_td3 = agent.train_model(
238
model=model_td3, tb_log_name="td3", total_timesteps=50000
239
)
240
241
# trade
242
# this e_trade_gym is initialized, then it will be used if i == 0
243
e_trade_gym = StockTradingEnv(
244
df=trade_data,
245
turbulence_threshold=70,
246
risk_indicator_col="vix",
247
**env_kwargs,
248
)
249
250
if if_using_a2c:
251
if len(result) >= 1:
252
env_kwargs["initial_amount"] = result["A2C"].iloc[-1]
253
e_trade_gym = StockTradingEnv(
254
df=trade_data,
255
turbulence_threshold=70,
256
risk_indicator_col="vix",
257
**env_kwargs,
258
)
259
result_a2c, actions_i_a2c = DRLAgent.DRL_prediction(
260
model=trained_a2c, environment=e_trade_gym
261
)
262
263
if if_using_ddpg:
264
if len(result) >= 1:
265
env_kwargs["initial_amount"] = result["DDPG"].iloc[-1]
266
e_trade_gym = StockTradingEnv(
267
df=trade_data,
268
turbulence_threshold=70,
269
risk_indicator_col="vix",
270
**env_kwargs,
271
)
272
result_ddpg, actions_i_ddpg = DRLAgent.DRL_prediction(
273
model=trained_ddpg, environment=e_trade_gym
274
)
275
276
if if_using_ppo:
277
if len(result) >= 1:
278
env_kwargs["initial_amount"] = result["PPO"].iloc[-1]
279
e_trade_gym = StockTradingEnv(
280
df=trade_data,
281
turbulence_threshold=70,
282
risk_indicator_col="vix",
283
**env_kwargs,
284
)
285
result_ppo, actions_i_ppo = DRLAgent.DRL_prediction(
286
model=trained_ppo, environment=e_trade_gym
287
)
288
289
if if_using_sac:
290
if len(result) >= 1:
291
env_kwargs["initial_amount"] = result["SAC"].iloc[-1]
292
e_trade_gym = StockTradingEnv(
293
df=trade_data,
294
turbulence_threshold=70,
295
risk_indicator_col="vix",
296
**env_kwargs,
297
)
298
result_sac, actions_i_sac = DRLAgent.DRL_prediction(
299
model=trained_sac, environment=e_trade_gym
300
)
301
302
if if_using_td3:
303
if len(result) >= 1:
304
env_kwargs["initial_amount"] = result["TD3"].iloc[-1]
305
e_trade_gym = StockTradingEnv(
306
df=trade_data,
307
turbulence_threshold=70,
308
risk_indicator_col="vix",
309
**env_kwargs,
310
)
311
result_td3, actions_i_td3 = DRLAgent.DRL_prediction(
312
model=trained_td3, environment=e_trade_gym
313
)
314
315
# in python version, we should check isinstance, but in notebook version, it is not necessary
316
if if_using_a2c and isinstance(result_a2c, tuple):
317
actions_i_a2c = result_a2c[1]
318
result_a2c = result_a2c[0]
319
if if_using_ddpg and isinstance(result_ddpg, tuple):
320
actions_i_ddpg = result_ddpg[1]
321
result_ddpg = result_ddpg[0]
322
if if_using_ppo and isinstance(result_ppo, tuple):
323
actions_i_ppo = result_ppo[1]
324
result_ppo = result_ppo[0]
325
if if_using_sac and isinstance(result_sac, tuple):
326
actions_i_sac = result_sac[1]
327
result_sac = result_sac[0]
328
if if_using_td3 and isinstance(result_td3, tuple):
329
actions_i_td3 = result_td3[1]
330
result_td3 = result_td3[0]
331
332
# merge actions
333
actions_a2c = pd.concat([actions_a2c, actions_i_a2c]) if if_using_a2c else None
334
actions_ddpg = (
335
pd.concat([actions_ddpg, actions_i_ddpg]) if if_using_ddpg else None
336
)
337
actions_ppo = pd.concat([actions_ppo, actions_i_ppo]) if if_using_ppo else None
338
actions_sac = pd.concat([actions_sac, actions_i_sac]) if if_using_sac else None
339
actions_td3 = pd.concat([actions_td3, actions_i_td3]) if if_using_td3 else None
340
341
# dji_i
342
trade_start = trade_starts[i]
343
trade_end = trade_ends[i]
344
dji_i_ = get_baseline(ticker="^DJI", start=trade_start, end=trade_end)
345
dji_i = pd.DataFrame()
346
dji_i[date_col] = dji_i_[date_col]
347
dji_i["DJI"] = dji_i_["close"]
348
# dji_i.rename(columns={'account_value': 'DJI'}, inplace=True)
349
350
# select the rows between trade_start and trade_end (not included), since some values may not in this region
351
dji_i = dji_i.loc[
352
(dji_i[date_col] >= trade_start) & (dji_i[date_col] < trade_end)
353
]
354
355
# init result_i by dji_i
356
result_i = dji_i
357
358
# rename column name of result_a2c, result_ddpg, etc., and then put them to result_i
359
if if_using_a2c:
360
result_a2c.rename(columns={"account_value": "A2C"}, inplace=True)
361
result_i = pd.merge(result_i, result_a2c, how="left")
362
if if_using_ddpg:
363
result_ddpg.rename(columns={"account_value": "DDPG"}, inplace=True)
364
result_i = pd.merge(result_i, result_ddpg, how="left")
365
if if_using_ppo:
366
result_ppo.rename(columns={"account_value": "PPO"}, inplace=True)
367
result_i = pd.merge(result_i, result_ppo, how="left")
368
if if_using_sac:
369
result_sac.rename(columns={"account_value": "SAC"}, inplace=True)
370
result_i = pd.merge(result_i, result_sac, how="left")
371
if if_using_td3:
372
result_td3.rename(columns={"account_value": "TD3"}, inplace=True)
373
result_i = pd.merge(result_i, result_td3, how="left")
374
375
# remove the rows with nan
376
result_i = result_i.dropna(axis=0, how="any")
377
378
# merge result_i to result
379
result = pd.concat([result, result_i], axis=0)
380
381
# store actions
382
if if_store_actions:
383
actions_a2c.to_csv("actions_a2c.csv") if if_using_a2c else None
384
actions_ddpg.to_csv("actions_ddpg.csv") if if_using_ddpg else None
385
actions_ppo.to_csv("actions_ppo.csv") if if_using_ppo else None
386
actions_sac.to_csv("actions_sac.csv") if if_using_sac else None
387
actions_td3.to_csv("actions_td3.csv") if if_using_td3 else None
388
389
# calc the column name of strategies, including DJI
390
col_strategies = []
391
for col in result.columns:
392
if col != date_col and col != "" and "Unnamed" not in col:
393
col_strategies.append(col)
394
395
# make sure that the first row of DJI is initial_amount
396
col = "DJI"
397
result[col] = result[col] / result[col].iloc[0] * initial_amount
398
result = result.reset_index(drop=True)
399
400
# stats
401
for col in col_strategies:
402
stats = backtest_stats(result, value_col_name=col)
403
print("\nstats of " + col + ": \n", stats)
404
405
# print and save result
406
print("result: ", result)
407
if if_store_result:
408
result.to_csv("result.csv")
409
410
# plot fig
411
plot_return(
412
result=result,
413
column_as_x=date_col,
414
if_need_calc_return=True,
415
savefig_filename="stock_trading_rolling_window.png",
416
xlabel="Date",
417
ylabel="Return",
418
if_transfer_date=True,
419
num_days_xticks=20,
420
)
421
422
423
if __name__ == "__main__":
424
train_start_date = "2009-01-01"
425
train_end_date = "2022-07-01"
426
trade_start_date = "2022-07-01"
427
trade_end_date = "2022-11-01"
428
rolling_window_length = 22 # num of trading days in a rolling window
429
if_store_actions = True
430
if_store_result = True
431
if_using_a2c = True
432
if_using_ddpg = True
433
if_using_ppo = True
434
if_using_sac = True
435
if_using_td3 = True
436
stock_trading_rolling_window(
437
train_start_date=train_start_date,
438
train_end_date=train_end_date,
439
trade_start_date=trade_start_date,
440
trade_end_date=trade_end_date,
441
rolling_window_length=rolling_window_length,
442
if_store_actions=if_store_actions,
443
if_store_result=if_store_result,
444
if_using_a2c=if_using_a2c,
445
if_using_ddpg=if_using_ddpg,
446
if_using_ppo=if_using_ppo,
447
if_using_sac=if_using_sac,
448
if_using_td3=if_using_td3,
449
)
450
451