Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/finrl/meta/env_cryptocurrency_trading/env_btc_ccxt.py
732 views
1
from __future__ import annotations
2
3
import numpy as np
4
5
6
class BitcoinEnv: # custom env
7
def __init__(
8
self,
9
data_cwd=None,
10
price_ary=None,
11
tech_ary=None,
12
time_frequency=15,
13
start=None,
14
mid1=172197,
15
mid2=216837,
16
end=None,
17
initial_account=1e6,
18
max_stock=1e2,
19
transaction_fee_percent=1e-3,
20
mode="train",
21
gamma=0.99,
22
):
23
self.stock_dim = 1
24
self.initial_account = initial_account
25
self.transaction_fee_percent = transaction_fee_percent
26
self.max_stock = 1
27
self.gamma = gamma
28
self.mode = mode
29
self.load_data(
30
data_cwd, price_ary, tech_ary, time_frequency, start, mid1, mid2, end
31
)
32
33
# reset
34
self.day = 0
35
self.initial_account__reset = self.initial_account
36
self.account = self.initial_account__reset
37
self.day_price = self.price_ary[self.day]
38
self.day_tech = self.tech_ary[self.day]
39
self.stocks = 0.0 # multi-stack
40
41
self.total_asset = self.account + self.day_price[0] * self.stocks
42
self.episode_return = 0.0
43
self.gamma_return = 0.0
44
45
"""env information"""
46
self.env_name = "BitcoinEnv4"
47
self.state_dim = 1 + 1 + self.price_ary.shape[1] + self.tech_ary.shape[1]
48
self.action_dim = 1
49
self.if_discrete = False
50
self.target_return = 10
51
self.max_step = self.price_ary.shape[0]
52
53
def reset(
54
self,
55
*,
56
seed=None,
57
options=None,
58
) -> np.ndarray:
59
self.day = 0
60
self.day_price = self.price_ary[self.day]
61
self.day_tech = self.tech_ary[self.day]
62
self.initial_account__reset = self.initial_account # reset()
63
self.account = self.initial_account__reset
64
self.stocks = 0.0
65
self.total_asset = self.account + self.day_price[0] * self.stocks
66
67
normalized_tech = [
68
self.day_tech[0] * 2**-1,
69
self.day_tech[1] * 2**-15,
70
self.day_tech[2] * 2**-15,
71
self.day_tech[3] * 2**-6,
72
self.day_tech[4] * 2**-6,
73
self.day_tech[5] * 2**-15,
74
self.day_tech[6] * 2**-15,
75
]
76
state = np.hstack(
77
(
78
self.account * 2**-18,
79
self.day_price * 2**-15,
80
normalized_tech,
81
self.stocks * 2**-4,
82
)
83
).astype(np.float32)
84
return state
85
86
def step(self, action) -> (np.ndarray, float, bool, None):
87
stock_action = action[0]
88
"""buy or sell stock"""
89
adj = self.day_price[0]
90
if stock_action < 0:
91
stock_action = max(
92
0, min(-1 * stock_action, 0.5 * self.total_asset / adj + self.stocks)
93
)
94
self.account += adj * stock_action * (1 - self.transaction_fee_percent)
95
self.stocks -= stock_action
96
elif stock_action > 0:
97
max_amount = self.account / adj
98
stock_action = min(stock_action, max_amount)
99
self.account -= adj * stock_action * (1 + self.transaction_fee_percent)
100
self.stocks += stock_action
101
102
"""update day"""
103
self.day += 1
104
self.day_price = self.price_ary[self.day]
105
self.day_tech = self.tech_ary[self.day]
106
done = (self.day + 1) == self.max_step
107
normalized_tech = [
108
self.day_tech[0] * 2**-1,
109
self.day_tech[1] * 2**-15,
110
self.day_tech[2] * 2**-15,
111
self.day_tech[3] * 2**-6,
112
self.day_tech[4] * 2**-6,
113
self.day_tech[5] * 2**-15,
114
self.day_tech[6] * 2**-15,
115
]
116
state = np.hstack(
117
(
118
self.account * 2**-18,
119
self.day_price * 2**-15,
120
normalized_tech,
121
self.stocks * 2**-4,
122
)
123
).astype(np.float32)
124
125
next_total_asset = self.account + self.day_price[0] * self.stocks
126
reward = (next_total_asset - self.total_asset) * 2**-16
127
self.total_asset = next_total_asset
128
129
self.gamma_return = self.gamma_return * self.gamma + reward
130
if done:
131
reward += self.gamma_return
132
self.gamma_return = 0.0
133
self.episode_return = next_total_asset / self.initial_account
134
return state, reward, done, None
135
136
def draw_cumulative_return(self, args, _torch) -> list:
137
state_dim = self.state_dim
138
action_dim = self.action_dim
139
140
agent = args.agent
141
net_dim = args.net_dim
142
cwd = args.cwd
143
144
agent.init(net_dim, state_dim, action_dim)
145
agent.save_load_model(cwd=cwd, if_save=False)
146
act = agent.act
147
device = agent.device
148
149
state = self.reset()
150
episode_returns = list()
151
episode_returns.append(1)
152
btc_returns = list() # the cumulative_return / initial_account
153
with _torch.no_grad():
154
for i in range(self.max_step):
155
if i == 0:
156
init_price = self.day_price[0]
157
btc_returns.append(self.day_price[i] / init_price)
158
s_tensor = _torch.as_tensor((state,), device=device)
159
a_tensor = act(s_tensor) # action_tanh = act.forward()
160
action = (
161
a_tensor.detach().cpu().numpy()[0]
162
) # not need detach(), because with torch.no_grad() outside
163
state, reward, done, _ = self.step(action)
164
165
episode_returns.append(self.total_asset / 1e6)
166
if done:
167
break
168
169
import matplotlib.pyplot as plt
170
171
plt.plot(episode_returns, label="agent return")
172
plt.plot(btc_returns, color="yellow", label="BTC return")
173
plt.grid()
174
plt.title("cumulative return")
175
plt.xlabel("day")
176
plt.xlabel("multiple of initial_account")
177
plt.legend()
178
plt.savefig(f"{cwd}/cumulative_return.jpg")
179
return episode_returns, btc_returns
180
181
def load_data(
182
self, data_cwd, price_ary, tech_ary, time_frequency, start, mid1, mid2, end
183
):
184
if data_cwd is not None:
185
try:
186
price_ary = np.load(f"{data_cwd}/price_ary.npy")
187
tech_ary = np.load(f"{data_cwd}/tech_ary.npy")
188
except BaseException:
189
raise ValueError("Data files not found!")
190
else:
191
price_ary = price_ary
192
tech_ary = tech_ary
193
194
n = price_ary.shape[0]
195
if self.mode == "train":
196
self.price_ary = price_ary[start:mid1]
197
self.tech_ary = tech_ary[start:mid1]
198
n = self.price_ary.shape[0]
199
x = n // int(time_frequency)
200
ind = [int(time_frequency) * i for i in range(x)]
201
self.price_ary = self.price_ary[ind]
202
self.tech_ary = self.tech_ary[ind]
203
elif self.mode == "test":
204
self.price_ary = price_ary[mid1:mid2]
205
self.tech_ary = tech_ary[mid1:mid2]
206
n = self.price_ary.shape[0]
207
x = n // int(time_frequency)
208
ind = [int(time_frequency) * i for i in range(x)]
209
self.price_ary = self.price_ary[ind]
210
self.tech_ary = self.tech_ary[ind]
211
elif self.mode == "trade":
212
self.price_ary = price_ary[mid2:end]
213
self.tech_ary = tech_ary[mid2:end]
214
n = self.price_ary.shape[0]
215
x = n // int(time_frequency)
216
ind = [int(time_frequency) * i for i in range(x)]
217
self.price_ary = self.price_ary[ind]
218
self.tech_ary = self.tech_ary[ind]
219
else:
220
raise ValueError("Invalid Mode!")
221
222