Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/finrl/agents/stablebaselines3/models.py
732 views
1
# DRL models from Stable Baselines 3
2
from __future__ import annotations
3
4
import statistics
5
import time
6
7
import numpy as np
8
import pandas as pd
9
from stable_baselines3 import A2C
10
from stable_baselines3 import DDPG
11
from stable_baselines3 import PPO
12
from stable_baselines3 import SAC
13
from stable_baselines3 import TD3
14
from stable_baselines3.common.callbacks import BaseCallback
15
from stable_baselines3.common.callbacks import CallbackList
16
from stable_baselines3.common.noise import NormalActionNoise
17
from stable_baselines3.common.noise import OrnsteinUhlenbeckActionNoise
18
from stable_baselines3.common.vec_env import DummyVecEnv
19
20
from finrl import config
21
from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv
22
from finrl.meta.preprocessor.preprocessors import data_split
23
24
MODELS = {"a2c": A2C, "ddpg": DDPG, "td3": TD3, "sac": SAC, "ppo": PPO}
25
26
MODEL_KWARGS = {x: config.__dict__[f"{x.upper()}_PARAMS"] for x in MODELS.keys()}
27
28
NOISE = {
29
"normal": NormalActionNoise,
30
"ornstein_uhlenbeck": OrnsteinUhlenbeckActionNoise,
31
}
32
33
34
class TensorboardCallback(BaseCallback):
35
"""
36
Custom callback for plotting additional values in tensorboard.
37
"""
38
39
def __init__(self, verbose=0):
40
super().__init__(verbose)
41
42
def _on_step(self) -> bool:
43
try:
44
self.logger.record(key="train/reward", value=self.locals["rewards"][0])
45
46
except BaseException as error:
47
try:
48
self.logger.record(key="train/reward", value=self.locals["reward"][0])
49
50
except BaseException as inner_error:
51
# Handle the case where neither "rewards" nor "reward" is found
52
self.logger.record(key="train/reward", value=None)
53
# Print the original error and the inner error for debugging
54
print("Original Error:", error)
55
print("Inner Error:", inner_error)
56
return True
57
58
def _on_rollout_end(self) -> bool:
59
try:
60
rollout_buffer_rewards = self.locals["rollout_buffer"].rewards.flatten()
61
self.logger.record(
62
key="train/reward_min", value=min(rollout_buffer_rewards)
63
)
64
self.logger.record(
65
key="train/reward_mean", value=statistics.mean(rollout_buffer_rewards)
66
)
67
self.logger.record(
68
key="train/reward_max", value=max(rollout_buffer_rewards)
69
)
70
except BaseException as error:
71
# Handle the case where "rewards" is not found
72
self.logger.record(key="train/reward_min", value=None)
73
self.logger.record(key="train/reward_mean", value=None)
74
self.logger.record(key="train/reward_max", value=None)
75
print("Logging Error:", error)
76
return True
77
78
79
class DRLAgent:
80
"""Provides implementations for DRL algorithms
81
82
Attributes
83
----------
84
env: gym environment class
85
user-defined class
86
87
Methods
88
-------
89
get_model()
90
setup DRL algorithms
91
train_model()
92
train DRL algorithms in a train dataset
93
and output the trained model
94
DRL_prediction()
95
make a prediction in a test dataset and get results
96
"""
97
98
def __init__(self, env):
99
self.env = env
100
101
def get_model(
102
self,
103
model_name,
104
policy="MlpPolicy",
105
policy_kwargs=None,
106
model_kwargs=None,
107
verbose=1,
108
seed=None,
109
tensorboard_log=None,
110
):
111
if model_name not in MODELS:
112
raise ValueError(
113
f"Model '{model_name}' not found in MODELS."
114
) # this is more informative than NotImplementedError("NotImplementedError")
115
116
if model_kwargs is None:
117
model_kwargs = MODEL_KWARGS[model_name]
118
119
if "action_noise" in model_kwargs:
120
n_actions = self.env.action_space.shape[-1]
121
model_kwargs["action_noise"] = NOISE[model_kwargs["action_noise"]](
122
mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)
123
)
124
print(model_kwargs)
125
return MODELS[model_name](
126
policy=policy,
127
env=self.env,
128
tensorboard_log=tensorboard_log,
129
verbose=verbose,
130
policy_kwargs=policy_kwargs,
131
seed=seed,
132
**model_kwargs,
133
)
134
135
@staticmethod
136
def train_model(
137
model,
138
tb_log_name,
139
total_timesteps=5000,
140
callbacks: Type[BaseCallback] = None,
141
): # this function is static method, so it can be called without creating an instance of the class
142
model = model.learn(
143
total_timesteps=total_timesteps,
144
tb_log_name=tb_log_name,
145
callback=(
146
CallbackList(
147
[TensorboardCallback()] + [callback for callback in callbacks]
148
)
149
if callbacks is not None
150
else TensorboardCallback()
151
),
152
)
153
return model
154
155
@staticmethod
156
def DRL_prediction(model, environment, deterministic=True):
157
"""make a prediction and get results"""
158
test_env, test_obs = environment.get_sb_env()
159
account_memory = None # This help avoid unnecessary list creation
160
actions_memory = None # optimize memory consumption
161
# state_memory=[] #add memory pool to store states
162
163
test_env.reset()
164
max_steps = len(environment.df.index.unique()) - 1
165
166
for i in range(len(environment.df.index.unique())):
167
action, _states = model.predict(test_obs, deterministic=deterministic)
168
# account_memory = test_env.env_method(method_name="save_asset_memory")
169
# actions_memory = test_env.env_method(method_name="save_action_memory")
170
test_obs, rewards, dones, info = test_env.step(action)
171
172
if (
173
i == max_steps - 1
174
): # more descriptive condition for early termination to clarify the logic
175
account_memory = test_env.env_method(method_name="save_asset_memory")
176
actions_memory = test_env.env_method(method_name="save_action_memory")
177
# add current state to state memory
178
# state_memory=test_env.env_method(method_name="save_state_memory")
179
180
if dones[0]:
181
print("hit end!")
182
break
183
return account_memory[0], actions_memory[0]
184
185
@staticmethod
186
def DRL_prediction_load_from_file(model_name, environment, cwd, deterministic=True):
187
if model_name not in MODELS:
188
raise ValueError(
189
f"Model '{model_name}' not found in MODELS."
190
) # this is more informative than NotImplementedError("NotImplementedError")
191
try:
192
# load agent
193
model = MODELS[model_name].load(cwd)
194
print("Successfully load model", cwd)
195
except BaseException as error:
196
raise ValueError(f"Failed to load agent. Error: {str(error)}") from error
197
198
# test on the testing env
199
state = environment.reset()
200
episode_returns = [] # the cumulative_return / initial_account
201
episode_total_assets = [environment.initial_total_asset]
202
done = False
203
while not done:
204
action = model.predict(state, deterministic=deterministic)[0]
205
state, reward, done, _ = environment.step(action)
206
207
total_asset = (
208
environment.amount
209
+ (environment.price_ary[environment.day] * environment.stocks).sum()
210
)
211
episode_total_assets.append(total_asset)
212
episode_return = total_asset / environment.initial_total_asset
213
episode_returns.append(episode_return)
214
215
print("episode_return", episode_return)
216
print("Test Finished!")
217
return episode_total_assets
218
219
220
class DRLEnsembleAgent:
221
@staticmethod
222
def get_model(
223
model_name,
224
env,
225
policy="MlpPolicy",
226
policy_kwargs=None,
227
model_kwargs=None,
228
seed=None,
229
verbose=1,
230
):
231
if model_name not in MODELS:
232
raise ValueError(
233
f"Model '{model_name}' not found in MODELS."
234
) # this is more informative than NotImplementedError("NotImplementedError")
235
236
if model_kwargs is None:
237
temp_model_kwargs = MODEL_KWARGS[model_name]
238
else:
239
temp_model_kwargs = model_kwargs.copy()
240
241
if "action_noise" in temp_model_kwargs:
242
n_actions = env.action_space.shape[-1]
243
temp_model_kwargs["action_noise"] = NOISE[
244
temp_model_kwargs["action_noise"]
245
](mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
246
print(temp_model_kwargs)
247
return MODELS[model_name](
248
policy=policy,
249
env=env,
250
tensorboard_log=f"{config.TENSORBOARD_LOG_DIR}/{model_name}",
251
verbose=verbose,
252
policy_kwargs=policy_kwargs,
253
seed=seed,
254
**temp_model_kwargs,
255
)
256
257
@staticmethod
258
def train_model(
259
model,
260
model_name,
261
tb_log_name,
262
iter_num,
263
total_timesteps=5000,
264
callbacks: Type[BaseCallback] = None,
265
):
266
model = model.learn(
267
total_timesteps=total_timesteps,
268
tb_log_name=tb_log_name,
269
callback=(
270
CallbackList(
271
[TensorboardCallback()] + [callback for callback in callbacks]
272
)
273
if callbacks is not None
274
else TensorboardCallback()
275
),
276
)
277
model.save(
278
f"{config.TRAINED_MODEL_DIR}/{model_name.upper()}_{total_timesteps // 1000}k_{iter_num}"
279
)
280
return model
281
282
@staticmethod
283
def get_validation_sharpe(iteration, model_name):
284
"""Calculate Sharpe ratio based on validation results"""
285
df_total_value = pd.read_csv(
286
f"results/account_value_validation_{model_name}_{iteration}.csv"
287
)
288
# If the agent did not make any transaction
289
if df_total_value["daily_return"].var() == 0:
290
if df_total_value["daily_return"].mean() > 0:
291
return np.inf
292
else:
293
return 0.0
294
else:
295
return (
296
(4**0.5)
297
* df_total_value["daily_return"].mean()
298
/ df_total_value["daily_return"].std()
299
)
300
301
def __init__(
302
self,
303
df,
304
train_period,
305
val_test_period,
306
rebalance_window,
307
validation_window,
308
stock_dim,
309
hmax,
310
initial_amount,
311
buy_cost_pct,
312
sell_cost_pct,
313
reward_scaling,
314
state_space,
315
action_space,
316
tech_indicator_list,
317
print_verbosity,
318
):
319
self.df = df
320
self.train_period = train_period
321
self.val_test_period = val_test_period
322
323
self.unique_trade_date = df[
324
(df.date > val_test_period[0]) & (df.date <= val_test_period[1])
325
].date.unique()
326
self.rebalance_window = rebalance_window
327
self.validation_window = validation_window
328
329
self.stock_dim = stock_dim
330
self.hmax = hmax
331
self.initial_amount = initial_amount
332
self.buy_cost_pct = buy_cost_pct
333
self.sell_cost_pct = sell_cost_pct
334
self.reward_scaling = reward_scaling
335
self.state_space = state_space
336
self.action_space = action_space
337
self.tech_indicator_list = tech_indicator_list
338
self.print_verbosity = print_verbosity
339
self.train_env = None # defined in train_validation() function
340
341
def DRL_validation(self, model, test_data, test_env, test_obs):
342
"""validation process"""
343
for _ in range(len(test_data.index.unique())):
344
action, _states = model.predict(test_obs)
345
test_obs, rewards, dones, info = test_env.step(action)
346
347
def DRL_prediction(
348
self, model, name, last_state, iter_num, turbulence_threshold, initial
349
):
350
"""make a prediction based on trained model"""
351
352
# trading env
353
trade_data = data_split(
354
self.df,
355
start=self.unique_trade_date[iter_num - self.rebalance_window],
356
end=self.unique_trade_date[iter_num],
357
)
358
trade_env = DummyVecEnv(
359
[
360
lambda: StockTradingEnv(
361
df=trade_data,
362
stock_dim=self.stock_dim,
363
hmax=self.hmax,
364
initial_amount=self.initial_amount,
365
num_stock_shares=[0] * self.stock_dim,
366
buy_cost_pct=[self.buy_cost_pct] * self.stock_dim,
367
sell_cost_pct=[self.sell_cost_pct] * self.stock_dim,
368
reward_scaling=self.reward_scaling,
369
state_space=self.state_space,
370
action_space=self.action_space,
371
tech_indicator_list=self.tech_indicator_list,
372
turbulence_threshold=turbulence_threshold,
373
initial=initial,
374
previous_state=last_state,
375
model_name=name,
376
mode="trade",
377
iteration=iter_num,
378
print_verbosity=self.print_verbosity,
379
)
380
]
381
)
382
383
trade_obs = trade_env.reset()
384
385
for i in range(len(trade_data.index.unique())):
386
action, _states = model.predict(trade_obs)
387
trade_obs, rewards, dones, info = trade_env.step(action)
388
if i == (len(trade_data.index.unique()) - 2):
389
# print(env_test.render())
390
last_state = trade_env.envs[0].render()
391
392
df_last_state = pd.DataFrame({"last_state": last_state})
393
df_last_state.to_csv(f"results/last_state_{name}_{i}.csv", index=False)
394
return last_state
395
396
def _train_window(
397
self,
398
model_name,
399
model_kwargs,
400
sharpe_list,
401
validation_start_date,
402
validation_end_date,
403
timesteps_dict,
404
i,
405
validation,
406
turbulence_threshold,
407
):
408
"""
409
Train the model for a single window.
410
"""
411
if model_kwargs is None:
412
return None, sharpe_list, -1
413
414
print(f"======{model_name} Training========")
415
model = self.get_model(
416
model_name, self.train_env, policy="MlpPolicy", model_kwargs=model_kwargs
417
)
418
model = self.train_model(
419
model,
420
model_name,
421
tb_log_name=f"{model_name}_{i}",
422
iter_num=i,
423
total_timesteps=timesteps_dict[model_name],
424
) # 100_000
425
print(
426
f"======{model_name} Validation from: ",
427
validation_start_date,
428
"to ",
429
validation_end_date,
430
)
431
val_env = DummyVecEnv(
432
[
433
lambda: StockTradingEnv(
434
df=validation,
435
stock_dim=self.stock_dim,
436
hmax=self.hmax,
437
initial_amount=self.initial_amount,
438
num_stock_shares=[0] * self.stock_dim,
439
buy_cost_pct=[self.buy_cost_pct] * self.stock_dim,
440
sell_cost_pct=[self.sell_cost_pct] * self.stock_dim,
441
reward_scaling=self.reward_scaling,
442
state_space=self.state_space,
443
action_space=self.action_space,
444
tech_indicator_list=self.tech_indicator_list,
445
turbulence_threshold=turbulence_threshold,
446
iteration=i,
447
model_name=model_name,
448
mode="validation",
449
print_verbosity=self.print_verbosity,
450
)
451
]
452
)
453
val_obs = val_env.reset()
454
self.DRL_validation(
455
model=model,
456
test_data=validation,
457
test_env=val_env,
458
test_obs=val_obs,
459
)
460
sharpe = self.get_validation_sharpe(i, model_name=model_name)
461
print(f"{model_name} Sharpe Ratio: ", sharpe)
462
sharpe_list.append(sharpe)
463
return model, sharpe_list, sharpe
464
465
def run_ensemble_strategy(
466
self,
467
A2C_model_kwargs,
468
PPO_model_kwargs,
469
DDPG_model_kwargs,
470
SAC_model_kwargs,
471
TD3_model_kwargs,
472
timesteps_dict,
473
):
474
# Model Parameters
475
kwargs = {
476
"a2c": A2C_model_kwargs,
477
"ppo": PPO_model_kwargs,
478
"ddpg": DDPG_model_kwargs,
479
"sac": SAC_model_kwargs,
480
"td3": TD3_model_kwargs,
481
}
482
# Model Sharpe Ratios
483
model_dct = {k: {"sharpe_list": [], "sharpe": -1} for k in MODELS.keys()}
484
485
"""Ensemble Strategy that combines A2C, PPO, DDPG, SAC, and TD3"""
486
print("============Start Ensemble Strategy============")
487
# for ensemble model, it's necessary to feed the last state
488
# of the previous model to the current model as the initial state
489
last_state_ensemble = []
490
491
model_use = []
492
validation_start_date_list = []
493
validation_end_date_list = []
494
iteration_list = []
495
496
insample_turbulence = self.df[
497
(self.df.date < self.train_period[1])
498
& (self.df.date >= self.train_period[0])
499
]
500
insample_turbulence_threshold = np.quantile(
501
insample_turbulence.turbulence.values, 0.90
502
)
503
504
start = time.time()
505
for i in range(
506
self.rebalance_window + self.validation_window,
507
len(self.unique_trade_date),
508
self.rebalance_window,
509
):
510
validation_start_date = self.unique_trade_date[
511
i - self.rebalance_window - self.validation_window
512
]
513
validation_end_date = self.unique_trade_date[i - self.rebalance_window]
514
515
validation_start_date_list.append(validation_start_date)
516
validation_end_date_list.append(validation_end_date)
517
iteration_list.append(i)
518
519
print("============================================")
520
# initial state is empty
521
if i - self.rebalance_window - self.validation_window == 0:
522
# inital state
523
initial = True
524
else:
525
# previous state
526
initial = False
527
528
# Tuning trubulence index based on historical data
529
# Turbulence lookback window is one quarter (63 days)
530
end_date_index = self.df.index[
531
self.df["date"]
532
== self.unique_trade_date[
533
i - self.rebalance_window - self.validation_window
534
]
535
].to_list()[-1]
536
start_date_index = end_date_index - 63 + 1
537
538
historical_turbulence = self.df.iloc[
539
start_date_index : (end_date_index + 1), :
540
]
541
542
historical_turbulence = historical_turbulence.drop_duplicates(
543
subset=["date"]
544
)
545
546
historical_turbulence_mean = np.mean(
547
historical_turbulence.turbulence.values
548
)
549
550
# print(historical_turbulence_mean)
551
552
if historical_turbulence_mean > insample_turbulence_threshold:
553
# if the mean of the historical data is greater than the 90% quantile of insample turbulence data
554
# then we assume that the current market is volatile,
555
# therefore we set the 90% quantile of insample turbulence data as the turbulence threshold
556
# meaning the current turbulence can't exceed the 90% quantile of insample turbulence data
557
turbulence_threshold = insample_turbulence_threshold
558
else:
559
# if the mean of the historical data is less than the 90% quantile of insample turbulence data
560
# then we tune up the turbulence_threshold, meaning we lower the risk
561
turbulence_threshold = np.quantile(
562
insample_turbulence.turbulence.values, 1
563
)
564
565
turbulence_threshold = np.quantile(
566
insample_turbulence.turbulence.values, 0.99
567
)
568
print("turbulence_threshold: ", turbulence_threshold)
569
570
# Environment Setup starts
571
# training env
572
train = data_split(
573
self.df,
574
start=self.train_period[0],
575
end=self.unique_trade_date[
576
i - self.rebalance_window - self.validation_window
577
],
578
)
579
self.train_env = DummyVecEnv(
580
[
581
lambda: StockTradingEnv(
582
df=train,
583
stock_dim=self.stock_dim,
584
hmax=self.hmax,
585
initial_amount=self.initial_amount,
586
num_stock_shares=[0] * self.stock_dim,
587
buy_cost_pct=[self.buy_cost_pct] * self.stock_dim,
588
sell_cost_pct=[self.sell_cost_pct] * self.stock_dim,
589
reward_scaling=self.reward_scaling,
590
state_space=self.state_space,
591
action_space=self.action_space,
592
tech_indicator_list=self.tech_indicator_list,
593
print_verbosity=self.print_verbosity,
594
)
595
]
596
)
597
598
validation = data_split(
599
self.df,
600
start=self.unique_trade_date[
601
i - self.rebalance_window - self.validation_window
602
],
603
end=self.unique_trade_date[i - self.rebalance_window],
604
)
605
# Environment Setup ends
606
607
# Training and Validation starts
608
print(
609
"======Model training from: ",
610
self.train_period[0],
611
"to ",
612
self.unique_trade_date[
613
i - self.rebalance_window - self.validation_window
614
],
615
)
616
# print("training: ",len(data_split(df, start=20090000, end=test.datadate.unique()[i-rebalance_window]) ))
617
# print("==============Model Training===========")
618
# Train Each Model
619
for model_name in MODELS.keys():
620
# Train The Model
621
model, sharpe_list, sharpe = self._train_window(
622
model_name,
623
kwargs[model_name],
624
model_dct[model_name]["sharpe_list"],
625
validation_start_date,
626
validation_end_date,
627
timesteps_dict,
628
i,
629
validation,
630
turbulence_threshold,
631
)
632
# Save the model's sharpe ratios, and the model itself
633
model_dct[model_name]["sharpe_list"] = sharpe_list
634
model_dct[model_name]["model"] = model
635
model_dct[model_name]["sharpe"] = sharpe
636
637
print(
638
"======Best Model Retraining from: ",
639
self.train_period[0],
640
"to ",
641
self.unique_trade_date[i - self.rebalance_window],
642
)
643
# Environment setup for model retraining up to first trade date
644
# train_full = data_split(self.df, start=self.train_period[0],
645
# end=self.unique_trade_date[i - self.rebalance_window])
646
# self.train_full_env = DummyVecEnv([lambda: StockTradingEnv(train_full,
647
# self.stock_dim,
648
# self.hmax,
649
# self.initial_amount,
650
# self.buy_cost_pct,
651
# self.sell_cost_pct,
652
# self.reward_scaling,
653
# self.state_space,
654
# self.action_space,
655
# self.tech_indicator_list,
656
# print_verbosity=self.print_verbosity
657
# )])
658
# Model Selection based on sharpe ratio
659
# Same order as MODELS: {"a2c": A2C, "ddpg": DDPG, "td3": TD3, "sac": SAC, "ppo": PPO}
660
sharpes = [model_dct[k]["sharpe"] for k in MODELS.keys()]
661
# Find the model with the highest sharpe ratio
662
max_mod = list(MODELS.keys())[np.argmax(sharpes)]
663
model_use.append(max_mod.upper())
664
model_ensemble = model_dct[max_mod]["model"]
665
# Training and Validation ends
666
667
# Trading starts
668
print(
669
"======Trading from: ",
670
self.unique_trade_date[i - self.rebalance_window],
671
"to ",
672
self.unique_trade_date[i],
673
)
674
# print("Used Model: ", model_ensemble)
675
last_state_ensemble = self.DRL_prediction(
676
model=model_ensemble,
677
name="ensemble",
678
last_state=last_state_ensemble,
679
iter_num=i,
680
turbulence_threshold=turbulence_threshold,
681
initial=initial,
682
)
683
# Trading ends
684
685
end = time.time()
686
print("Ensemble Strategy took: ", (end - start) / 60, " minutes")
687
688
df_summary = pd.DataFrame(
689
[
690
iteration_list,
691
validation_start_date_list,
692
validation_end_date_list,
693
model_use,
694
model_dct["a2c"]["sharpe_list"],
695
model_dct["ppo"]["sharpe_list"],
696
model_dct["ddpg"]["sharpe_list"],
697
model_dct["sac"]["sharpe_list"],
698
model_dct["td3"]["sharpe_list"],
699
]
700
).T
701
df_summary.columns = [
702
"Iter",
703
"Val Start",
704
"Val End",
705
"Model Used",
706
"A2C Sharpe",
707
"PPO Sharpe",
708
"DDPG Sharpe",
709
"SAC Sharpe",
710
"TD3 Sharpe",
711
]
712
713
return df_summary
714
715