Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/finrl/meta/env_stock_trading/env_stock_papertrading.py
732 views
1
from __future__ import annotations
2
3
import datetime
4
import threading
5
import time
6
7
import alpaca_trade_api as tradeapi
8
import gymnasium as gym
9
import numpy as np
10
import pandas as pd
11
import torch
12
13
from finrl.meta.data_processors.processor_alpaca import AlpacaProcessor
14
15
16
class AlpacaPaperTrading:
17
def __init__(
18
self,
19
ticker_list,
20
time_interval,
21
drl_lib,
22
agent,
23
cwd,
24
net_dim,
25
state_dim,
26
action_dim,
27
API_KEY,
28
API_SECRET,
29
API_BASE_URL,
30
tech_indicator_list,
31
turbulence_thresh=30,
32
max_stock=1e2,
33
latency=None,
34
):
35
# load agent
36
self.drl_lib = drl_lib
37
if agent == "ppo":
38
if drl_lib == "elegantrl":
39
from elegantrl.agents import AgentPPO
40
from elegantrl.train.run import init_agent
41
from elegantrl.train.config import (
42
Arguments,
43
) # bug fix:ModuleNotFoundError: No module named 'elegantrl.run'
44
45
# load agent
46
config = {
47
"state_dim": state_dim,
48
"action_dim": action_dim,
49
}
50
args = Arguments(agent_class=AgentPPO, env=StockEnvEmpty(config))
51
args.cwd = cwd
52
args.net_dim = net_dim
53
# load agent
54
try:
55
agent = init_agent(args, gpu_id=0)
56
self.act = agent.act
57
self.device = agent.device
58
except BaseException:
59
raise ValueError("Fail to load agent!")
60
61
elif drl_lib == "rllib":
62
from ray.rllib.agents import ppo
63
from ray.rllib.agents.ppo.ppo import PPOTrainer
64
65
config = ppo.DEFAULT_CONFIG.copy()
66
config["env"] = StockEnvEmpty
67
config["log_level"] = "WARN"
68
config["env_config"] = {
69
"state_dim": state_dim,
70
"action_dim": action_dim,
71
}
72
trainer = PPOTrainer(env=StockEnvEmpty, config=config)
73
trainer.restore(cwd)
74
try:
75
trainer.restore(cwd)
76
self.agent = trainer
77
print("Restoring from checkpoint path", cwd)
78
except:
79
raise ValueError("Fail to load agent!")
80
81
elif drl_lib == "stable_baselines3":
82
from stable_baselines3 import PPO
83
84
try:
85
# load agent
86
self.model = PPO.load(cwd)
87
print("Successfully load model", cwd)
88
except:
89
raise ValueError("Fail to load agent!")
90
91
else:
92
raise ValueError(
93
"The DRL library input is NOT supported yet. Please check your input."
94
)
95
96
else:
97
raise ValueError("Agent input is NOT supported yet.")
98
99
# connect to Alpaca trading API
100
try:
101
self.alpaca = tradeapi.REST(API_KEY, API_SECRET, API_BASE_URL, "v2")
102
except:
103
raise ValueError(
104
"Fail to connect Alpaca. Please check account info and internet connection."
105
)
106
107
# read trading time interval
108
if time_interval == "1s":
109
self.time_interval = 1
110
elif time_interval == "5s":
111
self.time_interval = 5
112
elif time_interval == "1Min":
113
self.time_interval = 60
114
elif time_interval == "5Min":
115
self.time_interval = 60 * 5
116
elif time_interval == "15Min":
117
self.time_interval = 60 * 15
118
elif (
119
time_interval == "1D"
120
): # bug fix:1D ValueError: Time interval input is NOT supported yet. Maybe any other better ways
121
self.time_interval = 24 * 60 * 60
122
else:
123
raise ValueError("Time interval input is NOT supported yet.")
124
125
# read trading settings
126
self.tech_indicator_list = tech_indicator_list
127
self.turbulence_thresh = turbulence_thresh
128
self.max_stock = max_stock
129
130
# initialize account
131
self.stocks = np.asarray([0] * len(ticker_list)) # stocks holding
132
self.stocks_cd = np.zeros_like(self.stocks)
133
self.cash = None # cash record
134
self.stocks_df = pd.DataFrame(
135
self.stocks, columns=["stocks"], index=ticker_list
136
)
137
self.asset_list = []
138
self.price = np.asarray([0] * len(ticker_list))
139
self.stockUniverse = ticker_list
140
self.turbulence_bool = 0
141
self.equities = []
142
143
def test_latency(self, test_times=10):
144
total_time = 0
145
for i in range(0, test_times):
146
time0 = time.time()
147
self.get_state()
148
time1 = time.time()
149
temp_time = time1 - time0
150
total_time += temp_time
151
latency = total_time / test_times
152
print("latency for data processing: ", latency)
153
return latency
154
155
def run(self):
156
orders = self.alpaca.list_orders(status="open")
157
for order in orders:
158
self.alpaca.cancel_order(order.id)
159
160
# Wait for market to open.
161
print("Waiting for market to open...")
162
tAMO = threading.Thread(target=self.awaitMarketOpen)
163
tAMO.start()
164
tAMO.join()
165
print("Market opened.")
166
while True:
167
# Figure out when the market will close so we can prepare to sell beforehand.
168
clock = self.alpaca.get_clock()
169
closingTime = clock.next_close.replace(
170
tzinfo=datetime.timezone.utc
171
).timestamp()
172
currTime = clock.timestamp.replace(tzinfo=datetime.timezone.utc).timestamp()
173
self.timeToClose = closingTime - currTime
174
175
if self.timeToClose < (60):
176
# Close all positions when 1 minutes til market close.
177
print("Market closing soon. Stop trading.")
178
break
179
180
"""# Close all positions when 1 minutes til market close.
181
print("Market closing soon. Closing positions.")
182
183
positions = self.alpaca.list_positions()
184
for position in positions:
185
if(position.side == 'long'):
186
orderSide = 'sell'
187
else:
188
orderSide = 'buy'
189
qty = abs(int(float(position.qty)))
190
respSO = []
191
tSubmitOrder = threading.Thread(target=self.submitOrder(qty, position.symbol, orderSide, respSO))
192
tSubmitOrder.start()
193
tSubmitOrder.join()
194
195
# Run script again after market close for next trading day.
196
print("Sleeping until market close (15 minutes).")
197
time.sleep(60 * 15)"""
198
199
else:
200
trade = threading.Thread(target=self.trade)
201
trade.start()
202
trade.join()
203
last_equity = float(self.alpaca.get_account().last_equity)
204
cur_time = time.time()
205
self.equities.append([cur_time, last_equity])
206
time.sleep(self.time_interval)
207
208
def awaitMarketOpen(self):
209
isOpen = self.alpaca.get_clock().is_open
210
while not isOpen:
211
clock = self.alpaca.get_clock()
212
openingTime = clock.next_open.replace(
213
tzinfo=datetime.timezone.utc
214
).timestamp()
215
currTime = clock.timestamp.replace(tzinfo=datetime.timezone.utc).timestamp()
216
timeToOpen = int((openingTime - currTime) / 60)
217
print(str(timeToOpen) + " minutes til market open.")
218
time.sleep(60)
219
isOpen = self.alpaca.get_clock().is_open
220
221
def trade(self):
222
state = self.get_state()
223
224
if self.drl_lib == "elegantrl":
225
with torch.no_grad():
226
s_tensor = torch.as_tensor((state,), device=self.device)
227
a_tensor = self.act(s_tensor)
228
action = a_tensor.detach().cpu().numpy()[0]
229
230
action = (action * self.max_stock).astype(int)
231
232
elif self.drl_lib == "rllib":
233
action = self.agent.compute_single_action(state)
234
235
elif self.drl_lib == "stable_baselines3":
236
action = self.model.predict(state)[0]
237
238
else:
239
raise ValueError(
240
"The DRL library input is NOT supported yet. Please check your input."
241
)
242
243
self.stocks_cd += 1
244
if self.turbulence_bool == 0:
245
min_action = 10 # stock_cd
246
for index in np.where(action < -min_action)[0]: # sell_index:
247
sell_num_shares = min(self.stocks[index], -action[index])
248
qty = abs(int(sell_num_shares))
249
respSO = []
250
tSubmitOrder = threading.Thread(
251
target=self.submitOrder(
252
qty, self.stockUniverse[index], "sell", respSO
253
)
254
)
255
tSubmitOrder.start()
256
tSubmitOrder.join()
257
self.cash = float(self.alpaca.get_account().cash)
258
self.stocks_cd[index] = 0
259
260
for index in np.where(action > min_action)[0]: # buy_index:
261
if self.cash < 0:
262
tmp_cash = 0
263
else:
264
tmp_cash = self.cash
265
buy_num_shares = min(
266
tmp_cash // self.price[index], abs(int(action[index]))
267
)
268
qty = abs(int(buy_num_shares))
269
respSO = []
270
tSubmitOrder = threading.Thread(
271
target=self.submitOrder(
272
qty, self.stockUniverse[index], "buy", respSO
273
)
274
)
275
tSubmitOrder.start()
276
tSubmitOrder.join()
277
self.cash = float(self.alpaca.get_account().cash)
278
self.stocks_cd[index] = 0
279
280
else: # sell all when turbulence
281
positions = self.alpaca.list_positions()
282
for position in positions:
283
if position.side == "long":
284
orderSide = "sell"
285
else:
286
orderSide = "buy"
287
qty = abs(int(float(position.qty)))
288
respSO = []
289
tSubmitOrder = threading.Thread(
290
target=self.submitOrder(qty, position.symbol, orderSide, respSO)
291
)
292
tSubmitOrder.start()
293
tSubmitOrder.join()
294
295
self.stocks_cd[:] = 0
296
297
def get_state(self):
298
alpaca = AlpacaProcessor(api=self.alpaca)
299
price, tech, turbulence = alpaca.fetch_latest_data(
300
ticker_list=self.stockUniverse,
301
time_interval="1Min",
302
tech_indicator_list=self.tech_indicator_list,
303
)
304
turbulence_bool = 1 if turbulence >= self.turbulence_thresh else 0
305
306
turbulence = (
307
self.sigmoid_sign(turbulence, self.turbulence_thresh) * 2**-5
308
).astype(np.float32)
309
310
tech = tech * 2**-7
311
positions = self.alpaca.list_positions()
312
stocks = [0] * len(self.stockUniverse)
313
for position in positions:
314
ind = self.stockUniverse.index(position.symbol)
315
stocks[ind] = abs(int(float(position.qty)))
316
317
stocks = np.asarray(stocks, dtype=float)
318
cash = float(self.alpaca.get_account().cash)
319
self.cash = cash
320
self.stocks = stocks
321
self.turbulence_bool = turbulence_bool
322
self.price = price
323
324
amount = np.array(self.cash * (2**-12), dtype=np.float32)
325
scale = np.array(2**-6, dtype=np.float32)
326
state = np.hstack(
327
(
328
amount,
329
turbulence,
330
self.turbulence_bool,
331
price * scale,
332
self.stocks * scale,
333
self.stocks_cd,
334
tech,
335
)
336
).astype(np.float32)
337
print(len(self.stockUniverse))
338
return state
339
340
def submitOrder(self, qty, stock, side, resp):
341
if qty > 0:
342
try:
343
self.alpaca.submit_order(stock, qty, side, "market", "day")
344
print(
345
"Market order of | "
346
+ str(qty)
347
+ " "
348
+ stock
349
+ " "
350
+ side
351
+ " | completed."
352
)
353
resp.append(True)
354
except:
355
print(
356
"Order of | "
357
+ str(qty)
358
+ " "
359
+ stock
360
+ " "
361
+ side
362
+ " | did not go through."
363
)
364
resp.append(False)
365
else:
366
print(
367
"Quantity is 0, order of | "
368
+ str(qty)
369
+ " "
370
+ stock
371
+ " "
372
+ side
373
+ " | not completed."
374
)
375
resp.append(True)
376
377
@staticmethod
378
def sigmoid_sign(ary, thresh):
379
def sigmoid(x):
380
return 1 / (1 + np.exp(-x * np.e)) - 0.5
381
382
return sigmoid(ary / thresh) * thresh
383
384
385
class StockEnvEmpty(gym.Env):
386
# Empty Env used for loading rllib agent
387
def __init__(self, config):
388
state_dim = config["state_dim"]
389
action_dim = config["action_dim"]
390
self.env_num = 1
391
self.max_step = 10000
392
self.env_name = "StockEnvEmpty"
393
self.state_dim = state_dim
394
self.action_dim = action_dim
395
self.if_discrete = False
396
self.target_return = 9999
397
self.observation_space = gym.spaces.Box(
398
low=-3000, high=3000, shape=(state_dim,), dtype=np.float32
399
)
400
self.action_space = gym.spaces.Box(
401
low=-1, high=1, shape=(action_dim,), dtype=np.float32
402
)
403
404
def reset(
405
self,
406
*,
407
seed=None,
408
options=None,
409
):
410
return
411
412
def step(self, actions):
413
return
414
415