Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/finrl/applications/stock_trading/ensemble_stock_trading.py
732 views
1
from __future__ import annotations
2
3
4
def main():
5
import warnings
6
7
warnings.filterwarnings("ignore")
8
import pandas as pd
9
import numpy as np
10
import matplotlib
11
import matplotlib.pyplot as plt
12
13
# matplotlib.use('Agg')
14
import datetime
15
16
from finrl.config_tickers import DOW_30_TICKER
17
from finrl.meta.preprocessor.yahoodownloader import YahooDownloader
18
from finrl.meta.preprocessor.preprocessors import FeatureEngineer, data_split
19
from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv
20
from finrl.agents.stablebaselines3.models import DRLAgent, DRLEnsembleAgent
21
from finrl.plot import backtest_stats, backtest_plot, get_daily_return, get_baseline
22
23
from pprint import pprint
24
25
import sys
26
27
sys.path.append("../FinRL-Library")
28
29
import itertools
30
31
import os
32
from finrl.main import check_and_make_directories
33
from finrl.config import (
34
DATA_SAVE_DIR,
35
TRAINED_MODEL_DIR,
36
TENSORBOARD_LOG_DIR,
37
RESULTS_DIR,
38
INDICATORS,
39
TRAIN_START_DATE,
40
TRAIN_END_DATE,
41
TEST_START_DATE,
42
TEST_END_DATE,
43
TRADE_START_DATE,
44
TRADE_END_DATE,
45
)
46
47
check_and_make_directories(
48
[DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR]
49
)
50
print(DOW_30_TICKER)
51
TRAIN_START_DATE = "2009-04-01"
52
TRAIN_END_DATE = "2021-01-01"
53
TEST_START_DATE = "2021-01-01"
54
TEST_END_DATE = "2022-06-01"
55
56
df = YahooDownloader(
57
start_date=TRAIN_START_DATE, end_date=TEST_END_DATE, ticker_list=DOW_30_TICKER
58
).fetch_data()
59
60
df.sort_values(["date", "tic"]).head()
61
62
fe = FeatureEngineer(
63
use_technical_indicator=True,
64
tech_indicator_list=INDICATORS,
65
use_turbulence=True,
66
user_defined_feature=False,
67
)
68
69
processed = fe.preprocess_data(df)
70
processed = processed.copy()
71
processed = processed.fillna(0)
72
processed = processed.replace(np.inf, 0)
73
74
stock_dimension = len(processed.tic.unique())
75
state_space = 1 + 2 * stock_dimension + len(INDICATORS) * stock_dimension
76
print(f"Stock Dimension: {stock_dimension}, State Space: {state_space}")
77
78
env_kwargs = {
79
"hmax": 100,
80
"initial_amount": 1000000,
81
"buy_cost_pct": 0.001,
82
"sell_cost_pct": 0.001,
83
"state_space": state_space,
84
"stock_dim": stock_dimension,
85
"tech_indicator_list": INDICATORS,
86
"action_space": stock_dimension,
87
"reward_scaling": 1e-4,
88
"print_verbosity": 5,
89
}
90
91
rebalance_window = 63 # rebalance_window is the number of days to retrain the model
92
validation_window = 63 # validation_window is the number of days to do validation and trading (e.g. if validation_window=63, then both validation and trading period will be 63 days)
93
94
ensemble_agent = DRLEnsembleAgent(
95
df=processed,
96
train_period=(TRAIN_START_DATE, TRAIN_END_DATE),
97
val_test_period=(TEST_START_DATE, TEST_END_DATE),
98
rebalance_window=rebalance_window,
99
validation_window=validation_window,
100
**env_kwargs,
101
)
102
103
A2C_model_kwargs = {"n_steps": 5, "ent_coef": 0.005, "learning_rate": 0.0007}
104
105
PPO_model_kwargs = {
106
"ent_coef": 0.01,
107
"n_steps": 2048,
108
"learning_rate": 0.00025,
109
"batch_size": 128,
110
}
111
112
DDPG_model_kwargs = {
113
# "action_noise":"ornstein_uhlenbeck",
114
"buffer_size": 10_000,
115
"learning_rate": 0.0005,
116
"batch_size": 64,
117
}
118
119
timesteps_dict = {"a2c": 10_000, "ppo": 10_000, "ddpg": 10_000}
120
df_summary = ensemble_agent.run_ensemble_strategy(
121
A2C_model_kwargs, PPO_model_kwargs, DDPG_model_kwargs, timesteps_dict
122
)
123
124
unique_trade_date = processed[
125
(processed.date > TEST_START_DATE) & (processed.date <= TEST_END_DATE)
126
].date.unique()
127
128
df_trade_date = pd.DataFrame({"datadate": unique_trade_date})
129
130
df_account_value = pd.DataFrame()
131
for i in range(
132
rebalance_window + validation_window,
133
len(unique_trade_date) + 1,
134
rebalance_window,
135
):
136
temp = pd.read_csv(
137
"results/account_value_trade_{}_{}.csv".format("ensemble", i)
138
)
139
df_account_value = df_account_value.append(temp, ignore_index=True)
140
sharpe = (
141
(252**0.5)
142
* df_account_value.account_value.pct_change(1).mean()
143
/ df_account_value.account_value.pct_change(1).std()
144
)
145
print("Sharpe Ratio: ", sharpe)
146
df_account_value = df_account_value.join(
147
df_trade_date[validation_window:].reset_index(drop=True)
148
)
149
150
df_account_value.account_value.plot()
151
152
print("==============Get Backtest Results===========")
153
now = datetime.datetime.now().strftime("%Y%m%d-%Hh%M")
154
155
perf_stats_all = backtest_stats(account_value=df_account_value)
156
perf_stats_all = pd.DataFrame(perf_stats_all)
157
158
# baseline stats
159
print("==============Get Baseline Stats===========")
160
baseline_df = get_baseline(
161
ticker="^DJI",
162
start=df_account_value.loc[0, "date"],
163
end=df_account_value.loc[len(df_account_value) - 1, "date"],
164
)
165
166
stats = backtest_stats(baseline_df, value_col_name="close")
167
168
print("==============Compare to DJIA===========")
169
170
# S&P 500: ^GSPC
171
# Dow Jones Index: ^DJI
172
# NASDAQ 100: ^NDX
173
backtest_plot(
174
df_account_value,
175
baseline_ticker="^DJI",
176
baseline_start=df_account_value.loc[0, "date"],
177
baseline_end=df_account_value.loc[len(df_account_value) - 1, "date"],
178
)
179
180
181
if __name__ == "__main__":
182
main()
183
184