Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/examples/Stock_NeurIPS2018_SB3.ipynb
726 views
Kernel: Python 3 (ipykernel)

Open In Colab

Deep Reinforcement Learning for Stock Trading from Scratch: Multiple Stock Trading

  • Pytorch Version

Content

Part 1. Task Discription

We train a DRL agent for stock trading. This task is modeled as a Markov Decision Process (MDP), and the objective function is maximizing (expected) cumulative return.

We specify the state-action-reward as follows:

  • State s: The state space represents an agent's perception of the market environment. Just like a human trader analyzing various information, here our agent passively observes many features and learns by interacting with the market environment (usually by replaying historical data).

  • Action a: The action space includes allowed actions that an agent can take at each state. For example, a ∈ {−1, 0, 1}, where −1, 0, 1 represent selling, holding, and buying. When an action operates multiple shares, a ∈{−k, ..., −1, 0, 1, ..., k}, e.g.. "Buy 10 shares of AAPL" or "Sell 10 shares of AAPL" are 10 or −10, respectively

  • Reward function r(s, a, s′): Reward is an incentive for an agent to learn a better policy. For example, it can be the change of the portfolio value when taking a at state s and arriving at new state s', i.e., r(s, a, s′) = v′ − v, where v′ and v represent the portfolio values at state s′ and s, respectively

Market environment: 30 consituent stocks of Dow Jones Industrial Average (DJIA) index. Accessed at the starting date of the testing period.

The data for this case study is obtained from Yahoo Finance API. The data contains Open-High-Low-Close price and volume.

Part 2. Install Python Packages

2.1. Install packages

## install required packages !pip install swig !pip install wrds !pip install pyportfolioopt ## install finrl library !pip install -q condacolab import condacolab condacolab.install() !apt-get update -y -qq && apt-get install -y -qq cmake libopenmpi-dev python3-dev zlib1g-dev libgl1-mesa-glx swig !pip install git+https://github.com/AI4Finance-Foundation/FinRL.git
Requirement already satisfied: swig in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (4.3.0) Requirement already satisfied: wrds in /home/random/.local/lib/python3.12/site-packages (3.2.0) Requirement already satisfied: numpy<1.27,>=1.26 in /home/random/.local/lib/python3.12/site-packages (from wrds) (1.26.4) Requirement already satisfied: packaging<23.3 in /home/random/.local/lib/python3.12/site-packages (from wrds) (23.2) Requirement already satisfied: pandas<2.3,>=2.2 in /home/random/.local/lib/python3.12/site-packages (from wrds) (2.2.3) Requirement already satisfied: psycopg2-binary<2.10,>=2.9 in /home/random/.local/lib/python3.12/site-packages (from wrds) (2.9.10) Requirement already satisfied: scipy<1.13,>=1.12 in /home/random/.local/lib/python3.12/site-packages (from wrds) (1.12.0) Requirement already satisfied: sqlalchemy<2.1,>=2 in /home/random/.local/lib/python3.12/site-packages (from wrds) (2.0.36) Requirement already satisfied: python-dateutil>=2.8.2 in /home/random/.local/lib/python3.12/site-packages (from pandas<2.3,>=2.2->wrds) (2.9.0.post0) Requirement already satisfied: pytz>=2020.1 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from pandas<2.3,>=2.2->wrds) (2024.1) Requirement already satisfied: tzdata>=2022.7 in /home/random/.local/lib/python3.12/site-packages (from pandas<2.3,>=2.2->wrds) (2024.2) Requirement already satisfied: typing-extensions>=4.6.0 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from sqlalchemy<2.1,>=2->wrds) (4.12.2) Requirement already satisfied: greenlet!=0.4.17 in /home/random/.local/lib/python3.12/site-packages (from sqlalchemy<2.1,>=2->wrds) (3.1.1) Requirement already satisfied: six>=1.5 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas<2.3,>=2.2->wrds) (1.16.0) Requirement already satisfied: pyportfolioopt in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (1.5.6) Requirement already satisfied: cvxpy>=1.1.19 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from pyportfolioopt) (1.6.0) Requirement already satisfied: ecos<3.0.0,>=2.0.14 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from pyportfolioopt) (2.0.14) Requirement already satisfied: numpy>=1.26.0 in /home/random/.local/lib/python3.12/site-packages (from pyportfolioopt) (1.26.4) Requirement already satisfied: pandas>=0.19 in /home/random/.local/lib/python3.12/site-packages (from pyportfolioopt) (2.2.3) Requirement already satisfied: plotly<6.0.0,>=5.0.0 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from pyportfolioopt) (5.24.1) Requirement already satisfied: scipy>=1.3 in /home/random/.local/lib/python3.12/site-packages (from pyportfolioopt) (1.12.0) Requirement already satisfied: osqp>=0.6.2 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from cvxpy>=1.1.19->pyportfolioopt) (0.6.7.post3) Requirement already satisfied: clarabel>=0.5.0 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from cvxpy>=1.1.19->pyportfolioopt) (0.9.0) Requirement already satisfied: scs>=3.2.4.post1 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from cvxpy>=1.1.19->pyportfolioopt) (3.2.7) Requirement already satisfied: python-dateutil>=2.8.2 in /home/random/.local/lib/python3.12/site-packages (from pandas>=0.19->pyportfolioopt) (2.9.0.post0) Requirement already satisfied: pytz>=2020.1 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from pandas>=0.19->pyportfolioopt) (2024.1) Requirement already satisfied: tzdata>=2022.7 in /home/random/.local/lib/python3.12/site-packages (from pandas>=0.19->pyportfolioopt) (2024.2) Requirement already satisfied: tenacity>=6.2.0 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from plotly<6.0.0,>=5.0.0->pyportfolioopt) (9.0.0) Requirement already satisfied: packaging in /home/random/.local/lib/python3.12/site-packages (from plotly<6.0.0,>=5.0.0->pyportfolioopt) (23.2) Requirement already satisfied: qdldl in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from osqp>=0.6.2->cvxpy>=1.1.19->pyportfolioopt) (0.1.7.post4) Requirement already satisfied: six>=1.5 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas>=0.19->pyportfolioopt) (1.16.0) [sudo] password for random: cmake is already the newest version (3.28.3-1build7). libopenmpi-dev is already the newest version (4.1.6-7ubuntu2). python3-dev is already the newest version (3.12.3-0ubuntu2). zlib1g-dev is already the newest version (1:1.3.dfsg-3.1ubuntu2.1). libgl1-mesa-glx is already the newest version (23.0.4-0ubuntu1~22.04.1). swig is already the newest version (4.2.0-2ubuntu1). 0 upgraded, 0 newly installed, 0 to remove and 39 not upgraded. Collecting git+https://github.com/AI4Finance-Foundation/FinRL.git Cloning https://github.com/AI4Finance-Foundation/FinRL.git to /tmp/pip-req-build-flt95p98 Running command git clone --filter=blob:none --quiet https://github.com/AI4Finance-Foundation/FinRL.git /tmp/pip-req-build-flt95p98 Resolved https://github.com/AI4Finance-Foundation/FinRL.git to commit ef471fcea1f3667442f5ecbf7b4c214610a5dd55 Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Collecting elegantrl@ git+https://github.com/AI4Finance-Foundation/ElegantRL.git (from finrl==0.3.6) Cloning https://github.com/AI4Finance-Foundation/ElegantRL.git to /tmp/pip-install-u43l6ss9/elegantrl_36782baa6d82461e89b600dda61820c8 Running command git clone --filter=blob:none --quiet https://github.com/AI4Finance-Foundation/ElegantRL.git /tmp/pip-install-u43l6ss9/elegantrl_36782baa6d82461e89b600dda61820c8 Resolved https://github.com/AI4Finance-Foundation/ElegantRL.git to commit 59d9a33e2b3ba2d77c052c2810bb61059736d88c Preparing metadata (setup.py) ... done Requirement already satisfied: alpaca-trade-api<4,>=3 in /home/random/.local/lib/python3.12/site-packages (from finrl==0.3.6) (3.2.0) Collecting ccxt<4,>=3 (from finrl==0.3.6) Using cached ccxt-3.1.60-py2.py3-none-any.whl.metadata (108 kB) Requirement already satisfied: exchange-calendars<5,>=4 in /home/random/.local/lib/python3.12/site-packages (from finrl==0.3.6) (4.6) Collecting jqdatasdk<2,>=1 (from finrl==0.3.6) Using cached jqdatasdk-1.9.7-py3-none-any.whl.metadata (5.8 kB) Collecting pyfolio<0.10,>=0.9 (from finrl==0.3.6) Using cached pyfolio-0.9.2.tar.gz (91 kB) Preparing metadata (setup.py) ... error error: subprocess-exited-with-error × python setup.py egg_info did not run successfully. exit code: 1 ╰─> [18 lines of output] /tmp/pip-install-u43l6ss9/pyfolio_f61a15f976d345b4a7050d0999ff9c7b/versioneer.py:468: SyntaxWarning: invalid escape sequence '\s' LONG_VERSION_PY['git'] = ''' Traceback (most recent call last): File "<string>", line 2, in <module> File "<pip-setuptools-caller>", line 34, in <module> File "/tmp/pip-install-u43l6ss9/pyfolio_f61a15f976d345b4a7050d0999ff9c7b/setup.py", line 71, in <module> version=versioneer.get_version(), ^^^^^^^^^^^^^^^^^^^^^^^^ File "/tmp/pip-install-u43l6ss9/pyfolio_f61a15f976d345b4a7050d0999ff9c7b/versioneer.py", line 1407, in get_version return get_versions()["version"] ^^^^^^^^^^^^^^ File "/tmp/pip-install-u43l6ss9/pyfolio_f61a15f976d345b4a7050d0999ff9c7b/versioneer.py", line 1341, in get_versions cfg = get_config_from_root(root) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/tmp/pip-install-u43l6ss9/pyfolio_f61a15f976d345b4a7050d0999ff9c7b/versioneer.py", line 399, in get_config_from_root parser = configparser.SafeConfigParser() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ AttributeError: module 'configparser' has no attribute 'SafeConfigParser'. Did you mean: 'RawConfigParser'? [end of output] note: This error originates from a subprocess, and is likely not a problem with pip. error: metadata-generation-failed × Encountered error while generating package metadata. ╰─> See above for output. note: This is an issue with the package mentioned above, not pip. hint: See above for details.

2.2. A list of Python packages

  • Yahoo Finance API

  • pandas

  • numpy

  • matplotlib

  • stockstats

  • OpenAI gym

  • stable-baselines

  • tensorflow

  • pyfolio

2.3. Import Packages

import pandas as pd import numpy as np import matplotlib import matplotlib.pyplot as plt # matplotlib.use('Agg') %matplotlib inline from finrl.meta.preprocessor.yahoodownloader import YahooDownloader from finrl.meta.preprocessor.preprocessors import FeatureEngineer, data_split from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv from finrl.agents.stablebaselines3.models import DRLAgent from stable_baselines3.common.logger import configure from finrl.meta.data_processor import DataProcessor from finrl.meta.data_processors.processor_yahoofinance import YahooFinanceProcessor from finrl.plot import backtest_stats, backtest_plot, get_daily_return, get_baseline from pprint import pprint import sys sys.path.append("../FinRL") import itertools
2025-01-04 15:29:19.697527: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. 2025-01-04 15:29:19.724461: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered WARNING: All log messages before absl::InitializeLog() is called are written to STDERR E0000 00:00:1736000959.745993 24692 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered E0000 00:00:1736000959.755250 24692 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2025-01-04 15:29:19.798332: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages/pyfolio/pos.py:25: UserWarning: Module "zipline.assets" not found; multipliers will not be applied to position notionals. warnings.warn(

2.4. Create Folders

from finrl import config from finrl import config_tickers import os from finrl.main import check_and_make_directories from finrl.config import ( DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR, INDICATORS, TRAIN_START_DATE, TRAIN_END_DATE, TEST_START_DATE, TEST_END_DATE, TRADE_START_DATE, TRADE_END_DATE, ) check_and_make_directories([DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR])

Part 3. Download Data

Yahoo Finance provides stock data, financial news, financial reports, etc. Yahoo Finance is free.

  • FinRL uses a class YahooDownloader in FinRL-Meta to fetch data via Yahoo Finance API

  • Call Limit: Using the Public API (without authentication), you are limited to 2,000 requests per hour per IP (or up to a total of 48,000 requests a day).


class YahooDownloader: Retrieving daily stock data from Yahoo Finance API

Attributes ---------- start_date : str start date of the data (modified from config.py) end_date : str end date of the data (modified from config.py) ticker_list : list a list of stock tickers (modified from config.py) Methods ------- fetch_data()
# from config.py, TRAIN_START_DATE is a string TRAIN_START_DATE # from config.py, TRAIN_END_DATE is a string TRAIN_END_DATE
'2020-07-31'
TRAIN_START_DATE = '2010-01-01' TRAIN_END_DATE = '2021-10-01' TRADE_START_DATE = '2021-10-01' TRADE_END_DATE = '2023-03-01'
#df = YahooDownloader(start_date = TRAIN_START_DATE, # end_date = TRADE_END_DATE, # ticker_list = config_tickers.DOW_30_TICKER).fetch_data() yfp = YahooFinanceProcessor() df = yfp.scrap_data(['AXP', 'AMGN', 'AAPL'], '2010-01-01', '2010-02-01') print(df)
Processing AXP (1/3)... 33.33% complete. Processing AMGN (2/3)... 66.67% complete. Processing AAPL (3/3)... 100.00% complete. Date Open High Low Close Adj Close Volume tick day 0 2010-01-04 7.62 7.66 7.59 7.64 6.45 493729600 AAPL 3 1 2010-01-04 56.63 57.87 56.56 57.72 40.92 5277400 AMGN 3 2 2010-01-04 40.81 41.10 40.39 40.92 32.83 6894300 AXP 3 3 2010-01-05 7.66 7.70 7.62 7.66 6.46 601904800 AAPL 4 4 2010-01-05 57.33 57.69 56.27 57.22 40.56 7882800 AMGN 4 5 2010-01-05 40.83 41.23 40.37 40.83 32.76 10641200 AXP 4 6 2010-01-06 7.66 7.69 7.53 7.53 6.36 552160000 AAPL 5 7 2010-01-06 56.94 57.39 56.50 56.79 40.26 6015100 AMGN 5 8 2010-01-06 41.23 41.67 41.17 41.49 33.29 8399400 AXP 5 9 2010-01-07 7.56 7.57 7.47 7.52 6.34 477131200 AAPL 6 10 2010-01-07 56.41 56.53 54.65 56.27 39.89 10371600 AMGN 6 11 2010-01-07 41.26 42.24 41.11 41.98 33.83 8981700 AXP 6 12 2010-01-08 7.51 7.57 7.47 7.57 6.39 447610800 AAPL 7 13 2010-01-08 56.07 56.83 55.64 56.77 40.24 6576000 AMGN 7 14 2010-01-08 41.76 42.48 41.40 41.95 33.80 7907700 AXP 7 15 2010-01-11 7.60 7.61 7.44 7.50 6.33 462229600 AAPL 10 16 2010-01-11 56.93 57.36 56.62 57.02 40.42 4062700 AMGN 10 17 2010-01-11 41.74 41.96 41.25 41.47 33.42 7396000 AXP 10 18 2010-01-12 7.47 7.49 7.37 7.42 6.26 594459600 AAPL 11 19 2010-01-12 57.14 57.42 54.82 56.03 39.72 11268300 AMGN 11 20 2010-01-12 41.27 42.35 41.25 42.02 33.86 12657300 AXP 11 21 2010-01-13 7.42 7.53 7.29 7.52 6.35 605892000 AAPL 12 22 2010-01-13 56.35 56.75 55.96 56.53 40.07 5056200 AMGN 12 23 2010-01-13 41.85 42.24 41.57 42.15 33.96 10137200 AXP 12 24 2010-01-14 7.50 7.52 7.47 7.48 6.31 432894000 AAPL 13 25 2010-01-14 56.35 56.53 55.91 56.16 39.81 4668900 AMGN 13 26 2010-01-14 42.04 42.74 42.02 42.68 34.39 8238400 AXP 13 27 2010-01-15 7.53 7.56 7.35 7.35 6.20 594067600 AAPL 14 28 2010-01-15 56.03 56.51 55.65 56.25 39.87 7240000 AMGN 14 29 2010-01-15 42.52 42.84 42.02 42.39 34.16 13629000 AXP 14 30 2010-01-19 7.44 7.69 7.40 7.68 6.48 730007600 AAPL 18 31 2010-01-19 56.41 57.75 56.24 57.55 40.80 8570100 AMGN 18 32 2010-01-19 42.24 43.05 42.11 42.96 34.62 9533800 AXP 18 33 2010-01-20 7.68 7.70 7.48 7.56 6.38 612152800 AAPL 19 34 2010-01-20 57.62 57.62 56.41 57.20 40.55 6625700 AMGN 19 35 2010-01-20 42.93 43.25 42.26 42.98 34.63 11643000 AXP 19 36 2010-01-21 7.57 7.62 7.40 7.43 6.27 608154400 AAPL 20 37 2010-01-21 57.43 57.56 56.31 56.63 40.14 5833700 AMGN 20 38 2010-01-21 42.99 43.10 41.53 42.16 33.97 16974300 AXP 20 39 2010-01-22 7.39 7.41 7.04 7.06 5.96 881767600 AAPL 21 40 2010-01-22 56.67 57.30 56.53 56.60 40.12 5967600 AMGN 21 41 2010-01-22 41.36 41.49 38.19 38.59 31.09 26170800 AXP 21 42 2010-01-25 7.23 7.31 7.15 7.25 6.12 1065699600 AAPL 24 43 2010-01-25 56.72 56.79 55.55 55.71 39.49 6719400 AMGN 24 44 2010-01-25 39.10 39.29 37.50 37.79 30.45 17587600 AXP 24 45 2010-01-26 7.36 7.63 7.24 7.36 6.20 1867110000 AAPL 25 46 2010-01-26 56.20 56.87 55.70 56.58 40.11 14880300 AMGN 25 47 2010-01-26 37.54 39.23 37.52 38.10 30.70 15709900 AXP 25 48 2010-01-27 7.39 7.52 7.13 7.42 6.26 1722568400 AAPL 26 49 2010-01-27 56.35 57.88 56.35 57.74 40.93 9695000 AMGN 26 50 2010-01-27 37.96 38.84 37.83 38.67 31.16 12908300 AXP 26 51 2010-01-28 7.32 7.34 7.10 7.12 6.00 1173502400 AAPL 27 52 2010-01-28 57.87 58.78 57.56 58.08 41.17 11638200 AMGN 27 53 2010-01-28 38.67 38.67 36.83 37.43 30.16 14148600 AXP 27 54 2010-01-29 7.18 7.22 6.79 6.86 5.79 1245952400 AAPL 28 55 2010-01-29 58.35 58.93 58.16 58.48 41.45 9465700 AMGN 28 56 2010-01-29 37.60 38.77 37.36 37.66 30.35 14219900 AXP 28
print(config_tickers.DOW_30_TICKER)
['AXP', 'AMGN', 'AAPL', 'BA', 'CAT', 'CSCO', 'CVX', 'GS', 'HD', 'HON', 'IBM', 'INTC', 'JNJ', 'KO', 'JPM', 'MCD', 'MMM', 'MRK', 'MSFT', 'NKE', 'PG', 'TRV', 'UNH', 'CRM', 'VZ', 'V', 'WBA', 'WMT', 'DIS', 'DOW']
df.shape
(57, 9)
df.sort_values(['date','tic'],ignore_index=True).head()
--------------------------------------------------------------------------- KeyError Traceback (most recent call last) /tmp/ipykernel_24692/1255811168.py in ?() ----> 1 df.sort_values(['date','tic'],ignore_index=True).head() ~/.local/lib/python3.12/site-packages/pandas/core/frame.py in ?(self, by, axis, ascending, inplace, kind, na_position, ignore_index, key) 7168 f"Length of ascending ({len(ascending)})" # type: ignore[arg-type] 7169 f" != length of by ({len(by)})" 7170 ) 7171 if len(by) > 1: -> 7172 keys = [self._get_label_or_level_values(x, axis=axis) for x in by] 7173 7174 # need to rewrap columns in Series to apply key function 7175 if key is not None: ~/.local/lib/python3.12/site-packages/pandas/core/generic.py in ?(self, key, axis) 1907 values = self.xs(key, axis=other_axes[0])._values 1908 elif self._is_level_reference(key, axis=axis): 1909 values = self.axes[axis].get_level_values(key)._values 1910 else: -> 1911 raise KeyError(key) 1912 1913 # Check for duplicates 1914 if values.ndim > 1: KeyError: 'date'

Part 4: Preprocess Data

We need to check for missing data and do feature engineering to convert the data point into a state.

  • Adding technical indicators. In practical trading, various information needs to be taken into account, such as historical prices, current holding shares, technical indicators, etc. Here, we demonstrate two trend-following technical indicators: MACD and RSI.

  • Adding turbulence index. Risk-aversion reflects whether an investor prefers to protect the capital. It also influences one's trading strategy when facing different market volatility level. To control the risk in a worst-case scenario, such as financial crisis of 2007–2008, FinRL employs the turbulence index that measures extreme fluctuation of asset price.

fe = FeatureEngineer( use_technical_indicator=True, tech_indicator_list = INDICATORS, use_vix=True, use_turbulence=True, user_defined_feature = False) processed = fe.preprocess_data(df)
list_ticker = processed["tic"].unique().tolist() list_date = list(pd.date_range(processed['date'].min(),processed['date'].max()).astype(str)) combination = list(itertools.product(list_date,list_ticker)) processed_full = pd.DataFrame(combination,columns=["date","tic"]).merge(processed,on=["date","tic"],how="left") processed_full = processed_full[processed_full['date'].isin(processed['date'])] processed_full = processed_full.sort_values(['date','tic']) processed_full = processed_full.fillna(0)
processed_full.sort_values(['date','tic'],ignore_index=True).head(10)
mvo_df = processed_full.sort_values(['date','tic'],ignore_index=True)[['date','tic','close']]

Part 5. Build A Market Environment in OpenAI Gym-style

The training process involves observing stock price change, taking an action and reward's calculation. By interacting with the market environment, the agent will eventually derive a trading strategy that may maximize (expected) rewards.

Our market environment, based on OpenAI Gym, simulates stock markets with historical market data.

Data Split

We split the data into training set and testing set as follows:

Training data period: 2009-01-01 to 2020-07-01

Trading data period: 2020-07-01 to 2021-10-31

train = data_split(processed_full, TRAIN_START_DATE,TRAIN_END_DATE) trade = data_split(processed_full, TRADE_START_DATE,TRADE_END_DATE) train_length = len(train) trade_length = len(trade) print(train_length) print(trade_length)
train.tail()
trade.head()
INDICATORS
stock_dimension = len(train.tic.unique()) state_space = 1 + 2*stock_dimension + len(INDICATORS)*stock_dimension print(f"Stock Dimension: {stock_dimension}, State Space: {state_space}")
buy_cost_list = sell_cost_list = [0.001] * stock_dimension num_stock_shares = [0] * stock_dimension env_kwargs = { "hmax": 100, "initial_amount": 1000000, "num_stock_shares": num_stock_shares, "buy_cost_pct": buy_cost_list, "sell_cost_pct": sell_cost_list, "state_space": state_space, "stock_dim": stock_dimension, "tech_indicator_list": INDICATORS, "action_space": stock_dimension, "reward_scaling": 1e-4 } e_train_gym = StockTradingEnv(df = train, **env_kwargs)

Environment for Training

env_train, _ = e_train_gym.get_sb_env() print(type(env_train))

Part 6: Train DRL Agents

  • The DRL algorithms are from Stable Baselines 3. Users are also encouraged to try ElegantRL and Ray RLlib.

  • FinRL includes fine-tuned standard DRL algorithms, such as DQN, DDPG, Multi-Agent DDPG, PPO, SAC, A2C and TD3. We also allow users to design their own DRL algorithms by adapting these DRL algorithms.

agent = DRLAgent(env = env_train) if_using_a2c = True if_using_ddpg = True if_using_ppo = True if_using_td3 = True if_using_sac = True

Agent Training: 5 algorithms (A2C, DDPG, PPO, TD3, SAC)

Agent 1: A2C

agent = DRLAgent(env = env_train) model_a2c = agent.get_model("a2c") if if_using_a2c: # set up logger tmp_path = RESULTS_DIR + '/a2c' new_logger_a2c = configure(tmp_path, ["stdout", "csv", "tensorboard"]) # Set new logger model_a2c.set_logger(new_logger_a2c)
trained_a2c = agent.train_model(model=model_a2c, tb_log_name='a2c', total_timesteps=50000) if if_using_a2c else None

Agent 2: DDPG

agent = DRLAgent(env = env_train) model_ddpg = agent.get_model("ddpg") if if_using_ddpg: # set up logger tmp_path = RESULTS_DIR + '/ddpg' new_logger_ddpg = configure(tmp_path, ["stdout", "csv", "tensorboard"]) # Set new logger model_ddpg.set_logger(new_logger_ddpg)
trained_ddpg = agent.train_model(model=model_ddpg, tb_log_name='ddpg', total_timesteps=50000) if if_using_ddpg else None

Agent 3: PPO

agent = DRLAgent(env = env_train) PPO_PARAMS = { "n_steps": 2048, "ent_coef": 0.01, "learning_rate": 0.00025, "batch_size": 128, } model_ppo = agent.get_model("ppo",model_kwargs = PPO_PARAMS) if if_using_ppo: # set up logger tmp_path = RESULTS_DIR + '/ppo' new_logger_ppo = configure(tmp_path, ["stdout", "csv", "tensorboard"]) # Set new logger model_ppo.set_logger(new_logger_ppo)
trained_ppo = agent.train_model(model=model_ppo, tb_log_name='ppo', total_timesteps=50000) if if_using_ppo else None

Agent 4: TD3

agent = DRLAgent(env = env_train) TD3_PARAMS = {"batch_size": 100, "buffer_size": 1000000, "learning_rate": 0.001} model_td3 = agent.get_model("td3",model_kwargs = TD3_PARAMS) if if_using_td3: # set up logger tmp_path = RESULTS_DIR + '/td3' new_logger_td3 = configure(tmp_path, ["stdout", "csv", "tensorboard"]) # Set new logger model_td3.set_logger(new_logger_td3)
trained_td3 = agent.train_model(model=model_td3, tb_log_name='td3', total_timesteps=50000) if if_using_td3 else None

Agent 5: SAC

agent = DRLAgent(env = env_train) SAC_PARAMS = { "batch_size": 128, "buffer_size": 100000, "learning_rate": 0.0001, "learning_starts": 100, "ent_coef": "auto_0.1", } model_sac = agent.get_model("sac",model_kwargs = SAC_PARAMS) if if_using_sac: # set up logger tmp_path = RESULTS_DIR + '/sac' new_logger_sac = configure(tmp_path, ["stdout", "csv", "tensorboard"]) # Set new logger model_sac.set_logger(new_logger_sac)
trained_sac = agent.train_model(model=model_sac, tb_log_name='sac', total_timesteps=50000) if if_using_sac else None

In-sample Performance

Assume that the initial capital is $1,000,000.

Set turbulence threshold

Set the turbulence threshold to be greater than the maximum of insample turbulence data. If current turbulence index is greater than the threshold, then we assume that the current market is volatile

data_risk_indicator = processed_full[(processed_full.date<TRAIN_END_DATE) & (processed_full.date>=TRAIN_START_DATE)] insample_risk_indicator = data_risk_indicator.drop_duplicates(subset=['date'])
insample_risk_indicator.vix.describe()
insample_risk_indicator.vix.quantile(0.996)
insample_risk_indicator.turbulence.describe()
insample_risk_indicator.turbulence.quantile(0.996)

Trading (Out-of-sample Performance)

We update periodically in order to take full advantage of the data, e.g., retrain quarterly, monthly or weekly. We also tune the parameters along the way, in this notebook we use the in-sample data from 2009-01 to 2020-07 to tune the parameters once, so there is some alpha decay here as the length of trade date extends.

Numerous hyperparameters – e.g. the learning rate, the total number of samples to train on – influence the learning process and are usually determined by testing some variations.

e_trade_gym = StockTradingEnv(df = trade, turbulence_threshold = 70,risk_indicator_col='vix', **env_kwargs) # env_trade, obs_trade = e_trade_gym.get_sb_env()
trade.head()
trained_moedl = trained_a2c df_account_value_a2c, df_actions_a2c = DRLAgent.DRL_prediction( model=trained_moedl, environment = e_trade_gym)
trained_moedl = trained_ddpg df_account_value_ddpg, df_actions_ddpg = DRLAgent.DRL_prediction( model=trained_moedl, environment = e_trade_gym)
trained_moedl = trained_ppo df_account_value_ppo, df_actions_ppo = DRLAgent.DRL_prediction( model=trained_moedl, environment = e_trade_gym)
trained_moedl = trained_td3 df_account_value_td3, df_actions_td3 = DRLAgent.DRL_prediction( model=trained_moedl, environment = e_trade_gym)
trained_moedl = trained_sac df_account_value_sac, df_actions_sac = DRLAgent.DRL_prediction( model=trained_moedl, environment = e_trade_gym)
df_account_value_a2c.shape

Part 6.5: Mean Variance Optimization

Mean Variance optimization is a very classic strategy in portfolio management. Here, we go through the whole process to do the mean variance optimization and add it as a baseline to compare.

First, process dataframe to the form for MVO weight calculation.

def process_df_for_mvo(df): df = df.sort_values(['date','tic'],ignore_index=True)[['date','tic','close']] fst = df fst = fst.iloc[0:stock_dimension, :] tic = fst['tic'].tolist() mvo = pd.DataFrame() for k in range(len(tic)): mvo[tic[k]] = 0 for i in range(df.shape[0]//stock_dimension): n = df n = n.iloc[i * stock_dimension:(i+1) * stock_dimension, :] date = n['date'][i*stock_dimension] mvo.loc[date] = n['close'].tolist() return mvo

Helper functions for mean returns and variance-covariance matrix

# Codes in this section partially refer to Dr G A Vijayalakshmi Pai # https://www.kaggle.com/code/vijipai/lesson-5-mean-variance-optimization-of-portfolios/notebook def StockReturnsComputing(StockPrice, Rows, Columns): import numpy as np StockReturn = np.zeros([Rows-1, Columns]) for j in range(Columns): # j: Assets for i in range(Rows-1): # i: Daily Prices StockReturn[i,j]=((StockPrice[i+1, j]-StockPrice[i,j])/StockPrice[i,j])* 100 return StockReturn

Calculate the weights for mean-variance

train_mvo = data_split(processed_full, TRAIN_START_DATE,TRAIN_END_DATE).reset_index() trade_mvo = data_split(processed_full, TRADE_START_DATE,TRADE_END_DATE).reset_index()
StockData = process_df_for_mvo(train_mvo) TradeData = process_df_for_mvo(trade_mvo) TradeData.to_numpy()
#compute asset returns arStockPrices = np.asarray(StockData) [Rows, Cols]=arStockPrices.shape arReturns = StockReturnsComputing(arStockPrices, Rows, Cols) #compute mean returns and variance covariance matrix of returns meanReturns = np.mean(arReturns, axis = 0) covReturns = np.cov(arReturns, rowvar=False) #set precision for printing results np.set_printoptions(precision=3, suppress = True) #display mean returns and variance-covariance matrix of returns print('Mean returns of assets in k-portfolio 1\n', meanReturns) print('Variance-Covariance matrix of returns\n', covReturns)

Use PyPortfolioOpt

from pypfopt.efficient_frontier import EfficientFrontier ef_mean = EfficientFrontier(meanReturns, covReturns, weight_bounds=(0, 0.5)) raw_weights_mean = ef_mean.max_sharpe() cleaned_weights_mean = ef_mean.clean_weights() mvo_weights = np.array([1000000 * cleaned_weights_mean[i] for i in range(29)]) mvo_weights
LastPrice = np.array([1/p for p in StockData.tail(1).to_numpy()[0]]) Initial_Portfolio = np.multiply(mvo_weights, LastPrice) Initial_Portfolio
Portfolio_Assets = TradeData @ Initial_Portfolio MVO_result = pd.DataFrame(Portfolio_Assets, columns=["Mean Var"]) # MVO_result

Part 7: Backtesting Results

Backtesting plays a key role in evaluating the performance of a trading strategy. Automated backtesting tool is preferred because it reduces the human error. We usually use the Quantopian pyfolio package to backtest our trading strategies. It is easy to use and consists of various individual plots that provide a comprehensive image of the performance of a trading strategy.

df_result_a2c = df_account_value_a2c.set_index(df_account_value_a2c.columns[0]) df_result_a2c.rename(columns = {'account_value':'a2c'}, inplace = True) df_result_ddpg = df_account_value_ddpg.set_index(df_account_value_ddpg.columns[0]) df_result_ddpg.rename(columns = {'account_value':'ddpg'}, inplace = True) df_result_td3 = df_account_value_td3.set_index(df_account_value_td3.columns[0]) df_result_td3.rename(columns = {'account_value':'td3'}, inplace = True) df_result_ppo = df_account_value_ppo.set_index(df_account_value_ppo.columns[0]) df_result_ppo.rename(columns = {'account_value':'ppo'}, inplace = True) df_result_sac = df_account_value_sac.set_index(df_account_value_sac.columns[0]) df_result_sac.rename(columns = {'account_value':'sac'}, inplace = True) df_account_value_a2c.to_csv("df_account_value_a2c.csv") #baseline stats print("==============Get Baseline Stats===========") df_dji_ = get_baseline( ticker="^DJI", start = TRADE_START_DATE, end = TRADE_END_DATE) stats = backtest_stats(df_dji_, value_col_name = 'close') df_dji = pd.DataFrame() df_dji['date'] = df_account_value_a2c['date'] df_dji['account_value'] = df_dji_['close'] / df_dji_['close'][0] * env_kwargs["initial_amount"] df_dji.to_csv("df_dji.csv") df_dji = df_dji.set_index(df_dji.columns[0]) df_dji.to_csv("df_dji+.csv") result = pd.DataFrame() result = pd.merge(result, df_result_a2c, how='outer', left_index=True, right_index=True) result = pd.merge(result, df_result_ddpg, how='outer', left_index=True, right_index=True) result = pd.merge(result, df_result_td3, how='outer', left_index=True, right_index=True) result = pd.merge(result, df_result_ppo, how='outer', left_index=True, right_index=True) result = pd.merge(result, df_result_sac, how='outer', left_index=True, right_index=True) result = pd.merge(result, MVO_result, how='outer', left_index=True, right_index=True) print(result.head()) result = pd.merge(result, df_dji, how='outer', left_index=True, right_index=True) # result.columns = ['a2c', 'ddpg', 'td3', 'ppo', 'sac', 'mean var', 'dji'] # print("result: ", result) result.to_csv("result.csv")
df_result_ddpg
%matplotlib inline plt.rcParams["figure.figsize"] = (15,5) plt.figure(); result.plot();