Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/finrl/train.py
728 views
1
from __future__ import annotations
2
3
from finrl.config import ERL_PARAMS
4
from finrl.config import INDICATORS
5
from finrl.config import RLlib_PARAMS
6
from finrl.config import SAC_PARAMS
7
from finrl.config import TRAIN_END_DATE
8
from finrl.config import TRAIN_START_DATE
9
from finrl.config_tickers import DOW_30_TICKER
10
from finrl.meta.data_processor import DataProcessor
11
from finrl.meta.env_stock_trading.env_stocktrading_np import StockTradingEnv
12
13
# construct environment
14
15
16
def train(
17
start_date,
18
end_date,
19
ticker_list,
20
data_source,
21
time_interval,
22
technical_indicator_list,
23
drl_lib,
24
env,
25
model_name,
26
if_vix=True,
27
**kwargs,
28
):
29
# download data
30
dp = DataProcessor(data_source, **kwargs)
31
data = dp.download_data(ticker_list, start_date, end_date, time_interval)
32
data = dp.clean_data(data)
33
data = dp.add_technical_indicator(data, technical_indicator_list)
34
if if_vix:
35
data = dp.add_vix(data)
36
price_array, tech_array, turbulence_array = dp.df_to_array(data, if_vix)
37
env_config = {
38
"price_array": price_array,
39
"tech_array": tech_array,
40
"turbulence_array": turbulence_array,
41
"if_train": True,
42
}
43
env_instance = env(config=env_config)
44
45
# read parameters
46
cwd = kwargs.get("cwd", "./" + str(model_name))
47
48
if drl_lib == "elegantrl":
49
from finrl.agents.elegantrl.models import DRLAgent as DRLAgent_erl
50
51
break_step = kwargs.get("break_step", 1e6)
52
erl_params = kwargs.get("erl_params")
53
agent = DRLAgent_erl(
54
env=env,
55
price_array=price_array,
56
tech_array=tech_array,
57
turbulence_array=turbulence_array,
58
)
59
model = agent.get_model(model_name, model_kwargs=erl_params)
60
trained_model = agent.train_model(
61
model=model, cwd=cwd, total_timesteps=break_step
62
)
63
elif drl_lib == "rllib":
64
total_episodes = kwargs.get("total_episodes", 100)
65
rllib_params = kwargs.get("rllib_params")
66
from finrl.agents.rllib.models import DRLAgent as DRLAgent_rllib
67
68
agent_rllib = DRLAgent_rllib(
69
env=env,
70
price_array=price_array,
71
tech_array=tech_array,
72
turbulence_array=turbulence_array,
73
)
74
model, model_config = agent_rllib.get_model(model_name)
75
model_config["lr"] = rllib_params["lr"]
76
model_config["train_batch_size"] = rllib_params["train_batch_size"]
77
model_config["gamma"] = rllib_params["gamma"]
78
# ray.shutdown()
79
trained_model = agent_rllib.train_model(
80
model=model,
81
model_name=model_name,
82
model_config=model_config,
83
total_episodes=total_episodes,
84
)
85
trained_model.save(cwd)
86
elif drl_lib == "stable_baselines3":
87
total_timesteps = kwargs.get("total_timesteps", 1e6)
88
agent_params = kwargs.get("agent_params")
89
from finrl.agents.stablebaselines3.models import DRLAgent as DRLAgent_sb3
90
91
agent = DRLAgent_sb3(env=env_instance)
92
model = agent.get_model(model_name, model_kwargs=agent_params)
93
trained_model = agent.train_model(
94
model=model, tb_log_name=model_name, total_timesteps=total_timesteps
95
)
96
print("Training is finished!")
97
trained_model.save(cwd)
98
print("Trained model is saved in " + str(cwd))
99
else:
100
raise ValueError("DRL library input is NOT supported. Please check.")
101
102
103
if __name__ == "__main__":
104
env = StockTradingEnv
105
106
# demo for elegantrl
107
kwargs = (
108
{}
109
) # in current meta, with respect yahoofinance, kwargs is {}. For other data sources, such as joinquant, kwargs is not empty
110
train(
111
start_date=TRAIN_START_DATE,
112
end_date=TRAIN_END_DATE,
113
ticker_list=DOW_30_TICKER,
114
data_source="yahoofinance",
115
time_interval="1D",
116
technical_indicator_list=INDICATORS,
117
drl_lib="elegantrl",
118
env=env,
119
model_name="ppo",
120
cwd="./test_ppo",
121
erl_params=ERL_PARAMS,
122
break_step=1e5,
123
kwargs=kwargs,
124
)
125
126
## if users want to use rllib, or stable-baselines3, users can remove the following comments
127
128
# # demo for rllib
129
# import ray
130
# ray.shutdown() # always shutdown previous session if any
131
# train(
132
# start_date=TRAIN_START_DATE,
133
# end_date=TRAIN_END_DATE,
134
# ticker_list=DOW_30_TICKER,
135
# data_source="yahoofinance",
136
# time_interval="1D",
137
# technical_indicator_list=INDICATORS,
138
# drl_lib="rllib",
139
# env=env,
140
# model_name="ppo",
141
# cwd="./test_ppo",
142
# rllib_params=RLlib_PARAMS,
143
# total_episodes=30,
144
# )
145
#
146
# # demo for stable-baselines3
147
# train(
148
# start_date=TRAIN_START_DATE,
149
# end_date=TRAIN_END_DATE,
150
# ticker_list=DOW_30_TICKER,
151
# data_source="yahoofinance",
152
# time_interval="1D",
153
# technical_indicator_list=INDICATORS,
154
# drl_lib="stable_baselines3",
155
# env=env,
156
# model_name="sac",
157
# cwd="./test_sac",
158
# agent_params=SAC_PARAMS,
159
# total_timesteps=1e4,
160
# )
161
162