Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/finrl/main.py
728 views
1
from __future__ import annotations
2
3
import os
4
from argparse import ArgumentParser
5
from typing import List
6
7
from finrl.config import ALPACA_API_BASE_URL
8
from finrl.config import DATA_SAVE_DIR
9
from finrl.config import ERL_PARAMS
10
from finrl.config import INDICATORS
11
from finrl.config import RESULTS_DIR
12
from finrl.config import TENSORBOARD_LOG_DIR
13
from finrl.config import TEST_END_DATE
14
from finrl.config import TEST_START_DATE
15
from finrl.config import TRADE_END_DATE
16
from finrl.config import TRADE_START_DATE
17
from finrl.config import TRAIN_END_DATE
18
from finrl.config import TRAIN_START_DATE
19
from finrl.config import TRAINED_MODEL_DIR
20
from finrl.config_tickers import DOW_30_TICKER
21
from finrl.meta.env_stock_trading.env_stocktrading_np import StockTradingEnv
22
23
# construct environment
24
25
# try:
26
# from finrl.config_private import ALPACA_API_KEY, ALPACA_API_SECRET
27
# except ImportError:
28
# raise FileNotFoundError(
29
# "Please set your own ALPACA_API_KEY and ALPACA_API_SECRET in config_private.py"
30
# )
31
32
33
def build_parser():
34
parser = ArgumentParser()
35
parser.add_argument(
36
"--mode",
37
dest="mode",
38
help="start mode, train, download_data" " backtest",
39
metavar="MODE",
40
default="train",
41
)
42
return parser
43
44
45
def check_and_make_directories(directories: list[str]):
46
for directory in directories:
47
if not os.path.exists(directory):
48
os.makedirs(directory)
49
50
51
def main() -> int:
52
parser = build_parser()
53
options = parser.parse_args()
54
check_and_make_directories(
55
[DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR]
56
)
57
58
if options.mode == "train":
59
from finrl import train
60
61
env = StockTradingEnv
62
63
# demo for elegantrl
64
kwargs = (
65
{}
66
) # in current meta, with respect yahoofinance, kwargs is {}. For other data sources, such as joinquant, kwargs is not empty
67
train(
68
start_date=TRAIN_START_DATE,
69
end_date=TRAIN_END_DATE,
70
ticker_list=DOW_30_TICKER,
71
data_source="yahoofinance",
72
time_interval="1D",
73
technical_indicator_list=INDICATORS,
74
drl_lib="elegantrl",
75
env=env,
76
model_name="ppo",
77
cwd="./test_ppo",
78
erl_params=ERL_PARAMS,
79
break_step=1e5,
80
kwargs=kwargs,
81
)
82
elif options.mode == "test":
83
from finrl import test
84
85
env = StockTradingEnv
86
87
# demo for elegantrl
88
# in current meta, with respect yahoofinance, kwargs is {}. For other data sources, such as joinquant, kwargs is not empty
89
kwargs = {}
90
91
account_value_erl = test( # noqa
92
start_date=TEST_START_DATE,
93
end_date=TEST_END_DATE,
94
ticker_list=DOW_30_TICKER,
95
data_source="yahoofinance",
96
time_interval="1D",
97
technical_indicator_list=INDICATORS,
98
drl_lib="elegantrl",
99
env=env,
100
model_name="ppo",
101
cwd="./test_ppo",
102
net_dimension=512,
103
kwargs=kwargs,
104
)
105
elif options.mode == "trade":
106
from finrl import trade
107
108
try:
109
from finrl.config_private import ALPACA_API_KEY, ALPACA_API_SECRET
110
except ImportError:
111
raise FileNotFoundError(
112
"Please set your own ALPACA_API_KEY and ALPACA_API_SECRET in config_private.py"
113
)
114
env = StockTradingEnv
115
kwargs = {}
116
trade(
117
start_date=TRADE_START_DATE,
118
end_date=TRADE_END_DATE,
119
ticker_list=DOW_30_TICKER,
120
data_source="yahoofinance",
121
time_interval="1D",
122
technical_indicator_list=INDICATORS,
123
drl_lib="elegantrl",
124
env=env,
125
model_name="ppo",
126
API_KEY=ALPACA_API_KEY,
127
API_SECRET=ALPACA_API_SECRET,
128
API_BASE_URL=ALPACA_API_BASE_URL,
129
trade_mode="paper_trading",
130
if_vix=True,
131
kwargs=kwargs,
132
state_dim=len(DOW_30_TICKER) * (len(INDICATORS) + 3)
133
+ 3, # bug fix: for ppo add dimension of state/observations space = len(stocks)* len(INDICATORS) + 3+ 3*len(stocks)
134
action_dim=len(
135
DOW_30_TICKER
136
), # bug fix: for ppo add dimension of action space = len(stocks)
137
)
138
else:
139
raise ValueError("Wrong mode.")
140
return 0
141
142
143
# Users can input the following command in terminal
144
# python main.py --mode=train
145
# python main.py --mode=test
146
# python main.py --mode=trade
147
if __name__ == "__main__":
148
raise SystemExit(main())
149
150