Path: blob/master/finrl/agents/stablebaselines3/tune_sb3.py
732 views
from __future__ import annotations12import datetime34import joblib5import optuna6import pandas as pd7from stable_baselines3 import A2C8from stable_baselines3 import DDPG9from stable_baselines3 import PPO10from stable_baselines3 import SAC11from stable_baselines3 import TD31213import finrl.agents.stablebaselines3.hyperparams_opt as hpt14from finrl import config15from finrl.agents.stablebaselines3.models import DRLAgent16from finrl.main import check_and_make_directories17from finrl.plot import backtest_stats181920class LoggingCallback:21def __init__(self, threshold: int, trial_number: int, patience: int):22"""23threshold:int tolerance for increase in sharpe ratio24trial_number: int Prune after minimum number of trials25patience: int patience for the threshold26"""27self.threshold = threshold28self.trial_number = trial_number29self.patience = patience30self.cb_list = [] # Trials list for which threshold is reached3132def __call__(self, study: optuna.study, frozen_trial: optuna.Trial):33# Setting the best value in the current trial34study.set_user_attr("previous_best_value", study.best_value)3536# Checking if the minimum number of trials have pass37if frozen_trial.number > self.trial_number:38previous_best_value = study.user_attrs.get("previous_best_value", None)39# Checking if the previous and current objective values have the same sign40if previous_best_value * study.best_value >= 0:41# Checking for the threshold condition42if abs(previous_best_value - study.best_value) < self.threshold:43self.cb_list.append(frozen_trial.number)44# If threshold is achieved for the patience amount of time45if len(self.cb_list) > self.patience:46print("The study stops now...")47print(48"With number",49frozen_trial.number,50"and value ",51frozen_trial.value,52)53print(54"The previous and current best values are {} and {} respectively".format(55previous_best_value, study.best_value56)57)58study.stop()596061class TuneSB3Optuna:62"""63Hyperparameter tuning of SB3 agents using Optuna6465Attributes66----------67env_train: Training environment for SB368model_name: str69env_trade: testing environment70logging_callback: callback for tuning71total_timesteps: int72n_trials: number of hyperparameter configurations7374Note:75The default sampler and pruner are used are76Tree Parzen Estimator and Hyperband Scheduler77respectively.78"""7980def __init__(81self,82env_train,83model_name: str,84env_trade,85logging_callback,86total_timesteps: int = 50000,87n_trials: int = 30,88):89self.env_train = env_train90self.agent = DRLAgent(env=env_train)91self.model_name = model_name92self.env_trade = env_trade93self.total_timesteps = total_timesteps94self.n_trials = n_trials95self.logging_callback = logging_callback96self.MODELS = {"a2c": A2C, "ddpg": DDPG, "td3": TD3, "sac": SAC, "ppo": PPO}9798check_and_make_directories(99[100config.DATA_SAVE_DIR,101config.TRAINED_MODEL_DIR,102config.TENSORBOARD_LOG_DIR,103config.RESULTS_DIR,104]105)106107def default_sample_hyperparameters(self, trial: optuna.Trial):108if self.model_name == "a2c":109return hpt.sample_a2c_params(trial)110elif self.model_name == "ddpg":111return hpt.sample_ddpg_params(trial)112elif self.model_name == "td3":113return hpt.sample_td3_params(trial)114elif self.model_name == "sac":115return hpt.sample_sac_params(trial)116elif self.model_name == "ppo":117return hpt.sample_ppo_params(trial)118119def calculate_sharpe(self, df: pd.DataFrame):120df["daily_return"] = df["account_value"].pct_change(1)121if df["daily_return"].std() != 0:122sharpe = (252**0.5) * df["daily_return"].mean() / df["daily_return"].std()123return sharpe124else:125return 0126127def objective(self, trial: optuna.Trial):128hyperparameters = self.default_sample_hyperparameters(trial)129policy_kwargs = hyperparameters["policy_kwargs"]130del hyperparameters["policy_kwargs"]131model = self.agent.get_model(132self.model_name, policy_kwargs=policy_kwargs, model_kwargs=hyperparameters133)134trained_model = self.agent.train_model(135model=model,136tb_log_name=self.model_name,137total_timesteps=self.total_timesteps,138)139trained_model.save(140f"./{config.TRAINED_MODEL_DIR}/{self.model_name}_{trial.number}.pth"141)142df_account_value, _ = DRLAgent.DRL_prediction(143model=trained_model, environment=self.env_trade144)145sharpe = self.calculate_sharpe(df_account_value)146147return sharpe148149def run_optuna(self):150sampler = optuna.samplers.TPESampler(seed=42)151study = optuna.create_study(152study_name=f"{self.model_name}_study",153direction="maximize",154sampler=sampler,155pruner=optuna.pruners.HyperbandPruner(),156)157158study.optimize(159self.objective,160n_trials=self.n_trials,161catch=(ValueError,),162callbacks=[self.logging_callback],163)164165joblib.dump(study, f"{self.model_name}_study.pkl")166return study167168def backtest(169self, final_study: optuna.Study170) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:171print("Hyperparameters after tuning", final_study.best_params)172print("Best Trial", final_study.best_trial)173174tuned_model = self.MODELS[self.model_name].load(175f"./{config.TRAINED_MODEL_DIR}/{self.model_name}_{final_study.best_trial.number}.pth",176env=self.env_train,177)178179df_account_value_tuned, df_actions_tuned = DRLAgent.DRL_prediction(180model=tuned_model, environment=self.env_trade181)182183print("==============Get Backtest Results===========")184now = datetime.datetime.now().strftime("%Y%m%d-%Hh%M")185186perf_stats_all_tuned = backtest_stats(account_value=df_account_value_tuned)187perf_stats_all_tuned = pd.DataFrame(perf_stats_all_tuned)188perf_stats_all_tuned.to_csv(189"./" + config.RESULTS_DIR + "/perf_stats_all_tuned_" + now + ".csv"190)191192return df_account_value_tuned, df_actions_tuned, perf_stats_all_tuned193194195