from __future__ import annotations
from finrl.config import INDICATORS
from finrl.config import RLlib_PARAMS
from finrl.config import TEST_END_DATE
from finrl.config import TEST_START_DATE
from finrl.config_tickers import DOW_30_TICKER
from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv
def test(
start_date,
end_date,
ticker_list,
data_source,
time_interval,
technical_indicator_list,
drl_lib,
env,
model_name,
if_vix=True,
**kwargs,
):
from finrl.meta.data_processor import DataProcessor
dp = DataProcessor(data_source, **kwargs)
data = dp.download_data(ticker_list, start_date, end_date, time_interval)
data = dp.clean_data(data)
data = dp.add_technical_indicator(data, technical_indicator_list)
if if_vix:
data = dp.add_vix(data)
price_array, tech_array, turbulence_array = dp.df_to_array(data, if_vix)
env_config = {
"price_array": price_array,
"tech_array": tech_array,
"turbulence_array": turbulence_array,
"if_train": False,
}
env_instance = env(config=env_config)
net_dimension = kwargs.get("net_dimension", 2**7)
cwd = kwargs.get("cwd", "./" + str(model_name))
print("price_array: ", len(price_array))
if drl_lib == "elegantrl":
from finrl.agents.elegantrl.models import DRLAgent as DRLAgent_erl
episode_total_assets = DRLAgent_erl.DRL_prediction(
model_name=model_name,
cwd=cwd,
net_dimension=net_dimension,
environment=env_instance,
)
return episode_total_assets
elif drl_lib == "rllib":
from finrl.agents.rllib.models import DRLAgent as DRLAgent_rllib
episode_total_assets = DRLAgent_rllib.DRL_prediction(
model_name=model_name,
env=env,
price_array=price_array,
tech_array=tech_array,
turbulence_array=turbulence_array,
agent_path=cwd,
)
return episode_total_assets
elif drl_lib == "stable_baselines3":
from finrl.agents.stablebaselines3.models import DRLAgent as DRLAgent_sb3
episode_total_assets = DRLAgent_sb3.DRL_prediction_load_from_file(
model_name=model_name, environment=env_instance, cwd=cwd
)
return episode_total_assets
else:
raise ValueError("DRL library input is NOT supported. Please check.")
if __name__ == "__main__":
env = StockTradingEnv
kwargs = (
{}
)
account_value_erl = test(
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
ticker_list=DOW_30_TICKER,
data_source="yahoofinance",
time_interval="1D",
technical_indicator_list=INDICATORS,
drl_lib="elegantrl",
env=env,
model_name="ppo",
cwd="./test_ppo",
net_dimension=512,
kwargs=kwargs,
)