Path: blob/master/examples/FinRL_StockTrading_2026_3_Backtest.py
1706 views
"""1Stock NeurIPS2018 Part 3. Backtest23This series is a reproduction of paper "Deep reinforcement learning for4automated stock trading: An ensemble strategy".56Introducing how to use the agents we trained to do backtest, and compare with baselines such as7Mean Variance Optimization and DJIA index.8"""910import matplotlib11matplotlib.use("Agg")12import matplotlib.pyplot as plt13import numpy as np14import pandas as pd15from stable_baselines3 import A2C, DDPG, PPO, SAC, TD31617from finrl.agents.stablebaselines3.models import DRLAgent18from finrl.config import INDICATORS, TRAINED_MODEL_DIR, TRADE_START_DATE, TRADE_END_DATE19from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv20from finrl.meta.preprocessor.yahoodownloader import YahooDownloader2122# %% Part 1. Load data2324train = pd.read_csv("train_data.csv")25trade = pd.read_csv("trade_data.csv")2627train = train.set_index(train.columns[0])28train.index.names = [""]29trade = trade.set_index(trade.columns[0])30trade.index.names = [""]3132# %% Part 2. Load trained agents3334if_using_a2c = True35if_using_ddpg = True36if_using_ppo = True37if_using_td3 = True38if_using_sac = True3940trained_a2c = A2C.load(TRAINED_MODEL_DIR + "/agent_a2c") if if_using_a2c else None41trained_ddpg = DDPG.load(TRAINED_MODEL_DIR + "/agent_ddpg") if if_using_ddpg else None42trained_ppo = PPO.load(TRAINED_MODEL_DIR + "/agent_ppo") if if_using_ppo else None43trained_td3 = TD3.load(TRAINED_MODEL_DIR + "/agent_td3") if if_using_td3 else None44trained_sac = SAC.load(TRAINED_MODEL_DIR + "/agent_sac") if if_using_sac else None4546# %% Part 3. Backtesting - DRL agents4748stock_dimension = len(trade.tic.unique())49state_space = 1 + 2 * stock_dimension + len(INDICATORS) * stock_dimension50print(f"Stock Dimension: {stock_dimension}, State Space: {state_space}")5152buy_cost_list = sell_cost_list = [0.001] * stock_dimension53num_stock_shares = [0] * stock_dimension5455env_kwargs = {56"hmax": 100,57"initial_amount": 1000000,58"num_stock_shares": num_stock_shares,59"buy_cost_pct": buy_cost_list,60"sell_cost_pct": sell_cost_list,61"state_space": state_space,62"stock_dim": stock_dimension,63"tech_indicator_list": INDICATORS,64"action_space": stock_dimension,65"reward_scaling": 1e-4,66}6768e_trade_gym = StockTradingEnv(69df=trade, turbulence_threshold=70, risk_indicator_col="vix", **env_kwargs70)7172df_account_value_a2c, df_actions_a2c = (73DRLAgent.DRL_prediction(model=trained_a2c, environment=e_trade_gym)74if if_using_a2c75else (None, None)76)7778df_account_value_ddpg, df_actions_ddpg = (79DRLAgent.DRL_prediction(model=trained_ddpg, environment=e_trade_gym)80if if_using_ddpg81else (None, None)82)8384df_account_value_ppo, df_actions_ppo = (85DRLAgent.DRL_prediction(model=trained_ppo, environment=e_trade_gym)86if if_using_ppo87else (None, None)88)8990df_account_value_td3, df_actions_td3 = (91DRLAgent.DRL_prediction(model=trained_td3, environment=e_trade_gym)92if if_using_td393else (None, None)94)9596df_account_value_sac, df_actions_sac = (97DRLAgent.DRL_prediction(model=trained_sac, environment=e_trade_gym)98if if_using_sac99else (None, None)100)101102# %% Part 4. Mean Variance Optimization baseline103104105def process_df_for_mvo(df):106return df.pivot(index="date", columns="tic", values="close")107108109def StockReturnsComputing(StockPrice, Rows, Columns):110StockReturn = np.zeros([Rows - 1, Columns])111for j in range(Columns):112for i in range(Rows - 1):113StockReturn[i, j] = (114(StockPrice[i + 1, j] - StockPrice[i, j]) / StockPrice[i, j]115) * 100116return StockReturn117118119StockData = process_df_for_mvo(train)120TradeData = process_df_for_mvo(trade)121122arStockPrices = np.asarray(StockData)123[Rows, Cols] = arStockPrices.shape124arReturns = StockReturnsComputing(arStockPrices, Rows, Cols)125126meanReturns = np.mean(arReturns, axis=0)127covReturns = np.cov(arReturns, rowvar=False)128129np.set_printoptions(precision=3, suppress=True)130print("Mean returns of assets in portfolio\n", meanReturns)131132from pypfopt.efficient_frontier import EfficientFrontier133134ef_mean = EfficientFrontier(meanReturns, covReturns, weight_bounds=(0, 0.5))135raw_weights_mean = ef_mean.max_sharpe()136cleaned_weights_mean = ef_mean.clean_weights()137mvo_weights = np.array(138[1000000 * cleaned_weights_mean[i] for i in range(len(cleaned_weights_mean))]139)140141LastPrice = np.array([1 / p for p in StockData.tail(1).to_numpy()[0]])142Initial_Portfolio = np.multiply(mvo_weights, LastPrice)143144Portfolio_Assets = TradeData @ Initial_Portfolio145MVO_result = pd.DataFrame(Portfolio_Assets, columns=["Mean Var"])146147# %% Part 5. DJIA index baseline148149import yfinance as yf150151df_dji = yf.download("^DJI", start=TRADE_START_DATE, end=TRADE_END_DATE)152df_dji = df_dji[["Close"]].reset_index()153df_dji.columns = ["date", "close"]154df_dji["date"] = df_dji["date"].astype(str)155fst_day = df_dji["close"].iloc[0]156dji = pd.merge(157df_dji["date"],158df_dji["close"].div(fst_day).mul(1000000),159how="outer",160left_index=True,161right_index=True,162).set_index("date")163164# %% Part 6. Compare results165166df_result_a2c = (167df_account_value_a2c.set_index(df_account_value_a2c.columns[0])168if if_using_a2c169else None170)171df_result_ddpg = (172df_account_value_ddpg.set_index(df_account_value_ddpg.columns[0])173if if_using_ddpg174else None175)176df_result_ppo = (177df_account_value_ppo.set_index(df_account_value_ppo.columns[0])178if if_using_ppo179else None180)181df_result_td3 = (182df_account_value_td3.set_index(df_account_value_td3.columns[0])183if if_using_td3184else None185)186df_result_sac = (187df_account_value_sac.set_index(df_account_value_sac.columns[0])188if if_using_sac189else None190)191192result = pd.DataFrame(193{194"a2c": df_result_a2c["account_value"] if if_using_a2c else None,195"ddpg": df_result_ddpg["account_value"] if if_using_ddpg else None,196"ppo": df_result_ppo["account_value"] if if_using_ppo else None,197"td3": df_result_td3["account_value"] if if_using_td3 else None,198"sac": df_result_sac["account_value"] if if_using_sac else None,199"mvo": MVO_result["Mean Var"],200"dji": dji["close"],201}202)203204print("\n=== Backtest Results ===")205print(result)206207# %% Part 7. Plot208209plt.rcParams["figure.figsize"] = (15, 5)210plt.figure()211result.plot()212plt.title("Portfolio Value Over Time")213plt.xlabel("Date")214plt.ylabel("Portfolio Value ($)")215plt.savefig("backtest_result.png", dpi=150, bbox_inches="tight")216print("\nPlot saved to backtest_result.png")217218219