Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/finrl/agents/stablebaselines3/tune_sb3.py
732 views
1
from __future__ import annotations
2
3
import datetime
4
5
import joblib
6
import optuna
7
import pandas as pd
8
from stable_baselines3 import A2C
9
from stable_baselines3 import DDPG
10
from stable_baselines3 import PPO
11
from stable_baselines3 import SAC
12
from stable_baselines3 import TD3
13
14
import finrl.agents.stablebaselines3.hyperparams_opt as hpt
15
from finrl import config
16
from finrl.agents.stablebaselines3.models import DRLAgent
17
from finrl.main import check_and_make_directories
18
from finrl.plot import backtest_stats
19
20
21
class LoggingCallback:
22
def __init__(self, threshold: int, trial_number: int, patience: int):
23
"""
24
threshold:int tolerance for increase in sharpe ratio
25
trial_number: int Prune after minimum number of trials
26
patience: int patience for the threshold
27
"""
28
self.threshold = threshold
29
self.trial_number = trial_number
30
self.patience = patience
31
self.cb_list = [] # Trials list for which threshold is reached
32
33
def __call__(self, study: optuna.study, frozen_trial: optuna.Trial):
34
# Setting the best value in the current trial
35
study.set_user_attr("previous_best_value", study.best_value)
36
37
# Checking if the minimum number of trials have pass
38
if frozen_trial.number > self.trial_number:
39
previous_best_value = study.user_attrs.get("previous_best_value", None)
40
# Checking if the previous and current objective values have the same sign
41
if previous_best_value * study.best_value >= 0:
42
# Checking for the threshold condition
43
if abs(previous_best_value - study.best_value) < self.threshold:
44
self.cb_list.append(frozen_trial.number)
45
# If threshold is achieved for the patience amount of time
46
if len(self.cb_list) > self.patience:
47
print("The study stops now...")
48
print(
49
"With number",
50
frozen_trial.number,
51
"and value ",
52
frozen_trial.value,
53
)
54
print(
55
"The previous and current best values are {} and {} respectively".format(
56
previous_best_value, study.best_value
57
)
58
)
59
study.stop()
60
61
62
class TuneSB3Optuna:
63
"""
64
Hyperparameter tuning of SB3 agents using Optuna
65
66
Attributes
67
----------
68
env_train: Training environment for SB3
69
model_name: str
70
env_trade: testing environment
71
logging_callback: callback for tuning
72
total_timesteps: int
73
n_trials: number of hyperparameter configurations
74
75
Note:
76
The default sampler and pruner are used are
77
Tree Parzen Estimator and Hyperband Scheduler
78
respectively.
79
"""
80
81
def __init__(
82
self,
83
env_train,
84
model_name: str,
85
env_trade,
86
logging_callback,
87
total_timesteps: int = 50000,
88
n_trials: int = 30,
89
):
90
self.env_train = env_train
91
self.agent = DRLAgent(env=env_train)
92
self.model_name = model_name
93
self.env_trade = env_trade
94
self.total_timesteps = total_timesteps
95
self.n_trials = n_trials
96
self.logging_callback = logging_callback
97
self.MODELS = {"a2c": A2C, "ddpg": DDPG, "td3": TD3, "sac": SAC, "ppo": PPO}
98
99
check_and_make_directories(
100
[
101
config.DATA_SAVE_DIR,
102
config.TRAINED_MODEL_DIR,
103
config.TENSORBOARD_LOG_DIR,
104
config.RESULTS_DIR,
105
]
106
)
107
108
def default_sample_hyperparameters(self, trial: optuna.Trial):
109
if self.model_name == "a2c":
110
return hpt.sample_a2c_params(trial)
111
elif self.model_name == "ddpg":
112
return hpt.sample_ddpg_params(trial)
113
elif self.model_name == "td3":
114
return hpt.sample_td3_params(trial)
115
elif self.model_name == "sac":
116
return hpt.sample_sac_params(trial)
117
elif self.model_name == "ppo":
118
return hpt.sample_ppo_params(trial)
119
120
def calculate_sharpe(self, df: pd.DataFrame):
121
df["daily_return"] = df["account_value"].pct_change(1)
122
if df["daily_return"].std() != 0:
123
sharpe = (252**0.5) * df["daily_return"].mean() / df["daily_return"].std()
124
return sharpe
125
else:
126
return 0
127
128
def objective(self, trial: optuna.Trial):
129
hyperparameters = self.default_sample_hyperparameters(trial)
130
policy_kwargs = hyperparameters["policy_kwargs"]
131
del hyperparameters["policy_kwargs"]
132
model = self.agent.get_model(
133
self.model_name, policy_kwargs=policy_kwargs, model_kwargs=hyperparameters
134
)
135
trained_model = self.agent.train_model(
136
model=model,
137
tb_log_name=self.model_name,
138
total_timesteps=self.total_timesteps,
139
)
140
trained_model.save(
141
f"./{config.TRAINED_MODEL_DIR}/{self.model_name}_{trial.number}.pth"
142
)
143
df_account_value, _ = DRLAgent.DRL_prediction(
144
model=trained_model, environment=self.env_trade
145
)
146
sharpe = self.calculate_sharpe(df_account_value)
147
148
return sharpe
149
150
def run_optuna(self):
151
sampler = optuna.samplers.TPESampler(seed=42)
152
study = optuna.create_study(
153
study_name=f"{self.model_name}_study",
154
direction="maximize",
155
sampler=sampler,
156
pruner=optuna.pruners.HyperbandPruner(),
157
)
158
159
study.optimize(
160
self.objective,
161
n_trials=self.n_trials,
162
catch=(ValueError,),
163
callbacks=[self.logging_callback],
164
)
165
166
joblib.dump(study, f"{self.model_name}_study.pkl")
167
return study
168
169
def backtest(
170
self, final_study: optuna.Study
171
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
172
print("Hyperparameters after tuning", final_study.best_params)
173
print("Best Trial", final_study.best_trial)
174
175
tuned_model = self.MODELS[self.model_name].load(
176
f"./{config.TRAINED_MODEL_DIR}/{self.model_name}_{final_study.best_trial.number}.pth",
177
env=self.env_train,
178
)
179
180
df_account_value_tuned, df_actions_tuned = DRLAgent.DRL_prediction(
181
model=tuned_model, environment=self.env_trade
182
)
183
184
print("==============Get Backtest Results===========")
185
now = datetime.datetime.now().strftime("%Y%m%d-%Hh%M")
186
187
perf_stats_all_tuned = backtest_stats(account_value=df_account_value_tuned)
188
perf_stats_all_tuned = pd.DataFrame(perf_stats_all_tuned)
189
perf_stats_all_tuned.to_csv(
190
"./" + config.RESULTS_DIR + "/perf_stats_all_tuned_" + now + ".csv"
191
)
192
193
return df_account_value_tuned, df_actions_tuned, perf_stats_all_tuned
194
195