Path: blob/master/finrl/applications/stock_trading/ensemble_stock_trading.py
732 views
from __future__ import annotations123def main():4import warnings56warnings.filterwarnings("ignore")7import pandas as pd8import numpy as np9import matplotlib10import matplotlib.pyplot as plt1112# matplotlib.use('Agg')13import datetime1415from finrl.config_tickers import DOW_30_TICKER16from finrl.meta.preprocessor.yahoodownloader import YahooDownloader17from finrl.meta.preprocessor.preprocessors import FeatureEngineer, data_split18from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv19from finrl.agents.stablebaselines3.models import DRLAgent, DRLEnsembleAgent20from finrl.plot import backtest_stats, backtest_plot, get_daily_return, get_baseline2122from pprint import pprint2324import sys2526sys.path.append("../FinRL-Library")2728import itertools2930import os31from finrl.main import check_and_make_directories32from finrl.config import (33DATA_SAVE_DIR,34TRAINED_MODEL_DIR,35TENSORBOARD_LOG_DIR,36RESULTS_DIR,37INDICATORS,38TRAIN_START_DATE,39TRAIN_END_DATE,40TEST_START_DATE,41TEST_END_DATE,42TRADE_START_DATE,43TRADE_END_DATE,44)4546check_and_make_directories(47[DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR]48)49print(DOW_30_TICKER)50TRAIN_START_DATE = "2009-04-01"51TRAIN_END_DATE = "2021-01-01"52TEST_START_DATE = "2021-01-01"53TEST_END_DATE = "2022-06-01"5455df = YahooDownloader(56start_date=TRAIN_START_DATE, end_date=TEST_END_DATE, ticker_list=DOW_30_TICKER57).fetch_data()5859df.sort_values(["date", "tic"]).head()6061fe = FeatureEngineer(62use_technical_indicator=True,63tech_indicator_list=INDICATORS,64use_turbulence=True,65user_defined_feature=False,66)6768processed = fe.preprocess_data(df)69processed = processed.copy()70processed = processed.fillna(0)71processed = processed.replace(np.inf, 0)7273stock_dimension = len(processed.tic.unique())74state_space = 1 + 2 * stock_dimension + len(INDICATORS) * stock_dimension75print(f"Stock Dimension: {stock_dimension}, State Space: {state_space}")7677env_kwargs = {78"hmax": 100,79"initial_amount": 1000000,80"buy_cost_pct": 0.001,81"sell_cost_pct": 0.001,82"state_space": state_space,83"stock_dim": stock_dimension,84"tech_indicator_list": INDICATORS,85"action_space": stock_dimension,86"reward_scaling": 1e-4,87"print_verbosity": 5,88}8990rebalance_window = 63 # rebalance_window is the number of days to retrain the model91validation_window = 63 # validation_window is the number of days to do validation and trading (e.g. if validation_window=63, then both validation and trading period will be 63 days)9293ensemble_agent = DRLEnsembleAgent(94df=processed,95train_period=(TRAIN_START_DATE, TRAIN_END_DATE),96val_test_period=(TEST_START_DATE, TEST_END_DATE),97rebalance_window=rebalance_window,98validation_window=validation_window,99**env_kwargs,100)101102A2C_model_kwargs = {"n_steps": 5, "ent_coef": 0.005, "learning_rate": 0.0007}103104PPO_model_kwargs = {105"ent_coef": 0.01,106"n_steps": 2048,107"learning_rate": 0.00025,108"batch_size": 128,109}110111DDPG_model_kwargs = {112# "action_noise":"ornstein_uhlenbeck",113"buffer_size": 10_000,114"learning_rate": 0.0005,115"batch_size": 64,116}117118timesteps_dict = {"a2c": 10_000, "ppo": 10_000, "ddpg": 10_000}119df_summary = ensemble_agent.run_ensemble_strategy(120A2C_model_kwargs, PPO_model_kwargs, DDPG_model_kwargs, timesteps_dict121)122123unique_trade_date = processed[124(processed.date > TEST_START_DATE) & (processed.date <= TEST_END_DATE)125].date.unique()126127df_trade_date = pd.DataFrame({"datadate": unique_trade_date})128129df_account_value = pd.DataFrame()130for i in range(131rebalance_window + validation_window,132len(unique_trade_date) + 1,133rebalance_window,134):135temp = pd.read_csv(136"results/account_value_trade_{}_{}.csv".format("ensemble", i)137)138df_account_value = df_account_value.append(temp, ignore_index=True)139sharpe = (140(252**0.5)141* df_account_value.account_value.pct_change(1).mean()142/ df_account_value.account_value.pct_change(1).std()143)144print("Sharpe Ratio: ", sharpe)145df_account_value = df_account_value.join(146df_trade_date[validation_window:].reset_index(drop=True)147)148149df_account_value.account_value.plot()150151print("==============Get Backtest Results===========")152now = datetime.datetime.now().strftime("%Y%m%d-%Hh%M")153154perf_stats_all = backtest_stats(account_value=df_account_value)155perf_stats_all = pd.DataFrame(perf_stats_all)156157# baseline stats158print("==============Get Baseline Stats===========")159baseline_df = get_baseline(160ticker="^DJI",161start=df_account_value.loc[0, "date"],162end=df_account_value.loc[len(df_account_value) - 1, "date"],163)164165stats = backtest_stats(baseline_df, value_col_name="close")166167print("==============Compare to DJIA===========")168169# S&P 500: ^GSPC170# Dow Jones Index: ^DJI171# NASDAQ 100: ^NDX172backtest_plot(173df_account_value,174baseline_ticker="^DJI",175baseline_start=df_account_value.loc[0, "date"],176baseline_end=df_account_value.loc[len(df_account_value) - 1, "date"],177)178179180if __name__ == "__main__":181main()182183184