Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/finrl/meta/data_processors/processor_yahoofinance.py
732 views
1
"""Reference: https://github.com/AI4Finance-LLC/FinRL"""
2
3
from __future__ import annotations
4
5
import datetime
6
import time
7
from datetime import date
8
from datetime import timedelta
9
from sqlite3 import Timestamp
10
from typing import Any
11
from typing import Dict
12
from typing import List
13
from typing import Optional
14
from typing import Type
15
from typing import TypeVar
16
from typing import Union
17
18
import numpy as np
19
import pandas as pd
20
import pandas_market_calendars as tc
21
import pytz
22
import yfinance as yf
23
from bs4 import BeautifulSoup
24
from selenium import webdriver
25
from selenium.webdriver.chrome.options import Options
26
from selenium.webdriver.chrome.service import Service
27
from selenium.webdriver.common.action_chains import ActionChains
28
from selenium.webdriver.common.by import By
29
from stockstats import StockDataFrame as Sdf
30
from webdriver_manager.chrome import ChromeDriverManager
31
32
### Added by aymeric75 for scrap_data function
33
34
35
class YahooFinanceProcessor:
36
"""Provides methods for retrieving daily stock data from
37
Yahoo Finance API
38
"""
39
40
def __init__(self):
41
pass
42
43
"""
44
Param
45
----------
46
start_date : str
47
start date of the data
48
end_date : str
49
end date of the data
50
ticker_list : list
51
a list of stock tickers
52
Example
53
-------
54
input:
55
ticker_list = config_tickers.DOW_30_TICKER
56
start_date = '2009-01-01'
57
end_date = '2021-10-31'
58
time_interval == "1D"
59
60
output:
61
date tic open high low close volume
62
0 2009-01-02 AAPL 3.067143 3.251429 3.041429 2.767330 746015200.0
63
1 2009-01-02 AMGN 58.590000 59.080002 57.750000 44.523766 6547900.0
64
2 2009-01-02 AXP 18.570000 19.520000 18.400000 15.477426 10955700.0
65
3 2009-01-02 BA 42.799999 45.560001 42.779999 33.941093 7010200.0
66
...
67
"""
68
69
######## ADDED BY aymeric75 ###################
70
71
def date_to_unix(self, date_str) -> int:
72
"""Convert a date string in yyyy-mm-dd format to Unix timestamp."""
73
dt = datetime.datetime.strptime(date_str, "%Y-%m-%d")
74
return int(dt.timestamp())
75
76
def fetch_stock_data(self, stock_name, period1, period2) -> pd.DataFrame:
77
# Base URL
78
url = f"https://finance.yahoo.com/quote/{stock_name}/history/?period1={period1}&period2={period2}&filter=history"
79
80
# Selenium WebDriver Setup
81
options = Options()
82
options.add_argument("--headless") # Headless for performance
83
options.add_argument("--disable-gpu") # Disable GPU for compatibility
84
driver = webdriver.Chrome(
85
service=Service(ChromeDriverManager().install()), options=options
86
)
87
88
# Navigate to the URL
89
driver.get(url)
90
driver.maximize_window()
91
time.sleep(5) # Wait for redirection and page load
92
93
# Handle potential popup
94
try:
95
RejectAll = driver.find_element(
96
By.XPATH, '//button[@class="btn secondary reject-all"]'
97
)
98
action = ActionChains(driver)
99
action.click(on_element=RejectAll)
100
action.perform()
101
time.sleep(5)
102
103
except Exception as e:
104
print("Popup not found or handled:", e)
105
106
# Parse the page for the table
107
soup = BeautifulSoup(driver.page_source, "html.parser")
108
table = soup.find("table")
109
if not table:
110
raise Exception("No table found after handling redirection and popup.")
111
112
# Extract headers
113
headers = [th.text.strip() for th in table.find_all("th")]
114
headers[4] = "Close"
115
headers[5] = "Adj Close"
116
headers = ["date", "open", "high", "low", "close", "adjcp", "volume"]
117
# , 'tic', 'day'
118
119
# Extract rows
120
rows = []
121
for tr in table.find_all("tr")[1:]: # Skip header row
122
cells = [td.text.strip() for td in tr.find_all("td")]
123
if len(cells) == len(headers): # Only add rows with correct column count
124
rows.append(cells)
125
126
# Create DataFrame
127
df = pd.DataFrame(rows, columns=headers)
128
129
# Convert columns to appropriate data types
130
def safe_convert(value, dtype):
131
try:
132
return dtype(value.replace(",", ""))
133
except ValueError:
134
return value
135
136
df["open"] = df["open"].apply(lambda x: safe_convert(x, float))
137
df["high"] = df["high"].apply(lambda x: safe_convert(x, float))
138
df["low"] = df["low"].apply(lambda x: safe_convert(x, float))
139
df["close"] = df["close"].apply(lambda x: safe_convert(x, float))
140
df["adjcp"] = df["adjcp"].apply(lambda x: safe_convert(x, float))
141
df["volume"] = df["volume"].apply(lambda x: safe_convert(x, int))
142
143
# Add 'tic' column
144
df["tic"] = stock_name
145
146
# Add 'day' column
147
start_date = datetime.datetime.fromtimestamp(period1)
148
df["date"] = pd.to_datetime(df["date"])
149
df["day"] = (df["date"] - start_date).dt.days
150
df = df[df["day"] >= 0] # Exclude rows with days before the start date
151
152
# Reverse the DataFrame rows
153
df = df.iloc[::-1].reset_index(drop=True)
154
155
return df
156
157
def scrap_data(self, stock_names, start_date, end_date) -> pd.DataFrame:
158
"""Fetch and combine stock data for multiple stock names."""
159
period1 = self.date_to_unix(start_date)
160
period2 = self.date_to_unix(end_date)
161
162
all_dataframes = []
163
total_stocks = len(stock_names)
164
165
for i, stock_name in enumerate(stock_names):
166
try:
167
print(
168
f"Processing {stock_name} ({i + 1}/{total_stocks})... {(i + 1) / total_stocks * 100:.2f}% complete."
169
)
170
df = self.fetch_stock_data(stock_name, period1, period2)
171
all_dataframes.append(df)
172
except Exception as e:
173
print(f"Error fetching data for {stock_name}: {e}")
174
175
combined_df = pd.concat(all_dataframes, ignore_index=True)
176
combined_df = combined_df.sort_values(by=["day", "tick"]).reset_index(drop=True)
177
178
return combined_df
179
180
######## END ADDED BY aymeric75 ###################
181
182
def convert_interval(self, time_interval: str) -> str:
183
# Convert FinRL 'standardised' time periods to Yahoo format: 1m, 2m, 5m, 15m, 30m, 60m, 90m, 1h, 1d, 5d, 1wk, 1mo, 3mo
184
yahoo_intervals = [
185
"1m",
186
"2m",
187
"5m",
188
"15m",
189
"30m",
190
"60m",
191
"90m",
192
"1h",
193
"1d",
194
"5d",
195
"1wk",
196
"1mo",
197
"3mo",
198
]
199
if time_interval in yahoo_intervals:
200
return time_interval
201
if time_interval in [
202
"1Min",
203
"2Min",
204
"5Min",
205
"15Min",
206
"30Min",
207
"60Min",
208
"90Min",
209
]:
210
time_interval = time_interval.replace("Min", "m")
211
elif time_interval in ["1H", "1D", "5D", "1h", "1d", "5d"]:
212
time_interval = time_interval.lower()
213
elif time_interval == "1W":
214
time_interval = "1wk"
215
elif time_interval in ["1M", "3M"]:
216
time_interval = time_interval.replace("M", "mo")
217
else:
218
raise ValueError("wrong time_interval")
219
220
return time_interval
221
222
def download_data(
223
self,
224
ticker_list: list[str],
225
start_date: str,
226
end_date: str,
227
time_interval: str,
228
proxy: str | dict = None,
229
) -> pd.DataFrame:
230
time_interval = self.convert_interval(time_interval)
231
232
self.start = start_date
233
self.end = end_date
234
self.time_interval = time_interval
235
236
# Download and save the data in a pandas DataFrame
237
start_date = pd.Timestamp(start_date)
238
end_date = pd.Timestamp(end_date)
239
delta = timedelta(days=1)
240
data_df = pd.DataFrame()
241
for tic in ticker_list:
242
current_tic_start_date = start_date
243
while (
244
current_tic_start_date <= end_date
245
): # downloading daily to workaround yfinance only allowing max 7 calendar (not trading) days of 1 min data per single download
246
temp_df = yf.download(
247
tic,
248
start=current_tic_start_date,
249
end=current_tic_start_date + delta,
250
interval=self.time_interval,
251
proxy=proxy,
252
)
253
if temp_df.columns.nlevels != 1:
254
temp_df.columns = temp_df.columns.droplevel(1)
255
256
temp_df["tic"] = tic
257
data_df = pd.concat([data_df, temp_df])
258
current_tic_start_date += delta
259
260
data_df = data_df.reset_index().drop(columns=["Adj Close"])
261
# convert the column names to match processor_alpaca.py as far as poss
262
data_df.columns = [
263
"timestamp",
264
"open",
265
"high",
266
"low",
267
"close",
268
"volume",
269
"tic",
270
]
271
272
return data_df
273
274
def clean_data(self, df: pd.DataFrame) -> pd.DataFrame:
275
tic_list = np.unique(df.tic.values)
276
NY = "America/New_York"
277
278
trading_days = self.get_trading_days(start=self.start, end=self.end)
279
# produce full timestamp index
280
if self.time_interval == "1d":
281
times = trading_days
282
elif self.time_interval == "1m":
283
times = []
284
for day in trading_days:
285
# NY = "America/New_York"
286
current_time = pd.Timestamp(day + " 09:30:00").tz_localize(NY)
287
for i in range(390): # 390 minutes in trading day
288
times.append(current_time)
289
current_time += pd.Timedelta(minutes=1)
290
else:
291
raise ValueError(
292
"Data clean at given time interval is not supported for YahooFinance data."
293
)
294
295
# create a new dataframe with full timestamp series
296
new_df = pd.DataFrame()
297
for tic in tic_list:
298
tmp_df = pd.DataFrame(
299
columns=["open", "high", "low", "close", "volume"], index=times
300
)
301
tic_df = df[
302
df.tic == tic
303
] # extract just the rows from downloaded data relating to this tic
304
for i in range(tic_df.shape[0]): # fill empty DataFrame using original data
305
tmp_timestamp = tic_df.iloc[i]["timestamp"]
306
if tmp_timestamp.tzinfo is None:
307
tmp_timestamp = tmp_timestamp.tz_localize(NY)
308
else:
309
tmp_timestamp = tmp_timestamp.tz_convert(NY)
310
tmp_df.loc[tmp_timestamp] = tic_df.iloc[i][
311
["open", "high", "low", "close", "volume"]
312
]
313
# print("(9) tmp_df\n", tmp_df.to_string()) # print ALL dataframe to check for missing rows from download
314
315
# if close on start date is NaN, fill data with first valid close
316
# and set volume to 0.
317
if str(tmp_df.iloc[0]["close"]) == "nan":
318
print("NaN data on start date, fill using first valid data.")
319
for i in range(tmp_df.shape[0]):
320
if str(tmp_df.iloc[i]["close"]) != "nan":
321
first_valid_close = tmp_df.iloc[i]["close"]
322
tmp_df.iloc[0] = [
323
first_valid_close,
324
first_valid_close,
325
first_valid_close,
326
first_valid_close,
327
0.0,
328
]
329
break
330
331
# if the close price of the first row is still NaN (All the prices are NaN in this case)
332
if str(tmp_df.iloc[0]["close"]) == "nan":
333
print(
334
"Missing data for ticker: ",
335
tic,
336
" . The prices are all NaN. Fill with 0.",
337
)
338
tmp_df.iloc[0] = [
339
0.0,
340
0.0,
341
0.0,
342
0.0,
343
0.0,
344
]
345
346
# fill NaN data with previous close and set volume to 0.
347
for i in range(tmp_df.shape[0]):
348
if str(tmp_df.iloc[i]["close"]) == "nan":
349
previous_close = tmp_df.iloc[i - 1]["close"]
350
if str(previous_close) == "nan":
351
raise ValueError
352
tmp_df.iloc[i] = [
353
previous_close,
354
previous_close,
355
previous_close,
356
previous_close,
357
0.0,
358
]
359
# print(tmp_df.iloc[i], " Filled NaN data with previous close and set volume to 0. ticker: ", tic)
360
361
# merge single ticker data to new DataFrame
362
tmp_df = tmp_df.astype(float)
363
tmp_df["tic"] = tic
364
new_df = pd.concat([new_df, tmp_df])
365
366
# print(("Data clean for ") + tic + (" is finished."))
367
368
# reset index and rename columns
369
new_df = new_df.reset_index()
370
new_df = new_df.rename(columns={"index": "timestamp"})
371
372
# print("Data clean all finished!")
373
374
return new_df
375
376
def add_technical_indicator(
377
self, data: pd.DataFrame, tech_indicator_list: list[str]
378
):
379
"""
380
calculate technical indicators
381
use stockstats package to add technical inidactors
382
:param data: (df) pandas dataframe
383
:return: (df) pandas dataframe
384
"""
385
df = data.copy()
386
df = df.sort_values(by=["tic", "timestamp"])
387
stock = Sdf.retype(df.copy())
388
unique_ticker = stock.tic.unique()
389
390
for indicator in tech_indicator_list:
391
indicator_df = pd.DataFrame()
392
for i in range(len(unique_ticker)):
393
try:
394
temp_indicator = stock[stock.tic == unique_ticker[i]][indicator]
395
temp_indicator = pd.DataFrame(temp_indicator)
396
temp_indicator["tic"] = unique_ticker[i]
397
temp_indicator["timestamp"] = df[df.tic == unique_ticker[i]][
398
"timestamp"
399
].to_list()
400
indicator_df = pd.concat(
401
[indicator_df, temp_indicator], ignore_index=True
402
)
403
except Exception as e:
404
print(e)
405
df = df.merge(
406
indicator_df[["tic", "timestamp", indicator]],
407
on=["tic", "timestamp"],
408
how="left",
409
)
410
df = df.sort_values(by=["timestamp", "tic"])
411
return df
412
413
def add_vix(self, data: pd.DataFrame) -> pd.DataFrame:
414
"""
415
add vix from yahoo finance
416
:param data: (df) pandas dataframe
417
:return: (df) pandas dataframe
418
"""
419
vix_df = self.download_data(["VIXY"], self.start, self.end, self.time_interval)
420
cleaned_vix = self.clean_data(vix_df)
421
print("cleaned_vix\n", cleaned_vix)
422
vix = cleaned_vix[["timestamp", "close"]]
423
print('cleaned_vix[["timestamp", "close"]\n', vix)
424
vix = vix.rename(columns={"close": "VIXY"})
425
print('vix.rename(columns={"close": "VIXY"}\n', vix)
426
427
df = data.copy()
428
print("df\n", df)
429
df = df.merge(vix, on="timestamp")
430
df = df.sort_values(["timestamp", "tic"]).reset_index(drop=True)
431
return df
432
433
def calculate_turbulence(
434
self, data: pd.DataFrame, time_period: int = 252
435
) -> pd.DataFrame:
436
# can add other market assets
437
df = data.copy()
438
df_price_pivot = df.pivot(index="timestamp", columns="tic", values="close")
439
# use returns to calculate turbulence
440
df_price_pivot = df_price_pivot.pct_change()
441
442
unique_date = df.timestamp.unique()
443
# start after a fixed timestamp period
444
start = time_period
445
turbulence_index = [0] * start
446
# turbulence_index = [0]
447
count = 0
448
for i in range(start, len(unique_date)):
449
current_price = df_price_pivot[df_price_pivot.index == unique_date[i]]
450
# use one year rolling window to calcualte covariance
451
hist_price = df_price_pivot[
452
(df_price_pivot.index < unique_date[i])
453
& (df_price_pivot.index >= unique_date[i - time_period])
454
]
455
# Drop tickers which has number missing values more than the "oldest" ticker
456
filtered_hist_price = hist_price.iloc[
457
hist_price.isna().sum().min() :
458
].dropna(axis=1)
459
460
cov_temp = filtered_hist_price.cov()
461
current_temp = current_price[[x for x in filtered_hist_price]] - np.mean(
462
filtered_hist_price, axis=0
463
)
464
temp = current_temp.values.dot(np.linalg.pinv(cov_temp)).dot(
465
current_temp.values.T
466
)
467
if temp > 0:
468
count += 1
469
if count > 2:
470
turbulence_temp = temp[0][0]
471
else:
472
# avoid large outlier because of the calculation just begins
473
turbulence_temp = 0
474
else:
475
turbulence_temp = 0
476
turbulence_index.append(turbulence_temp)
477
478
turbulence_index = pd.DataFrame(
479
{"timestamp": df_price_pivot.index, "turbulence": turbulence_index}
480
)
481
return turbulence_index
482
483
def add_turbulence(
484
self, data: pd.DataFrame, time_period: int = 252
485
) -> pd.DataFrame:
486
"""
487
add turbulence index from a precalcualted dataframe
488
:param data: (df) pandas dataframe
489
:return: (df) pandas dataframe
490
"""
491
df = data.copy()
492
turbulence_index = self.calculate_turbulence(df, time_period=time_period)
493
df = df.merge(turbulence_index, on="timestamp")
494
df = df.sort_values(["timestamp", "tic"]).reset_index(drop=True)
495
return df
496
497
def df_to_array(
498
self, df: pd.DataFrame, tech_indicator_list: list[str], if_vix: bool
499
) -> list[np.ndarray]:
500
df = df.copy()
501
unique_ticker = df.tic.unique()
502
if_first_time = True
503
for tic in unique_ticker:
504
if if_first_time:
505
price_array = df[df.tic == tic][["close"]].values
506
tech_array = df[df.tic == tic][tech_indicator_list].values
507
if if_vix:
508
turbulence_array = df[df.tic == tic]["VIXY"].values
509
else:
510
turbulence_array = df[df.tic == tic]["turbulence"].values
511
if_first_time = False
512
else:
513
price_array = np.hstack(
514
[price_array, df[df.tic == tic][["close"]].values]
515
)
516
tech_array = np.hstack(
517
[tech_array, df[df.tic == tic][tech_indicator_list].values]
518
)
519
# print("Successfully transformed into array")
520
return price_array, tech_array, turbulence_array
521
522
def get_trading_days(self, start: str, end: str) -> list[str]:
523
nyse = tc.get_calendar("NYSE")
524
df = nyse.date_range_htf("1D", pd.Timestamp(start), pd.Timestamp(end))
525
trading_days = []
526
for day in df:
527
trading_days.append(str(day)[:10])
528
return trading_days
529
530
# ****** NB: YAHOO FINANCE DATA MAY BE IN REAL-TIME OR DELAYED BY 15 MINUTES OR MORE, DEPENDING ON THE EXCHANGE ******
531
def fetch_latest_data(
532
self,
533
ticker_list: list[str],
534
time_interval: str,
535
tech_indicator_list: list[str],
536
limit: int = 100,
537
) -> pd.DataFrame:
538
time_interval = self.convert_interval(time_interval)
539
540
end_datetime = datetime.datetime.now()
541
start_datetime = end_datetime - datetime.timedelta(
542
minutes=limit + 1
543
) # get the last rows up to limit
544
545
data_df = pd.DataFrame()
546
for tic in ticker_list:
547
barset = yf.download(
548
tic, start_datetime, end_datetime, interval=time_interval
549
) # use start and end datetime to simulate the limit parameter
550
barset["tic"] = tic
551
data_df = pd.concat([data_df, barset])
552
553
data_df = data_df.reset_index().drop(
554
columns=["Adj Close"]
555
) # Alpaca data does not have 'Adj Close'
556
557
data_df.columns = [ # convert to Alpaca column names lowercase
558
"timestamp",
559
"open",
560
"high",
561
"low",
562
"close",
563
"volume",
564
"tic",
565
]
566
567
start_time = data_df.timestamp.min()
568
end_time = data_df.timestamp.max()
569
times = []
570
current_time = start_time
571
end = end_time + pd.Timedelta(minutes=1)
572
while current_time != end:
573
times.append(current_time)
574
current_time += pd.Timedelta(minutes=1)
575
576
df = data_df.copy()
577
new_df = pd.DataFrame()
578
for tic in ticker_list:
579
tmp_df = pd.DataFrame(
580
columns=["open", "high", "low", "close", "volume"], index=times
581
)
582
tic_df = df[df.tic == tic]
583
for i in range(tic_df.shape[0]):
584
tmp_df.loc[tic_df.iloc[i]["timestamp"]] = tic_df.iloc[i][
585
["open", "high", "low", "close", "volume"]
586
]
587
588
if str(tmp_df.iloc[0]["close"]) == "nan":
589
for i in range(tmp_df.shape[0]):
590
if str(tmp_df.iloc[i]["close"]) != "nan":
591
first_valid_close = tmp_df.iloc[i]["close"]
592
tmp_df.iloc[0] = [
593
first_valid_close,
594
first_valid_close,
595
first_valid_close,
596
first_valid_close,
597
0.0,
598
]
599
break
600
if str(tmp_df.iloc[0]["close"]) == "nan":
601
print(
602
"Missing data for ticker: ",
603
tic,
604
" . The prices are all NaN. Fill with 0.",
605
)
606
tmp_df.iloc[0] = [
607
0.0,
608
0.0,
609
0.0,
610
0.0,
611
0.0,
612
]
613
614
for i in range(tmp_df.shape[0]):
615
if str(tmp_df.iloc[i]["close"]) == "nan":
616
previous_close = tmp_df.iloc[i - 1]["close"]
617
if str(previous_close) == "nan":
618
previous_close = 0.0
619
tmp_df.iloc[i] = [
620
previous_close,
621
previous_close,
622
previous_close,
623
previous_close,
624
0.0,
625
]
626
tmp_df = tmp_df.astype(float)
627
tmp_df["tic"] = tic
628
new_df = pd.concat([new_df, tmp_df])
629
630
new_df = new_df.reset_index()
631
new_df = new_df.rename(columns={"index": "timestamp"})
632
633
df = self.add_technical_indicator(new_df, tech_indicator_list)
634
df["VIXY"] = 0
635
636
price_array, tech_array, turbulence_array = self.df_to_array(
637
df, tech_indicator_list, if_vix=True
638
)
639
latest_price = price_array[-1]
640
latest_tech = tech_array[-1]
641
start_datetime = end_datetime - datetime.timedelta(minutes=1)
642
turb_df = yf.download("VIXY", start_datetime, limit=1)
643
latest_turb = turb_df["Close"].values
644
return latest_price, latest_tech, latest_turb
645
646