Path: blob/master/finrl/meta/env_stock_trading/env_stocktrading_cashpenalty.py
732 views
from __future__ import annotations12import random3import time4from copy import deepcopy56import gym7import matplotlib8import numpy as np9import pandas as pd10from gym import spaces11from stable_baselines3.common import logger12from stable_baselines3.common.vec_env import DummyVecEnv13from stable_baselines3.common.vec_env import SubprocVecEnv1415matplotlib.use("Agg")161718class StockTradingEnvCashpenalty(gym.Env):19"""20A stock trading environment for OpenAI gym21This environment penalizes the model for not maintaining a reserve of cash.22This enables the model to manage cash reserves in addition to performing trading procedures.23Reward at any step is given as follows24r_i = (sum(cash, asset_value) - initial_cash - max(0, sum(cash, asset_value)*cash_penalty_proportion-cash))/(days_elapsed)25This reward function takes into account a liquidity requirement, as well as long-term accrued rewards.26Parameters:27df (pandas.DataFrame): Dataframe containing data28buy_cost_pct (float): cost for buying shares29sell_cost_pct (float): cost for selling shares30hmax (int, array): maximum cash to be traded in each trade per asset. If an array is provided, then each index correspond to each asset31discrete_actions (bool): option to choose whether perform dicretization on actions space or not32shares_increment (int): multiples number of shares can be bought in each trade. Only applicable if discrete_actions=True33turbulence_threshold (float): Maximum turbulence allowed in market for purchases to occur. If exceeded, positions are liquidated34print_verbosity(int): When iterating (step), how often to print stats about state of env35initial_amount: (int, float): Amount of cash initially available36daily_information_columns (list(str)): Columns to use when building state space from the dataframe. It could be OHLC columns or any other variables such as technical indicators and turbulence index37cash_penalty_proportion (int, float): Penalty to apply if the algorithm runs out of cash38patient (bool): option to choose whether end the cycle when we're running out of cash or just don't buy anything until we got additional cash3940RL Inputs and Outputs41action space: [<n_assets>,] in range {-1, 1}42state space: {start_cash, [shares_i for in in assets], [[indicator_j for j in indicators] for i in assets]]}43TODO:44Organize functions45Write README46Document tests47"""4849metadata = {"render.modes": ["human"]}5051def __init__(52self,53df,54buy_cost_pct=3e-3,55sell_cost_pct=3e-3,56date_col_name="date",57hmax=10,58discrete_actions=False,59shares_increment=1,60turbulence_threshold=None,61print_verbosity=10,62initial_amount=1e6,63daily_information_cols=["open", "close", "high", "low", "volume"],64cache_indicator_data=True,65cash_penalty_proportion=0.1,66random_start=True,67patient=False,68currency="$",69):70self.df = df71self.stock_col = "tic"72self.assets = df[self.stock_col].unique()73self.dates = df[date_col_name].sort_values().unique()74self.random_start = random_start75self.discrete_actions = discrete_actions76self.patient = patient77self.currency = currency78self.df = self.df.set_index(date_col_name)79self.shares_increment = shares_increment80self.hmax = hmax81self.initial_amount = initial_amount82self.print_verbosity = print_verbosity83self.buy_cost_pct = buy_cost_pct84self.sell_cost_pct = sell_cost_pct85self.turbulence_threshold = turbulence_threshold86self.daily_information_cols = daily_information_cols87self.state_space = (881 + len(self.assets) + len(self.assets) * len(self.daily_information_cols)89)90self.action_space = spaces.Box(low=-1, high=1, shape=(len(self.assets),))91self.observation_space = spaces.Box(92low=-np.inf, high=np.inf, shape=(self.state_space,)93)94self.turbulence = 095self.episode = -1 # initialize so we can call reset96self.episode_history = []97self.printed_header = False98self.cache_indicator_data = cache_indicator_data99self.cached_data = None100self.cash_penalty_proportion = cash_penalty_proportion101if self.cache_indicator_data:102print("caching data")103self.cached_data = [104self.get_date_vector(i) for i, _ in enumerate(self.dates)105]106print("data cached!")107108def seed(self, seed=None):109if seed is None:110seed = int(round(time.time() * 1000))111random.seed(seed)112113@property114def current_step(self):115return self.date_index - self.starting_point116117@property118def cash_on_hand(self):119# amount of cash held at current timestep120return self.state_memory[-1][0]121122@property123def holdings(self):124# Quantity of shares held at current timestep125return self.state_memory[-1][1 : len(self.assets) + 1]126127@property128def closings(self):129return np.array(self.get_date_vector(self.date_index, cols=["close"]))130131def reset(132self,133*,134seed=None,135options=None,136):137self.seed()138self.sum_trades = 0139if self.random_start:140starting_point = random.choice(range(int(len(self.dates) * 0.5)))141self.starting_point = starting_point142else:143self.starting_point = 0144self.date_index = self.starting_point145self.turbulence = 0146self.episode += 1147self.actions_memory = []148self.transaction_memory = []149self.state_memory = []150self.account_information = {151"cash": [],152"asset_value": [],153"total_assets": [],154"reward": [],155}156init_state = np.array(157[self.initial_amount]158+ [0] * len(self.assets)159+ self.get_date_vector(self.date_index)160)161self.state_memory.append(init_state)162return init_state163164def get_date_vector(self, date, cols=None):165if (cols is None) and (self.cached_data is not None):166return self.cached_data[date]167else:168date = self.dates[date]169if cols is None:170cols = self.daily_information_cols171trunc_df = self.df.loc[[date]]172v = []173for a in self.assets:174subset = trunc_df[trunc_df[self.stock_col] == a]175v += subset.loc[date, cols].tolist()176assert len(v) == len(self.assets) * len(cols)177return v178179def return_terminal(self, reason="Last Date", reward=0):180state = self.state_memory[-1]181self.log_step(reason=reason, terminal_reward=reward)182# Add outputs to logger interface183gl_pct = self.account_information["total_assets"][-1] / self.initial_amount184logger.record("environment/GainLoss_pct", (gl_pct - 1) * 100)185logger.record(186"environment/total_assets",187int(self.account_information["total_assets"][-1]),188)189reward_pct = self.account_information["total_assets"][-1] / self.initial_amount190logger.record("environment/total_reward_pct", (reward_pct - 1) * 100)191logger.record("environment/total_trades", self.sum_trades)192logger.record(193"environment/avg_daily_trades",194self.sum_trades / (self.current_step),195)196logger.record(197"environment/avg_daily_trades_per_asset",198self.sum_trades / (self.current_step) / len(self.assets),199)200logger.record("environment/completed_steps", self.current_step)201logger.record(202"environment/sum_rewards", np.sum(self.account_information["reward"])203)204logger.record(205"environment/cash_proportion",206self.account_information["cash"][-1]207/ self.account_information["total_assets"][-1],208)209return state, reward, True, {}210211def log_step(self, reason, terminal_reward=None):212if terminal_reward is None:213terminal_reward = self.account_information["reward"][-1]214cash_pct = (215self.account_information["cash"][-1]216/ self.account_information["total_assets"][-1]217)218gl_pct = self.account_information["total_assets"][-1] / self.initial_amount219rec = [220self.episode,221self.date_index - self.starting_point,222reason,223f"{self.currency}{'{:0,.0f}'.format(float(self.account_information['cash'][-1]))}",224f"{self.currency}{'{:0,.0f}'.format(float(self.account_information['total_assets'][-1]))}",225f"{terminal_reward*100:0.5f}%",226f"{(gl_pct - 1)*100:0.5f}%",227f"{cash_pct*100:0.2f}%",228]229self.episode_history.append(rec)230print(self.template.format(*rec))231232def log_header(self):233if self.printed_header is False:234self.template = "{0:4}|{1:4}|{2:15}|{3:15}|{4:15}|{5:10}|{6:10}|{7:10}" # column widths: 8, 10, 15, 7, 10235print(236self.template.format(237"EPISODE",238"STEPS",239"TERMINAL_REASON",240"CASH",241"TOT_ASSETS",242"TERMINAL_REWARD_unsc",243"GAINLOSS_PCT",244"CASH_PROPORTION",245)246)247self.printed_header = True248249def get_reward(self):250if self.current_step == 0:251return 0252else:253assets = self.account_information["total_assets"][-1]254cash = self.account_information["cash"][-1]255cash_penalty = max(0, (assets * self.cash_penalty_proportion - cash))256assets -= cash_penalty257reward = (assets / self.initial_amount) - 1258reward /= self.current_step259return reward260261def get_transactions(self, actions):262"""263This function takes in a raw 'action' from the model and makes it into realistic transactions264This function includes logic for discretizing265It also includes turbulence logic.266"""267# record actions of the model268self.actions_memory.append(actions)269270# multiply actions by the hmax value271actions = actions * self.hmax272273# Do nothing for shares with zero value274actions = np.where(self.closings > 0, actions, 0)275276# discretize optionally277if self.discrete_actions:278# convert into integer because we can't buy fraction of shares279actions = actions // self.closings280actions = actions.astype(int)281# round down actions to the nearest multiplies of shares_increment282actions = np.where(283actions >= 0,284(actions // self.shares_increment) * self.shares_increment,285((actions + self.shares_increment) // self.shares_increment)286* self.shares_increment,287)288else:289actions = actions / self.closings290291# can't sell more than we have292actions = np.maximum(actions, -np.array(self.holdings))293294# deal with turbulence295if self.turbulence_threshold is not None:296# if turbulence goes over threshold, just clear out all positions297if self.turbulence >= self.turbulence_threshold:298actions = -(np.array(self.holdings))299self.log_step(reason="TURBULENCE")300301return actions302303def step(self, actions):304# let's just log what we're doing in terms of max actions at each step.305self.sum_trades += np.sum(np.abs(actions))306self.log_header()307# print if it's time.308if (self.current_step + 1) % self.print_verbosity == 0:309self.log_step(reason="update")310# if we're at the end311if self.date_index == len(self.dates) - 1:312# if we hit the end, set reward to total gains (or losses)313return self.return_terminal(reward=self.get_reward())314else:315"""316First, we need to compute values of holdings, save these, and log everything.317Then we can reward our model for its earnings.318"""319# compute value of cash + assets320begin_cash = self.cash_on_hand321assert min(self.holdings) >= 0322asset_value = np.dot(self.holdings, self.closings)323# log the values of cash, assets, and total assets324self.account_information["cash"].append(begin_cash)325self.account_information["asset_value"].append(asset_value)326self.account_information["total_assets"].append(begin_cash + asset_value)327328# compute reward once we've computed the value of things!329reward = self.get_reward()330self.account_information["reward"].append(reward)331332# Now, let's get down to business at hand.333transactions = self.get_transactions(actions)334335# compute our proceeds from sells, and add to cash336sells = -np.clip(transactions, -np.inf, 0)337proceeds = np.dot(sells, self.closings)338costs = proceeds * self.sell_cost_pct339coh = begin_cash + proceeds340# compute the cost of our buys341buys = np.clip(transactions, 0, np.inf)342spend = np.dot(buys, self.closings)343costs += spend * self.buy_cost_pct344# if we run out of cash...345if (spend + costs) > coh:346if self.patient:347# ... just don't buy anything until we got additional cash348self.log_step(reason="CASH SHORTAGE")349transactions = np.where(transactions > 0, 0, transactions)350spend = 0351costs = 0352else:353# ... end the cycle and penalize354return self.return_terminal(355reason="CASH SHORTAGE", reward=self.get_reward()356)357self.transaction_memory.append(358transactions359) # capture what the model's could do360# verify we didn't do anything impossible here361assert (spend + costs) <= coh362# update our holdings363coh = coh - spend - costs364holdings_updated = self.holdings + transactions365self.date_index += 1366if self.turbulence_threshold is not None:367self.turbulence = self.get_date_vector(368self.date_index, cols=["turbulence"]369)[0]370# Update State371state = (372[coh] + list(holdings_updated) + self.get_date_vector(self.date_index)373)374self.state_memory.append(state)375return state, reward, False, {}376377def get_sb_env(self):378def get_self():379return deepcopy(self)380381e = DummyVecEnv([get_self])382obs = e.reset()383return e, obs384385def get_multiproc_env(self, n=10):386def get_self():387return deepcopy(self)388389e = SubprocVecEnv([get_self for _ in range(n)], start_method="fork")390obs = e.reset()391return e, obs392393def save_asset_memory(self):394if self.current_step == 0:395return None396else:397self.account_information["date"] = self.dates[398-len(self.account_information["cash"]) :399]400return pd.DataFrame(self.account_information)401402def save_action_memory(self):403if self.current_step == 0:404return None405else:406return pd.DataFrame(407{408"date": self.dates[-len(self.account_information["cash"]) :],409"actions": self.actions_memory,410"transactions": self.transaction_memory,411}412)413414415