Path: blob/master/examples/FinRL_StockTrading_2026_2_train.py
1706 views
"""1Stock NeurIPS2018 Part 2. Train23This series is a reproduction of paper "Deep reinforcement learning for4automated stock trading: An ensemble strategy".56Introduce how to use FinRL to make data into the gym form environment, and train DRL agents on it.7"""89import pandas as pd10from stable_baselines3.common.logger import configure1112from finrl.agents.stablebaselines3.models import DRLAgent13from finrl.config import INDICATORS, TRAINED_MODEL_DIR, RESULTS_DIR14from finrl.main import check_and_make_directories15from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv1617# %% Part 1. Prepare directories1819check_and_make_directories([TRAINED_MODEL_DIR])2021# %% Part 2. Build environment2223train = pd.read_csv("train_data.csv")24train = train.set_index(train.columns[0])25train.index.names = [""]2627stock_dimension = len(train.tic.unique())28state_space = 1 + 2 * stock_dimension + len(INDICATORS) * stock_dimension29print(f"Stock Dimension: {stock_dimension}, State Space: {state_space}")3031buy_cost_list = sell_cost_list = [0.001] * stock_dimension32num_stock_shares = [0] * stock_dimension3334env_kwargs = {35"hmax": 100,36"initial_amount": 1000000,37"num_stock_shares": num_stock_shares,38"buy_cost_pct": buy_cost_list,39"sell_cost_pct": sell_cost_list,40"state_space": state_space,41"stock_dim": stock_dimension,42"tech_indicator_list": INDICATORS,43"action_space": stock_dimension,44"reward_scaling": 1e-4,45}4647e_train_gym = StockTradingEnv(df=train, **env_kwargs)48env_train, _ = e_train_gym.get_sb_env()49print(type(env_train))5051# %% Part 3. Train DRL Agents5253if_using_a2c = True54if_using_ddpg = True55if_using_ppo = True56if_using_td3 = True57if_using_sac = True5859# --- Agent 1: A2C ---60agent = DRLAgent(env=env_train)61model_a2c = agent.get_model("a2c")62if if_using_a2c:63tmp_path = RESULTS_DIR + "/a2c"64new_logger_a2c = configure(tmp_path, ["stdout", "csv", "tensorboard"])65model_a2c.set_logger(new_logger_a2c)6667trained_a2c = (68agent.train_model(model=model_a2c, tb_log_name="a2c", total_timesteps=20000)69if if_using_a2c70else None71)72if if_using_a2c:73trained_a2c.save(TRAINED_MODEL_DIR + "/agent_a2c")7475# --- Agent 2: DDPG ---76agent = DRLAgent(env=env_train)77model_ddpg = agent.get_model("ddpg")78if if_using_ddpg:79tmp_path = RESULTS_DIR + "/ddpg"80new_logger_ddpg = configure(tmp_path, ["stdout", "csv", "tensorboard"])81model_ddpg.set_logger(new_logger_ddpg)8283trained_ddpg = (84agent.train_model(model=model_ddpg, tb_log_name="ddpg", total_timesteps=20000)85if if_using_ddpg86else None87)88if if_using_ddpg:89trained_ddpg.save(TRAINED_MODEL_DIR + "/agent_ddpg")9091# --- Agent 3: PPO ---92agent = DRLAgent(env=env_train)93PPO_PARAMS = {94"n_steps": 2048,95"ent_coef": 0.01,96"learning_rate": 0.00025,97"batch_size": 128,98}99model_ppo = agent.get_model("ppo", model_kwargs=PPO_PARAMS)100if if_using_ppo:101tmp_path = RESULTS_DIR + "/ppo"102new_logger_ppo = configure(tmp_path, ["stdout", "csv", "tensorboard"])103model_ppo.set_logger(new_logger_ppo)104105trained_ppo = (106agent.train_model(model=model_ppo, tb_log_name="ppo", total_timesteps=20000)107if if_using_ppo108else None109)110if if_using_ppo:111trained_ppo.save(TRAINED_MODEL_DIR + "/agent_ppo")112113# --- Agent 4: TD3 ---114agent = DRLAgent(env=env_train)115TD3_PARAMS = {116"batch_size": 100,117"buffer_size": 1000000,118"learning_rate": 0.001,119}120model_td3 = agent.get_model("td3", model_kwargs=TD3_PARAMS)121if if_using_td3:122tmp_path = RESULTS_DIR + "/td3"123new_logger_td3 = configure(tmp_path, ["stdout", "csv", "tensorboard"])124model_td3.set_logger(new_logger_td3)125126trained_td3 = (127agent.train_model(model=model_td3, tb_log_name="td3", total_timesteps=20000)128if if_using_td3129else None130)131if if_using_td3:132trained_td3.save(TRAINED_MODEL_DIR + "/agent_td3")133134# --- Agent 5: SAC ---135agent = DRLAgent(env=env_train)136SAC_PARAMS = {137"batch_size": 128,138"buffer_size": 100000,139"learning_rate": 0.0001,140"learning_starts": 100,141"ent_coef": "auto_0.1",142}143model_sac = agent.get_model("sac", model_kwargs=SAC_PARAMS)144if if_using_sac:145tmp_path = RESULTS_DIR + "/sac"146new_logger_sac = configure(tmp_path, ["stdout", "csv", "tensorboard"])147model_sac.set_logger(new_logger_sac)148149trained_sac = (150agent.train_model(model=model_sac, tb_log_name="sac", total_timesteps=20000)151if if_using_sac152else None153)154if if_using_sac:155trained_sac.save(TRAINED_MODEL_DIR + "/agent_sac")156157print("All agents trained and saved to", TRAINED_MODEL_DIR)158159160