Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/finrl/agents/portfolio_optimization/models.py
732 views
1
"""
2
DRL models to solve the portfolio optimization task with reinforcement learning.
3
This agent was developed to work with environments like PortfolioOptimizationEnv.
4
"""
5
6
from __future__ import annotations
7
8
from .algorithms import PolicyGradient
9
10
MODELS = {"pg": PolicyGradient}
11
12
13
class DRLAgent:
14
"""Implementation for DRL algorithms for portfolio optimization.
15
16
Note:
17
During testing, the agent is optimized through online learning.
18
The parameters of the policy is updated repeatedly after a constant
19
period of time. To disable it, set learning rate to 0.
20
21
Attributes:
22
env: Gym environment class.
23
"""
24
25
def __init__(self, env):
26
"""Agent initialization.
27
28
Args:
29
env: Gym environment to be used in training.
30
"""
31
self.env = env
32
33
def get_model(
34
self, model_name, device="cpu", model_kwargs=None, policy_kwargs=None
35
):
36
"""Setups DRL model.
37
38
Args:
39
model_name: Name of the model according to MODELS list.
40
device: Device used to instantiate neural networks.
41
model_kwargs: Arguments to be passed to model class.
42
policy_kwargs: Arguments to be passed to policy class.
43
44
Note:
45
model_kwargs and policy_kwargs are dictionaries. The keys must be strings
46
with the same names as the class arguments. Example for model_kwargs::
47
48
{ "lr": 0.01, "policy": EIIE }
49
50
Returns:
51
An instance of the model.
52
"""
53
if model_name not in MODELS:
54
raise NotImplementedError("The model requested was not implemented.")
55
56
model = MODELS[model_name]
57
model_kwargs = {} if model_kwargs is None else model_kwargs
58
policy_kwargs = {} if policy_kwargs is None else policy_kwargs
59
60
# add device settings
61
model_kwargs["device"] = device
62
policy_kwargs["device"] = device
63
64
# add policy_kwargs inside model_kwargs
65
model_kwargs["policy_kwargs"] = policy_kwargs
66
67
return model(self.env, **model_kwargs)
68
69
@staticmethod
70
def train_model(model, episodes=100):
71
"""Trains portfolio optimization model.
72
73
Args:
74
model: Instance of the model.
75
episoded: Number of episodes.
76
77
Returns:
78
An instance of the trained model.
79
"""
80
model.train(episodes)
81
return model
82
83
@staticmethod
84
def DRL_validation(
85
model,
86
test_env,
87
policy=None,
88
online_training_period=10,
89
learning_rate=None,
90
optimizer=None,
91
):
92
"""Tests a model in a testing environment.
93
94
Args:
95
model: Instance of the model.
96
test_env: Gym environment to be used in testing.
97
policy: Policy architecture to be used. If None, it will use the training
98
architecture.
99
online_training_period: Period in which an online training will occur. To
100
disable online learning, use a very big value.
101
batch_size: Batch size to train neural network. If None, it will use the
102
training batch size.
103
lr: Policy neural network learning rate. If None, it will use the training
104
learning rate
105
optimizer: Optimizer of neural network. If None, it will use the training
106
optimizer
107
"""
108
model.test(test_env, policy, online_training_period, learning_rate, optimizer)
109
110