Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/finrl/test.py
728 views
1
from __future__ import annotations
2
3
from finrl.config import INDICATORS
4
from finrl.config import RLlib_PARAMS
5
from finrl.config import TEST_END_DATE
6
from finrl.config import TEST_START_DATE
7
from finrl.config_tickers import DOW_30_TICKER
8
from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv
9
10
11
def test(
12
start_date,
13
end_date,
14
ticker_list,
15
data_source,
16
time_interval,
17
technical_indicator_list,
18
drl_lib,
19
env,
20
model_name,
21
if_vix=True,
22
**kwargs,
23
):
24
# import data processor
25
from finrl.meta.data_processor import DataProcessor
26
27
# fetch data
28
dp = DataProcessor(data_source, **kwargs)
29
data = dp.download_data(ticker_list, start_date, end_date, time_interval)
30
data = dp.clean_data(data)
31
data = dp.add_technical_indicator(data, technical_indicator_list)
32
33
if if_vix:
34
data = dp.add_vix(data)
35
price_array, tech_array, turbulence_array = dp.df_to_array(data, if_vix)
36
37
env_config = {
38
"price_array": price_array,
39
"tech_array": tech_array,
40
"turbulence_array": turbulence_array,
41
"if_train": False,
42
}
43
env_instance = env(config=env_config)
44
45
# load elegantrl needs state dim, action dim and net dim
46
net_dimension = kwargs.get("net_dimension", 2**7)
47
cwd = kwargs.get("cwd", "./" + str(model_name))
48
print("price_array: ", len(price_array))
49
50
if drl_lib == "elegantrl":
51
from finrl.agents.elegantrl.models import DRLAgent as DRLAgent_erl
52
53
episode_total_assets = DRLAgent_erl.DRL_prediction(
54
model_name=model_name,
55
cwd=cwd,
56
net_dimension=net_dimension,
57
environment=env_instance,
58
)
59
return episode_total_assets
60
elif drl_lib == "rllib":
61
from finrl.agents.rllib.models import DRLAgent as DRLAgent_rllib
62
63
episode_total_assets = DRLAgent_rllib.DRL_prediction(
64
model_name=model_name,
65
env=env,
66
price_array=price_array,
67
tech_array=tech_array,
68
turbulence_array=turbulence_array,
69
agent_path=cwd,
70
)
71
return episode_total_assets
72
elif drl_lib == "stable_baselines3":
73
from finrl.agents.stablebaselines3.models import DRLAgent as DRLAgent_sb3
74
75
episode_total_assets = DRLAgent_sb3.DRL_prediction_load_from_file(
76
model_name=model_name, environment=env_instance, cwd=cwd
77
)
78
return episode_total_assets
79
else:
80
raise ValueError("DRL library input is NOT supported. Please check.")
81
82
83
if __name__ == "__main__":
84
env = StockTradingEnv
85
86
# demo for elegantrl
87
kwargs = (
88
{}
89
) # in current meta, with respect yahoofinance, kwargs is {}. For other data sources, such as joinquant, kwargs is not empty
90
91
account_value_erl = test(
92
start_date=TEST_START_DATE,
93
end_date=TEST_END_DATE,
94
ticker_list=DOW_30_TICKER,
95
data_source="yahoofinance",
96
time_interval="1D",
97
technical_indicator_list=INDICATORS,
98
drl_lib="elegantrl",
99
env=env,
100
model_name="ppo",
101
cwd="./test_ppo",
102
net_dimension=512,
103
kwargs=kwargs,
104
)
105
106
## if users want to use rllib, or stable-baselines3, users can remove the following comments
107
108
# # demo for rllib
109
# import ray
110
# ray.shutdown() # always shutdown previous session if any
111
# account_value_rllib = test(
112
# start_date=TEST_START_DATE,
113
# end_date=TEST_END_DATE,
114
# ticker_list=DOW_30_TICKER,
115
# data_source="yahoofinance",
116
# time_interval="1D",
117
# technical_indicator_list=INDICATORS,
118
# drl_lib="rllib",
119
# env=env,
120
# model_name="ppo",
121
# cwd="./test_ppo/checkpoint_000030/checkpoint-30",
122
# rllib_params=RLlib_PARAMS,
123
# )
124
#
125
# # demo for stable baselines3
126
# account_value_sb3 = test(
127
# start_date=TEST_START_DATE,
128
# end_date=TEST_END_DATE,
129
# ticker_list=DOW_30_TICKER,
130
# data_source="yahoofinance",
131
# time_interval="1D",
132
# technical_indicator_list=INDICATORS,
133
# drl_lib="stable_baselines3",
134
# env=env,
135
# model_name="sac",
136
# cwd="./test_sac.zip",
137
# )
138
139