Path: blob/master/finrl/agents/elegantrl/models.py
732 views
"""1DRL models from ElegantRL: https://github.com/AI4Finance-Foundation/ElegantRL2"""34from __future__ import annotations56import torch7from elegantrl.agents import *8from elegantrl.train.config import Config9from elegantrl.train.run import train_agent1011MODELS = {12"ddpg": AgentDDPG,13"td3": AgentTD3,14"sac": AgentSAC,15"ppo": AgentPPO,16"a2c": AgentA2C,17}18OFF_POLICY_MODELS = ["ddpg", "td3", "sac"]19ON_POLICY_MODELS = ["ppo"]20# MODEL_KWARGS = {x: config.__dict__[f"{x.upper()}_PARAMS"] for x in MODELS.keys()}21#22# NOISE = {23# "normal": NormalActionNoise,24# "ornstein_uhlenbeck": OrnsteinUhlenbeckActionNoise,25# }262728class DRLAgent:29"""Implementations of DRL algorithms30Attributes31----------32env: gym environment class33user-defined class34Methods35-------36get_model()37setup DRL algorithms38train_model()39train DRL algorithms in a train dataset40and output the trained model41DRL_prediction()42make a prediction in a test dataset and get results43"""4445def __init__(self, env, price_array, tech_array, turbulence_array):46self.env = env47self.price_array = price_array48self.tech_array = tech_array49self.turbulence_array = turbulence_array5051def get_model(self, model_name, model_kwargs):52self.env_config = {53"price_array": self.price_array,54"tech_array": self.tech_array,55"turbulence_array": self.turbulence_array,56"if_train": True,57}58self.model_kwargs = model_kwargs59self.gamma = model_kwargs.get("gamma", 0.985)6061env = self.env62env.env_num = 163agent = MODELS[model_name]64if model_name not in MODELS:65raise NotImplementedError("NotImplementedError")6667stock_dim = self.price_array.shape[1]68self.state_dim = 1 + 2 + 3 * stock_dim + self.tech_array.shape[1]69self.action_dim = stock_dim70self.env_args = {71"env_name": "StockEnv",72"config": self.env_config,73"state_dim": self.state_dim,74"action_dim": self.action_dim,75"if_discrete": False,76"max_step": self.price_array.shape[0] - 1,77}7879model = Config(agent_class=agent, env_class=env, env_args=self.env_args)80model.if_off_policy = model_name in OFF_POLICY_MODELS81if model_kwargs is not None:82try:83model.break_step = int(842e585) # break training if 'total_step > break_step'86model.net_dims = (87128,8864,89) # the middle layer dimension of MultiLayer Perceptron90model.gamma = self.gamma # discount factor of future rewards91model.horizon_len = model.max_step92model.repeat_times = 16 # repeatedly update network using ReplayBuffer to keep critic's loss small93model.learning_rate = model_kwargs.get("learning_rate", 1e-4)94model.state_value_tau = 0.1 # the tau of normalize for value and state `std = (1-std)*std + tau*std`95model.eval_times = model_kwargs.get("eval_times", 2**5)96model.eval_per_step = int(2e4)97except BaseException:98raise ValueError(99"Fail to read arguments, please check 'model_kwargs' input."100)101return model102103def train_model(self, model, cwd, total_timesteps=5000):104model.cwd = cwd105model.break_step = total_timesteps106train_agent(model)107108@staticmethod109def DRL_prediction(model_name, cwd, net_dimension, environment, env_args):110import torch111112gpu_id = 0 # >=0 means GPU ID, -1 means CPU113agent_class = MODELS[model_name]114stock_dim = env_args["price_array"].shape[1]115state_dim = 1 + 2 + 3 * stock_dim + env_args["tech_array"].shape[1]116action_dim = stock_dim117env_args = {118"env_num": 1,119"env_name": "StockEnv",120"state_dim": state_dim,121"action_dim": action_dim,122"if_discrete": False,123"max_step": env_args["price_array"].shape[0] - 1,124"config": env_args,125}126127actor_path = f"{cwd}/act.pth"128net_dim = [2**7]129130"""init"""131env = environment132env_class = env133args = Config(agent_class=agent_class, env_class=env_class, env_args=env_args)134args.cwd = cwd135act = agent_class(136net_dim, env.state_dim, env.action_dim, gpu_id=gpu_id, args=args137).act138parameters_dict = {}139act = torch.load(actor_path)140for name, param in act.named_parameters():141parameters_dict[name] = torch.tensor(param.detach().cpu().numpy())142143act.load_state_dict(parameters_dict)144145if_discrete = env.if_discrete146device = next(act.parameters()).device147state = env.reset()148episode_returns = [] # the cumulative_return / initial_account149episode_total_assets = [env.initial_total_asset]150max_step = env.max_step151for steps in range(max_step):152s_tensor = torch.as_tensor(153state, dtype=torch.float32, device=device154).unsqueeze(0)155a_tensor = act(s_tensor).argmax(dim=1) if if_discrete else act(s_tensor)156action = (157a_tensor.detach().cpu().numpy()[0]158) # not need detach(), because using torch.no_grad() outside159state, reward, done, _ = env.step(action)160total_asset = env.amount + (env.price_ary[env.day] * env.stocks).sum()161episode_total_assets.append(total_asset)162episode_return = total_asset / env.initial_total_asset163episode_returns.append(episode_return)164if done:165break166print("Test Finished!")167print("episode_retuen", episode_return)168return episode_total_assets169170171