Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/finrl/meta/env_cryptocurrency_trading/env_multiple_crypto.py
732 views
1
from __future__ import annotations
2
3
import numpy as np
4
5
6
class CryptoEnv: # custom env
7
def __init__(
8
self,
9
config,
10
lookback=1,
11
initial_capital=1e6,
12
buy_cost_pct=1e-3,
13
sell_cost_pct=1e-3,
14
gamma=0.99,
15
):
16
self.lookback = lookback
17
self.initial_total_asset = initial_capital
18
self.initial_cash = initial_capital
19
self.buy_cost_pct = buy_cost_pct
20
self.sell_cost_pct = sell_cost_pct
21
self.max_stock = 1
22
self.gamma = gamma
23
self.price_array = config["price_array"]
24
self.tech_array = config["tech_array"]
25
self._generate_action_normalizer()
26
self.crypto_num = self.price_array.shape[1]
27
self.max_step = self.price_array.shape[0] - lookback - 1
28
29
# reset
30
self.time = lookback - 1
31
self.cash = self.initial_cash
32
self.current_price = self.price_array[self.time]
33
self.current_tech = self.tech_array[self.time]
34
self.stocks = np.zeros(self.crypto_num, dtype=np.float32)
35
36
self.total_asset = self.cash + (self.stocks * self.price_array[self.time]).sum()
37
self.episode_return = 0.0
38
self.gamma_return = 0.0
39
40
"""env information"""
41
self.env_name = "MulticryptoEnv"
42
self.state_dim = (
43
1 + (self.price_array.shape[1] + self.tech_array.shape[1]) * lookback
44
)
45
self.action_dim = self.price_array.shape[1]
46
self.if_discrete = False
47
self.target_return = 10
48
49
def reset(
50
self,
51
*,
52
seed=None,
53
options=None,
54
) -> np.ndarray:
55
self.time = self.lookback - 1
56
self.current_price = self.price_array[self.time]
57
self.current_tech = self.tech_array[self.time]
58
self.cash = self.initial_cash # reset()
59
self.stocks = np.zeros(self.crypto_num, dtype=np.float32)
60
self.total_asset = self.cash + (self.stocks * self.price_array[self.time]).sum()
61
62
state = self.get_state()
63
return state
64
65
def step(self, actions) -> (np.ndarray, float, bool, None):
66
self.time += 1
67
68
price = self.price_array[self.time]
69
for i in range(self.action_dim):
70
norm_vector_i = self.action_norm_vector[i]
71
actions[i] = actions[i] * norm_vector_i
72
73
for index in np.where(actions < 0)[0]: # sell_index:
74
if price[index] > 0: # Sell only if current asset is > 0
75
sell_num_shares = min(self.stocks[index], -actions[index])
76
self.stocks[index] -= sell_num_shares
77
self.cash += price[index] * sell_num_shares * (1 - self.sell_cost_pct)
78
79
for index in np.where(actions > 0)[0]: # buy_index:
80
if (
81
price[index] > 0
82
): # Buy only if the price is > 0 (no missing data in this particular date)
83
buy_num_shares = min(
84
self.cash // (price[index] * (1 + self.buy_cost_pct)),
85
actions[index],
86
)
87
self.stocks[index] += buy_num_shares
88
self.cash -= price[index] * buy_num_shares * (1 + self.buy_cost_pct)
89
90
"""update time"""
91
done = self.time == self.max_step
92
state = self.get_state()
93
next_total_asset = self.cash + (self.stocks * self.price_array[self.time]).sum()
94
reward = (next_total_asset - self.total_asset) * 2**-16
95
self.total_asset = next_total_asset
96
self.gamma_return = self.gamma_return * self.gamma + reward
97
self.cumu_return = self.total_asset / self.initial_cash
98
if done:
99
reward = self.gamma_return
100
self.episode_return = self.total_asset / self.initial_cash
101
return state, reward, done, None
102
103
def get_state(self):
104
state = np.hstack((self.cash * 2**-18, self.stocks * 2**-3))
105
for i in range(self.lookback):
106
tech_i = self.tech_array[self.time - i]
107
normalized_tech_i = tech_i * 2**-15
108
state = np.hstack((state, normalized_tech_i)).astype(np.float32)
109
return state
110
111
def close(self):
112
pass
113
114
def _generate_action_normalizer(self):
115
action_norm_vector = []
116
price_0 = self.price_array[0]
117
for price in price_0:
118
x = len(str(price)) - 7
119
action_norm_vector.append(1 / ((10) ** x))
120
121
self.action_norm_vector = np.asarray(action_norm_vector)
122
123