Path: blob/master/unit_tests/environments/test_cash_penalty.py
728 views
from __future__ import annotations12import numpy as np3import pytest45from finrl.meta.env_stock_trading.env_stocktrading_cashpenalty import (6StockTradingEnvCashpenalty,7)8from finrl.meta.preprocessor.yahoodownloader import YahooDownloader91011@pytest.fixture(scope="session")12def ticker_list():13return ["AAPL", "GOOG"]141516@pytest.fixture(scope="session")17def indicator_list():18return ["open", "close", "high", "low", "volume"]192021@pytest.fixture(scope="session")22def data(ticker_list):23return YahooDownloader(24start_date="2019-01-01", end_date="2019-02-01", ticker_list=ticker_list25).fetch_data()262728def test_zero_step(data, ticker_list):29# Prove that zero actions results in zero stock buys, and no price changes30init_amt = 1e631env = StockTradingEnvCashpenalty(32df=data, initial_amount=init_amt, cache_indicator_data=False33)34_ = env.reset()3536# step with all zeros37for i in range(2):38actions = np.zeros(len(ticker_list))39next_state, _, _, _ = env.step(actions)40cash = next_state[0]41holdings = next_state[1 : 1 + len(ticker_list)]42asset_value = env.account_information["asset_value"][-1]43total_assets = env.account_information["total_assets"][-1]4445assert cash == init_amt46assert init_amt == total_assets4748assert np.sum(holdings) == 049assert asset_value == 05051assert env.current_step == i + 1525354def test_patient(data, ticker_list):55# Prove that we just not buying any new assets if running out of cash and the cycle is not ended56aapl_first_close = data[data["tic"] == "AAPL"].head(1)["close"].values[0]57init_amt = aapl_first_close58hmax = aapl_first_close * 10059env = StockTradingEnvCashpenalty(60df=data,61initial_amount=init_amt,62hmax=hmax,63cache_indicator_data=False,64patient=True,65random_start=False,66)67_ = env.reset()6869actions = np.array([1.0, 1.0])70next_state, _, is_done, _ = env.step(actions)71holdings = next_state[1 : 1 + len(ticker_list)]7273assert not is_done74assert np.sum(holdings) == 0757677@pytest.mark.xfail(reason="Not implemented")78def test_cost_penalties():79raise NotImplementedError808182@pytest.mark.xfail(reason="Not implemented")83def test_purchases():84raise NotImplementedError858687@pytest.mark.xfail(reason="Not implemented")88def test_gains():89raise NotImplementedError909192@pytest.mark.skip(reason="this test is not working correctly")93def test_validate_caching(data):94# prove that results with or without caching don't change anything95init_amt = 1e696env_uncached = StockTradingEnvCashpenalty(97df=data, initial_amount=init_amt, cache_indicator_data=False98)99env_cached = StockTradingEnvCashpenalty(100df=data, initial_amount=init_amt, cache_indicator_data=True101)102_ = env_uncached.reset()103_ = env_cached.reset()104for i in range(10):105actions = np.random.uniform(low=-1, high=1, size=2)106print(f"actions: {actions}")107un_state, un_reward, _, _ = env_uncached.step(actions)108ca_state, ca_reward, _, _ = env_cached.step(actions)109110assert un_state == ca_state111assert un_reward == ca_reward112113114