Path: blob/master/finrl/agents/rllib/drllibv2.py
732 views
# @Author: Astarag Mohapatra1from __future__ import annotations23import ray45assert (6ray.__version__ > "2.0.0"7), "Please install ray 2.2.0 by doing 'pip install ray[rllib] ray[tune] lz4' , lz4 is for population based tuning"8from pprint import pprint910from ray import tune11from ray.tune.search import ConcurrencyLimiter12from ray.rllib.algorithms import Algorithm13from ray.tune import register_env1415from ray.air import RunConfig, FailureConfig, ScalingConfig16from ray.tune.tune_config import TuneConfig17from ray.air.config import CheckpointConfig1819import psutil2021psutil_memory_in_bytes = psutil.virtual_memory().total22ray._private.utils.get_system_memory = lambda: psutil_memory_in_bytes23from typing import Dict, Optional, Any, List, Union242526class DRLlibv2:27"""28It instantiates RLlib model with Ray tune functionality29Params30-------------------------------------31trainable:32Any Trainable class that takes config as parameter33train_env:34Training environment instance35train_env_name: str36Name of the training environment37params: dict38hyperparameters dictionary39run_name: str40tune run name41framework: str42"torch" or "tf" for tensorflow43local_dir: str44to save the results and tensorboard plots45num_workers: int46number of workers47search_alg48search space for hyperparameters49concurrent_trials:50Number of concurrent hyperparameters trial to run51num_samples: int52Number of samples of hyperparameters config to run53scheduler:54Stopping suboptimal trials55log_level: str = "WARN",56Verbosity: "DEBUG"57num_gpus: Union[float, int] = 058GPUs for trial59num_cpus: Union[float, int] = 260CPUs for rollout collection61dataframe_save: str62Saving the tune results63metric: str64Metric for hyperparameter optimization in Bayesian Methods65mode: str66Maximize or Minimize the metric67max_failures: int68Number of failures to TuneError69training_iterations: str70Number of times session.report() is called71checkpoint_num_to_keep: int72Number of checkpoints to keep73checkpoint_freq: int74Checkpoint freq wrt training iterations75reuse_actors:bool76Reuse actors for tuning7778It has the following methods:79Methods80-------------------------------------81train_tune_model: It takes in the params dictionary and fits in sklearn style to our trainable class82restore_agent: It restores previously errored or stopped trials or experiments83infer_results: It returns the results dataframe and trial informations84get_test_agent: It returns the testing agent for inference8586Example87---------------------------------------88def sample_ppo_params():89return {90"entropy_coeff": tune.loguniform(0.00000001, 0.1),91"lr": tune.loguniform(5e-5, 0.001),92"sgd_minibatch_size": tune.choice([ 32, 64, 128, 256, 512]),93"lambda": tune.choice([0.1,0.3,0.5,0.7,0.9,1.0]),94}95optuna_search = OptunaSearch(96metric="episode_reward_mean",97mode="max")98drl_agent = DRLlibv2(99trainable="PPO",100train_env=env(train_env_config),101train_env_name="StockTrading_train",102framework="torch",103num_workers=1,104log_level="DEBUG",105run_name = 'test',106local_dir = "test",107params = sample_ppo_params(),108num_samples = 1,109num_gpus=1,110training_iterations=10,111search_alg = optuna_search,112checkpoint_freq=5113)114#Tune or train the model115res = drl_agent.train_tune_model()116117#Get the tune results118results_df, best_result = drl_agent.infer_results()119120#Get the best testing agent121test_agent = drl_agent.get_test_agent(test_env_instance,'StockTrading_testenv')122"""123124def __init__(125self,126trainable: str | Any,127train_env_name: str,128params: dict,129train_env=None,130run_name: str = "tune_run",131framework: str = "torch",132local_dir: str = "tune_results",133num_workers: int = 1,134search_alg=None,135concurrent_trials: int = 0,136num_samples: int = 0,137scheduler=None,138log_level: str = "WARN",139num_gpus: float | int = 0,140num_cpus: float | int = 2,141dataframe_save: str = "tune.csv",142metric: str = "episode_reward_mean",143mode: str | list[str] = "max",144max_failures: int = 0,145training_iterations: int = 100,146checkpoint_num_to_keep: None | int = None,147checkpoint_freq: int = 0,148reuse_actors: bool = False,149):150if train_env is not None:151register_env(train_env_name, lambda config: train_env)152153self.params = params154self.params["framework"] = framework155self.params["log_level"] = log_level156self.params["num_gpus"] = num_gpus157self.params["num_workers"] = num_workers158self.params["env"] = train_env_name159160self.run_name = run_name161self.local_dir = local_dir162self.search_alg = search_alg163if concurrent_trials != 0:164self.search_alg = ConcurrencyLimiter(165self.search_alg, max_concurrent=concurrent_trials166)167self.scheduler = scheduler168self.num_samples = num_samples169self.trainable = trainable170if isinstance(self.trainable, str):171self.trainable.upper()172self.num_cpus = num_cpus173self.num_gpus = num_gpus174self.dataframe_save = dataframe_save175self.metric = metric176self.mode = mode177self.max_failures = max_failures178self.training_iterations = training_iterations179self.checkpoint_freq = checkpoint_freq180self.checkpoint_num_to_keep = checkpoint_num_to_keep181self.reuse_actors = reuse_actors182183def train_tune_model(self):184"""185Tuning and training the model186Returns the results object187"""188ray.init(189num_cpus=self.num_cpus, num_gpus=self.num_gpus, ignore_reinit_error=True190)191192tuner = tune.Tuner(193self.trainable,194param_space=self.params,195tune_config=TuneConfig(196search_alg=self.search_alg,197num_samples=self.num_samples,198metric=self.metric,199mode=self.mode,200reuse_actors=self.reuse_actors,201),202run_config=RunConfig(203name=self.run_name,204local_dir=self.local_dir,205failure_config=FailureConfig(206max_failures=self.max_failures, fail_fast=False207),208stop={"training_iteration": self.training_iterations},209checkpoint_config=CheckpointConfig(210num_to_keep=self.checkpoint_num_to_keep,211checkpoint_score_attribute=self.metric,212checkpoint_score_order=self.mode,213checkpoint_frequency=self.checkpoint_freq,214checkpoint_at_end=True,215),216verbose=3,217),218)219220self.results = tuner.fit()221if self.search_alg is not None:222self.search_alg.save_to_dir(self.local_dir)223# ray.shutdown()224return self.results225226def infer_results(self, to_dataframe: str = None, mode: str = "a"):227"""228Get tune results in a dataframe and best results object229"""230results_df = self.results.get_dataframe()231232if to_dataframe is None:233to_dataframe = self.dataframe_save234235results_df.to_csv(to_dataframe, mode=mode)236237best_result = self.results.get_best_result()238# best_result = self.results.get_best_result()239# best_metric = best_result.metrics240# best_checkpoint = best_result.checkpoint241# best_trial_dir = best_result.log_dir242# results_df = self.results.get_dataframe()243244return results_df, best_result245246def restore_agent(247self,248checkpoint_path: str = "",249restore_search: bool = False,250resume_unfinished: bool = True,251resume_errored: bool = False,252restart_errored: bool = False,253):254"""255Restore errored or stopped trials256"""257# if restore_search:258# self.search_alg = self.search_alg.restore_from_dir(self.local_dir)259if checkpoint_path == "":260checkpoint_path = self.results.get_best_result().checkpoint._local_path261262restored_agent = tune.Tuner.restore(263checkpoint_path,264restart_errored=restart_errored,265resume_unfinished=resume_unfinished,266resume_errored=resume_errored,267)268print(restored_agent)269self.results = restored_agent.fit()270271if self.search_alg is not None:272self.search_alg.save_to_dir(self.local_dir)273return self.results274275def get_test_agent(self, test_env_name: str, test_env=None, checkpoint=None):276"""277Get test agent278"""279if test_env is not None:280register_env(test_env_name, lambda config: test_env)281282if checkpoint is None:283checkpoint = self.results.get_best_result().checkpoint284285testing_agent = Algorithm.from_checkpoint(checkpoint)286# testing_agent.config['env'] = test_env_name287288return testing_agent289290291