Path: blob/master/finrl/meta/env_stock_trading/env_stocktrading.py
732 views
from __future__ import annotations12from typing import List34import gymnasium as gym5import matplotlib6import matplotlib.pyplot as plt7import numpy as np8import pandas as pd9from gymnasium import spaces10from gymnasium.utils import seeding11from stable_baselines3.common.vec_env import DummyVecEnv1213matplotlib.use("Agg")1415# from stable_baselines3.common.logger import Logger, KVWriter, CSVOutputFormat161718class StockTradingEnv(gym.Env):19"""20A stock trading environment for OpenAI gym2122Parameters:23df (pandas.DataFrame): Dataframe containing data24hmax (int): Maximum cash to be traded in each trade per asset.25initial_amount (int): Amount of cash initially available26buy_cost_pct (float, array): Cost for buying shares, each index corresponds to each asset27sell_cost_pct (float, array): Cost for selling shares, each index corresponds to each asset28turbulence_threshold (float): Maximum turbulence allowed in market for purchases to occur. If exceeded, positions are liquidated29print_verbosity(int): When iterating (step), how often to print stats about state of env30"""3132metadata = {"render.modes": ["human"]}3334def __init__(35self,36df: pd.DataFrame,37stock_dim: int,38hmax: int,39initial_amount: int,40num_stock_shares: list[int],41buy_cost_pct: list[float],42sell_cost_pct: list[float],43reward_scaling: float,44state_space: int,45action_space: int,46tech_indicator_list: list[str],47turbulence_threshold=None,48risk_indicator_col="turbulence",49make_plots: bool = False,50print_verbosity=10,51day=0,52initial=True,53previous_state=[],54model_name="",55mode="",56iteration="",57):58self.day = day59self.df = df60self.stock_dim = stock_dim61self.hmax = hmax62self.num_stock_shares = num_stock_shares63self.initial_amount = initial_amount # get the initial cash64self.buy_cost_pct = buy_cost_pct65self.sell_cost_pct = sell_cost_pct66self.reward_scaling = reward_scaling67self.state_space = state_space68self.action_space = action_space69self.tech_indicator_list = tech_indicator_list70self.action_space = spaces.Box(low=-1, high=1, shape=(self.action_space,))71self.observation_space = spaces.Box(72low=-np.inf, high=np.inf, shape=(self.state_space,)73)74self.data = self.df.loc[self.day, :]75self.terminal = False76self.make_plots = make_plots77self.print_verbosity = print_verbosity78self.turbulence_threshold = turbulence_threshold79self.risk_indicator_col = risk_indicator_col80self.initial = initial81self.previous_state = previous_state82self.model_name = model_name83self.mode = mode84self.iteration = iteration85# initalize state86self.state = self._initiate_state()8788# initialize reward89self.reward = 090self.turbulence = 091self.cost = 092self.trades = 093self.episode = 094# memorize all the total balance change95self.asset_memory = [96self.initial_amount97+ np.sum(98np.array(self.num_stock_shares)99* np.array(self.state[1 : 1 + self.stock_dim])100)101] # the initial total asset is calculated by cash + sum (num_share_stock_i * price_stock_i)102self.rewards_memory = []103self.actions_memory = []104self.state_memory = (105[]106) # we need sometimes to preserve the state in the middle of trading process107self.date_memory = [self._get_date()]108# self.logger = Logger('results',[CSVOutputFormat])109# self.reset()110self._seed()111112def _sell_stock(self, index, action):113def _do_sell_normal():114if (115self.state[index + 2 * self.stock_dim + 1] != True116): # check if the stock is able to sell, for simlicity we just add it in techical index117# if self.state[index + 1] > 0: # if we use price<0 to denote a stock is unable to trade in that day, the total asset calculation may be wrong for the price is unreasonable118# Sell only if the price is > 0 (no missing data in this particular date)119# perform sell action based on the sign of the action120if self.state[index + self.stock_dim + 1] > 0:121# Sell only if current asset is > 0122sell_num_shares = min(123abs(action), self.state[index + self.stock_dim + 1]124)125sell_amount = (126self.state[index + 1]127* sell_num_shares128* (1 - self.sell_cost_pct[index])129)130# update balance131self.state[0] += sell_amount132133self.state[index + self.stock_dim + 1] -= sell_num_shares134self.cost += (135self.state[index + 1]136* sell_num_shares137* self.sell_cost_pct[index]138)139self.trades += 1140else:141sell_num_shares = 0142else:143sell_num_shares = 0144145return sell_num_shares146147# perform sell action based on the sign of the action148if self.turbulence_threshold is not None:149if self.turbulence >= self.turbulence_threshold:150if self.state[index + 1] > 0:151# Sell only if the price is > 0 (no missing data in this particular date)152# if turbulence goes over threshold, just clear out all positions153if self.state[index + self.stock_dim + 1] > 0:154# Sell only if current asset is > 0155sell_num_shares = self.state[index + self.stock_dim + 1]156sell_amount = (157self.state[index + 1]158* sell_num_shares159* (1 - self.sell_cost_pct[index])160)161# update balance162self.state[0] += sell_amount163self.state[index + self.stock_dim + 1] = 0164self.cost += (165self.state[index + 1]166* sell_num_shares167* self.sell_cost_pct[index]168)169self.trades += 1170else:171sell_num_shares = 0172else:173sell_num_shares = 0174else:175sell_num_shares = _do_sell_normal()176else:177sell_num_shares = _do_sell_normal()178179return sell_num_shares180181def _buy_stock(self, index, action):182def _do_buy():183if (184self.state[index + 2 * self.stock_dim + 1] != True185): # check if the stock is able to buy186# if self.state[index + 1] >0:187# Buy only if the price is > 0 (no missing data in this particular date)188available_amount = self.state[0] // (189self.state[index + 1] * (1 + self.buy_cost_pct[index])190) # when buying stocks, we should consider the cost of trading when calculating available_amount, or we may be have cash<0191# print('available_amount:{}'.format(available_amount))192193# update balance194buy_num_shares = min(available_amount, action)195buy_amount = (196self.state[index + 1]197* buy_num_shares198* (1 + self.buy_cost_pct[index])199)200self.state[0] -= buy_amount201202self.state[index + self.stock_dim + 1] += buy_num_shares203204self.cost += (205self.state[index + 1] * buy_num_shares * self.buy_cost_pct[index]206)207self.trades += 1208else:209buy_num_shares = 0210211return buy_num_shares212213# perform buy action based on the sign of the action214if self.turbulence_threshold is None:215buy_num_shares = _do_buy()216else:217if self.turbulence < self.turbulence_threshold:218buy_num_shares = _do_buy()219else:220buy_num_shares = 0221pass222223return buy_num_shares224225def _make_plot(self):226plt.plot(self.asset_memory, "r")227plt.savefig(f"results/account_value_trade_{self.episode}.png")228plt.close()229230def step(self, actions):231self.terminal = self.day >= len(self.df.index.unique()) - 1232if self.terminal:233# print(f"Episode: {self.episode}")234if self.make_plots:235self._make_plot()236end_total_asset = self.state[0] + sum(237np.array(self.state[1 : (self.stock_dim + 1)])238* np.array(self.state[(self.stock_dim + 1) : (self.stock_dim * 2 + 1)])239)240df_total_value = pd.DataFrame(self.asset_memory)241tot_reward = (242self.state[0]243+ sum(244np.array(self.state[1 : (self.stock_dim + 1)])245* np.array(246self.state[(self.stock_dim + 1) : (self.stock_dim * 2 + 1)]247)248)249- self.asset_memory[0]250) # initial_amount is only cash part of our initial asset251df_total_value.columns = ["account_value"]252df_total_value["date"] = self.date_memory253df_total_value["daily_return"] = df_total_value["account_value"].pct_change(2541255)256if df_total_value["daily_return"].std() != 0:257sharpe = (258(252**0.5)259* df_total_value["daily_return"].mean()260/ df_total_value["daily_return"].std()261)262df_rewards = pd.DataFrame(self.rewards_memory)263df_rewards.columns = ["account_rewards"]264df_rewards["date"] = self.date_memory[:-1]265if self.episode % self.print_verbosity == 0:266print(f"day: {self.day}, episode: {self.episode}")267print(f"begin_total_asset: {self.asset_memory[0]:0.2f}")268print(f"end_total_asset: {end_total_asset:0.2f}")269print(f"total_reward: {tot_reward:0.2f}")270print(f"total_cost: {self.cost:0.2f}")271print(f"total_trades: {self.trades}")272if df_total_value["daily_return"].std() != 0:273print(f"Sharpe: {sharpe:0.3f}")274print("=================================")275276if (self.model_name != "") and (self.mode != ""):277df_actions = self.save_action_memory()278df_actions.to_csv(279"results/actions_{}_{}_{}.csv".format(280self.mode, self.model_name, self.iteration281)282)283df_total_value.to_csv(284"results/account_value_{}_{}_{}.csv".format(285self.mode, self.model_name, self.iteration286),287index=False,288)289df_rewards.to_csv(290"results/account_rewards_{}_{}_{}.csv".format(291self.mode, self.model_name, self.iteration292),293index=False,294)295plt.plot(self.asset_memory, "r")296plt.savefig(297"results/account_value_{}_{}_{}.png".format(298self.mode, self.model_name, self.iteration299)300)301plt.close()302303# Add outputs to logger interface304# logger.record("environment/portfolio_value", end_total_asset)305# logger.record("environment/total_reward", tot_reward)306# logger.record("environment/total_reward_pct", (tot_reward / (end_total_asset - tot_reward)) * 100)307# logger.record("environment/total_cost", self.cost)308# logger.record("environment/total_trades", self.trades)309310return self.state, self.reward, self.terminal, False, {}311312else:313actions = actions * self.hmax # actions initially is scaled between 0 to 1314actions = actions.astype(315int316) # convert into integer because we can't by fraction of shares317if self.turbulence_threshold is not None:318if self.turbulence >= self.turbulence_threshold:319actions = np.array([-self.hmax] * self.stock_dim)320begin_total_asset = self.state[0] + sum(321np.array(self.state[1 : (self.stock_dim + 1)])322* np.array(self.state[(self.stock_dim + 1) : (self.stock_dim * 2 + 1)])323)324# print("begin_total_asset:{}".format(begin_total_asset))325326argsort_actions = np.argsort(actions)327sell_index = argsort_actions[: np.where(actions < 0)[0].shape[0]]328buy_index = argsort_actions[::-1][: np.where(actions > 0)[0].shape[0]]329330for index in sell_index:331# print(f"Num shares before: {self.state[index+self.stock_dim+1]}")332# print(f'take sell action before : {actions[index]}')333actions[index] = self._sell_stock(index, actions[index]) * (-1)334# print(f'take sell action after : {actions[index]}')335# print(f"Num shares after: {self.state[index+self.stock_dim+1]}")336337for index in buy_index:338# print('take buy action: {}'.format(actions[index]))339actions[index] = self._buy_stock(index, actions[index])340341self.actions_memory.append(actions)342343# state: s -> s+1344self.day += 1345self.data = self.df.loc[self.day, :]346if self.turbulence_threshold is not None:347if len(self.df.tic.unique()) == 1:348self.turbulence = self.data[self.risk_indicator_col]349elif len(self.df.tic.unique()) > 1:350self.turbulence = self.data[self.risk_indicator_col].values[0]351self.state = self._update_state()352353end_total_asset = self.state[0] + sum(354np.array(self.state[1 : (self.stock_dim + 1)])355* np.array(self.state[(self.stock_dim + 1) : (self.stock_dim * 2 + 1)])356)357self.asset_memory.append(end_total_asset)358self.date_memory.append(self._get_date())359self.reward = end_total_asset - begin_total_asset360self.rewards_memory.append(self.reward)361self.reward = self.reward * self.reward_scaling362self.state_memory.append(363self.state364) # add current state in state_recorder for each step365366return self.state, self.reward, self.terminal, False, {}367368def reset(369self,370*,371seed=None,372options=None,373):374# initiate state375self.day = 0376self.data = self.df.loc[self.day, :]377self.state = self._initiate_state()378379if self.initial:380self.asset_memory = [381self.initial_amount382+ np.sum(383np.array(self.num_stock_shares)384* np.array(self.state[1 : 1 + self.stock_dim])385)386]387else:388previous_total_asset = self.previous_state[0] + sum(389np.array(self.state[1 : (self.stock_dim + 1)])390* np.array(391self.previous_state[(self.stock_dim + 1) : (self.stock_dim * 2 + 1)]392)393)394self.asset_memory = [previous_total_asset]395396self.turbulence = 0397self.cost = 0398self.trades = 0399self.terminal = False400# self.iteration=self.iteration401self.rewards_memory = []402self.actions_memory = []403self.date_memory = [self._get_date()]404405self.episode += 1406407return self.state, {}408409def render(self, mode="human", close=False):410return self.state411412def _initiate_state(self):413if self.initial:414# For Initial State415if len(self.df.tic.unique()) > 1:416# for multiple stock417state = (418[self.initial_amount]419+ self.data.close.values.tolist()420+ self.num_stock_shares421+ sum(422(423self.data[tech].values.tolist()424for tech in self.tech_indicator_list425),426[],427)428) # append initial stocks_share to initial state, instead of all zero429else:430# for single stock431state = (432[self.initial_amount]433+ [self.data.close]434+ [0] * self.stock_dim435+ sum(([self.data[tech]] for tech in self.tech_indicator_list), [])436)437else:438# Using Previous State439if len(self.df.tic.unique()) > 1:440# for multiple stock441state = (442[self.previous_state[0]]443+ self.data.close.values.tolist()444+ self.previous_state[445(self.stock_dim + 1) : (self.stock_dim * 2 + 1)446]447+ sum(448(449self.data[tech].values.tolist()450for tech in self.tech_indicator_list451),452[],453)454)455else:456# for single stock457state = (458[self.previous_state[0]]459+ [self.data.close]460+ self.previous_state[461(self.stock_dim + 1) : (self.stock_dim * 2 + 1)462]463+ sum(([self.data[tech]] for tech in self.tech_indicator_list), [])464)465return state466467def _update_state(self):468if len(self.df.tic.unique()) > 1:469# for multiple stock470state = (471[self.state[0]]472+ self.data.close.values.tolist()473+ list(self.state[(self.stock_dim + 1) : (self.stock_dim * 2 + 1)])474+ sum(475(476self.data[tech].values.tolist()477for tech in self.tech_indicator_list478),479[],480)481)482483else:484# for single stock485state = (486[self.state[0]]487+ [self.data.close]488+ list(self.state[(self.stock_dim + 1) : (self.stock_dim * 2 + 1)])489+ sum(([self.data[tech]] for tech in self.tech_indicator_list), [])490)491492return state493494def _get_date(self):495if len(self.df.tic.unique()) > 1:496date = self.data.date.unique()[0]497else:498date = self.data.date499return date500501# add save_state_memory to preserve state in the trading process502def save_state_memory(self):503if len(self.df.tic.unique()) > 1:504# date and close price length must match actions length505date_list = self.date_memory[:-1]506df_date = pd.DataFrame(date_list)507df_date.columns = ["date"]508509state_list = self.state_memory510df_states = pd.DataFrame(511state_list,512columns=[513"cash",514"Bitcoin_price",515"Gold_price",516"Bitcoin_num",517"Gold_num",518"Bitcoin_Disable",519"Gold_Disable",520],521)522df_states.index = df_date.date523# df_actions = pd.DataFrame({'date':date_list,'actions':action_list})524else:525date_list = self.date_memory[:-1]526state_list = self.state_memory527df_states = pd.DataFrame({"date": date_list, "states": state_list})528# print(df_states)529return df_states530531def save_asset_memory(self):532date_list = self.date_memory533asset_list = self.asset_memory534# print(len(date_list))535# print(len(asset_list))536df_account_value = pd.DataFrame(537{"date": date_list, "account_value": asset_list}538)539return df_account_value540541def save_action_memory(self):542if len(self.df.tic.unique()) > 1:543# date and close price length must match actions length544date_list = self.date_memory[:-1]545df_date = pd.DataFrame(date_list)546df_date.columns = ["date"]547548action_list = self.actions_memory549df_actions = pd.DataFrame(action_list)550df_actions.columns = self.data.tic.values551df_actions.index = df_date.date552# df_actions = pd.DataFrame({'date':date_list,'actions':action_list})553else:554date_list = self.date_memory[:-1]555action_list = self.actions_memory556df_actions = pd.DataFrame({"date": date_list, "actions": action_list})557return df_actions558559def _seed(self, seed=None):560self.np_random, seed = seeding.np_random(seed)561return [seed]562563def get_sb_env(self):564e = DummyVecEnv([lambda: self])565obs = e.reset()566return e, obs567568569