Path: blob/master/finrl/agents/rllib/models.py
732 views
# DRL models from RLlib1from __future__ import annotations23import ray4from ray.rllib.algorithms.a2c import a2c5from ray.rllib.algorithms.ddpg import ddpg6from ray.rllib.algorithms.ppo import ppo7from ray.rllib.algorithms.sac import sac8from ray.rllib.algorithms.td3 import td3910MODELS = {"a2c": a2c, "ddpg": ddpg, "td3": td3, "sac": sac, "ppo": ppo}111213# MODEL_KWARGS = {x: config.__dict__[f"{x.upper()}_PARAMS"] for x in MODELS.keys()}141516class DRLAgent:17"""Implementations for DRL algorithms1819Attributes20----------21env: gym environment class22user-defined class23price_array: numpy array24OHLC data25tech_array: numpy array26techical data27turbulence_array: numpy array28turbulence/risk data29Methods30-------31get_model()32setup DRL algorithms33train_model()34train DRL algorithms in a train dataset35and output the trained model36DRL_prediction()37make a prediction in a test dataset and get results38"""3940def __init__(self, env, price_array, tech_array, turbulence_array):41self.env = env42self.price_array = price_array43self.tech_array = tech_array44self.turbulence_array = turbulence_array4546def get_model(47self,48model_name,49# policy="MlpPolicy",50# policy_kwargs=None,51# model_kwargs=None,52):53if model_name not in MODELS:54raise NotImplementedError("NotImplementedError")5556# if model_kwargs is None:57# model_kwargs = MODEL_KWARGS[model_name]5859model = MODELS[model_name]60# get algorithm default configration based on algorithm in RLlib61if model_name == "a2c":62model_config = model.A2C_DEFAULT_CONFIG.copy()63elif model_name == "td3":64model_config = model.TD3_DEFAULT_CONFIG.copy()65else:66model_config = model.DEFAULT_CONFIG.copy()67# pass env, log_level, price_array, tech_array, and turbulence_array to config68model_config["env"] = self.env69model_config["log_level"] = "WARN"70model_config["env_config"] = {71"price_array": self.price_array,72"tech_array": self.tech_array,73"turbulence_array": self.turbulence_array,74"if_train": True,75}7677return model, model_config7879def train_model(80self, model, model_name, model_config, total_episodes=100, init_ray=True81):82if model_name not in MODELS:83raise NotImplementedError("NotImplementedError")84if init_ray:85ray.init(86ignore_reinit_error=True87) # Other Ray APIs will not work until `ray.init()` is called.8889if model_name == "ppo":90trainer = model.PPOTrainer(env=self.env, config=model_config)91elif model_name == "a2c":92trainer = model.A2CTrainer(env=self.env, config=model_config)93elif model_name == "ddpg":94trainer = model.DDPGTrainer(env=self.env, config=model_config)95elif model_name == "td3":96trainer = model.TD3Trainer(env=self.env, config=model_config)97elif model_name == "sac":98trainer = model.SACTrainer(env=self.env, config=model_config)99100for _ in range(total_episodes):101trainer.train()102103ray.shutdown()104105# save the trained model106cwd = "./test_" + str(model_name)107trainer.save(cwd)108109return trainer110111@staticmethod112def DRL_prediction(113model_name,114env,115price_array,116tech_array,117turbulence_array,118agent_path="./test_ppo/checkpoint_000100/checkpoint-100",119):120if model_name not in MODELS:121raise NotImplementedError("NotImplementedError")122123if model_name == "a2c":124model_config = MODELS[model_name].A2C_DEFAULT_CONFIG.copy()125elif model_name == "td3":126model_config = MODELS[model_name].TD3_DEFAULT_CONFIG.copy()127else:128model_config = MODELS[model_name].DEFAULT_CONFIG.copy()129model_config["env"] = env130model_config["log_level"] = "WARN"131model_config["env_config"] = {132"price_array": price_array,133"tech_array": tech_array,134"turbulence_array": turbulence_array,135"if_train": False,136}137env_config = {138"price_array": price_array,139"tech_array": tech_array,140"turbulence_array": turbulence_array,141"if_train": False,142}143env_instance = env(config=env_config)144145# ray.init() # Other Ray APIs will not work until `ray.init()` is called.146if model_name == "ppo":147trainer = MODELS[model_name].PPOTrainer(env=env, config=model_config)148elif model_name == "a2c":149trainer = MODELS[model_name].A2CTrainer(env=env, config=model_config)150elif model_name == "ddpg":151trainer = MODELS[model_name].DDPGTrainer(env=env, config=model_config)152elif model_name == "td3":153trainer = MODELS[model_name].TD3Trainer(env=env, config=model_config)154elif model_name == "sac":155trainer = MODELS[model_name].SACTrainer(env=env, config=model_config)156157try:158trainer.restore(agent_path)159print("Restoring from checkpoint path", agent_path)160except BaseException:161raise ValueError("Fail to load agent!")162163# test on the testing env164state = env_instance.reset()165episode_returns = [] # the cumulative_return / initial_account166episode_total_assets = [env_instance.initial_total_asset]167done = False168while not done:169action = trainer.compute_single_action(state)170state, reward, done, _ = env_instance.step(action)171172total_asset = (173env_instance.amount174+ (env_instance.price_ary[env_instance.day] * env_instance.stocks).sum()175)176episode_total_assets.append(total_asset)177episode_return = total_asset / env_instance.initial_total_asset178episode_returns.append(episode_return)179ray.shutdown()180print("episode return: " + str(episode_return))181print("Test Finished!")182return episode_total_assets183184185