Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/finrl/agents/rllib/drllibv2.py
732 views
1
# @Author: Astarag Mohapatra
2
from __future__ import annotations
3
4
import ray
5
6
assert (
7
ray.__version__ > "2.0.0"
8
), "Please install ray 2.2.0 by doing 'pip install ray[rllib] ray[tune] lz4' , lz4 is for population based tuning"
9
from pprint import pprint
10
11
from ray import tune
12
from ray.tune.search import ConcurrencyLimiter
13
from ray.rllib.algorithms import Algorithm
14
from ray.tune import register_env
15
16
from ray.air import RunConfig, FailureConfig, ScalingConfig
17
from ray.tune.tune_config import TuneConfig
18
from ray.air.config import CheckpointConfig
19
20
import psutil
21
22
psutil_memory_in_bytes = psutil.virtual_memory().total
23
ray._private.utils.get_system_memory = lambda: psutil_memory_in_bytes
24
from typing import Dict, Optional, Any, List, Union
25
26
27
class DRLlibv2:
28
"""
29
It instantiates RLlib model with Ray tune functionality
30
Params
31
-------------------------------------
32
trainable:
33
Any Trainable class that takes config as parameter
34
train_env:
35
Training environment instance
36
train_env_name: str
37
Name of the training environment
38
params: dict
39
hyperparameters dictionary
40
run_name: str
41
tune run name
42
framework: str
43
"torch" or "tf" for tensorflow
44
local_dir: str
45
to save the results and tensorboard plots
46
num_workers: int
47
number of workers
48
search_alg
49
search space for hyperparameters
50
concurrent_trials:
51
Number of concurrent hyperparameters trial to run
52
num_samples: int
53
Number of samples of hyperparameters config to run
54
scheduler:
55
Stopping suboptimal trials
56
log_level: str = "WARN",
57
Verbosity: "DEBUG"
58
num_gpus: Union[float, int] = 0
59
GPUs for trial
60
num_cpus: Union[float, int] = 2
61
CPUs for rollout collection
62
dataframe_save: str
63
Saving the tune results
64
metric: str
65
Metric for hyperparameter optimization in Bayesian Methods
66
mode: str
67
Maximize or Minimize the metric
68
max_failures: int
69
Number of failures to TuneError
70
training_iterations: str
71
Number of times session.report() is called
72
checkpoint_num_to_keep: int
73
Number of checkpoints to keep
74
checkpoint_freq: int
75
Checkpoint freq wrt training iterations
76
reuse_actors:bool
77
Reuse actors for tuning
78
79
It has the following methods:
80
Methods
81
-------------------------------------
82
train_tune_model: It takes in the params dictionary and fits in sklearn style to our trainable class
83
restore_agent: It restores previously errored or stopped trials or experiments
84
infer_results: It returns the results dataframe and trial informations
85
get_test_agent: It returns the testing agent for inference
86
87
Example
88
---------------------------------------
89
def sample_ppo_params():
90
return {
91
"entropy_coeff": tune.loguniform(0.00000001, 0.1),
92
"lr": tune.loguniform(5e-5, 0.001),
93
"sgd_minibatch_size": tune.choice([ 32, 64, 128, 256, 512]),
94
"lambda": tune.choice([0.1,0.3,0.5,0.7,0.9,1.0]),
95
}
96
optuna_search = OptunaSearch(
97
metric="episode_reward_mean",
98
mode="max")
99
drl_agent = DRLlibv2(
100
trainable="PPO",
101
train_env=env(train_env_config),
102
train_env_name="StockTrading_train",
103
framework="torch",
104
num_workers=1,
105
log_level="DEBUG",
106
run_name = 'test',
107
local_dir = "test",
108
params = sample_ppo_params(),
109
num_samples = 1,
110
num_gpus=1,
111
training_iterations=10,
112
search_alg = optuna_search,
113
checkpoint_freq=5
114
)
115
#Tune or train the model
116
res = drl_agent.train_tune_model()
117
118
#Get the tune results
119
results_df, best_result = drl_agent.infer_results()
120
121
#Get the best testing agent
122
test_agent = drl_agent.get_test_agent(test_env_instance,'StockTrading_testenv')
123
"""
124
125
def __init__(
126
self,
127
trainable: str | Any,
128
train_env_name: str,
129
params: dict,
130
train_env=None,
131
run_name: str = "tune_run",
132
framework: str = "torch",
133
local_dir: str = "tune_results",
134
num_workers: int = 1,
135
search_alg=None,
136
concurrent_trials: int = 0,
137
num_samples: int = 0,
138
scheduler=None,
139
log_level: str = "WARN",
140
num_gpus: float | int = 0,
141
num_cpus: float | int = 2,
142
dataframe_save: str = "tune.csv",
143
metric: str = "episode_reward_mean",
144
mode: str | list[str] = "max",
145
max_failures: int = 0,
146
training_iterations: int = 100,
147
checkpoint_num_to_keep: None | int = None,
148
checkpoint_freq: int = 0,
149
reuse_actors: bool = False,
150
):
151
if train_env is not None:
152
register_env(train_env_name, lambda config: train_env)
153
154
self.params = params
155
self.params["framework"] = framework
156
self.params["log_level"] = log_level
157
self.params["num_gpus"] = num_gpus
158
self.params["num_workers"] = num_workers
159
self.params["env"] = train_env_name
160
161
self.run_name = run_name
162
self.local_dir = local_dir
163
self.search_alg = search_alg
164
if concurrent_trials != 0:
165
self.search_alg = ConcurrencyLimiter(
166
self.search_alg, max_concurrent=concurrent_trials
167
)
168
self.scheduler = scheduler
169
self.num_samples = num_samples
170
self.trainable = trainable
171
if isinstance(self.trainable, str):
172
self.trainable.upper()
173
self.num_cpus = num_cpus
174
self.num_gpus = num_gpus
175
self.dataframe_save = dataframe_save
176
self.metric = metric
177
self.mode = mode
178
self.max_failures = max_failures
179
self.training_iterations = training_iterations
180
self.checkpoint_freq = checkpoint_freq
181
self.checkpoint_num_to_keep = checkpoint_num_to_keep
182
self.reuse_actors = reuse_actors
183
184
def train_tune_model(self):
185
"""
186
Tuning and training the model
187
Returns the results object
188
"""
189
ray.init(
190
num_cpus=self.num_cpus, num_gpus=self.num_gpus, ignore_reinit_error=True
191
)
192
193
tuner = tune.Tuner(
194
self.trainable,
195
param_space=self.params,
196
tune_config=TuneConfig(
197
search_alg=self.search_alg,
198
num_samples=self.num_samples,
199
metric=self.metric,
200
mode=self.mode,
201
reuse_actors=self.reuse_actors,
202
),
203
run_config=RunConfig(
204
name=self.run_name,
205
local_dir=self.local_dir,
206
failure_config=FailureConfig(
207
max_failures=self.max_failures, fail_fast=False
208
),
209
stop={"training_iteration": self.training_iterations},
210
checkpoint_config=CheckpointConfig(
211
num_to_keep=self.checkpoint_num_to_keep,
212
checkpoint_score_attribute=self.metric,
213
checkpoint_score_order=self.mode,
214
checkpoint_frequency=self.checkpoint_freq,
215
checkpoint_at_end=True,
216
),
217
verbose=3,
218
),
219
)
220
221
self.results = tuner.fit()
222
if self.search_alg is not None:
223
self.search_alg.save_to_dir(self.local_dir)
224
# ray.shutdown()
225
return self.results
226
227
def infer_results(self, to_dataframe: str = None, mode: str = "a"):
228
"""
229
Get tune results in a dataframe and best results object
230
"""
231
results_df = self.results.get_dataframe()
232
233
if to_dataframe is None:
234
to_dataframe = self.dataframe_save
235
236
results_df.to_csv(to_dataframe, mode=mode)
237
238
best_result = self.results.get_best_result()
239
# best_result = self.results.get_best_result()
240
# best_metric = best_result.metrics
241
# best_checkpoint = best_result.checkpoint
242
# best_trial_dir = best_result.log_dir
243
# results_df = self.results.get_dataframe()
244
245
return results_df, best_result
246
247
def restore_agent(
248
self,
249
checkpoint_path: str = "",
250
restore_search: bool = False,
251
resume_unfinished: bool = True,
252
resume_errored: bool = False,
253
restart_errored: bool = False,
254
):
255
"""
256
Restore errored or stopped trials
257
"""
258
# if restore_search:
259
# self.search_alg = self.search_alg.restore_from_dir(self.local_dir)
260
if checkpoint_path == "":
261
checkpoint_path = self.results.get_best_result().checkpoint._local_path
262
263
restored_agent = tune.Tuner.restore(
264
checkpoint_path,
265
restart_errored=restart_errored,
266
resume_unfinished=resume_unfinished,
267
resume_errored=resume_errored,
268
)
269
print(restored_agent)
270
self.results = restored_agent.fit()
271
272
if self.search_alg is not None:
273
self.search_alg.save_to_dir(self.local_dir)
274
return self.results
275
276
def get_test_agent(self, test_env_name: str, test_env=None, checkpoint=None):
277
"""
278
Get test agent
279
"""
280
if test_env is not None:
281
register_env(test_env_name, lambda config: test_env)
282
283
if checkpoint is None:
284
checkpoint = self.results.get_best_result().checkpoint
285
286
testing_agent = Algorithm.from_checkpoint(checkpoint)
287
# testing_agent.config['env'] = test_env_name
288
289
return testing_agent
290
291