Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/finrl/meta/env_stock_trading/env_stocktrading.py
732 views
1
from __future__ import annotations
2
3
from typing import List
4
5
import gymnasium as gym
6
import matplotlib
7
import matplotlib.pyplot as plt
8
import numpy as np
9
import pandas as pd
10
from gymnasium import spaces
11
from gymnasium.utils import seeding
12
from stable_baselines3.common.vec_env import DummyVecEnv
13
14
matplotlib.use("Agg")
15
16
# from stable_baselines3.common.logger import Logger, KVWriter, CSVOutputFormat
17
18
19
class StockTradingEnv(gym.Env):
20
"""
21
A stock trading environment for OpenAI gym
22
23
Parameters:
24
df (pandas.DataFrame): Dataframe containing data
25
hmax (int): Maximum cash to be traded in each trade per asset.
26
initial_amount (int): Amount of cash initially available
27
buy_cost_pct (float, array): Cost for buying shares, each index corresponds to each asset
28
sell_cost_pct (float, array): Cost for selling shares, each index corresponds to each asset
29
turbulence_threshold (float): Maximum turbulence allowed in market for purchases to occur. If exceeded, positions are liquidated
30
print_verbosity(int): When iterating (step), how often to print stats about state of env
31
"""
32
33
metadata = {"render.modes": ["human"]}
34
35
def __init__(
36
self,
37
df: pd.DataFrame,
38
stock_dim: int,
39
hmax: int,
40
initial_amount: int,
41
num_stock_shares: list[int],
42
buy_cost_pct: list[float],
43
sell_cost_pct: list[float],
44
reward_scaling: float,
45
state_space: int,
46
action_space: int,
47
tech_indicator_list: list[str],
48
turbulence_threshold=None,
49
risk_indicator_col="turbulence",
50
make_plots: bool = False,
51
print_verbosity=10,
52
day=0,
53
initial=True,
54
previous_state=[],
55
model_name="",
56
mode="",
57
iteration="",
58
):
59
self.day = day
60
self.df = df
61
self.stock_dim = stock_dim
62
self.hmax = hmax
63
self.num_stock_shares = num_stock_shares
64
self.initial_amount = initial_amount # get the initial cash
65
self.buy_cost_pct = buy_cost_pct
66
self.sell_cost_pct = sell_cost_pct
67
self.reward_scaling = reward_scaling
68
self.state_space = state_space
69
self.action_space = action_space
70
self.tech_indicator_list = tech_indicator_list
71
self.action_space = spaces.Box(low=-1, high=1, shape=(self.action_space,))
72
self.observation_space = spaces.Box(
73
low=-np.inf, high=np.inf, shape=(self.state_space,)
74
)
75
self.data = self.df.loc[self.day, :]
76
self.terminal = False
77
self.make_plots = make_plots
78
self.print_verbosity = print_verbosity
79
self.turbulence_threshold = turbulence_threshold
80
self.risk_indicator_col = risk_indicator_col
81
self.initial = initial
82
self.previous_state = previous_state
83
self.model_name = model_name
84
self.mode = mode
85
self.iteration = iteration
86
# initalize state
87
self.state = self._initiate_state()
88
89
# initialize reward
90
self.reward = 0
91
self.turbulence = 0
92
self.cost = 0
93
self.trades = 0
94
self.episode = 0
95
# memorize all the total balance change
96
self.asset_memory = [
97
self.initial_amount
98
+ np.sum(
99
np.array(self.num_stock_shares)
100
* np.array(self.state[1 : 1 + self.stock_dim])
101
)
102
] # the initial total asset is calculated by cash + sum (num_share_stock_i * price_stock_i)
103
self.rewards_memory = []
104
self.actions_memory = []
105
self.state_memory = (
106
[]
107
) # we need sometimes to preserve the state in the middle of trading process
108
self.date_memory = [self._get_date()]
109
# self.logger = Logger('results',[CSVOutputFormat])
110
# self.reset()
111
self._seed()
112
113
def _sell_stock(self, index, action):
114
def _do_sell_normal():
115
if (
116
self.state[index + 2 * self.stock_dim + 1] != True
117
): # check if the stock is able to sell, for simlicity we just add it in techical index
118
# if self.state[index + 1] > 0: # if we use price<0 to denote a stock is unable to trade in that day, the total asset calculation may be wrong for the price is unreasonable
119
# Sell only if the price is > 0 (no missing data in this particular date)
120
# perform sell action based on the sign of the action
121
if self.state[index + self.stock_dim + 1] > 0:
122
# Sell only if current asset is > 0
123
sell_num_shares = min(
124
abs(action), self.state[index + self.stock_dim + 1]
125
)
126
sell_amount = (
127
self.state[index + 1]
128
* sell_num_shares
129
* (1 - self.sell_cost_pct[index])
130
)
131
# update balance
132
self.state[0] += sell_amount
133
134
self.state[index + self.stock_dim + 1] -= sell_num_shares
135
self.cost += (
136
self.state[index + 1]
137
* sell_num_shares
138
* self.sell_cost_pct[index]
139
)
140
self.trades += 1
141
else:
142
sell_num_shares = 0
143
else:
144
sell_num_shares = 0
145
146
return sell_num_shares
147
148
# perform sell action based on the sign of the action
149
if self.turbulence_threshold is not None:
150
if self.turbulence >= self.turbulence_threshold:
151
if self.state[index + 1] > 0:
152
# Sell only if the price is > 0 (no missing data in this particular date)
153
# if turbulence goes over threshold, just clear out all positions
154
if self.state[index + self.stock_dim + 1] > 0:
155
# Sell only if current asset is > 0
156
sell_num_shares = self.state[index + self.stock_dim + 1]
157
sell_amount = (
158
self.state[index + 1]
159
* sell_num_shares
160
* (1 - self.sell_cost_pct[index])
161
)
162
# update balance
163
self.state[0] += sell_amount
164
self.state[index + self.stock_dim + 1] = 0
165
self.cost += (
166
self.state[index + 1]
167
* sell_num_shares
168
* self.sell_cost_pct[index]
169
)
170
self.trades += 1
171
else:
172
sell_num_shares = 0
173
else:
174
sell_num_shares = 0
175
else:
176
sell_num_shares = _do_sell_normal()
177
else:
178
sell_num_shares = _do_sell_normal()
179
180
return sell_num_shares
181
182
def _buy_stock(self, index, action):
183
def _do_buy():
184
if (
185
self.state[index + 2 * self.stock_dim + 1] != True
186
): # check if the stock is able to buy
187
# if self.state[index + 1] >0:
188
# Buy only if the price is > 0 (no missing data in this particular date)
189
available_amount = self.state[0] // (
190
self.state[index + 1] * (1 + self.buy_cost_pct[index])
191
) # when buying stocks, we should consider the cost of trading when calculating available_amount, or we may be have cash<0
192
# print('available_amount:{}'.format(available_amount))
193
194
# update balance
195
buy_num_shares = min(available_amount, action)
196
buy_amount = (
197
self.state[index + 1]
198
* buy_num_shares
199
* (1 + self.buy_cost_pct[index])
200
)
201
self.state[0] -= buy_amount
202
203
self.state[index + self.stock_dim + 1] += buy_num_shares
204
205
self.cost += (
206
self.state[index + 1] * buy_num_shares * self.buy_cost_pct[index]
207
)
208
self.trades += 1
209
else:
210
buy_num_shares = 0
211
212
return buy_num_shares
213
214
# perform buy action based on the sign of the action
215
if self.turbulence_threshold is None:
216
buy_num_shares = _do_buy()
217
else:
218
if self.turbulence < self.turbulence_threshold:
219
buy_num_shares = _do_buy()
220
else:
221
buy_num_shares = 0
222
pass
223
224
return buy_num_shares
225
226
def _make_plot(self):
227
plt.plot(self.asset_memory, "r")
228
plt.savefig(f"results/account_value_trade_{self.episode}.png")
229
plt.close()
230
231
def step(self, actions):
232
self.terminal = self.day >= len(self.df.index.unique()) - 1
233
if self.terminal:
234
# print(f"Episode: {self.episode}")
235
if self.make_plots:
236
self._make_plot()
237
end_total_asset = self.state[0] + sum(
238
np.array(self.state[1 : (self.stock_dim + 1)])
239
* np.array(self.state[(self.stock_dim + 1) : (self.stock_dim * 2 + 1)])
240
)
241
df_total_value = pd.DataFrame(self.asset_memory)
242
tot_reward = (
243
self.state[0]
244
+ sum(
245
np.array(self.state[1 : (self.stock_dim + 1)])
246
* np.array(
247
self.state[(self.stock_dim + 1) : (self.stock_dim * 2 + 1)]
248
)
249
)
250
- self.asset_memory[0]
251
) # initial_amount is only cash part of our initial asset
252
df_total_value.columns = ["account_value"]
253
df_total_value["date"] = self.date_memory
254
df_total_value["daily_return"] = df_total_value["account_value"].pct_change(
255
1
256
)
257
if df_total_value["daily_return"].std() != 0:
258
sharpe = (
259
(252**0.5)
260
* df_total_value["daily_return"].mean()
261
/ df_total_value["daily_return"].std()
262
)
263
df_rewards = pd.DataFrame(self.rewards_memory)
264
df_rewards.columns = ["account_rewards"]
265
df_rewards["date"] = self.date_memory[:-1]
266
if self.episode % self.print_verbosity == 0:
267
print(f"day: {self.day}, episode: {self.episode}")
268
print(f"begin_total_asset: {self.asset_memory[0]:0.2f}")
269
print(f"end_total_asset: {end_total_asset:0.2f}")
270
print(f"total_reward: {tot_reward:0.2f}")
271
print(f"total_cost: {self.cost:0.2f}")
272
print(f"total_trades: {self.trades}")
273
if df_total_value["daily_return"].std() != 0:
274
print(f"Sharpe: {sharpe:0.3f}")
275
print("=================================")
276
277
if (self.model_name != "") and (self.mode != ""):
278
df_actions = self.save_action_memory()
279
df_actions.to_csv(
280
"results/actions_{}_{}_{}.csv".format(
281
self.mode, self.model_name, self.iteration
282
)
283
)
284
df_total_value.to_csv(
285
"results/account_value_{}_{}_{}.csv".format(
286
self.mode, self.model_name, self.iteration
287
),
288
index=False,
289
)
290
df_rewards.to_csv(
291
"results/account_rewards_{}_{}_{}.csv".format(
292
self.mode, self.model_name, self.iteration
293
),
294
index=False,
295
)
296
plt.plot(self.asset_memory, "r")
297
plt.savefig(
298
"results/account_value_{}_{}_{}.png".format(
299
self.mode, self.model_name, self.iteration
300
)
301
)
302
plt.close()
303
304
# Add outputs to logger interface
305
# logger.record("environment/portfolio_value", end_total_asset)
306
# logger.record("environment/total_reward", tot_reward)
307
# logger.record("environment/total_reward_pct", (tot_reward / (end_total_asset - tot_reward)) * 100)
308
# logger.record("environment/total_cost", self.cost)
309
# logger.record("environment/total_trades", self.trades)
310
311
return self.state, self.reward, self.terminal, False, {}
312
313
else:
314
actions = actions * self.hmax # actions initially is scaled between 0 to 1
315
actions = actions.astype(
316
int
317
) # convert into integer because we can't by fraction of shares
318
if self.turbulence_threshold is not None:
319
if self.turbulence >= self.turbulence_threshold:
320
actions = np.array([-self.hmax] * self.stock_dim)
321
begin_total_asset = self.state[0] + sum(
322
np.array(self.state[1 : (self.stock_dim + 1)])
323
* np.array(self.state[(self.stock_dim + 1) : (self.stock_dim * 2 + 1)])
324
)
325
# print("begin_total_asset:{}".format(begin_total_asset))
326
327
argsort_actions = np.argsort(actions)
328
sell_index = argsort_actions[: np.where(actions < 0)[0].shape[0]]
329
buy_index = argsort_actions[::-1][: np.where(actions > 0)[0].shape[0]]
330
331
for index in sell_index:
332
# print(f"Num shares before: {self.state[index+self.stock_dim+1]}")
333
# print(f'take sell action before : {actions[index]}')
334
actions[index] = self._sell_stock(index, actions[index]) * (-1)
335
# print(f'take sell action after : {actions[index]}')
336
# print(f"Num shares after: {self.state[index+self.stock_dim+1]}")
337
338
for index in buy_index:
339
# print('take buy action: {}'.format(actions[index]))
340
actions[index] = self._buy_stock(index, actions[index])
341
342
self.actions_memory.append(actions)
343
344
# state: s -> s+1
345
self.day += 1
346
self.data = self.df.loc[self.day, :]
347
if self.turbulence_threshold is not None:
348
if len(self.df.tic.unique()) == 1:
349
self.turbulence = self.data[self.risk_indicator_col]
350
elif len(self.df.tic.unique()) > 1:
351
self.turbulence = self.data[self.risk_indicator_col].values[0]
352
self.state = self._update_state()
353
354
end_total_asset = self.state[0] + sum(
355
np.array(self.state[1 : (self.stock_dim + 1)])
356
* np.array(self.state[(self.stock_dim + 1) : (self.stock_dim * 2 + 1)])
357
)
358
self.asset_memory.append(end_total_asset)
359
self.date_memory.append(self._get_date())
360
self.reward = end_total_asset - begin_total_asset
361
self.rewards_memory.append(self.reward)
362
self.reward = self.reward * self.reward_scaling
363
self.state_memory.append(
364
self.state
365
) # add current state in state_recorder for each step
366
367
return self.state, self.reward, self.terminal, False, {}
368
369
def reset(
370
self,
371
*,
372
seed=None,
373
options=None,
374
):
375
# initiate state
376
self.day = 0
377
self.data = self.df.loc[self.day, :]
378
self.state = self._initiate_state()
379
380
if self.initial:
381
self.asset_memory = [
382
self.initial_amount
383
+ np.sum(
384
np.array(self.num_stock_shares)
385
* np.array(self.state[1 : 1 + self.stock_dim])
386
)
387
]
388
else:
389
previous_total_asset = self.previous_state[0] + sum(
390
np.array(self.state[1 : (self.stock_dim + 1)])
391
* np.array(
392
self.previous_state[(self.stock_dim + 1) : (self.stock_dim * 2 + 1)]
393
)
394
)
395
self.asset_memory = [previous_total_asset]
396
397
self.turbulence = 0
398
self.cost = 0
399
self.trades = 0
400
self.terminal = False
401
# self.iteration=self.iteration
402
self.rewards_memory = []
403
self.actions_memory = []
404
self.date_memory = [self._get_date()]
405
406
self.episode += 1
407
408
return self.state, {}
409
410
def render(self, mode="human", close=False):
411
return self.state
412
413
def _initiate_state(self):
414
if self.initial:
415
# For Initial State
416
if len(self.df.tic.unique()) > 1:
417
# for multiple stock
418
state = (
419
[self.initial_amount]
420
+ self.data.close.values.tolist()
421
+ self.num_stock_shares
422
+ sum(
423
(
424
self.data[tech].values.tolist()
425
for tech in self.tech_indicator_list
426
),
427
[],
428
)
429
) # append initial stocks_share to initial state, instead of all zero
430
else:
431
# for single stock
432
state = (
433
[self.initial_amount]
434
+ [self.data.close]
435
+ [0] * self.stock_dim
436
+ sum(([self.data[tech]] for tech in self.tech_indicator_list), [])
437
)
438
else:
439
# Using Previous State
440
if len(self.df.tic.unique()) > 1:
441
# for multiple stock
442
state = (
443
[self.previous_state[0]]
444
+ self.data.close.values.tolist()
445
+ self.previous_state[
446
(self.stock_dim + 1) : (self.stock_dim * 2 + 1)
447
]
448
+ sum(
449
(
450
self.data[tech].values.tolist()
451
for tech in self.tech_indicator_list
452
),
453
[],
454
)
455
)
456
else:
457
# for single stock
458
state = (
459
[self.previous_state[0]]
460
+ [self.data.close]
461
+ self.previous_state[
462
(self.stock_dim + 1) : (self.stock_dim * 2 + 1)
463
]
464
+ sum(([self.data[tech]] for tech in self.tech_indicator_list), [])
465
)
466
return state
467
468
def _update_state(self):
469
if len(self.df.tic.unique()) > 1:
470
# for multiple stock
471
state = (
472
[self.state[0]]
473
+ self.data.close.values.tolist()
474
+ list(self.state[(self.stock_dim + 1) : (self.stock_dim * 2 + 1)])
475
+ sum(
476
(
477
self.data[tech].values.tolist()
478
for tech in self.tech_indicator_list
479
),
480
[],
481
)
482
)
483
484
else:
485
# for single stock
486
state = (
487
[self.state[0]]
488
+ [self.data.close]
489
+ list(self.state[(self.stock_dim + 1) : (self.stock_dim * 2 + 1)])
490
+ sum(([self.data[tech]] for tech in self.tech_indicator_list), [])
491
)
492
493
return state
494
495
def _get_date(self):
496
if len(self.df.tic.unique()) > 1:
497
date = self.data.date.unique()[0]
498
else:
499
date = self.data.date
500
return date
501
502
# add save_state_memory to preserve state in the trading process
503
def save_state_memory(self):
504
if len(self.df.tic.unique()) > 1:
505
# date and close price length must match actions length
506
date_list = self.date_memory[:-1]
507
df_date = pd.DataFrame(date_list)
508
df_date.columns = ["date"]
509
510
state_list = self.state_memory
511
df_states = pd.DataFrame(
512
state_list,
513
columns=[
514
"cash",
515
"Bitcoin_price",
516
"Gold_price",
517
"Bitcoin_num",
518
"Gold_num",
519
"Bitcoin_Disable",
520
"Gold_Disable",
521
],
522
)
523
df_states.index = df_date.date
524
# df_actions = pd.DataFrame({'date':date_list,'actions':action_list})
525
else:
526
date_list = self.date_memory[:-1]
527
state_list = self.state_memory
528
df_states = pd.DataFrame({"date": date_list, "states": state_list})
529
# print(df_states)
530
return df_states
531
532
def save_asset_memory(self):
533
date_list = self.date_memory
534
asset_list = self.asset_memory
535
# print(len(date_list))
536
# print(len(asset_list))
537
df_account_value = pd.DataFrame(
538
{"date": date_list, "account_value": asset_list}
539
)
540
return df_account_value
541
542
def save_action_memory(self):
543
if len(self.df.tic.unique()) > 1:
544
# date and close price length must match actions length
545
date_list = self.date_memory[:-1]
546
df_date = pd.DataFrame(date_list)
547
df_date.columns = ["date"]
548
549
action_list = self.actions_memory
550
df_actions = pd.DataFrame(action_list)
551
df_actions.columns = self.data.tic.values
552
df_actions.index = df_date.date
553
# df_actions = pd.DataFrame({'date':date_list,'actions':action_list})
554
else:
555
date_list = self.date_memory[:-1]
556
action_list = self.actions_memory
557
df_actions = pd.DataFrame({"date": date_list, "actions": action_list})
558
return df_actions
559
560
def _seed(self, seed=None):
561
self.np_random, seed = seeding.np_random(seed)
562
return [seed]
563
564
def get_sb_env(self):
565
e = DummyVecEnv([lambda: self])
566
obs = e.reset()
567
return e, obs
568
569