Path: blob/master/finrl/agents/stablebaselines3/models.py
732 views
# DRL models from Stable Baselines 31from __future__ import annotations23import statistics4import time56import numpy as np7import pandas as pd8from stable_baselines3 import A2C9from stable_baselines3 import DDPG10from stable_baselines3 import PPO11from stable_baselines3 import SAC12from stable_baselines3 import TD313from stable_baselines3.common.callbacks import BaseCallback14from stable_baselines3.common.callbacks import CallbackList15from stable_baselines3.common.noise import NormalActionNoise16from stable_baselines3.common.noise import OrnsteinUhlenbeckActionNoise17from stable_baselines3.common.vec_env import DummyVecEnv1819from finrl import config20from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv21from finrl.meta.preprocessor.preprocessors import data_split2223MODELS = {"a2c": A2C, "ddpg": DDPG, "td3": TD3, "sac": SAC, "ppo": PPO}2425MODEL_KWARGS = {x: config.__dict__[f"{x.upper()}_PARAMS"] for x in MODELS.keys()}2627NOISE = {28"normal": NormalActionNoise,29"ornstein_uhlenbeck": OrnsteinUhlenbeckActionNoise,30}313233class TensorboardCallback(BaseCallback):34"""35Custom callback for plotting additional values in tensorboard.36"""3738def __init__(self, verbose=0):39super().__init__(verbose)4041def _on_step(self) -> bool:42try:43self.logger.record(key="train/reward", value=self.locals["rewards"][0])4445except BaseException as error:46try:47self.logger.record(key="train/reward", value=self.locals["reward"][0])4849except BaseException as inner_error:50# Handle the case where neither "rewards" nor "reward" is found51self.logger.record(key="train/reward", value=None)52# Print the original error and the inner error for debugging53print("Original Error:", error)54print("Inner Error:", inner_error)55return True5657def _on_rollout_end(self) -> bool:58try:59rollout_buffer_rewards = self.locals["rollout_buffer"].rewards.flatten()60self.logger.record(61key="train/reward_min", value=min(rollout_buffer_rewards)62)63self.logger.record(64key="train/reward_mean", value=statistics.mean(rollout_buffer_rewards)65)66self.logger.record(67key="train/reward_max", value=max(rollout_buffer_rewards)68)69except BaseException as error:70# Handle the case where "rewards" is not found71self.logger.record(key="train/reward_min", value=None)72self.logger.record(key="train/reward_mean", value=None)73self.logger.record(key="train/reward_max", value=None)74print("Logging Error:", error)75return True767778class DRLAgent:79"""Provides implementations for DRL algorithms8081Attributes82----------83env: gym environment class84user-defined class8586Methods87-------88get_model()89setup DRL algorithms90train_model()91train DRL algorithms in a train dataset92and output the trained model93DRL_prediction()94make a prediction in a test dataset and get results95"""9697def __init__(self, env):98self.env = env99100def get_model(101self,102model_name,103policy="MlpPolicy",104policy_kwargs=None,105model_kwargs=None,106verbose=1,107seed=None,108tensorboard_log=None,109):110if model_name not in MODELS:111raise ValueError(112f"Model '{model_name}' not found in MODELS."113) # this is more informative than NotImplementedError("NotImplementedError")114115if model_kwargs is None:116model_kwargs = MODEL_KWARGS[model_name]117118if "action_noise" in model_kwargs:119n_actions = self.env.action_space.shape[-1]120model_kwargs["action_noise"] = NOISE[model_kwargs["action_noise"]](121mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)122)123print(model_kwargs)124return MODELS[model_name](125policy=policy,126env=self.env,127tensorboard_log=tensorboard_log,128verbose=verbose,129policy_kwargs=policy_kwargs,130seed=seed,131**model_kwargs,132)133134@staticmethod135def train_model(136model,137tb_log_name,138total_timesteps=5000,139callbacks: Type[BaseCallback] = None,140): # this function is static method, so it can be called without creating an instance of the class141model = model.learn(142total_timesteps=total_timesteps,143tb_log_name=tb_log_name,144callback=(145CallbackList(146[TensorboardCallback()] + [callback for callback in callbacks]147)148if callbacks is not None149else TensorboardCallback()150),151)152return model153154@staticmethod155def DRL_prediction(model, environment, deterministic=True):156"""make a prediction and get results"""157test_env, test_obs = environment.get_sb_env()158account_memory = None # This help avoid unnecessary list creation159actions_memory = None # optimize memory consumption160# state_memory=[] #add memory pool to store states161162test_env.reset()163max_steps = len(environment.df.index.unique()) - 1164165for i in range(len(environment.df.index.unique())):166action, _states = model.predict(test_obs, deterministic=deterministic)167# account_memory = test_env.env_method(method_name="save_asset_memory")168# actions_memory = test_env.env_method(method_name="save_action_memory")169test_obs, rewards, dones, info = test_env.step(action)170171if (172i == max_steps - 1173): # more descriptive condition for early termination to clarify the logic174account_memory = test_env.env_method(method_name="save_asset_memory")175actions_memory = test_env.env_method(method_name="save_action_memory")176# add current state to state memory177# state_memory=test_env.env_method(method_name="save_state_memory")178179if dones[0]:180print("hit end!")181break182return account_memory[0], actions_memory[0]183184@staticmethod185def DRL_prediction_load_from_file(model_name, environment, cwd, deterministic=True):186if model_name not in MODELS:187raise ValueError(188f"Model '{model_name}' not found in MODELS."189) # this is more informative than NotImplementedError("NotImplementedError")190try:191# load agent192model = MODELS[model_name].load(cwd)193print("Successfully load model", cwd)194except BaseException as error:195raise ValueError(f"Failed to load agent. Error: {str(error)}") from error196197# test on the testing env198state = environment.reset()199episode_returns = [] # the cumulative_return / initial_account200episode_total_assets = [environment.initial_total_asset]201done = False202while not done:203action = model.predict(state, deterministic=deterministic)[0]204state, reward, done, _ = environment.step(action)205206total_asset = (207environment.amount208+ (environment.price_ary[environment.day] * environment.stocks).sum()209)210episode_total_assets.append(total_asset)211episode_return = total_asset / environment.initial_total_asset212episode_returns.append(episode_return)213214print("episode_return", episode_return)215print("Test Finished!")216return episode_total_assets217218219class DRLEnsembleAgent:220@staticmethod221def get_model(222model_name,223env,224policy="MlpPolicy",225policy_kwargs=None,226model_kwargs=None,227seed=None,228verbose=1,229):230if model_name not in MODELS:231raise ValueError(232f"Model '{model_name}' not found in MODELS."233) # this is more informative than NotImplementedError("NotImplementedError")234235if model_kwargs is None:236temp_model_kwargs = MODEL_KWARGS[model_name]237else:238temp_model_kwargs = model_kwargs.copy()239240if "action_noise" in temp_model_kwargs:241n_actions = env.action_space.shape[-1]242temp_model_kwargs["action_noise"] = NOISE[243temp_model_kwargs["action_noise"]244](mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))245print(temp_model_kwargs)246return MODELS[model_name](247policy=policy,248env=env,249tensorboard_log=f"{config.TENSORBOARD_LOG_DIR}/{model_name}",250verbose=verbose,251policy_kwargs=policy_kwargs,252seed=seed,253**temp_model_kwargs,254)255256@staticmethod257def train_model(258model,259model_name,260tb_log_name,261iter_num,262total_timesteps=5000,263callbacks: Type[BaseCallback] = None,264):265model = model.learn(266total_timesteps=total_timesteps,267tb_log_name=tb_log_name,268callback=(269CallbackList(270[TensorboardCallback()] + [callback for callback in callbacks]271)272if callbacks is not None273else TensorboardCallback()274),275)276model.save(277f"{config.TRAINED_MODEL_DIR}/{model_name.upper()}_{total_timesteps // 1000}k_{iter_num}"278)279return model280281@staticmethod282def get_validation_sharpe(iteration, model_name):283"""Calculate Sharpe ratio based on validation results"""284df_total_value = pd.read_csv(285f"results/account_value_validation_{model_name}_{iteration}.csv"286)287# If the agent did not make any transaction288if df_total_value["daily_return"].var() == 0:289if df_total_value["daily_return"].mean() > 0:290return np.inf291else:292return 0.0293else:294return (295(4**0.5)296* df_total_value["daily_return"].mean()297/ df_total_value["daily_return"].std()298)299300def __init__(301self,302df,303train_period,304val_test_period,305rebalance_window,306validation_window,307stock_dim,308hmax,309initial_amount,310buy_cost_pct,311sell_cost_pct,312reward_scaling,313state_space,314action_space,315tech_indicator_list,316print_verbosity,317):318self.df = df319self.train_period = train_period320self.val_test_period = val_test_period321322self.unique_trade_date = df[323(df.date > val_test_period[0]) & (df.date <= val_test_period[1])324].date.unique()325self.rebalance_window = rebalance_window326self.validation_window = validation_window327328self.stock_dim = stock_dim329self.hmax = hmax330self.initial_amount = initial_amount331self.buy_cost_pct = buy_cost_pct332self.sell_cost_pct = sell_cost_pct333self.reward_scaling = reward_scaling334self.state_space = state_space335self.action_space = action_space336self.tech_indicator_list = tech_indicator_list337self.print_verbosity = print_verbosity338self.train_env = None # defined in train_validation() function339340def DRL_validation(self, model, test_data, test_env, test_obs):341"""validation process"""342for _ in range(len(test_data.index.unique())):343action, _states = model.predict(test_obs)344test_obs, rewards, dones, info = test_env.step(action)345346def DRL_prediction(347self, model, name, last_state, iter_num, turbulence_threshold, initial348):349"""make a prediction based on trained model"""350351# trading env352trade_data = data_split(353self.df,354start=self.unique_trade_date[iter_num - self.rebalance_window],355end=self.unique_trade_date[iter_num],356)357trade_env = DummyVecEnv(358[359lambda: StockTradingEnv(360df=trade_data,361stock_dim=self.stock_dim,362hmax=self.hmax,363initial_amount=self.initial_amount,364num_stock_shares=[0] * self.stock_dim,365buy_cost_pct=[self.buy_cost_pct] * self.stock_dim,366sell_cost_pct=[self.sell_cost_pct] * self.stock_dim,367reward_scaling=self.reward_scaling,368state_space=self.state_space,369action_space=self.action_space,370tech_indicator_list=self.tech_indicator_list,371turbulence_threshold=turbulence_threshold,372initial=initial,373previous_state=last_state,374model_name=name,375mode="trade",376iteration=iter_num,377print_verbosity=self.print_verbosity,378)379]380)381382trade_obs = trade_env.reset()383384for i in range(len(trade_data.index.unique())):385action, _states = model.predict(trade_obs)386trade_obs, rewards, dones, info = trade_env.step(action)387if i == (len(trade_data.index.unique()) - 2):388# print(env_test.render())389last_state = trade_env.envs[0].render()390391df_last_state = pd.DataFrame({"last_state": last_state})392df_last_state.to_csv(f"results/last_state_{name}_{i}.csv", index=False)393return last_state394395def _train_window(396self,397model_name,398model_kwargs,399sharpe_list,400validation_start_date,401validation_end_date,402timesteps_dict,403i,404validation,405turbulence_threshold,406):407"""408Train the model for a single window.409"""410if model_kwargs is None:411return None, sharpe_list, -1412413print(f"======{model_name} Training========")414model = self.get_model(415model_name, self.train_env, policy="MlpPolicy", model_kwargs=model_kwargs416)417model = self.train_model(418model,419model_name,420tb_log_name=f"{model_name}_{i}",421iter_num=i,422total_timesteps=timesteps_dict[model_name],423) # 100_000424print(425f"======{model_name} Validation from: ",426validation_start_date,427"to ",428validation_end_date,429)430val_env = DummyVecEnv(431[432lambda: StockTradingEnv(433df=validation,434stock_dim=self.stock_dim,435hmax=self.hmax,436initial_amount=self.initial_amount,437num_stock_shares=[0] * self.stock_dim,438buy_cost_pct=[self.buy_cost_pct] * self.stock_dim,439sell_cost_pct=[self.sell_cost_pct] * self.stock_dim,440reward_scaling=self.reward_scaling,441state_space=self.state_space,442action_space=self.action_space,443tech_indicator_list=self.tech_indicator_list,444turbulence_threshold=turbulence_threshold,445iteration=i,446model_name=model_name,447mode="validation",448print_verbosity=self.print_verbosity,449)450]451)452val_obs = val_env.reset()453self.DRL_validation(454model=model,455test_data=validation,456test_env=val_env,457test_obs=val_obs,458)459sharpe = self.get_validation_sharpe(i, model_name=model_name)460print(f"{model_name} Sharpe Ratio: ", sharpe)461sharpe_list.append(sharpe)462return model, sharpe_list, sharpe463464def run_ensemble_strategy(465self,466A2C_model_kwargs,467PPO_model_kwargs,468DDPG_model_kwargs,469SAC_model_kwargs,470TD3_model_kwargs,471timesteps_dict,472):473# Model Parameters474kwargs = {475"a2c": A2C_model_kwargs,476"ppo": PPO_model_kwargs,477"ddpg": DDPG_model_kwargs,478"sac": SAC_model_kwargs,479"td3": TD3_model_kwargs,480}481# Model Sharpe Ratios482model_dct = {k: {"sharpe_list": [], "sharpe": -1} for k in MODELS.keys()}483484"""Ensemble Strategy that combines A2C, PPO, DDPG, SAC, and TD3"""485print("============Start Ensemble Strategy============")486# for ensemble model, it's necessary to feed the last state487# of the previous model to the current model as the initial state488last_state_ensemble = []489490model_use = []491validation_start_date_list = []492validation_end_date_list = []493iteration_list = []494495insample_turbulence = self.df[496(self.df.date < self.train_period[1])497& (self.df.date >= self.train_period[0])498]499insample_turbulence_threshold = np.quantile(500insample_turbulence.turbulence.values, 0.90501)502503start = time.time()504for i in range(505self.rebalance_window + self.validation_window,506len(self.unique_trade_date),507self.rebalance_window,508):509validation_start_date = self.unique_trade_date[510i - self.rebalance_window - self.validation_window511]512validation_end_date = self.unique_trade_date[i - self.rebalance_window]513514validation_start_date_list.append(validation_start_date)515validation_end_date_list.append(validation_end_date)516iteration_list.append(i)517518print("============================================")519# initial state is empty520if i - self.rebalance_window - self.validation_window == 0:521# inital state522initial = True523else:524# previous state525initial = False526527# Tuning trubulence index based on historical data528# Turbulence lookback window is one quarter (63 days)529end_date_index = self.df.index[530self.df["date"]531== self.unique_trade_date[532i - self.rebalance_window - self.validation_window533]534].to_list()[-1]535start_date_index = end_date_index - 63 + 1536537historical_turbulence = self.df.iloc[538start_date_index : (end_date_index + 1), :539]540541historical_turbulence = historical_turbulence.drop_duplicates(542subset=["date"]543)544545historical_turbulence_mean = np.mean(546historical_turbulence.turbulence.values547)548549# print(historical_turbulence_mean)550551if historical_turbulence_mean > insample_turbulence_threshold:552# if the mean of the historical data is greater than the 90% quantile of insample turbulence data553# then we assume that the current market is volatile,554# therefore we set the 90% quantile of insample turbulence data as the turbulence threshold555# meaning the current turbulence can't exceed the 90% quantile of insample turbulence data556turbulence_threshold = insample_turbulence_threshold557else:558# if the mean of the historical data is less than the 90% quantile of insample turbulence data559# then we tune up the turbulence_threshold, meaning we lower the risk560turbulence_threshold = np.quantile(561insample_turbulence.turbulence.values, 1562)563564turbulence_threshold = np.quantile(565insample_turbulence.turbulence.values, 0.99566)567print("turbulence_threshold: ", turbulence_threshold)568569# Environment Setup starts570# training env571train = data_split(572self.df,573start=self.train_period[0],574end=self.unique_trade_date[575i - self.rebalance_window - self.validation_window576],577)578self.train_env = DummyVecEnv(579[580lambda: StockTradingEnv(581df=train,582stock_dim=self.stock_dim,583hmax=self.hmax,584initial_amount=self.initial_amount,585num_stock_shares=[0] * self.stock_dim,586buy_cost_pct=[self.buy_cost_pct] * self.stock_dim,587sell_cost_pct=[self.sell_cost_pct] * self.stock_dim,588reward_scaling=self.reward_scaling,589state_space=self.state_space,590action_space=self.action_space,591tech_indicator_list=self.tech_indicator_list,592print_verbosity=self.print_verbosity,593)594]595)596597validation = data_split(598self.df,599start=self.unique_trade_date[600i - self.rebalance_window - self.validation_window601],602end=self.unique_trade_date[i - self.rebalance_window],603)604# Environment Setup ends605606# Training and Validation starts607print(608"======Model training from: ",609self.train_period[0],610"to ",611self.unique_trade_date[612i - self.rebalance_window - self.validation_window613],614)615# print("training: ",len(data_split(df, start=20090000, end=test.datadate.unique()[i-rebalance_window]) ))616# print("==============Model Training===========")617# Train Each Model618for model_name in MODELS.keys():619# Train The Model620model, sharpe_list, sharpe = self._train_window(621model_name,622kwargs[model_name],623model_dct[model_name]["sharpe_list"],624validation_start_date,625validation_end_date,626timesteps_dict,627i,628validation,629turbulence_threshold,630)631# Save the model's sharpe ratios, and the model itself632model_dct[model_name]["sharpe_list"] = sharpe_list633model_dct[model_name]["model"] = model634model_dct[model_name]["sharpe"] = sharpe635636print(637"======Best Model Retraining from: ",638self.train_period[0],639"to ",640self.unique_trade_date[i - self.rebalance_window],641)642# Environment setup for model retraining up to first trade date643# train_full = data_split(self.df, start=self.train_period[0],644# end=self.unique_trade_date[i - self.rebalance_window])645# self.train_full_env = DummyVecEnv([lambda: StockTradingEnv(train_full,646# self.stock_dim,647# self.hmax,648# self.initial_amount,649# self.buy_cost_pct,650# self.sell_cost_pct,651# self.reward_scaling,652# self.state_space,653# self.action_space,654# self.tech_indicator_list,655# print_verbosity=self.print_verbosity656# )])657# Model Selection based on sharpe ratio658# Same order as MODELS: {"a2c": A2C, "ddpg": DDPG, "td3": TD3, "sac": SAC, "ppo": PPO}659sharpes = [model_dct[k]["sharpe"] for k in MODELS.keys()]660# Find the model with the highest sharpe ratio661max_mod = list(MODELS.keys())[np.argmax(sharpes)]662model_use.append(max_mod.upper())663model_ensemble = model_dct[max_mod]["model"]664# Training and Validation ends665666# Trading starts667print(668"======Trading from: ",669self.unique_trade_date[i - self.rebalance_window],670"to ",671self.unique_trade_date[i],672)673# print("Used Model: ", model_ensemble)674last_state_ensemble = self.DRL_prediction(675model=model_ensemble,676name="ensemble",677last_state=last_state_ensemble,678iter_num=i,679turbulence_threshold=turbulence_threshold,680initial=initial,681)682# Trading ends683684end = time.time()685print("Ensemble Strategy took: ", (end - start) / 60, " minutes")686687df_summary = pd.DataFrame(688[689iteration_list,690validation_start_date_list,691validation_end_date_list,692model_use,693model_dct["a2c"]["sharpe_list"],694model_dct["ppo"]["sharpe_list"],695model_dct["ddpg"]["sharpe_list"],696model_dct["sac"]["sharpe_list"],697model_dct["td3"]["sharpe_list"],698]699).T700df_summary.columns = [701"Iter",702"Val Start",703"Val End",704"Model Used",705"A2C Sharpe",706"PPO Sharpe",707"DDPG Sharpe",708"SAC Sharpe",709"TD3 Sharpe",710]711712return df_summary713714715