Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/examples/FinRL_StockTrading_2026_2_train.py
1706 views
1
"""
2
Stock NeurIPS2018 Part 2. Train
3
4
This series is a reproduction of paper "Deep reinforcement learning for
5
automated stock trading: An ensemble strategy".
6
7
Introduce how to use FinRL to make data into the gym form environment, and train DRL agents on it.
8
"""
9
10
import pandas as pd
11
from stable_baselines3.common.logger import configure
12
13
from finrl.agents.stablebaselines3.models import DRLAgent
14
from finrl.config import INDICATORS, TRAINED_MODEL_DIR, RESULTS_DIR
15
from finrl.main import check_and_make_directories
16
from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv
17
18
# %% Part 1. Prepare directories
19
20
check_and_make_directories([TRAINED_MODEL_DIR])
21
22
# %% Part 2. Build environment
23
24
train = pd.read_csv("train_data.csv")
25
train = train.set_index(train.columns[0])
26
train.index.names = [""]
27
28
stock_dimension = len(train.tic.unique())
29
state_space = 1 + 2 * stock_dimension + len(INDICATORS) * stock_dimension
30
print(f"Stock Dimension: {stock_dimension}, State Space: {state_space}")
31
32
buy_cost_list = sell_cost_list = [0.001] * stock_dimension
33
num_stock_shares = [0] * stock_dimension
34
35
env_kwargs = {
36
"hmax": 100,
37
"initial_amount": 1000000,
38
"num_stock_shares": num_stock_shares,
39
"buy_cost_pct": buy_cost_list,
40
"sell_cost_pct": sell_cost_list,
41
"state_space": state_space,
42
"stock_dim": stock_dimension,
43
"tech_indicator_list": INDICATORS,
44
"action_space": stock_dimension,
45
"reward_scaling": 1e-4,
46
}
47
48
e_train_gym = StockTradingEnv(df=train, **env_kwargs)
49
env_train, _ = e_train_gym.get_sb_env()
50
print(type(env_train))
51
52
# %% Part 3. Train DRL Agents
53
54
if_using_a2c = True
55
if_using_ddpg = True
56
if_using_ppo = True
57
if_using_td3 = True
58
if_using_sac = True
59
60
# --- Agent 1: A2C ---
61
agent = DRLAgent(env=env_train)
62
model_a2c = agent.get_model("a2c")
63
if if_using_a2c:
64
tmp_path = RESULTS_DIR + "/a2c"
65
new_logger_a2c = configure(tmp_path, ["stdout", "csv", "tensorboard"])
66
model_a2c.set_logger(new_logger_a2c)
67
68
trained_a2c = (
69
agent.train_model(model=model_a2c, tb_log_name="a2c", total_timesteps=20000)
70
if if_using_a2c
71
else None
72
)
73
if if_using_a2c:
74
trained_a2c.save(TRAINED_MODEL_DIR + "/agent_a2c")
75
76
# --- Agent 2: DDPG ---
77
agent = DRLAgent(env=env_train)
78
model_ddpg = agent.get_model("ddpg")
79
if if_using_ddpg:
80
tmp_path = RESULTS_DIR + "/ddpg"
81
new_logger_ddpg = configure(tmp_path, ["stdout", "csv", "tensorboard"])
82
model_ddpg.set_logger(new_logger_ddpg)
83
84
trained_ddpg = (
85
agent.train_model(model=model_ddpg, tb_log_name="ddpg", total_timesteps=20000)
86
if if_using_ddpg
87
else None
88
)
89
if if_using_ddpg:
90
trained_ddpg.save(TRAINED_MODEL_DIR + "/agent_ddpg")
91
92
# --- Agent 3: PPO ---
93
agent = DRLAgent(env=env_train)
94
PPO_PARAMS = {
95
"n_steps": 2048,
96
"ent_coef": 0.01,
97
"learning_rate": 0.00025,
98
"batch_size": 128,
99
}
100
model_ppo = agent.get_model("ppo", model_kwargs=PPO_PARAMS)
101
if if_using_ppo:
102
tmp_path = RESULTS_DIR + "/ppo"
103
new_logger_ppo = configure(tmp_path, ["stdout", "csv", "tensorboard"])
104
model_ppo.set_logger(new_logger_ppo)
105
106
trained_ppo = (
107
agent.train_model(model=model_ppo, tb_log_name="ppo", total_timesteps=20000)
108
if if_using_ppo
109
else None
110
)
111
if if_using_ppo:
112
trained_ppo.save(TRAINED_MODEL_DIR + "/agent_ppo")
113
114
# --- Agent 4: TD3 ---
115
agent = DRLAgent(env=env_train)
116
TD3_PARAMS = {
117
"batch_size": 100,
118
"buffer_size": 1000000,
119
"learning_rate": 0.001,
120
}
121
model_td3 = agent.get_model("td3", model_kwargs=TD3_PARAMS)
122
if if_using_td3:
123
tmp_path = RESULTS_DIR + "/td3"
124
new_logger_td3 = configure(tmp_path, ["stdout", "csv", "tensorboard"])
125
model_td3.set_logger(new_logger_td3)
126
127
trained_td3 = (
128
agent.train_model(model=model_td3, tb_log_name="td3", total_timesteps=20000)
129
if if_using_td3
130
else None
131
)
132
if if_using_td3:
133
trained_td3.save(TRAINED_MODEL_DIR + "/agent_td3")
134
135
# --- Agent 5: SAC ---
136
agent = DRLAgent(env=env_train)
137
SAC_PARAMS = {
138
"batch_size": 128,
139
"buffer_size": 100000,
140
"learning_rate": 0.0001,
141
"learning_starts": 100,
142
"ent_coef": "auto_0.1",
143
}
144
model_sac = agent.get_model("sac", model_kwargs=SAC_PARAMS)
145
if if_using_sac:
146
tmp_path = RESULTS_DIR + "/sac"
147
new_logger_sac = configure(tmp_path, ["stdout", "csv", "tensorboard"])
148
model_sac.set_logger(new_logger_sac)
149
150
trained_sac = (
151
agent.train_model(model=model_sac, tb_log_name="sac", total_timesteps=20000)
152
if if_using_sac
153
else None
154
)
155
if if_using_sac:
156
trained_sac.save(TRAINED_MODEL_DIR + "/agent_sac")
157
158
print("All agents trained and saved to", TRAINED_MODEL_DIR)
159
160