Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/examples/FinRL_StockTrading_2026_3_Backtest.py
1706 views
1
"""
2
Stock NeurIPS2018 Part 3. Backtest
3
4
This series is a reproduction of paper "Deep reinforcement learning for
5
automated stock trading: An ensemble strategy".
6
7
Introducing how to use the agents we trained to do backtest, and compare with baselines such as
8
Mean Variance Optimization and DJIA index.
9
"""
10
11
import matplotlib
12
matplotlib.use("Agg")
13
import matplotlib.pyplot as plt
14
import numpy as np
15
import pandas as pd
16
from stable_baselines3 import A2C, DDPG, PPO, SAC, TD3
17
18
from finrl.agents.stablebaselines3.models import DRLAgent
19
from finrl.config import INDICATORS, TRAINED_MODEL_DIR, TRADE_START_DATE, TRADE_END_DATE
20
from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv
21
from finrl.meta.preprocessor.yahoodownloader import YahooDownloader
22
23
# %% Part 1. Load data
24
25
train = pd.read_csv("train_data.csv")
26
trade = pd.read_csv("trade_data.csv")
27
28
train = train.set_index(train.columns[0])
29
train.index.names = [""]
30
trade = trade.set_index(trade.columns[0])
31
trade.index.names = [""]
32
33
# %% Part 2. Load trained agents
34
35
if_using_a2c = True
36
if_using_ddpg = True
37
if_using_ppo = True
38
if_using_td3 = True
39
if_using_sac = True
40
41
trained_a2c = A2C.load(TRAINED_MODEL_DIR + "/agent_a2c") if if_using_a2c else None
42
trained_ddpg = DDPG.load(TRAINED_MODEL_DIR + "/agent_ddpg") if if_using_ddpg else None
43
trained_ppo = PPO.load(TRAINED_MODEL_DIR + "/agent_ppo") if if_using_ppo else None
44
trained_td3 = TD3.load(TRAINED_MODEL_DIR + "/agent_td3") if if_using_td3 else None
45
trained_sac = SAC.load(TRAINED_MODEL_DIR + "/agent_sac") if if_using_sac else None
46
47
# %% Part 3. Backtesting - DRL agents
48
49
stock_dimension = len(trade.tic.unique())
50
state_space = 1 + 2 * stock_dimension + len(INDICATORS) * stock_dimension
51
print(f"Stock Dimension: {stock_dimension}, State Space: {state_space}")
52
53
buy_cost_list = sell_cost_list = [0.001] * stock_dimension
54
num_stock_shares = [0] * stock_dimension
55
56
env_kwargs = {
57
"hmax": 100,
58
"initial_amount": 1000000,
59
"num_stock_shares": num_stock_shares,
60
"buy_cost_pct": buy_cost_list,
61
"sell_cost_pct": sell_cost_list,
62
"state_space": state_space,
63
"stock_dim": stock_dimension,
64
"tech_indicator_list": INDICATORS,
65
"action_space": stock_dimension,
66
"reward_scaling": 1e-4,
67
}
68
69
e_trade_gym = StockTradingEnv(
70
df=trade, turbulence_threshold=70, risk_indicator_col="vix", **env_kwargs
71
)
72
73
df_account_value_a2c, df_actions_a2c = (
74
DRLAgent.DRL_prediction(model=trained_a2c, environment=e_trade_gym)
75
if if_using_a2c
76
else (None, None)
77
)
78
79
df_account_value_ddpg, df_actions_ddpg = (
80
DRLAgent.DRL_prediction(model=trained_ddpg, environment=e_trade_gym)
81
if if_using_ddpg
82
else (None, None)
83
)
84
85
df_account_value_ppo, df_actions_ppo = (
86
DRLAgent.DRL_prediction(model=trained_ppo, environment=e_trade_gym)
87
if if_using_ppo
88
else (None, None)
89
)
90
91
df_account_value_td3, df_actions_td3 = (
92
DRLAgent.DRL_prediction(model=trained_td3, environment=e_trade_gym)
93
if if_using_td3
94
else (None, None)
95
)
96
97
df_account_value_sac, df_actions_sac = (
98
DRLAgent.DRL_prediction(model=trained_sac, environment=e_trade_gym)
99
if if_using_sac
100
else (None, None)
101
)
102
103
# %% Part 4. Mean Variance Optimization baseline
104
105
106
def process_df_for_mvo(df):
107
return df.pivot(index="date", columns="tic", values="close")
108
109
110
def StockReturnsComputing(StockPrice, Rows, Columns):
111
StockReturn = np.zeros([Rows - 1, Columns])
112
for j in range(Columns):
113
for i in range(Rows - 1):
114
StockReturn[i, j] = (
115
(StockPrice[i + 1, j] - StockPrice[i, j]) / StockPrice[i, j]
116
) * 100
117
return StockReturn
118
119
120
StockData = process_df_for_mvo(train)
121
TradeData = process_df_for_mvo(trade)
122
123
arStockPrices = np.asarray(StockData)
124
[Rows, Cols] = arStockPrices.shape
125
arReturns = StockReturnsComputing(arStockPrices, Rows, Cols)
126
127
meanReturns = np.mean(arReturns, axis=0)
128
covReturns = np.cov(arReturns, rowvar=False)
129
130
np.set_printoptions(precision=3, suppress=True)
131
print("Mean returns of assets in portfolio\n", meanReturns)
132
133
from pypfopt.efficient_frontier import EfficientFrontier
134
135
ef_mean = EfficientFrontier(meanReturns, covReturns, weight_bounds=(0, 0.5))
136
raw_weights_mean = ef_mean.max_sharpe()
137
cleaned_weights_mean = ef_mean.clean_weights()
138
mvo_weights = np.array(
139
[1000000 * cleaned_weights_mean[i] for i in range(len(cleaned_weights_mean))]
140
)
141
142
LastPrice = np.array([1 / p for p in StockData.tail(1).to_numpy()[0]])
143
Initial_Portfolio = np.multiply(mvo_weights, LastPrice)
144
145
Portfolio_Assets = TradeData @ Initial_Portfolio
146
MVO_result = pd.DataFrame(Portfolio_Assets, columns=["Mean Var"])
147
148
# %% Part 5. DJIA index baseline
149
150
import yfinance as yf
151
152
df_dji = yf.download("^DJI", start=TRADE_START_DATE, end=TRADE_END_DATE)
153
df_dji = df_dji[["Close"]].reset_index()
154
df_dji.columns = ["date", "close"]
155
df_dji["date"] = df_dji["date"].astype(str)
156
fst_day = df_dji["close"].iloc[0]
157
dji = pd.merge(
158
df_dji["date"],
159
df_dji["close"].div(fst_day).mul(1000000),
160
how="outer",
161
left_index=True,
162
right_index=True,
163
).set_index("date")
164
165
# %% Part 6. Compare results
166
167
df_result_a2c = (
168
df_account_value_a2c.set_index(df_account_value_a2c.columns[0])
169
if if_using_a2c
170
else None
171
)
172
df_result_ddpg = (
173
df_account_value_ddpg.set_index(df_account_value_ddpg.columns[0])
174
if if_using_ddpg
175
else None
176
)
177
df_result_ppo = (
178
df_account_value_ppo.set_index(df_account_value_ppo.columns[0])
179
if if_using_ppo
180
else None
181
)
182
df_result_td3 = (
183
df_account_value_td3.set_index(df_account_value_td3.columns[0])
184
if if_using_td3
185
else None
186
)
187
df_result_sac = (
188
df_account_value_sac.set_index(df_account_value_sac.columns[0])
189
if if_using_sac
190
else None
191
)
192
193
result = pd.DataFrame(
194
{
195
"a2c": df_result_a2c["account_value"] if if_using_a2c else None,
196
"ddpg": df_result_ddpg["account_value"] if if_using_ddpg else None,
197
"ppo": df_result_ppo["account_value"] if if_using_ppo else None,
198
"td3": df_result_td3["account_value"] if if_using_td3 else None,
199
"sac": df_result_sac["account_value"] if if_using_sac else None,
200
"mvo": MVO_result["Mean Var"],
201
"dji": dji["close"],
202
}
203
)
204
205
print("\n=== Backtest Results ===")
206
print(result)
207
208
# %% Part 7. Plot
209
210
plt.rcParams["figure.figsize"] = (15, 5)
211
plt.figure()
212
result.plot()
213
plt.title("Portfolio Value Over Time")
214
plt.xlabel("Date")
215
plt.ylabel("Portfolio Value ($)")
216
plt.savefig("backtest_result.png", dpi=150, bbox_inches="tight")
217
print("\nPlot saved to backtest_result.png")
218
219