Path: blob/master/finrl/applications/stock_trading/fundamental_stock_trading.py
732 views
from __future__ import annotations123def main():4import pandas as pd5import numpy as np6import matplotlib7import matplotlib.pyplot as plt89# matplotlib.use('Agg')10import datetime1112from finrl import config13from finrl import config_tickers14from finrl.meta.preprocessor.yahoodownloader import YahooDownloader15from finrl.meta.preprocessor.preprocessors import FeatureEngineer, data_split16from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv17from finrl.agents.stablebaselines3.models import DRLAgent18from finrl.plot import backtest_stats, backtest_plot, get_daily_return, get_baseline19from finrl.main import check_and_make_directories20from pprint import pprint21from stable_baselines3.common.logger import configure22import sys2324sys.path.append("../FinRL")2526import itertools2728from finrl.config import (29DATA_SAVE_DIR,30TRAINED_MODEL_DIR,31TENSORBOARD_LOG_DIR,32RESULTS_DIR,33INDICATORS,34TRAIN_START_DATE,35TRAIN_END_DATE,36TEST_START_DATE,37TEST_END_DATE,38TRADE_START_DATE,39TRADE_END_DATE,40)4142from finrl.config_tickers import DOW_30_TICKER4344check_and_make_directories(45[DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR]46)4748print(DOW_30_TICKER)4950TRAIN_START_DATE = "2009-01-01"51TRAIN_END_DATE = "2019-01-01"52TEST_START_DATE = "2019-01-01"53TEST_END_DATE = "2021-01-01"5455df = YahooDownloader(56start_date=TRAIN_START_DATE, end_date=TEST_END_DATE, ticker_list=DOW_30_TICKER57).fetch_data()5859df["date"] = pd.to_datetime(df["date"], format="%Y-%m-%d")6061df.sort_values(["date", "tic"], ignore_index=True).head()6263# Import fundamental data from my GitHub repository64url = "https://raw.githubusercontent.com/mariko-sawada/FinRL_with_fundamental_data/main/dow_30_fundamental_wrds.csv"6566fund = pd.read_csv(url)6768# List items that are used to calculate financial ratios6970items = [71"datadate", # Date72"tic", # Ticker73"oiadpq", # Quarterly operating income74"revtq", # Quartely revenue75"niq", # Quartely net income76"atq", # Total asset77"teqq", # Shareholder's equity78"epspiy", # EPS(Basic) incl. Extraordinary items79"ceqq", # Common Equity80"cshoq", # Common Shares Outstanding81"dvpspq", # Dividends per share82"actq", # Current assets83"lctq", # Current liabilities84"cheq", # Cash & Equivalent85"rectq", # Recievalbles86"cogsq", # Cost of Goods Sold87"invtq", # Inventories88"apq", # Account payable89"dlttq", # Long term debt90"dlcq", # Debt in current liabilites91"ltq", # Liabilities92]9394# Omit items that will not be used95fund_data = fund[items]9697# Rename column names for the sake of readability98fund_data = fund_data.rename(99columns={100"datadate": "date", # Date101"oiadpq": "op_inc_q", # Quarterly operating income102"revtq": "rev_q", # Quartely revenue103"niq": "net_inc_q", # Quartely net income104"atq": "tot_assets", # Assets105"teqq": "sh_equity", # Shareholder's equity106"epspiy": "eps_incl_ex", # EPS(Basic) incl. Extraordinary items107"ceqq": "com_eq", # Common Equity108"cshoq": "sh_outstanding", # Common Shares Outstanding109"dvpspq": "div_per_sh", # Dividends per share110"actq": "cur_assets", # Current assets111"lctq": "cur_liabilities", # Current liabilities112"cheq": "cash_eq", # Cash & Equivalent113"rectq": "receivables", # Receivalbles114"cogsq": "cogs_q", # Cost of Goods Sold115"invtq": "inventories", # Inventories116"apq": "payables", # Account payable117"dlttq": "long_debt", # Long term debt118"dlcq": "short_debt", # Debt in current liabilites119"ltq": "tot_liabilities", # Liabilities120}121)122123# Calculate financial ratios124date = pd.to_datetime(fund_data["date"], format="%Y%m%d")125126tic = fund_data["tic"].to_frame("tic")127128# Profitability ratios129# Operating Margin130OPM = pd.Series(np.empty(fund_data.shape[0], dtype=object), name="OPM")131for i in range(0, fund_data.shape[0]):132if i - 3 < 0:133OPM[i] = np.nan134elif fund_data.iloc[i, 1] != fund_data.iloc[i - 3, 1]:135OPM.iloc[i] = np.nan136else:137OPM.iloc[i] = np.sum(fund_data["op_inc_q"].iloc[i - 3 : i]) / np.sum(138fund_data["rev_q"].iloc[i - 3 : i]139)140141# Net Profit Margin142NPM = pd.Series(np.empty(fund_data.shape[0], dtype=object), name="NPM")143for i in range(0, fund_data.shape[0]):144if i - 3 < 0:145NPM[i] = np.nan146elif fund_data.iloc[i, 1] != fund_data.iloc[i - 3, 1]:147NPM.iloc[i] = np.nan148else:149NPM.iloc[i] = np.sum(fund_data["net_inc_q"].iloc[i - 3 : i]) / np.sum(150fund_data["rev_q"].iloc[i - 3 : i]151)152153# Return On Assets154ROA = pd.Series(np.empty(fund_data.shape[0], dtype=object), name="ROA")155for i in range(0, fund_data.shape[0]):156if i - 3 < 0:157ROA[i] = np.nan158elif fund_data.iloc[i, 1] != fund_data.iloc[i - 3, 1]:159ROA.iloc[i] = np.nan160else:161ROA.iloc[i] = (162np.sum(fund_data["net_inc_q"].iloc[i - 3 : i])163/ fund_data["tot_assets"].iloc[i]164)165166# Return on Equity167ROE = pd.Series(np.empty(fund_data.shape[0], dtype=object), name="ROE")168for i in range(0, fund_data.shape[0]):169if i - 3 < 0:170ROE[i] = np.nan171elif fund_data.iloc[i, 1] != fund_data.iloc[i - 3, 1]:172ROE.iloc[i] = np.nan173else:174ROE.iloc[i] = (175np.sum(fund_data["net_inc_q"].iloc[i - 3 : i])176/ fund_data["sh_equity"].iloc[i]177)178179# For calculating valuation ratios in the next subpart, calculate per share items in advance180# Earnings Per Share181EPS = fund_data["eps_incl_ex"].to_frame("EPS")182183# Book Per Share184BPS = (fund_data["com_eq"] / fund_data["sh_outstanding"]).to_frame(185"BPS"186) # Need to check units187188# Dividend Per Share189DPS = fund_data["div_per_sh"].to_frame("DPS")190191# Liquidity ratios192# Current ratio193cur_ratio = (fund_data["cur_assets"] / fund_data["cur_liabilities"]).to_frame(194"cur_ratio"195)196197# Quick ratio198quick_ratio = (199(fund_data["cash_eq"] + fund_data["receivables"]) / fund_data["cur_liabilities"]200).to_frame("quick_ratio")201202# Cash ratio203cash_ratio = (fund_data["cash_eq"] / fund_data["cur_liabilities"]).to_frame(204"cash_ratio"205)206207# Efficiency ratios208# Inventory turnover ratio209inv_turnover = pd.Series(210np.empty(fund_data.shape[0], dtype=object), name="inv_turnover"211)212for i in range(0, fund_data.shape[0]):213if i - 3 < 0:214inv_turnover[i] = np.nan215elif fund_data.iloc[i, 1] != fund_data.iloc[i - 3, 1]:216inv_turnover.iloc[i] = np.nan217else:218inv_turnover.iloc[i] = (219np.sum(fund_data["cogs_q"].iloc[i - 3 : i])220/ fund_data["inventories"].iloc[i]221)222223# Receivables turnover ratio224acc_rec_turnover = pd.Series(225np.empty(fund_data.shape[0], dtype=object), name="acc_rec_turnover"226)227for i in range(0, fund_data.shape[0]):228if i - 3 < 0:229acc_rec_turnover[i] = np.nan230elif fund_data.iloc[i, 1] != fund_data.iloc[i - 3, 1]:231acc_rec_turnover.iloc[i] = np.nan232else:233acc_rec_turnover.iloc[i] = (234np.sum(fund_data["rev_q"].iloc[i - 3 : i])235/ fund_data["receivables"].iloc[i]236)237238# Payable turnover ratio239acc_pay_turnover = pd.Series(240np.empty(fund_data.shape[0], dtype=object), name="acc_pay_turnover"241)242for i in range(0, fund_data.shape[0]):243if i - 3 < 0:244acc_pay_turnover[i] = np.nan245elif fund_data.iloc[i, 1] != fund_data.iloc[i - 3, 1]:246acc_pay_turnover.iloc[i] = np.nan247else:248acc_pay_turnover.iloc[i] = (249np.sum(fund_data["cogs_q"].iloc[i - 3 : i])250/ fund_data["payables"].iloc[i]251)252253## Leverage financial ratios254# Debt ratio255debt_ratio = (fund_data["tot_liabilities"] / fund_data["tot_assets"]).to_frame(256"debt_ratio"257)258259# Debt to Equity ratio260debt_to_equity = (fund_data["tot_liabilities"] / fund_data["sh_equity"]).to_frame(261"debt_to_equity"262)263264# Create a dataframe that merges all the ratios265ratios = pd.concat(266[267date,268tic,269OPM,270NPM,271ROA,272ROE,273EPS,274BPS,275DPS,276cur_ratio,277quick_ratio,278cash_ratio,279inv_turnover,280acc_rec_turnover,281acc_pay_turnover,282debt_ratio,283debt_to_equity,284],285axis=1,286)287288# Replace NAs infinite values with zero289final_ratios = ratios.copy()290final_ratios = final_ratios.fillna(0)291final_ratios = final_ratios.replace(np.inf, 0)292293list_ticker = df["tic"].unique().tolist()294list_date = list(pd.date_range(df["date"].min(), df["date"].max()))295combination = list(itertools.product(list_date, list_ticker))296297# Merge stock price data and ratios into one dataframe298processed_full = pd.DataFrame(combination, columns=["date", "tic"]).merge(299df, on=["date", "tic"], how="left"300)301processed_full = processed_full.merge(final_ratios, how="left", on=["date", "tic"])302processed_full = processed_full.sort_values(["tic", "date"])303304# Backfill the ratio data to make them daily305processed_full = processed_full.bfill(axis="rows")306307# Calculate P/E, P/B and dividend yield using daily closing price308processed_full["PE"] = processed_full["close"] / processed_full["EPS"]309processed_full["PB"] = processed_full["close"] / processed_full["BPS"]310processed_full["Div_yield"] = processed_full["DPS"] / processed_full["close"]311312# Drop per share items used for the above calculation313processed_full = processed_full.drop(columns=["day", "EPS", "BPS", "DPS"])314# Replace NAs infinite values with zero315processed_full = processed_full.copy()316processed_full = processed_full.fillna(0)317processed_full = processed_full.replace(np.inf, 0)318319# Check the final data320processed_full.sort_values(["date", "tic"], ignore_index=True).head(10)321322train_data = data_split(processed_full, TRAIN_START_DATE, TRAIN_END_DATE)323trade_data = data_split(processed_full, TEST_START_DATE, TEST_END_DATE)324# Check the length of the two datasets325print(len(train_data))326print(len(trade_data))327328import gym329import matplotlib330import matplotlib.pyplot as plt331import numpy as np332import pandas as pd333from gym import spaces334from gym.utils import seeding335from stable_baselines3.common.vec_env import DummyVecEnv336337matplotlib.use("Agg")338339# from stable_baselines3.common import logger340341class StockTradingEnv(gym.Env):342"""A stock trading environment for OpenAI gym"""343344metadata = {"render.modes": ["human"]}345346def __init__(347self,348df,349stock_dim,350hmax,351initial_amount,352buy_cost_pct,353sell_cost_pct,354reward_scaling,355state_space,356action_space,357tech_indicator_list,358turbulence_threshold=None,359risk_indicator_col="turbulence",360make_plots=False,361print_verbosity=10,362day=0,363initial=True,364previous_state=[],365model_name="",366mode="",367iteration="",368):369self.day = day370self.df = df371self.stock_dim = stock_dim372self.hmax = hmax373self.initial_amount = initial_amount374self.buy_cost_pct = buy_cost_pct375self.sell_cost_pct = sell_cost_pct376self.reward_scaling = reward_scaling377self.state_space = state_space378self.action_space = action_space379self.tech_indicator_list = tech_indicator_list380self.action_space = spaces.Box(low=-1, high=1, shape=(self.action_space,))381self.observation_space = spaces.Box(382low=-np.inf, high=np.inf, shape=(self.state_space,)383)384self.data = self.df.loc[self.day, :]385self.terminal = False386self.make_plots = make_plots387self.print_verbosity = print_verbosity388self.turbulence_threshold = turbulence_threshold389self.risk_indicator_col = risk_indicator_col390self.initial = initial391self.previous_state = previous_state392self.model_name = model_name393self.mode = mode394self.iteration = iteration395# initalize state396self.state = self._initiate_state()397398# initialize reward399self.reward = 0400self.turbulence = 0401self.cost = 0402self.trades = 0403self.episode = 0404# memorize all the total balance change405self.asset_memory = [self.initial_amount]406self.rewards_memory = []407self.actions_memory = []408self.date_memory = [self._get_date()]409# self.reset()410self._seed()411412def _sell_stock(self, index, action):413def _do_sell_normal():414if self.state[index + 1] > 0:415# Sell only if the price is > 0 (no missing data in this particular date)416# perform sell action based on the sign of the action417if self.state[index + self.stock_dim + 1] > 0:418# Sell only if current asset is > 0419sell_num_shares = min(420abs(action), self.state[index + self.stock_dim + 1]421)422sell_amount = (423self.state[index + 1]424* sell_num_shares425* (1 - self.sell_cost_pct)426)427# update balance428self.state[0] += sell_amount429430self.state[index + self.stock_dim + 1] -= sell_num_shares431self.cost += (432self.state[index + 1] * sell_num_shares * self.sell_cost_pct433)434self.trades += 1435else:436sell_num_shares = 0437else:438sell_num_shares = 0439440return sell_num_shares441442# perform sell action based on the sign of the action443if self.turbulence_threshold is not None:444if self.turbulence >= self.turbulence_threshold:445if self.state[index + 1] > 0:446# Sell only if the price is > 0 (no missing data in this particular date)447# if turbulence goes over threshold, just clear out all positions448if self.state[index + self.stock_dim + 1] > 0:449# Sell only if current asset is > 0450sell_num_shares = self.state[index + self.stock_dim + 1]451sell_amount = (452self.state[index + 1]453* sell_num_shares454* (1 - self.sell_cost_pct)455)456# update balance457self.state[0] += sell_amount458self.state[index + self.stock_dim + 1] = 0459self.cost += (460self.state[index + 1]461* sell_num_shares462* self.sell_cost_pct463)464self.trades += 1465else:466sell_num_shares = 0467else:468sell_num_shares = 0469else:470sell_num_shares = _do_sell_normal()471else:472sell_num_shares = _do_sell_normal()473474return sell_num_shares475476def _buy_stock(self, index, action):477def _do_buy():478if self.state[index + 1] > 0:479# Buy only if the price is > 0 (no missing data in this particular date)480available_amount = self.state[0] // self.state[index + 1]481# print('available_amount:{}'.format(available_amount))482483# update balance484buy_num_shares = min(available_amount, action)485buy_amount = (486self.state[index + 1] * buy_num_shares * (1 + self.buy_cost_pct)487)488self.state[0] -= buy_amount489490self.state[index + self.stock_dim + 1] += buy_num_shares491492self.cost += (493self.state[index + 1] * buy_num_shares * self.buy_cost_pct494)495self.trades += 1496else:497buy_num_shares = 0498499return buy_num_shares500501# perform buy action based on the sign of the action502if self.turbulence_threshold is None:503buy_num_shares = _do_buy()504else:505if self.turbulence < self.turbulence_threshold:506buy_num_shares = _do_buy()507else:508buy_num_shares = 0509pass510511return buy_num_shares512513def _make_plot(self):514plt.plot(self.asset_memory, "r")515plt.savefig(f"results/account_value_trade_{self.episode}.png")516plt.close()517518def step(self, actions):519self.terminal = self.day >= len(self.df.index.unique()) - 1520if self.terminal:521# print(f"Episode: {self.episode}")522if self.make_plots:523self._make_plot()524end_total_asset = self.state[0] + sum(525np.array(self.state[1 : (self.stock_dim + 1)])526* np.array(527self.state[(self.stock_dim + 1) : (self.stock_dim * 2 + 1)]528)529)530df_total_value = pd.DataFrame(self.asset_memory)531tot_reward = (532self.state[0]533+ sum(534np.array(self.state[1 : (self.stock_dim + 1)])535* np.array(536self.state[(self.stock_dim + 1) : (self.stock_dim * 2 + 1)]537)538)539- self.initial_amount540)541df_total_value.columns = ["account_value"]542df_total_value["date"] = self.date_memory543df_total_value["daily_return"] = df_total_value[544"account_value"545].pct_change(1)546if df_total_value["daily_return"].std() != 0:547sharpe = (548(252**0.5)549* df_total_value["daily_return"].mean()550/ df_total_value["daily_return"].std()551)552df_rewards = pd.DataFrame(self.rewards_memory)553df_rewards.columns = ["account_rewards"]554df_rewards["date"] = self.date_memory[:-1]555if self.episode % self.print_verbosity == 0:556print(f"day: {self.day}, episode: {self.episode}")557print(f"begin_total_asset: {self.asset_memory[0]:0.2f}")558print(f"end_total_asset: {end_total_asset:0.2f}")559print(f"total_reward: {tot_reward:0.2f}")560print(f"total_cost: {self.cost:0.2f}")561print(f"total_trades: {self.trades}")562if df_total_value["daily_return"].std() != 0:563print(f"Sharpe: {sharpe:0.3f}")564print("=================================")565566if (self.model_name != "") and (self.mode != ""):567df_actions = self.save_action_memory()568df_actions.to_csv(569"results/actions_{}_{}_{}.csv".format(570self.mode, self.model_name, self.iteration571)572)573df_total_value.to_csv(574"results/account_value_{}_{}_{}.csv".format(575self.mode, self.model_name, self.iteration576),577index=False,578)579df_rewards.to_csv(580"results/account_rewards_{}_{}_{}.csv".format(581self.mode, self.model_name, self.iteration582),583index=False,584)585plt.plot(self.asset_memory, "r")586plt.savefig(587"results/account_value_{}_{}_{}.png".format(588self.mode, self.model_name, self.iteration589),590index=False,591)592plt.close()593594# Add outputs to logger interface595# logger.record("environment/portfolio_value", end_total_asset)596# logger.record("environment/total_reward", tot_reward)597# logger.record("environment/total_reward_pct", (tot_reward / (end_total_asset - tot_reward)) * 100)598# logger.record("environment/total_cost", self.cost)599# logger.record("environment/total_trades", self.trades)600601return self.state, self.reward, self.terminal, {}602603else:604actions = (605actions * self.hmax606) # actions initially is scaled between 0 to 1607actions = actions.astype(608int609) # convert into integer because we can't by fraction of shares610if self.turbulence_threshold is not None:611if self.turbulence >= self.turbulence_threshold:612actions = np.array([-self.hmax] * self.stock_dim)613begin_total_asset = self.state[0] + sum(614np.array(self.state[1 : (self.stock_dim + 1)])615* np.array(616self.state[(self.stock_dim + 1) : (self.stock_dim * 2 + 1)]617)618)619# print("begin_total_asset:{}".format(begin_total_asset))620621argsort_actions = np.argsort(actions)622623sell_index = argsort_actions[: np.where(actions < 0)[0].shape[0]]624buy_index = argsort_actions[::-1][: np.where(actions > 0)[0].shape[0]]625626for index in sell_index:627# print(f"Num shares before: {self.state[index+self.stock_dim+1]}")628# print(f'take sell action before : {actions[index]}')629actions[index] = self._sell_stock(index, actions[index]) * (-1)630# print(f'take sell action after : {actions[index]}')631# print(f"Num shares after: {self.state[index+self.stock_dim+1]}")632633for index in buy_index:634# print('take buy action: {}'.format(actions[index]))635actions[index] = self._buy_stock(index, actions[index])636637self.actions_memory.append(actions)638639# state: s -> s+1640self.day += 1641self.data = self.df.loc[self.day, :]642if self.turbulence_threshold is not None:643if len(self.df.tic.unique()) == 1:644self.turbulence = self.data[self.risk_indicator_col]645elif len(self.df.tic.unique()) > 1:646self.turbulence = self.data[self.risk_indicator_col].values[0]647self.state = self._update_state()648649end_total_asset = self.state[0] + sum(650np.array(self.state[1 : (self.stock_dim + 1)])651* np.array(652self.state[(self.stock_dim + 1) : (self.stock_dim * 2 + 1)]653)654)655self.asset_memory.append(end_total_asset)656self.date_memory.append(self._get_date())657self.reward = end_total_asset - begin_total_asset658self.rewards_memory.append(self.reward)659self.reward = self.reward * self.reward_scaling660661return self.state, self.reward, self.terminal, {}662663def reset(self):664# initiate state665self.state = self._initiate_state()666667if self.initial:668self.asset_memory = [self.initial_amount]669else:670previous_total_asset = self.previous_state[0] + sum(671np.array(self.state[1 : (self.stock_dim + 1)])672* np.array(673self.previous_state[674(self.stock_dim + 1) : (self.stock_dim * 2 + 1)675]676)677)678self.asset_memory = [previous_total_asset]679680self.day = 0681self.data = self.df.loc[self.day, :]682self.turbulence = 0683self.cost = 0684self.trades = 0685self.terminal = False686# self.iteration=self.iteration687self.rewards_memory = []688self.actions_memory = []689self.date_memory = [self._get_date()]690691self.episode += 1692693return self.state694695def render(self, mode="human", close=False):696return self.state697698def _initiate_state(self):699if self.initial:700# For Initial State701if len(self.df.tic.unique()) > 1:702# for multiple stock703state = (704[self.initial_amount]705+ self.data.close.values.tolist()706+ [0] * self.stock_dim707+ sum(708[709self.data[tech].values.tolist()710for tech in self.tech_indicator_list711],712[],713)714)715else:716# for single stock717state = (718[self.initial_amount]719+ [self.data.close]720+ [0] * self.stock_dim721+ sum(722[[self.data[tech]] for tech in self.tech_indicator_list], []723)724)725else:726# Using Previous State727if len(self.df.tic.unique()) > 1:728# for multiple stock729state = (730[self.previous_state[0]]731+ self.data.close.values.tolist()732+ self.previous_state[733(self.stock_dim + 1) : (self.stock_dim * 2 + 1)734]735+ sum(736[737self.data[tech].values.tolist()738for tech in self.tech_indicator_list739],740[],741)742)743else:744# for single stock745state = (746[self.previous_state[0]]747+ [self.data.close]748+ self.previous_state[749(self.stock_dim + 1) : (self.stock_dim * 2 + 1)750]751+ sum(752[[self.data[tech]] for tech in self.tech_indicator_list], []753)754)755return state756757def _update_state(self):758if len(self.df.tic.unique()) > 1:759# for multiple stock760state = (761[self.state[0]]762+ self.data.close.values.tolist()763+ list(self.state[(self.stock_dim + 1) : (self.stock_dim * 2 + 1)])764+ sum(765[766self.data[tech].values.tolist()767for tech in self.tech_indicator_list768],769[],770)771)772773else:774# for single stock775state = (776[self.state[0]]777+ [self.data.close]778+ list(self.state[(self.stock_dim + 1) : (self.stock_dim * 2 + 1)])779+ sum([[self.data[tech]] for tech in self.tech_indicator_list], [])780)781return state782783def _get_date(self):784if len(self.df.tic.unique()) > 1:785date = self.data.date.unique()[0]786else:787date = self.data.date788return date789790def save_asset_memory(self):791date_list = self.date_memory792asset_list = self.asset_memory793# print(len(date_list))794# print(len(asset_list))795df_account_value = pd.DataFrame(796{"date": date_list, "account_value": asset_list}797)798return df_account_value799800def save_action_memory(self):801if len(self.df.tic.unique()) > 1:802# date and close price length must match actions length803date_list = self.date_memory[:-1]804df_date = pd.DataFrame(date_list)805df_date.columns = ["date"]806807action_list = self.actions_memory808df_actions = pd.DataFrame(action_list)809df_actions.columns = self.data.tic.values810df_actions.index = df_date.date811# df_actions = pd.DataFrame({'date':date_list,'actions':action_list})812else:813date_list = self.date_memory[:-1]814action_list = self.actions_memory815df_actions = pd.DataFrame({"date": date_list, "actions": action_list})816return df_actions817818def _seed(self, seed=None):819self.np_random, seed = seeding.np_random(seed)820return [seed]821822def get_sb_env(self):823e = DummyVecEnv([lambda: self])824obs = e.reset()825return e, obs826827ratio_list = [828"OPM",829"NPM",830"ROA",831"ROE",832"cur_ratio",833"quick_ratio",834"cash_ratio",835"inv_turnover",836"acc_rec_turnover",837"acc_pay_turnover",838"debt_ratio",839"debt_to_equity",840"PE",841"PB",842"Div_yield",843]844845stock_dimension = len(train_data.tic.unique())846state_space = 1 + 2 * stock_dimension + len(ratio_list) * stock_dimension847print(f"Stock Dimension: {stock_dimension}, State Space: {state_space}")848849# Parameters for the environment850env_kwargs = {851"hmax": 100,852"initial_amount": 1000000,853"buy_cost_pct": 0.001,854"sell_cost_pct": 0.001,855"state_space": state_space,856"stock_dim": stock_dimension,857"tech_indicator_list": ratio_list,858"action_space": stock_dimension,859"reward_scaling": 1e-4,860}861862# Establish the training environment using StockTradingEnv() class863e_train_gym = StockTradingEnv(df=train_data, **env_kwargs)864865env_train, _ = e_train_gym.get_sb_env()866print(type(env_train))867868# Set up the agent using DRLAgent() class using the environment created in the previous part869agent = DRLAgent(env=env_train)870871if_using_a2c = False872if_using_ddpg = False873if_using_ppo = False874if_using_td3 = False875if_using_sac = True876877agent = DRLAgent(env=env_train)878PPO_PARAMS = {879"n_steps": 2048,880"ent_coef": 0.01,881"learning_rate": 0.00025,882"batch_size": 128,883}884model_ppo = agent.get_model("ppo", model_kwargs=PPO_PARAMS)885886if if_using_ppo:887# set up logger888tmp_path = RESULTS_DIR + "/ppo"889new_logger_ppo = configure(tmp_path, ["stdout", "csv", "tensorboard"])890# Set new logger891model_ppo.set_logger(new_logger_ppo)892893trained_ppo = (894agent.train_model(model=model_ppo, tb_log_name="ppo", total_timesteps=50000)895if if_using_ppo896else None897)898899agent = DRLAgent(env=env_train)900model_ddpg = agent.get_model("ddpg")901902if if_using_ddpg:903# set up logger904tmp_path = RESULTS_DIR + "/ddpg"905new_logger_ddpg = configure(tmp_path, ["stdout", "csv", "tensorboard"])906# Set new logger907model_ddpg.set_logger(new_logger_ddpg)908909trained_ddpg = (910agent.train_model(model=model_ddpg, tb_log_name="ddpg", total_timesteps=50000)911if if_using_ddpg912else None913)914915agent = DRLAgent(env=env_train)916model_a2c = agent.get_model("a2c")917918if if_using_a2c:919# set up logger920tmp_path = RESULTS_DIR + "/a2c"921new_logger_a2c = configure(tmp_path, ["stdout", "csv", "tensorboard"])922# Set new logger923model_a2c.set_logger(new_logger_a2c)924925trained_a2c = (926agent.train_model(model=model_a2c, tb_log_name="a2c", total_timesteps=50000)927if if_using_a2c928else None929)930931agent = DRLAgent(env=env_train)932TD3_PARAMS = {"batch_size": 100, "buffer_size": 1000000, "learning_rate": 0.001}933934model_td3 = agent.get_model("td3", model_kwargs=TD3_PARAMS)935936if if_using_td3:937# set up logger938tmp_path = RESULTS_DIR + "/td3"939new_logger_td3 = configure(tmp_path, ["stdout", "csv", "tensorboard"])940# Set new logger941model_td3.set_logger(new_logger_td3)942943trained_td3 = (944agent.train_model(model=model_td3, tb_log_name="td3", total_timesteps=30000)945if if_using_td3946else None947)948949agent = DRLAgent(env=env_train)950SAC_PARAMS = {951"batch_size": 128,952"buffer_size": 1000000,953"learning_rate": 0.0001,954"learning_starts": 100,955"ent_coef": "auto_0.1",956}957958model_sac = agent.get_model("sac", model_kwargs=SAC_PARAMS)959960if if_using_sac:961# set up logger962tmp_path = RESULTS_DIR + "/sac"963new_logger_sac = configure(tmp_path, ["stdout", "csv", "tensorboard"])964# Set new logger965model_sac.set_logger(new_logger_sac)966967trained_sac = (968agent.train_model(model=model_sac, tb_log_name="sac", total_timesteps=30000)969if if_using_sac970else None971)972973trade_data = data_split(processed_full, TEST_START_DATE, TEST_END_DATE)974e_trade_gym = StockTradingEnv(df=trade_data, **env_kwargs)975# env_trade, obs_trade = e_trade_gym.get_sb_env()976977df_account_value_ppo, df_actions_ppo = (978DRLAgent.DRL_prediction(model=trained_ppo, environment=e_trade_gym)979if if_using_ppo980else [None, None]981)982983df_account_value_ddpg, df_actions_ddpg = (984DRLAgent.DRL_prediction(model=trained_ddpg, environment=e_trade_gym)985if if_using_ddpg986else [None, None]987)988989df_account_value_a2c, df_actions_a2c = (990DRLAgent.DRL_prediction(model=trained_a2c, environment=e_trade_gym)991if if_using_a2c992else [None, None]993)994995df_account_value_td3, df_actions_td3 = (996DRLAgent.DRL_prediction(model=trained_td3, environment=e_trade_gym)997if if_using_td3998else [None, None]999)10001001df_account_value_sac, df_actions_sac = (1002DRLAgent.DRL_prediction(model=trained_sac, environment=e_trade_gym)1003if if_using_sac1004else [None, None]1005)10061007print("==============Get Backtest Results===========")1008now = datetime.datetime.now().strftime("%Y%m%d-%Hh%M")10091010if if_using_ppo:1011print("\n ppo:")1012perf_stats_all_ppo = backtest_stats(account_value=df_account_value_ppo)1013perf_stats_all_ppo = pd.DataFrame(perf_stats_all_ppo)1014perf_stats_all_ppo.to_csv(1015"./" + config.RESULTS_DIR + "/perf_stats_all_ppo_" + now + ".csv"1016)10171018if if_using_ddpg:1019print("\n ddpg:")1020perf_stats_all_ddpg = backtest_stats(account_value=df_account_value_ddpg)1021perf_stats_all_ddpg = pd.DataFrame(perf_stats_all_ddpg)1022perf_stats_all_ddpg.to_csv(1023"./" + config.RESULTS_DIR + "/perf_stats_all_ddpg_" + now + ".csv"1024)10251026if if_using_a2c:1027print("\n a2c:")1028perf_stats_all_a2c = backtest_stats(account_value=df_account_value_a2c)1029perf_stats_all_a2c = pd.DataFrame(perf_stats_all_a2c)1030perf_stats_all_a2c.to_csv(1031"./" + config.RESULTS_DIR + "/perf_stats_all_a2c_" + now + ".csv"1032)10331034if if_using_td3:1035print("\n atd3:")1036perf_stats_all_td3 = backtest_stats(account_value=df_account_value_td3)1037perf_stats_all_td3 = pd.DataFrame(perf_stats_all_td3)1038perf_stats_all_td3.to_csv(1039"./" + config.RESULTS_DIR + "/perf_stats_all_td3_" + now + ".csv"1040)10411042if if_using_sac:1043print("\n sac:")1044perf_stats_all_sac = backtest_stats(account_value=df_account_value_sac)1045perf_stats_all_sac = pd.DataFrame(perf_stats_all_sac)1046perf_stats_all_sac.to_csv(1047"./" + config.RESULTS_DIR + "/perf_stats_all_sac_" + now + ".csv"1048)10491050# baseline stats1051print("==============Get Baseline Stats===========")1052baseline_df = get_baseline(ticker="^DJI", start=TEST_START_DATE, end=TEST_END_DATE)10531054stats = backtest_stats(baseline_df, value_col_name="close")10551056print("==============Compare to DJIA===========")10571058# S&P 500: ^GSPC1059# Dow Jones Index: ^DJI1060# NASDAQ 100: ^NDX10611062if if_using_ppo:1063backtest_plot(1064df_account_value_ppo,1065baseline_ticker="^DJI",1066baseline_start=TEST_START_DATE,1067baseline_end=TEST_END_DATE,1068)10691070if if_using_ddpg:1071backtest_plot(1072df_account_value_ddpg,1073baseline_ticker="^DJI",1074baseline_start=TEST_START_DATE,1075baseline_end=TEST_END_DATE,1076)10771078if if_using_a2c:1079backtest_plot(1080df_account_value_a2c,1081baseline_ticker="^DJI",1082baseline_start=TEST_START_DATE,1083baseline_end=TEST_END_DATE,1084)10851086if if_using_td3:1087backtest_plot(1088df_account_value_td3,1089baseline_ticker="^DJI",1090baseline_start=TEST_START_DATE,1091baseline_end=TEST_END_DATE,1092)10931094if if_using_sac:1095backtest_plot(1096df_account_value_sac,1097baseline_ticker="^DJI",1098baseline_start=TEST_START_DATE,1099baseline_end=TEST_END_DATE,1100)110111021103if __name__ == "__main__":1104main()110511061107