Path: blob/master/finrl/agents/portfolio_optimization/models.py
732 views
"""1DRL models to solve the portfolio optimization task with reinforcement learning.2This agent was developed to work with environments like PortfolioOptimizationEnv.3"""45from __future__ import annotations67from .algorithms import PolicyGradient89MODELS = {"pg": PolicyGradient}101112class DRLAgent:13"""Implementation for DRL algorithms for portfolio optimization.1415Note:16During testing, the agent is optimized through online learning.17The parameters of the policy is updated repeatedly after a constant18period of time. To disable it, set learning rate to 0.1920Attributes:21env: Gym environment class.22"""2324def __init__(self, env):25"""Agent initialization.2627Args:28env: Gym environment to be used in training.29"""30self.env = env3132def get_model(33self, model_name, device="cpu", model_kwargs=None, policy_kwargs=None34):35"""Setups DRL model.3637Args:38model_name: Name of the model according to MODELS list.39device: Device used to instantiate neural networks.40model_kwargs: Arguments to be passed to model class.41policy_kwargs: Arguments to be passed to policy class.4243Note:44model_kwargs and policy_kwargs are dictionaries. The keys must be strings45with the same names as the class arguments. Example for model_kwargs::4647{ "lr": 0.01, "policy": EIIE }4849Returns:50An instance of the model.51"""52if model_name not in MODELS:53raise NotImplementedError("The model requested was not implemented.")5455model = MODELS[model_name]56model_kwargs = {} if model_kwargs is None else model_kwargs57policy_kwargs = {} if policy_kwargs is None else policy_kwargs5859# add device settings60model_kwargs["device"] = device61policy_kwargs["device"] = device6263# add policy_kwargs inside model_kwargs64model_kwargs["policy_kwargs"] = policy_kwargs6566return model(self.env, **model_kwargs)6768@staticmethod69def train_model(model, episodes=100):70"""Trains portfolio optimization model.7172Args:73model: Instance of the model.74episoded: Number of episodes.7576Returns:77An instance of the trained model.78"""79model.train(episodes)80return model8182@staticmethod83def DRL_validation(84model,85test_env,86policy=None,87online_training_period=10,88learning_rate=None,89optimizer=None,90):91"""Tests a model in a testing environment.9293Args:94model: Instance of the model.95test_env: Gym environment to be used in testing.96policy: Policy architecture to be used. If None, it will use the training97architecture.98online_training_period: Period in which an online training will occur. To99disable online learning, use a very big value.100batch_size: Batch size to train neural network. If None, it will use the101training batch size.102lr: Policy neural network learning rate. If None, it will use the training103learning rate104optimizer: Optimizer of neural network. If None, it will use the training105optimizer106"""107model.test(test_env, policy, online_training_period, learning_rate, optimizer)108109110