Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AI4Finance-Foundation
GitHub Repository: AI4Finance-Foundation/FinRL
Path: blob/master/finrl/plot.py
728 views
1
from __future__ import annotations
2
3
import copy
4
import datetime
5
from copy import deepcopy
6
7
import matplotlib.dates as mdates
8
import matplotlib.pyplot as plt
9
import numpy as np
10
import pandas as pd
11
import pyfolio
12
from pyfolio import timeseries
13
14
from finrl import config
15
from finrl.meta.data_processors.func import date2str
16
from finrl.meta.data_processors.func import str2date
17
from finrl.meta.preprocessor.yahoodownloader import YahooDownloader
18
19
20
def get_daily_return(df, value_col_name="account_value"):
21
df = deepcopy(df)
22
df["daily_return"] = df[value_col_name].pct_change(1)
23
df["date"] = pd.to_datetime(df["date"])
24
df.set_index("date", inplace=True, drop=True)
25
df.index = df.index.tz_localize("UTC")
26
return pd.Series(df["daily_return"], index=df.index)
27
28
29
def convert_daily_return_to_pyfolio_ts(df):
30
strategy_ret = df.copy()
31
strategy_ret["date"] = pd.to_datetime(strategy_ret["date"])
32
strategy_ret.set_index("date", drop=False, inplace=True)
33
strategy_ret.index = strategy_ret.index.tz_localize("UTC")
34
del strategy_ret["date"]
35
return pd.Series(strategy_ret["daily_return"].values, index=strategy_ret.index)
36
37
38
def backtest_stats(account_value, value_col_name="account_value"):
39
dr_test = get_daily_return(account_value, value_col_name=value_col_name)
40
perf_stats_all = timeseries.perf_stats(
41
returns=dr_test,
42
positions=None,
43
transactions=None,
44
turnover_denom="AGB",
45
)
46
print(perf_stats_all)
47
return perf_stats_all
48
49
50
def backtest_plot(
51
account_value,
52
baseline_start=config.TRADE_START_DATE,
53
baseline_end=config.TRADE_END_DATE,
54
baseline_ticker="^DJI",
55
value_col_name="account_value",
56
):
57
df = deepcopy(account_value)
58
df["date"] = pd.to_datetime(df["date"])
59
test_returns = get_daily_return(df, value_col_name=value_col_name)
60
61
baseline_df = get_baseline(
62
ticker=baseline_ticker, start=baseline_start, end=baseline_end
63
)
64
65
baseline_df["date"] = pd.to_datetime(baseline_df["date"], format="%Y-%m-%d")
66
baseline_df = pd.merge(df[["date"]], baseline_df, how="left", on="date")
67
baseline_df = baseline_df.fillna(method="ffill").fillna(method="bfill")
68
baseline_returns = get_daily_return(baseline_df, value_col_name="close")
69
70
with pyfolio.plotting.plotting_context(font_scale=1.1):
71
pyfolio.create_full_tear_sheet(
72
returns=test_returns, benchmark_rets=baseline_returns, set_context=False
73
)
74
75
76
def get_baseline(ticker, start, end):
77
return YahooDownloader(
78
start_date=start, end_date=end, ticker_list=[ticker]
79
).fetch_data()
80
81
82
def trx_plot(df_trade, df_actions, ticker_list):
83
df_trx = pd.DataFrame(np.array(df_actions["transactions"].to_list()))
84
df_trx.columns = ticker_list
85
df_trx.index = df_actions["date"]
86
df_trx.index.name = ""
87
88
for i in range(df_trx.shape[1]):
89
df_trx_temp = df_trx.iloc[:, i]
90
df_trx_temp_sign = np.sign(df_trx_temp)
91
buying_signal = df_trx_temp_sign.apply(lambda x: x > 0)
92
selling_signal = df_trx_temp_sign.apply(lambda x: x < 0)
93
94
tic_plot = df_trade[
95
(df_trade["tic"] == df_trx_temp.name)
96
& (df_trade["date"].isin(df_trx.index))
97
]["close"]
98
tic_plot.index = df_trx_temp.index
99
100
plt.figure(figsize=(10, 8))
101
plt.plot(tic_plot, color="g", lw=2.0)
102
plt.plot(
103
tic_plot,
104
"^",
105
markersize=10,
106
color="m",
107
label="buying signal",
108
markevery=buying_signal,
109
)
110
plt.plot(
111
tic_plot,
112
"v",
113
markersize=10,
114
color="k",
115
label="selling signal",
116
markevery=selling_signal,
117
)
118
plt.title(
119
f"{df_trx_temp.name} Num Transactions: {len(buying_signal[buying_signal == True]) + len(selling_signal[selling_signal == True])}"
120
)
121
plt.legend()
122
plt.gca().xaxis.set_major_locator(mdates.DayLocator(interval=25))
123
plt.xticks(rotation=45, ha="right")
124
plt.show()
125
126
127
# 2022-01-15 -> 01/15/2022
128
def transfer_date(str_dat):
129
return datetime.datetime.strptime(str_dat, "%Y-%m-%d").date().strftime("%m/%d/%Y")
130
131
132
def plot_result_from_csv(
133
csv_file: str,
134
column_as_x: str,
135
savefig_filename: str = "fig/result.png",
136
xlabel: str = "Date",
137
ylabel: str = "Result",
138
num_days_xticks: int = 20,
139
xrotation: int = 0,
140
):
141
result = pd.read_csv(csv_file)
142
plot_result(
143
result,
144
column_as_x,
145
savefig_filename,
146
xlabel,
147
ylabel,
148
num_days_xticks,
149
xrotation,
150
)
151
152
153
# select_start_date: included
154
# select_end_date: included
155
# is if_need_calc_return is True, it is account_value, and then transfer it to return
156
# it is better that column_as_x is the first column, and the other columns are strategies
157
# xrotation: the rotation of xlabel, may be used in dates. Default=0 (adaptive adjustment)
158
def plot_result(
159
result: pd.DataFrame(),
160
column_as_x: str,
161
savefig_filename: str = "fig/result.png",
162
xlabel: str = "Date",
163
ylabel: str = "Result",
164
num_days_xticks: int = 20,
165
xrotation: int = 0,
166
):
167
columns = result.columns
168
columns_strtegy = []
169
for i in range(len(columns)):
170
col = columns[i]
171
if "Unnamed" not in col and col != column_as_x:
172
columns_strtegy.append(col)
173
174
result.reindex()
175
176
x = result[column_as_x].values.tolist()
177
plt.rcParams["figure.figsize"] = (15, 6)
178
# plt.figure()
179
180
fig, ax = plt.subplots()
181
colors = [
182
"black",
183
"red",
184
"green",
185
"blue",
186
"cyan",
187
"magenta",
188
"yellow",
189
"aliceblue",
190
"coral",
191
"darksalmon",
192
"firebrick",
193
"honeydew",
194
]
195
for i in range(len(columns_strtegy)):
196
col = columns_strtegy[i]
197
ax.plot(
198
x,
199
result[col],
200
color=colors[i],
201
linewidth=1,
202
linestyle="-",
203
)
204
205
plt.title("", fontsize=20)
206
plt.xlabel(xlabel, fontsize=20)
207
plt.ylabel(ylabel, fontsize=20)
208
209
plt.legend(labels=columns_strtegy, loc="best", fontsize=16)
210
211
# set grid
212
plt.grid()
213
214
plt.xticks(size=22) # 设置刻度大小
215
plt.yticks(size=22) # 设置刻度大小
216
217
# #设置每隔多少距离⼀个刻度
218
# plt.xticks(x[::60])
219
220
# # 设置每月定位符
221
# if if_set_x_monthlocator:
222
# ax.xaxis.set_major_locator(mdates.MonthLocator()) # interval = 1
223
224
# 设置每隔多少距离⼀个刻度
225
plt.xticks(x[::num_days_xticks])
226
227
plt.setp(ax.get_xticklabels(), rotation=xrotation, horizontalalignment="center")
228
229
# 为防止x轴label重叠,自动调整label旋转角度
230
if xrotation == 0:
231
if_overlap = get_if_overlap(fig, ax)
232
233
if if_overlap == True:
234
plt.gcf().autofmt_xdate(ha="right") # ⾃动旋转⽇期标记
235
236
plt.tight_layout() # 自动调整子图间距
237
238
plt.savefig(savefig_filename)
239
240
plt.show()
241
242
243
def get_if_overlap(fig, ax):
244
fig.canvas.draw()
245
# 获取日期标签的边界框
246
bboxes = [label.get_window_extent() for label in ax.get_xticklabels()]
247
# 计算日期标签之间的距离
248
distances = [bboxes[i + 1].x0 - bboxes[i].x1 for i in range(len(bboxes) - 1)]
249
# 如果有任何距离小于0,说明有重叠
250
if any(distance < 0 for distance in distances):
251
if_overlap = True
252
else:
253
if_overlap = False
254
255
return if_overlap
256
257
258
def plot_return(
259
result: pd.DataFrame(),
260
column_as_x: str,
261
if_need_calc_return: bool,
262
savefig_filename: str = "fig/result.png",
263
xlabel: str = "Date",
264
ylabel: str = "Return",
265
if_transfer_date: bool = True,
266
select_start_date: str = None,
267
select_end_date: str = None,
268
num_days_xticks: int = 20,
269
xrotation: int = 0,
270
):
271
if select_start_date is None:
272
select_start_date: str = result[column_as_x].iloc[0]
273
select_end_date: str = result[column_as_x].iloc[-1]
274
# calc returns if if_need_calc_return is True, so that result stores returns
275
select_start_date_index = result[column_as_x].tolist().index(select_start_date)
276
columns = result.columns
277
columns_strtegy = []
278
column_as_x_index = None
279
for i in range(len(columns)):
280
col = columns[i]
281
if col == column_as_x:
282
column_as_x_index = i
283
elif "Unnamed" not in col:
284
columns_strtegy.append(col)
285
if if_need_calc_return:
286
result[col] = result[col] / result[col][select_start_date_index] - 1
287
288
# select the result between select_start_date and select_end_date
289
# if date is 2020-01-15, transfer it to 01/15/2020
290
num_rows, num_cols = result.shape
291
tmp_result = copy.deepcopy(result)
292
result = pd.DataFrame()
293
if_first_row = True
294
columns = []
295
for i in range(num_rows):
296
if (
297
str2date(select_start_date)
298
<= str2date(tmp_result[column_as_x][i])
299
<= str2date(select_end_date)
300
):
301
if "-" in tmp_result.iloc[i][column_as_x] and if_transfer_date:
302
new_date = transfer_date(tmp_result.iloc[i][column_as_x])
303
else:
304
new_date = tmp_result.iloc[i][column_as_x]
305
tmp_result.iloc[i, column_as_x_index] = new_date
306
# print("tmp_result.iloc[i]: ", tmp_result.iloc[i])
307
# result = result.append(tmp_result.iloc[i])
308
if if_first_row:
309
columns = tmp_result.iloc[i].index.tolist()
310
result = pd.DataFrame(columns=columns)
311
# result = pd.concat([result, tmp_result.iloc[i]], axis=1)
312
# result = pd.DataFrame(tmp_result.iloc[i])
313
# result.columns = tmp_result.iloc[i].index.tolist()
314
if_first_row = False
315
row = pd.DataFrame([tmp_result.iloc[i].tolist()], columns=columns)
316
result = pd.concat([result, row], axis=0)
317
318
# print final return of each strategy
319
final_return = {}
320
for col in columns_strtegy:
321
final_return[col] = result.iloc[-1][col]
322
print("final return: ", final_return)
323
324
result.reindex()
325
326
plot_result(
327
result=result,
328
column_as_x=column_as_x,
329
savefig_filename=savefig_filename,
330
xlabel=xlabel,
331
ylabel=ylabel,
332
num_days_xticks=num_days_xticks,
333
xrotation=xrotation,
334
)
335
336
337
def plot_return_from_csv(
338
csv_file: str,
339
column_as_x: str,
340
if_need_calc_return: bool,
341
savefig_filename: str = "fig/result.png",
342
xlabel: str = "Date",
343
ylabel: str = "Return",
344
if_transfer_date: bool = True,
345
select_start_date: str = None,
346
select_end_date: str = None,
347
num_days_xticks: int = 20,
348
xrotation: int = 0,
349
):
350
result = pd.read_csv(csv_file)
351
plot_return(
352
result,
353
column_as_x,
354
if_need_calc_return,
355
savefig_filename,
356
xlabel,
357
ylabel,
358
if_transfer_date,
359
select_start_date,
360
select_end_date,
361
num_days_xticks,
362
xrotation,
363
)
364
365