Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/finrl/agents/rllib/models.py
732 views
1
# DRL models from RLlib
2
from __future__ import annotations
3
4
import ray
5
from ray.rllib.algorithms.a2c import a2c
6
from ray.rllib.algorithms.ddpg import ddpg
7
from ray.rllib.algorithms.ppo import ppo
8
from ray.rllib.algorithms.sac import sac
9
from ray.rllib.algorithms.td3 import td3
10
11
MODELS = {"a2c": a2c, "ddpg": ddpg, "td3": td3, "sac": sac, "ppo": ppo}
12
13
14
# MODEL_KWARGS = {x: config.__dict__[f"{x.upper()}_PARAMS"] for x in MODELS.keys()}
15
16
17
class DRLAgent:
18
"""Implementations for DRL algorithms
19
20
Attributes
21
----------
22
env: gym environment class
23
user-defined class
24
price_array: numpy array
25
OHLC data
26
tech_array: numpy array
27
techical data
28
turbulence_array: numpy array
29
turbulence/risk data
30
Methods
31
-------
32
get_model()
33
setup DRL algorithms
34
train_model()
35
train DRL algorithms in a train dataset
36
and output the trained model
37
DRL_prediction()
38
make a prediction in a test dataset and get results
39
"""
40
41
def __init__(self, env, price_array, tech_array, turbulence_array):
42
self.env = env
43
self.price_array = price_array
44
self.tech_array = tech_array
45
self.turbulence_array = turbulence_array
46
47
def get_model(
48
self,
49
model_name,
50
# policy="MlpPolicy",
51
# policy_kwargs=None,
52
# model_kwargs=None,
53
):
54
if model_name not in MODELS:
55
raise NotImplementedError("NotImplementedError")
56
57
# if model_kwargs is None:
58
# model_kwargs = MODEL_KWARGS[model_name]
59
60
model = MODELS[model_name]
61
# get algorithm default configration based on algorithm in RLlib
62
if model_name == "a2c":
63
model_config = model.A2C_DEFAULT_CONFIG.copy()
64
elif model_name == "td3":
65
model_config = model.TD3_DEFAULT_CONFIG.copy()
66
else:
67
model_config = model.DEFAULT_CONFIG.copy()
68
# pass env, log_level, price_array, tech_array, and turbulence_array to config
69
model_config["env"] = self.env
70
model_config["log_level"] = "WARN"
71
model_config["env_config"] = {
72
"price_array": self.price_array,
73
"tech_array": self.tech_array,
74
"turbulence_array": self.turbulence_array,
75
"if_train": True,
76
}
77
78
return model, model_config
79
80
def train_model(
81
self, model, model_name, model_config, total_episodes=100, init_ray=True
82
):
83
if model_name not in MODELS:
84
raise NotImplementedError("NotImplementedError")
85
if init_ray:
86
ray.init(
87
ignore_reinit_error=True
88
) # Other Ray APIs will not work until `ray.init()` is called.
89
90
if model_name == "ppo":
91
trainer = model.PPOTrainer(env=self.env, config=model_config)
92
elif model_name == "a2c":
93
trainer = model.A2CTrainer(env=self.env, config=model_config)
94
elif model_name == "ddpg":
95
trainer = model.DDPGTrainer(env=self.env, config=model_config)
96
elif model_name == "td3":
97
trainer = model.TD3Trainer(env=self.env, config=model_config)
98
elif model_name == "sac":
99
trainer = model.SACTrainer(env=self.env, config=model_config)
100
101
for _ in range(total_episodes):
102
trainer.train()
103
104
ray.shutdown()
105
106
# save the trained model
107
cwd = "./test_" + str(model_name)
108
trainer.save(cwd)
109
110
return trainer
111
112
@staticmethod
113
def DRL_prediction(
114
model_name,
115
env,
116
price_array,
117
tech_array,
118
turbulence_array,
119
agent_path="./test_ppo/checkpoint_000100/checkpoint-100",
120
):
121
if model_name not in MODELS:
122
raise NotImplementedError("NotImplementedError")
123
124
if model_name == "a2c":
125
model_config = MODELS[model_name].A2C_DEFAULT_CONFIG.copy()
126
elif model_name == "td3":
127
model_config = MODELS[model_name].TD3_DEFAULT_CONFIG.copy()
128
else:
129
model_config = MODELS[model_name].DEFAULT_CONFIG.copy()
130
model_config["env"] = env
131
model_config["log_level"] = "WARN"
132
model_config["env_config"] = {
133
"price_array": price_array,
134
"tech_array": tech_array,
135
"turbulence_array": turbulence_array,
136
"if_train": False,
137
}
138
env_config = {
139
"price_array": price_array,
140
"tech_array": tech_array,
141
"turbulence_array": turbulence_array,
142
"if_train": False,
143
}
144
env_instance = env(config=env_config)
145
146
# ray.init() # Other Ray APIs will not work until `ray.init()` is called.
147
if model_name == "ppo":
148
trainer = MODELS[model_name].PPOTrainer(env=env, config=model_config)
149
elif model_name == "a2c":
150
trainer = MODELS[model_name].A2CTrainer(env=env, config=model_config)
151
elif model_name == "ddpg":
152
trainer = MODELS[model_name].DDPGTrainer(env=env, config=model_config)
153
elif model_name == "td3":
154
trainer = MODELS[model_name].TD3Trainer(env=env, config=model_config)
155
elif model_name == "sac":
156
trainer = MODELS[model_name].SACTrainer(env=env, config=model_config)
157
158
try:
159
trainer.restore(agent_path)
160
print("Restoring from checkpoint path", agent_path)
161
except BaseException:
162
raise ValueError("Fail to load agent!")
163
164
# test on the testing env
165
state = env_instance.reset()
166
episode_returns = [] # the cumulative_return / initial_account
167
episode_total_assets = [env_instance.initial_total_asset]
168
done = False
169
while not done:
170
action = trainer.compute_single_action(state)
171
state, reward, done, _ = env_instance.step(action)
172
173
total_asset = (
174
env_instance.amount
175
+ (env_instance.price_ary[env_instance.day] * env_instance.stocks).sum()
176
)
177
episode_total_assets.append(total_asset)
178
episode_return = total_asset / env_instance.initial_total_asset
179
episode_returns.append(episode_return)
180
ray.shutdown()
181
print("episode return: " + str(episode_return))
182
print("Test Finished!")
183
return episode_total_assets
184
185