Path: blob/master/finrl/applications/stock_trading/stock_trading_rolling_window.py
732 views
from __future__ import annotations12import itertools34import pandas as pd5from stable_baselines3.common.logger import configure67from finrl.agents.stablebaselines3.models import DRLAgent8from finrl.config import DATA_SAVE_DIR9from finrl.config import INDICATORS10from finrl.config import RESULTS_DIR11from finrl.config import TENSORBOARD_LOG_DIR12from finrl.config import TEST_END_DATE13from finrl.config import TEST_START_DATE14from finrl.config import TRAINED_MODEL_DIR15from finrl.config_tickers import DOW_30_TICKER16from finrl.main import check_and_make_directories17from finrl.meta.data_processors.func import calc_train_trade_data18from finrl.meta.data_processors.func import calc_train_trade_starts_ends_if_rolling19from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv20from finrl.meta.preprocessor.preprocessors import data_split21from finrl.meta.preprocessor.preprocessors import FeatureEngineer22from finrl.meta.preprocessor.yahoodownloader import YahooDownloader23from finrl.plot import backtest_stats24from finrl.plot import get_baseline25from finrl.plot import plot_return2627# matplotlib.use('Agg')282930def stock_trading_rolling_window(31train_start_date: str,32train_end_date: str,33trade_start_date: str,34trade_end_date: str,35rolling_window_length: int,36if_store_actions: bool = True,37if_store_result: bool = True,38if_using_a2c: bool = True,39if_using_ddpg: bool = True,40if_using_ppo: bool = True,41if_using_sac: bool = True,42if_using_td3: bool = True,43):44# sys.path.append("../FinRL")45check_and_make_directories(46[DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR]47)48date_col = "date"49tic_col = "tic"50df = YahooDownloader(51start_date=train_start_date, end_date=trade_end_date, ticker_list=DOW_30_TICKER52).fetch_data()53fe = FeatureEngineer(54use_technical_indicator=True,55tech_indicator_list=INDICATORS,56use_vix=True,57use_turbulence=True,58user_defined_feature=False,59)6061processed = fe.preprocess_data(df)62list_ticker = processed[tic_col].unique().tolist()63list_date = list(64pd.date_range(processed[date_col].min(), processed[date_col].max()).astype(str)65)66combination = list(itertools.product(list_date, list_ticker))6768init_train_trade_data = pd.DataFrame(69combination, columns=[date_col, tic_col]70).merge(processed, on=[date_col, tic_col], how="left")71init_train_trade_data = init_train_trade_data[72init_train_trade_data[date_col].isin(processed[date_col])73]74init_train_trade_data = init_train_trade_data.sort_values([date_col, tic_col])7576init_train_trade_data = init_train_trade_data.fillna(0)7778init_train_data = data_split(79init_train_trade_data, train_start_date, train_end_date80)81init_trade_data = data_split(82init_train_trade_data, trade_start_date, trade_end_date83)8485stock_dimension = len(init_train_data.tic.unique())86state_space = 1 + 2 * stock_dimension + len(INDICATORS) * stock_dimension87print(f"Stock Dimension: {stock_dimension}, State Space: {state_space}")88buy_cost_list = sell_cost_list = [0.001] * stock_dimension89num_stock_shares = [0] * stock_dimension9091initial_amount = 100000092env_kwargs = {93"hmax": 100,94"initial_amount": initial_amount,95"num_stock_shares": num_stock_shares,96"buy_cost_pct": buy_cost_list,97"sell_cost_pct": sell_cost_list,98"state_space": state_space,99"stock_dim": stock_dimension,100"tech_indicator_list": INDICATORS,101"action_space": stock_dimension,102"reward_scaling": 1e-4,103}104105# split the init_train_data and init_trade_data to subsets106init_train_dates = init_train_data[date_col].unique()107init_trade_dates = init_trade_data[date_col].unique()108109(110train_starts,111train_ends,112trade_starts,113trade_ends,114) = calc_train_trade_starts_ends_if_rolling(115init_train_dates, init_trade_dates, rolling_window_length116)117118result = pd.DataFrame()119actions_a2c = pd.DataFrame(columns=DOW_30_TICKER)120actions_ddpg = pd.DataFrame(columns=DOW_30_TICKER)121actions_ppo = pd.DataFrame(columns=DOW_30_TICKER)122actions_sac = pd.DataFrame(columns=DOW_30_TICKER)123actions_td3 = pd.DataFrame(columns=DOW_30_TICKER)124125for i in range(len(train_starts)):126print("i: ", i)127train_data, trade_data = calc_train_trade_data(128i,129train_starts,130train_ends,131trade_starts,132trade_ends,133init_train_data,134init_trade_data,135date_col,136)137e_train_gym = StockTradingEnv(df=train_data, **env_kwargs)138env_train, _ = e_train_gym.get_sb_env()139140# train141142if if_using_a2c:143if len(result) >= 1:144env_kwargs["initial_amount"] = result["A2C"].iloc[-1]145e_train_gym = StockTradingEnv(df=train_data, **env_kwargs)146env_train, _ = e_train_gym.get_sb_env()147agent = DRLAgent(env=env_train)148model_a2c = agent.get_model("a2c")149# set up logger150tmp_path = RESULTS_DIR + "/a2c"151new_logger_a2c = configure(tmp_path, ["stdout", "csv", "tensorboard"])152# Set new logger153model_a2c.set_logger(new_logger_a2c)154trained_a2c = agent.train_model(155model=model_a2c, tb_log_name="a2c", total_timesteps=50000156)157158if if_using_ddpg:159if len(result) >= 1:160env_kwargs["initial_amount"] = result["DDPG"].iloc[-1]161e_train_gym = StockTradingEnv(df=train_data, **env_kwargs)162env_train, _ = e_train_gym.get_sb_env()163agent = DRLAgent(env=env_train)164model_ddpg = agent.get_model("ddpg")165# set up logger166tmp_path = RESULTS_DIR + "/ddpg"167new_logger_ddpg = configure(tmp_path, ["stdout", "csv", "tensorboard"])168# Set new logger169model_ddpg.set_logger(new_logger_ddpg)170trained_ddpg = agent.train_model(171model=model_ddpg, tb_log_name="ddpg", total_timesteps=40000172)173174if if_using_ppo:175if len(result) >= 1:176env_kwargs["initial_amount"] = result["PPO"].iloc[-1]177e_train_gym = StockTradingEnv(df=train_data, **env_kwargs)178env_train, _ = e_train_gym.get_sb_env()179agent = DRLAgent(env=env_train)180PPO_PARAMS = {181"n_steps": 2048,182"ent_coef": 0.005,183"learning_rate": 0.0001,184"batch_size": 64,185}186model_ppo = agent.get_model("ppo", model_kwargs=PPO_PARAMS)187# set up logger188tmp_path = RESULTS_DIR + "/ppo"189new_logger_ppo = configure(tmp_path, ["stdout", "csv", "tensorboard"])190# Set new logger191model_ppo.set_logger(new_logger_ppo)192trained_ppo = agent.train_model(193model=model_ppo, tb_log_name="ppo", total_timesteps=50000194)195196if if_using_sac:197if len(result) >= 1:198env_kwargs["initial_amount"] = result["SAC"].iloc[-1]199e_train_gym = StockTradingEnv(df=train_data, **env_kwargs)200env_train, _ = e_train_gym.get_sb_env()201agent = DRLAgent(env=env_train)202SAC_PARAMS = {203"batch_size": 64,204"buffer_size": 100000,205"learning_rate": 0.00015,206"learning_starts": 100,207"ent_coef": "auto_0.1",208}209model_sac = agent.get_model("sac", model_kwargs=SAC_PARAMS)210# set up logger211tmp_path = RESULTS_DIR + "/sac"212new_logger_sac = configure(tmp_path, ["stdout", "csv", "tensorboard"])213# Set new logger214model_sac.set_logger(new_logger_sac)215trained_sac = agent.train_model(216model=model_sac, tb_log_name="sac", total_timesteps=50000217)218219if if_using_td3:220if len(result) >= 1:221env_kwargs["initial_amount"] = result["TD3"].iloc[-1]222e_train_gym = StockTradingEnv(df=train_data, **env_kwargs)223env_train, _ = e_train_gym.get_sb_env()224agent = DRLAgent(env=env_train)225TD3_PARAMS = {226"batch_size": 64,227"buffer_size": 100000,228"learning_rate": 0.0008,229}230model_td3 = agent.get_model("td3", model_kwargs=TD3_PARAMS)231# set up logger232tmp_path = RESULTS_DIR + "/td3"233new_logger_td3 = configure(tmp_path, ["stdout", "csv", "tensorboard"])234# Set new logger235model_td3.set_logger(new_logger_td3)236trained_td3 = agent.train_model(237model=model_td3, tb_log_name="td3", total_timesteps=50000238)239240# trade241# this e_trade_gym is initialized, then it will be used if i == 0242e_trade_gym = StockTradingEnv(243df=trade_data,244turbulence_threshold=70,245risk_indicator_col="vix",246**env_kwargs,247)248249if if_using_a2c:250if len(result) >= 1:251env_kwargs["initial_amount"] = result["A2C"].iloc[-1]252e_trade_gym = StockTradingEnv(253df=trade_data,254turbulence_threshold=70,255risk_indicator_col="vix",256**env_kwargs,257)258result_a2c, actions_i_a2c = DRLAgent.DRL_prediction(259model=trained_a2c, environment=e_trade_gym260)261262if if_using_ddpg:263if len(result) >= 1:264env_kwargs["initial_amount"] = result["DDPG"].iloc[-1]265e_trade_gym = StockTradingEnv(266df=trade_data,267turbulence_threshold=70,268risk_indicator_col="vix",269**env_kwargs,270)271result_ddpg, actions_i_ddpg = DRLAgent.DRL_prediction(272model=trained_ddpg, environment=e_trade_gym273)274275if if_using_ppo:276if len(result) >= 1:277env_kwargs["initial_amount"] = result["PPO"].iloc[-1]278e_trade_gym = StockTradingEnv(279df=trade_data,280turbulence_threshold=70,281risk_indicator_col="vix",282**env_kwargs,283)284result_ppo, actions_i_ppo = DRLAgent.DRL_prediction(285model=trained_ppo, environment=e_trade_gym286)287288if if_using_sac:289if len(result) >= 1:290env_kwargs["initial_amount"] = result["SAC"].iloc[-1]291e_trade_gym = StockTradingEnv(292df=trade_data,293turbulence_threshold=70,294risk_indicator_col="vix",295**env_kwargs,296)297result_sac, actions_i_sac = DRLAgent.DRL_prediction(298model=trained_sac, environment=e_trade_gym299)300301if if_using_td3:302if len(result) >= 1:303env_kwargs["initial_amount"] = result["TD3"].iloc[-1]304e_trade_gym = StockTradingEnv(305df=trade_data,306turbulence_threshold=70,307risk_indicator_col="vix",308**env_kwargs,309)310result_td3, actions_i_td3 = DRLAgent.DRL_prediction(311model=trained_td3, environment=e_trade_gym312)313314# in python version, we should check isinstance, but in notebook version, it is not necessary315if if_using_a2c and isinstance(result_a2c, tuple):316actions_i_a2c = result_a2c[1]317result_a2c = result_a2c[0]318if if_using_ddpg and isinstance(result_ddpg, tuple):319actions_i_ddpg = result_ddpg[1]320result_ddpg = result_ddpg[0]321if if_using_ppo and isinstance(result_ppo, tuple):322actions_i_ppo = result_ppo[1]323result_ppo = result_ppo[0]324if if_using_sac and isinstance(result_sac, tuple):325actions_i_sac = result_sac[1]326result_sac = result_sac[0]327if if_using_td3 and isinstance(result_td3, tuple):328actions_i_td3 = result_td3[1]329result_td3 = result_td3[0]330331# merge actions332actions_a2c = pd.concat([actions_a2c, actions_i_a2c]) if if_using_a2c else None333actions_ddpg = (334pd.concat([actions_ddpg, actions_i_ddpg]) if if_using_ddpg else None335)336actions_ppo = pd.concat([actions_ppo, actions_i_ppo]) if if_using_ppo else None337actions_sac = pd.concat([actions_sac, actions_i_sac]) if if_using_sac else None338actions_td3 = pd.concat([actions_td3, actions_i_td3]) if if_using_td3 else None339340# dji_i341trade_start = trade_starts[i]342trade_end = trade_ends[i]343dji_i_ = get_baseline(ticker="^DJI", start=trade_start, end=trade_end)344dji_i = pd.DataFrame()345dji_i[date_col] = dji_i_[date_col]346dji_i["DJI"] = dji_i_["close"]347# dji_i.rename(columns={'account_value': 'DJI'}, inplace=True)348349# select the rows between trade_start and trade_end (not included), since some values may not in this region350dji_i = dji_i.loc[351(dji_i[date_col] >= trade_start) & (dji_i[date_col] < trade_end)352]353354# init result_i by dji_i355result_i = dji_i356357# rename column name of result_a2c, result_ddpg, etc., and then put them to result_i358if if_using_a2c:359result_a2c.rename(columns={"account_value": "A2C"}, inplace=True)360result_i = pd.merge(result_i, result_a2c, how="left")361if if_using_ddpg:362result_ddpg.rename(columns={"account_value": "DDPG"}, inplace=True)363result_i = pd.merge(result_i, result_ddpg, how="left")364if if_using_ppo:365result_ppo.rename(columns={"account_value": "PPO"}, inplace=True)366result_i = pd.merge(result_i, result_ppo, how="left")367if if_using_sac:368result_sac.rename(columns={"account_value": "SAC"}, inplace=True)369result_i = pd.merge(result_i, result_sac, how="left")370if if_using_td3:371result_td3.rename(columns={"account_value": "TD3"}, inplace=True)372result_i = pd.merge(result_i, result_td3, how="left")373374# remove the rows with nan375result_i = result_i.dropna(axis=0, how="any")376377# merge result_i to result378result = pd.concat([result, result_i], axis=0)379380# store actions381if if_store_actions:382actions_a2c.to_csv("actions_a2c.csv") if if_using_a2c else None383actions_ddpg.to_csv("actions_ddpg.csv") if if_using_ddpg else None384actions_ppo.to_csv("actions_ppo.csv") if if_using_ppo else None385actions_sac.to_csv("actions_sac.csv") if if_using_sac else None386actions_td3.to_csv("actions_td3.csv") if if_using_td3 else None387388# calc the column name of strategies, including DJI389col_strategies = []390for col in result.columns:391if col != date_col and col != "" and "Unnamed" not in col:392col_strategies.append(col)393394# make sure that the first row of DJI is initial_amount395col = "DJI"396result[col] = result[col] / result[col].iloc[0] * initial_amount397result = result.reset_index(drop=True)398399# stats400for col in col_strategies:401stats = backtest_stats(result, value_col_name=col)402print("\nstats of " + col + ": \n", stats)403404# print and save result405print("result: ", result)406if if_store_result:407result.to_csv("result.csv")408409# plot fig410plot_return(411result=result,412column_as_x=date_col,413if_need_calc_return=True,414savefig_filename="stock_trading_rolling_window.png",415xlabel="Date",416ylabel="Return",417if_transfer_date=True,418num_days_xticks=20,419)420421422if __name__ == "__main__":423train_start_date = "2009-01-01"424train_end_date = "2022-07-01"425trade_start_date = "2022-07-01"426trade_end_date = "2022-11-01"427rolling_window_length = 22 # num of trading days in a rolling window428if_store_actions = True429if_store_result = True430if_using_a2c = True431if_using_ddpg = True432if_using_ppo = True433if_using_sac = True434if_using_td3 = True435stock_trading_rolling_window(436train_start_date=train_start_date,437train_end_date=train_end_date,438trade_start_date=trade_start_date,439trade_end_date=trade_end_date,440rolling_window_length=rolling_window_length,441if_store_actions=if_store_actions,442if_store_result=if_store_result,443if_using_a2c=if_using_a2c,444if_using_ddpg=if_using_ddpg,445if_using_ppo=if_using_ppo,446if_using_sac=if_using_sac,447if_using_td3=if_using_td3,448)449450451