Path: blob/master/finrl/applications/stock_trading/stock_trading.py
732 views
from __future__ import annotations12import itertools3import sys45import pandas as pd6from stable_baselines3.common.logger import configure78from finrl.agents.stablebaselines3.models import DRLAgent9from finrl.config import DATA_SAVE_DIR10from finrl.config import INDICATORS11from finrl.config import RESULTS_DIR12from finrl.config import TENSORBOARD_LOG_DIR13from finrl.config import TRAINED_MODEL_DIR14from finrl.config_tickers import DOW_30_TICKER15from finrl.main import check_and_make_directories16from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv17from finrl.meta.preprocessor.preprocessors import data_split18from finrl.meta.preprocessor.preprocessors import FeatureEngineer19from finrl.meta.preprocessor.yahoodownloader import YahooDownloader20from finrl.plot import backtest_stats21from finrl.plot import get_baseline22from finrl.plot import plot_return2324# matplotlib.use('Agg')252627def stock_trading(28train_start_date: str,29train_end_date: str,30trade_start_date: str,31trade_end_date: str,32if_store_actions: bool = True,33if_store_result: bool = True,34if_using_a2c: bool = True,35if_using_ddpg: bool = True,36if_using_ppo: bool = True,37if_using_sac: bool = True,38if_using_td3: bool = True,39):40sys.path.append("../FinRL")41check_and_make_directories(42[DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR]43)44date_col = "date"45tic_col = "tic"46df = YahooDownloader(47start_date=train_start_date, end_date=trade_end_date, ticker_list=DOW_30_TICKER48).fetch_data()49fe = FeatureEngineer(50use_technical_indicator=True,51tech_indicator_list=INDICATORS,52use_vix=True,53use_turbulence=True,54user_defined_feature=False,55)5657processed = fe.preprocess_data(df)58list_ticker = processed[tic_col].unique().tolist()59list_date = list(60pd.date_range(processed[date_col].min(), processed[date_col].max()).astype(str)61)62combination = list(itertools.product(list_date, list_ticker))6364init_train_trade_data = pd.DataFrame(65combination, columns=[date_col, tic_col]66).merge(processed, on=[date_col, tic_col], how="left")67init_train_trade_data = init_train_trade_data[68init_train_trade_data[date_col].isin(processed[date_col])69]70init_train_trade_data = init_train_trade_data.sort_values([date_col, tic_col])7172init_train_trade_data = init_train_trade_data.fillna(0)7374init_train_data = data_split(75init_train_trade_data, train_start_date, train_end_date76)77init_trade_data = data_split(78init_train_trade_data, trade_start_date, trade_end_date79)8081stock_dimension = len(init_train_data.tic.unique())82state_space = 1 + 2 * stock_dimension + len(INDICATORS) * stock_dimension83print(f"Stock Dimension: {stock_dimension}, State Space: {state_space}")84buy_cost_list = sell_cost_list = [0.001] * stock_dimension85num_stock_shares = [0] * stock_dimension8687initial_amount = 100000088env_kwargs = {89"hmax": 100,90"initial_amount": initial_amount,91"num_stock_shares": num_stock_shares,92"buy_cost_pct": buy_cost_list,93"sell_cost_pct": sell_cost_list,94"state_space": state_space,95"stock_dim": stock_dimension,96"tech_indicator_list": INDICATORS,97"action_space": stock_dimension,98"reward_scaling": 1e-4,99}100101e_train_gym = StockTradingEnv(df=init_train_data, **env_kwargs)102103env_train, _ = e_train_gym.get_sb_env()104print(type(env_train))105106if if_using_a2c:107agent = DRLAgent(env=env_train)108model_a2c = agent.get_model("a2c")109# set up logger110tmp_path = RESULTS_DIR + "/a2c"111new_logger_a2c = configure(tmp_path, ["stdout", "csv", "tensorboard"])112# Set new logger113model_a2c.set_logger(new_logger_a2c)114trained_a2c = agent.train_model(115model=model_a2c, tb_log_name="a2c", total_timesteps=50000116)117118if if_using_ddpg:119agent = DRLAgent(env=env_train)120model_ddpg = agent.get_model("ddpg")121# set up logger122tmp_path = RESULTS_DIR + "/ddpg"123new_logger_ddpg = configure(tmp_path, ["stdout", "csv", "tensorboard"])124# Set new logger125model_ddpg.set_logger(new_logger_ddpg)126trained_ddpg = agent.train_model(127model=model_ddpg, tb_log_name="ddpg", total_timesteps=50000128)129130if if_using_ppo:131agent = DRLAgent(env=env_train)132PPO_PARAMS = {133"n_steps": 2048,134"ent_coef": 0.01,135"learning_rate": 0.00025,136"batch_size": 128,137}138model_ppo = agent.get_model("ppo", model_kwargs=PPO_PARAMS)139# set up logger140tmp_path = RESULTS_DIR + "/ppo"141new_logger_ppo = configure(tmp_path, ["stdout", "csv", "tensorboard"])142# Set new logger143model_ppo.set_logger(new_logger_ppo)144trained_ppo = agent.train_model(145model=model_ppo, tb_log_name="ppo", total_timesteps=50000146)147148if if_using_sac:149agent = DRLAgent(env=env_train)150SAC_PARAMS = {151"batch_size": 128,152"buffer_size": 100000,153"learning_rate": 0.0001,154"learning_starts": 100,155"ent_coef": "auto_0.1",156}157model_sac = agent.get_model("sac", model_kwargs=SAC_PARAMS)158# set up logger159tmp_path = RESULTS_DIR + "/sac"160new_logger_sac = configure(tmp_path, ["stdout", "csv", "tensorboard"])161# Set new logger162model_sac.set_logger(new_logger_sac)163trained_sac = agent.train_model(164model=model_sac, tb_log_name="sac", total_timesteps=50000165)166167if if_using_td3:168agent = DRLAgent(env=env_train)169TD3_PARAMS = {"batch_size": 100, "buffer_size": 1000000, "learning_rate": 0.001}170model_td3 = agent.get_model("td3", model_kwargs=TD3_PARAMS)171# set up logger172tmp_path = RESULTS_DIR + "/td3"173new_logger_td3 = configure(tmp_path, ["stdout", "csv", "tensorboard"])174# Set new logger175model_td3.set_logger(new_logger_td3)176trained_td3 = agent.train_model(177model=model_td3, tb_log_name="td3", total_timesteps=50000178)179180# trade181e_trade_gym = StockTradingEnv(182df=init_trade_data,183turbulence_threshold=70,184risk_indicator_col="vix",185**env_kwargs,186)187# env_trade, obs_trade = e_trade_gym.get_sb_env()188189if if_using_a2c:190result_a2c, actions_a2c = DRLAgent.DRL_prediction(191model=trained_a2c, environment=e_trade_gym192)193194if if_using_ddpg:195result_ddpg, actions_ddpg = DRLAgent.DRL_prediction(196model=trained_ddpg, environment=e_trade_gym197)198199if if_using_ppo:200result_ppo, actions_ppo = DRLAgent.DRL_prediction(201model=trained_ppo, environment=e_trade_gym202)203204if if_using_sac:205result_sac, actions_sac = DRLAgent.DRL_prediction(206model=trained_sac, environment=e_trade_gym207)208209if if_using_td3:210result_td3, actions_td3 = DRLAgent.DRL_prediction(211model=trained_td3, environment=e_trade_gym212)213214# in python version, we should check isinstance, but in notebook version, it is not necessary215if if_using_a2c and isinstance(result_a2c, tuple):216actions_a2c = result_a2c[1]217result_a2c = result_a2c[0]218if if_using_ddpg and isinstance(result_ddpg, tuple):219actions_ddpg = result_ddpg[1]220result_ddpg = result_ddpg[0]221if if_using_ppo and isinstance(result_ppo, tuple):222actions_ppo = result_ppo[1]223result_ppo = result_ppo[0]224if if_using_sac and isinstance(result_sac, tuple):225actions_sac = result_sac[1]226result_sac = result_sac[0]227if if_using_td3 and isinstance(result_td3, tuple):228actions_td3 = result_td3[1]229result_td3 = result_td3[0]230231# store actions232if if_store_actions:233actions_a2c.to_csv("actions_a2c.csv") if if_using_a2c else None234actions_ddpg.to_csv("actions_ddpg.csv") if if_using_ddpg else None235actions_td3.to_csv("actions_td3.csv") if if_using_td3 else None236actions_ppo.to_csv("actions_ppo.csv") if if_using_ppo else None237actions_sac.to_csv("actions_sac.csv") if if_using_sac else None238239# dji240dji_ = get_baseline(ticker="^DJI", start=trade_start_date, end=trade_end_date)241dji = pd.DataFrame()242dji[date_col] = dji_[date_col]243dji["DJI"] = dji_["close"]244# select the rows between trade_start and trade_end (not included), since some values may not in this region245dji = dji.loc[246(dji[date_col] >= trade_start_date) & (dji[date_col] < trade_end_date)247]248249result = dji250251if if_using_a2c:252result_a2c.rename(columns={"account_value": "A2C"}, inplace=True)253result = pd.merge(result, result_a2c, how="left")254if if_using_ddpg:255result_ddpg.rename(columns={"account_value": "DDPG"}, inplace=True)256result = pd.merge(result, result_ddpg, how="left")257if if_using_td3:258result_td3.rename(columns={"account_value": "TD3"}, inplace=True)259result = pd.merge(result, result_td3, how="left")260if if_using_ppo:261result_ppo.rename(columns={"account_value": "PPO"}, inplace=True)262result = pd.merge(result, result_ppo, how="left")263if if_using_sac:264result_sac.rename(columns={"account_value": "SAC"}, inplace=True)265result = pd.merge(result, result_sac, how="left")266267# remove the rows with nan268result = result.dropna(axis=0, how="any")269270# calc the column name of strategies, including DJI271col_strategies = []272for col in result.columns:273if col != date_col and col != "" and "Unnamed" not in col:274col_strategies.append(col)275276# make sure that the first row of DJI is initial_amount277col = "DJI"278result[col] = result[col] / result[col].iloc[0] * initial_amount279result = result.reset_index(drop=True)280281# stats282for col in col_strategies:283stats = backtest_stats(result, value_col_name=col)284print("\nstats of " + col + ": \n", stats)285286# print and save result287print("result: ", result)288if if_store_result:289result.to_csv("result.csv")290291# plot fig292plot_return(293result=result,294column_as_x=date_col,295if_need_calc_return=True,296savefig_filename="stock_trading.png",297xlabel="Date",298ylabel="Return",299if_transfer_date=True,300num_days_xticks=20,301)302303304if __name__ == "__main__":305train_start_date = "2009-01-01"306train_end_date = "2022-09-01"307trade_start_date = "2022-09-01"308trade_end_date = "2023-11-01"309if_store_actions = True310if_store_result = True311if_using_a2c = True312if_using_ddpg = True313if_using_ppo = True314if_using_sac = True315if_using_td3 = True316317stock_trading(318train_start_date=train_start_date,319train_end_date=train_end_date,320trade_start_date=trade_start_date,321trade_end_date=trade_end_date,322if_store_actions=if_store_actions,323if_store_result=if_store_result,324if_using_a2c=if_using_a2c,325if_using_ddpg=if_using_ddpg,326if_using_ppo=if_using_ppo,327if_using_sac=if_using_sac,328if_using_td3=if_using_td3,329)330331332