Path: blob/master/examples/FinRL_PaperTrading_Demo_refactored.py
726 views
# Disclaimer: Nothing herein is financial advice, and NOT a recommendation to trade real money. Many platforms exist for simulated trading (paper trading) which can be used for building and developing the methods discussed. Please use common sense and always first consult a professional before trading or investing.1# install finrl library2# %pip install --upgrade git+https://github.com/AI4Finance-Foundation/FinRL.git3# Alpaca keys4from __future__ import annotations56import argparse78parser = argparse.ArgumentParser()9parser.add_argument("data_key", help="data source api key")10parser.add_argument("data_secret", help="data source api secret")11parser.add_argument("data_url", help="data source api base url")12parser.add_argument("trading_key", help="trading api key")13parser.add_argument("trading_secret", help="trading api secret")14parser.add_argument("trading_url", help="trading api base url")15args = parser.parse_args()16DATA_API_KEY = args.data_key17DATA_API_SECRET = args.data_secret18DATA_API_BASE_URL = args.data_url19TRADING_API_KEY = args.trading_key20TRADING_API_SECRET = args.trading_secret21TRADING_API_BASE_URL = args.trading_url2223print("DATA_API_KEY: ", DATA_API_KEY)24print("DATA_API_SECRET: ", DATA_API_SECRET)25print("DATA_API_BASE_URL: ", DATA_API_BASE_URL)26print("TRADING_API_KEY: ", TRADING_API_KEY)27print("TRADING_API_SECRET: ", TRADING_API_SECRET)28print("TRADING_API_BASE_URL: ", TRADING_API_BASE_URL)2930from finrl.meta.env_stock_trading.env_stocktrading_np import StockTradingEnv31from finrl.meta.paper_trading.alpaca import PaperTradingAlpaca32from finrl.meta.paper_trading.common import train, test, alpaca_history, DIA_history33from finrl.config import INDICATORS3435# Import Dow Jones 30 Symbols36from finrl.config_tickers import DOW_30_TICKER3738ticker_list = DOW_30_TICKER39env = StockTradingEnv40# if you want to use larger datasets (change to longer period), and it raises error, please try to increase "target_step". It should be larger than the episode steps.41ERL_PARAMS = {42"learning_rate": 3e-6,43"batch_size": 2048,44"gamma": 0.985,45"seed": 312,46"net_dimension": [128, 64],47"target_step": 5000,48"eval_gap": 30,49"eval_times": 1,50}5152# Set up sliding window of 6 days training and 2 days testing53import datetime54from pandas.tseries.offsets import BDay # BDay is business day, not birthday...5556today = datetime.datetime.today()5758TEST_END_DATE = (today - BDay(1)).to_pydatetime().date()59TEST_START_DATE = (TEST_END_DATE - BDay(1)).to_pydatetime().date()60TRAIN_END_DATE = (TEST_START_DATE - BDay(1)).to_pydatetime().date()61TRAIN_START_DATE = (TRAIN_END_DATE - BDay(5)).to_pydatetime().date()62TRAINFULL_START_DATE = TRAIN_START_DATE63TRAINFULL_END_DATE = TEST_END_DATE6465TRAIN_START_DATE = str(TRAIN_START_DATE)66TRAIN_END_DATE = str(TRAIN_END_DATE)67TEST_START_DATE = str(TEST_START_DATE)68TEST_END_DATE = str(TEST_END_DATE)69TRAINFULL_START_DATE = str(TRAINFULL_START_DATE)70TRAINFULL_END_DATE = str(TRAINFULL_END_DATE)7172print("TRAIN_START_DATE: ", TRAIN_START_DATE)73print("TRAIN_END_DATE: ", TRAIN_END_DATE)74print("TEST_START_DATE: ", TEST_START_DATE)75print("TEST_END_DATE: ", TEST_END_DATE)76print("TRAINFULL_START_DATE: ", TRAINFULL_START_DATE)77print("TRAINFULL_END_DATE: ", TRAINFULL_END_DATE)7879train(80start_date=TRAIN_START_DATE,81end_date=TRAIN_END_DATE,82ticker_list=ticker_list,83data_source="alpaca",84time_interval="1Min",85technical_indicator_list=INDICATORS,86drl_lib="elegantrl",87env=env,88model_name="ppo",89if_vix=True,90API_KEY=DATA_API_KEY,91API_SECRET=DATA_API_SECRET,92API_BASE_URL=DATA_API_BASE_URL,93erl_params=ERL_PARAMS,94cwd="./papertrading_erl", # current_working_dir95break_step=1e5,96)9798account_value_erl = test(99start_date=TEST_START_DATE,100end_date=TEST_END_DATE,101ticker_list=ticker_list,102data_source="alpaca",103time_interval="1Min",104technical_indicator_list=INDICATORS,105drl_lib="elegantrl",106env=env,107model_name="ppo",108if_vix=True,109API_KEY=DATA_API_KEY,110API_SECRET=DATA_API_SECRET,111API_BASE_URL=DATA_API_BASE_URL,112cwd="./papertrading_erl",113net_dimension=ERL_PARAMS["net_dimension"],114)115116train(117start_date=TRAINFULL_START_DATE, # After tuning well, retrain on the training and testing sets118end_date=TRAINFULL_END_DATE,119ticker_list=ticker_list,120data_source="alpaca",121time_interval="1Min",122technical_indicator_list=INDICATORS,123drl_lib="elegantrl",124env=env,125model_name="ppo",126if_vix=True,127API_KEY=DATA_API_KEY,128API_SECRET=DATA_API_SECRET,129API_BASE_URL=DATA_API_BASE_URL,130erl_params=ERL_PARAMS,131cwd="./papertrading_erl_retrain",132break_step=2e5,133)134135action_dim = len(DOW_30_TICKER)136state_dim = (1371 + 2 + 3 * action_dim + len(INDICATORS) * action_dim138) # Calculate the DRL state dimension manually for paper trading. amount + (turbulence, turbulence_bool) + (price, shares, cd (holding time)) * stock_dim + tech_dim139140paper_trading_erl = PaperTradingAlpaca(141ticker_list=DOW_30_TICKER,142time_interval="1Min",143drl_lib="elegantrl",144agent="ppo",145cwd="./papertrading_erl_retrain",146net_dim=ERL_PARAMS["net_dimension"],147state_dim=state_dim,148action_dim=action_dim,149API_KEY=TRADING_API_KEY,150API_SECRET=TRADING_API_SECRET,151API_BASE_URL=TRADING_API_BASE_URL,152tech_indicator_list=INDICATORS,153turbulence_thresh=30,154max_stock=1e2,155)156157paper_trading_erl.run()158159# Check Portfolio Performance160# ## Get cumulative return161df_erl, cumu_erl = alpaca_history(162key=DATA_API_KEY,163secret=DATA_API_SECRET,164url=DATA_API_BASE_URL,165start="2022-09-01", # must be within 1 month166end="2022-09-12",167) # change the date if error occurs168169df_djia, cumu_djia = DIA_history(start="2022-09-01")170returns_erl = cumu_erl - 1171returns_dia = cumu_djia - 1172returns_dia = returns_dia[: returns_erl.shape[0]]173174# plot and save175import matplotlib.pyplot as plt176177plt.figure(dpi=1000)178plt.grid()179plt.grid(which="minor", axis="y")180plt.title("Stock Trading (Paper trading)", fontsize=20)181plt.plot(returns_erl, label="ElegantRL Agent", color="red")182# plt.plot(returns_sb3, label = 'Stable-Baselines3 Agent', color = 'blue' )183# plt.plot(returns_rllib, label = 'RLlib Agent', color = 'green')184plt.plot(returns_dia, label="DJIA", color="grey")185plt.ylabel("Return", fontsize=16)186plt.xlabel("Year 2021", fontsize=16)187plt.xticks(size=14)188plt.yticks(size=14)189ax = plt.gca()190ax.xaxis.set_major_locator(ticker_list.MultipleLocator(78))191ax.xaxis.set_minor_locator(ticker_list.MultipleLocator(6))192ax.yaxis.set_minor_locator(ticker_list.MultipleLocator(0.005))193ax.yaxis.set_major_formatter(ticker_list.PercentFormatter(xmax=1, decimals=2))194ax.xaxis.set_major_formatter(195ticker_list.FixedFormatter(["", "10-19", "", "10-20", "", "10-21", "", "10-22"])196)197plt.legend(fontsize=10.5)198plt.savefig("papertrading_stock.png")199200201