Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/unit_tests/environments/test_cash_penalty.py
728 views
1
from __future__ import annotations
2
3
import numpy as np
4
import pytest
5
6
from finrl.meta.env_stock_trading.env_stocktrading_cashpenalty import (
7
StockTradingEnvCashpenalty,
8
)
9
from finrl.meta.preprocessor.yahoodownloader import YahooDownloader
10
11
12
@pytest.fixture(scope="session")
13
def ticker_list():
14
return ["AAPL", "GOOG"]
15
16
17
@pytest.fixture(scope="session")
18
def indicator_list():
19
return ["open", "close", "high", "low", "volume"]
20
21
22
@pytest.fixture(scope="session")
23
def data(ticker_list):
24
return YahooDownloader(
25
start_date="2019-01-01", end_date="2019-02-01", ticker_list=ticker_list
26
).fetch_data()
27
28
29
def test_zero_step(data, ticker_list):
30
# Prove that zero actions results in zero stock buys, and no price changes
31
init_amt = 1e6
32
env = StockTradingEnvCashpenalty(
33
df=data, initial_amount=init_amt, cache_indicator_data=False
34
)
35
_ = env.reset()
36
37
# step with all zeros
38
for i in range(2):
39
actions = np.zeros(len(ticker_list))
40
next_state, _, _, _ = env.step(actions)
41
cash = next_state[0]
42
holdings = next_state[1 : 1 + len(ticker_list)]
43
asset_value = env.account_information["asset_value"][-1]
44
total_assets = env.account_information["total_assets"][-1]
45
46
assert cash == init_amt
47
assert init_amt == total_assets
48
49
assert np.sum(holdings) == 0
50
assert asset_value == 0
51
52
assert env.current_step == i + 1
53
54
55
def test_patient(data, ticker_list):
56
# Prove that we just not buying any new assets if running out of cash and the cycle is not ended
57
aapl_first_close = data[data["tic"] == "AAPL"].head(1)["close"].values[0]
58
init_amt = aapl_first_close
59
hmax = aapl_first_close * 100
60
env = StockTradingEnvCashpenalty(
61
df=data,
62
initial_amount=init_amt,
63
hmax=hmax,
64
cache_indicator_data=False,
65
patient=True,
66
random_start=False,
67
)
68
_ = env.reset()
69
70
actions = np.array([1.0, 1.0])
71
next_state, _, is_done, _ = env.step(actions)
72
holdings = next_state[1 : 1 + len(ticker_list)]
73
74
assert not is_done
75
assert np.sum(holdings) == 0
76
77
78
@pytest.mark.xfail(reason="Not implemented")
79
def test_cost_penalties():
80
raise NotImplementedError
81
82
83
@pytest.mark.xfail(reason="Not implemented")
84
def test_purchases():
85
raise NotImplementedError
86
87
88
@pytest.mark.xfail(reason="Not implemented")
89
def test_gains():
90
raise NotImplementedError
91
92
93
@pytest.mark.skip(reason="this test is not working correctly")
94
def test_validate_caching(data):
95
# prove that results with or without caching don't change anything
96
init_amt = 1e6
97
env_uncached = StockTradingEnvCashpenalty(
98
df=data, initial_amount=init_amt, cache_indicator_data=False
99
)
100
env_cached = StockTradingEnvCashpenalty(
101
df=data, initial_amount=init_amt, cache_indicator_data=True
102
)
103
_ = env_uncached.reset()
104
_ = env_cached.reset()
105
for i in range(10):
106
actions = np.random.uniform(low=-1, high=1, size=2)
107
print(f"actions: {actions}")
108
un_state, un_reward, _, _ = env_uncached.step(actions)
109
ca_state, ca_reward, _, _ = env_cached.step(actions)
110
111
assert un_state == ca_state
112
assert un_reward == ca_reward
113
114