Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
polakowo
GitHub Repository: polakowo/vectorbt
Path: blob/master/tests/test_data.py
1071 views
1
from datetime import datetime, timedelta, timezone
2
3
import numpy as np
4
import pandas as pd
5
import pytest
6
import pytz
7
8
import vectorbt as vbt
9
from vectorbt.utils.config import merge_dicts
10
from vectorbt.utils.datetime_ import to_timezone
11
12
seed = 42
13
14
15
# ############# Global ############# #
16
17
def setup_module():
18
vbt.settings.numba['check_func_suffix'] = True
19
vbt.settings.caching.enabled = False
20
vbt.settings.caching.whitelist = []
21
vbt.settings.caching.blacklist = []
22
23
24
def teardown_module():
25
vbt.settings.reset()
26
27
28
# ############# base.py ############# #
29
30
31
class MyData(vbt.Data):
32
@classmethod
33
def download_symbol(cls, symbol, shape=(5, 3), start_date=datetime(2020, 1, 1), columns=None, index_mask=None,
34
column_mask=None, return_arr=False, tz_localize=None, seed=seed):
35
np.random.seed(seed)
36
a = np.random.uniform(size=shape) + symbol
37
if return_arr:
38
return a
39
index = [start_date + timedelta(days=i) for i in range(a.shape[0])]
40
if a.ndim == 1:
41
sr = pd.Series(a, index=index, name=columns)
42
if index_mask is not None:
43
sr = sr.loc[index_mask]
44
if tz_localize is not None:
45
sr = sr.tz_localize(tz_localize)
46
return sr
47
df = pd.DataFrame(a, index=index, columns=columns)
48
if index_mask is not None:
49
df = df.loc[index_mask]
50
if column_mask is not None:
51
df = df.loc[:, column_mask]
52
if tz_localize is not None:
53
df = df.tz_localize(tz_localize)
54
return df
55
56
def update_symbol(self, symbol, n=1, **kwargs):
57
download_kwargs = self.select_symbol_kwargs(symbol, self.download_kwargs)
58
download_kwargs['start_date'] = self.data[symbol].index[-1]
59
shape = download_kwargs.pop('shape', (5, 3))
60
new_shape = (n, shape[1]) if len(shape) > 1 else (n,)
61
new_seed = download_kwargs.pop('seed', seed) + 1
62
kwargs = merge_dicts(download_kwargs, kwargs)
63
return self.download_symbol(symbol, shape=new_shape, seed=new_seed, **kwargs)
64
65
66
class TestData:
67
def test_config(self, tmp_path):
68
data = MyData.download([0, 1], shape=(5, 3), columns=['feat0', 'feat1', 'feat2'])
69
assert MyData.loads(data.dumps()) == data
70
data.save(tmp_path / 'data')
71
assert MyData.load(tmp_path / 'data') == data
72
73
def test_download(self):
74
pd.testing.assert_series_equal(
75
MyData.download(0, shape=(5,), return_arr=True).data[0],
76
pd.Series(
77
[
78
0.3745401188473625,
79
0.9507143064099162,
80
0.7319939418114051,
81
0.5986584841970366,
82
0.15601864044243652
83
]
84
)
85
)
86
pd.testing.assert_frame_equal(
87
MyData.download(0, shape=(5, 3), return_arr=True).data[0],
88
pd.DataFrame(
89
[
90
[0.3745401188473625, 0.9507143064099162, 0.7319939418114051],
91
[0.5986584841970366, 0.15601864044243652, 0.15599452033620265],
92
[0.05808361216819946, 0.8661761457749352, 0.6011150117432088],
93
[0.7080725777960455, 0.020584494295802447, 0.9699098521619943],
94
[0.8324426408004217, 0.21233911067827616, 0.18182496720710062]
95
]
96
)
97
)
98
index = pd.DatetimeIndex(
99
[
100
'2020-01-01 00:00:00',
101
'2020-01-02 00:00:00',
102
'2020-01-03 00:00:00',
103
'2020-01-04 00:00:00',
104
'2020-01-05 00:00:00'
105
],
106
freq='D',
107
tz=timezone.utc
108
)
109
pd.testing.assert_series_equal(
110
MyData.download(0, shape=(5,)).data[0],
111
pd.Series(
112
[
113
0.3745401188473625,
114
0.9507143064099162,
115
0.7319939418114051,
116
0.5986584841970366,
117
0.15601864044243652
118
],
119
index=index
120
)
121
)
122
pd.testing.assert_series_equal(
123
MyData.download(0, shape=(5,), columns='feat0').data[0],
124
pd.Series(
125
[
126
0.3745401188473625,
127
0.9507143064099162,
128
0.7319939418114051,
129
0.5986584841970366,
130
0.15601864044243652
131
],
132
index=index,
133
name='feat0'
134
)
135
)
136
pd.testing.assert_frame_equal(
137
MyData.download(0, shape=(5, 3)).data[0],
138
pd.DataFrame(
139
[
140
[0.3745401188473625, 0.9507143064099162, 0.7319939418114051],
141
[0.5986584841970366, 0.15601864044243652, 0.15599452033620265],
142
[0.05808361216819946, 0.8661761457749352, 0.6011150117432088],
143
[0.7080725777960455, 0.020584494295802447, 0.9699098521619943],
144
[0.8324426408004217, 0.21233911067827616, 0.18182496720710062]
145
],
146
index=index
147
)
148
)
149
pd.testing.assert_frame_equal(
150
MyData.download(0, shape=(5, 3), columns=['feat0', 'feat1', 'feat2']).data[0],
151
pd.DataFrame(
152
[
153
[0.3745401188473625, 0.9507143064099162, 0.7319939418114051],
154
[0.5986584841970366, 0.15601864044243652, 0.15599452033620265],
155
[0.05808361216819946, 0.8661761457749352, 0.6011150117432088],
156
[0.7080725777960455, 0.020584494295802447, 0.9699098521619943],
157
[0.8324426408004217, 0.21233911067827616, 0.18182496720710062]
158
],
159
index=index,
160
columns=pd.Index(['feat0', 'feat1', 'feat2'], dtype='object'))
161
)
162
pd.testing.assert_series_equal(
163
MyData.download([0, 1], shape=(5,)).data[0],
164
pd.Series(
165
[
166
0.3745401188473625,
167
0.9507143064099162,
168
0.7319939418114051,
169
0.5986584841970366,
170
0.15601864044243652
171
],
172
index=index
173
)
174
)
175
pd.testing.assert_series_equal(
176
MyData.download([0, 1], shape=(5,)).data[1],
177
pd.Series(
178
[
179
1.3745401188473625,
180
1.9507143064099162,
181
1.7319939418114051,
182
1.5986584841970366,
183
1.15601864044243652
184
],
185
index=index
186
)
187
)
188
pd.testing.assert_frame_equal(
189
MyData.download([0, 1], shape=(5, 3)).data[0],
190
pd.DataFrame(
191
[
192
[0.3745401188473625, 0.9507143064099162, 0.7319939418114051],
193
[0.5986584841970366, 0.15601864044243652, 0.15599452033620265],
194
[0.05808361216819946, 0.8661761457749352, 0.6011150117432088],
195
[0.7080725777960455, 0.020584494295802447, 0.9699098521619943],
196
[0.8324426408004217, 0.21233911067827616, 0.18182496720710062]
197
],
198
index=index
199
)
200
)
201
pd.testing.assert_frame_equal(
202
MyData.download([0, 1], shape=(5, 3)).data[1],
203
pd.DataFrame(
204
[
205
[1.3745401188473625, 1.9507143064099162, 1.7319939418114051],
206
[1.5986584841970366, 1.15601864044243652, 1.15599452033620265],
207
[1.05808361216819946, 1.8661761457749352, 1.6011150117432088],
208
[1.7080725777960455, 1.020584494295802447, 1.9699098521619943],
209
[1.8324426408004217, 1.21233911067827616, 1.18182496720710062]
210
],
211
index=index
212
)
213
)
214
index2 = pd.DatetimeIndex(
215
[
216
'2020-01-01 00:00:00',
217
'2020-01-02 00:00:00',
218
'2020-01-03 00:00:00',
219
'2020-01-04 00:00:00',
220
'2020-01-05 00:00:00'
221
],
222
freq='D',
223
tz=pytz.utc
224
).tz_convert(to_timezone('Europe/Berlin'))
225
pd.testing.assert_series_equal(
226
MyData.download(0, shape=(5,), tz_localize='UTC', tz_convert='Europe/Berlin').data[0],
227
pd.Series(
228
[
229
0.3745401188473625,
230
0.9507143064099162,
231
0.7319939418114051,
232
0.5986584841970366,
233
0.15601864044243652
234
],
235
index=index2
236
)
237
)
238
index_mask = vbt.symbol_dict({
239
0: [False, True, True, True, True],
240
1: [True, True, True, True, False]
241
})
242
pd.testing.assert_series_equal(
243
MyData.download([0, 1], shape=(5,), index_mask=index_mask, missing_index='nan').data[0],
244
pd.Series(
245
[
246
np.nan,
247
0.9507143064099162,
248
0.7319939418114051,
249
0.5986584841970366,
250
0.15601864044243652
251
],
252
index=index
253
)
254
)
255
pd.testing.assert_series_equal(
256
MyData.download([0, 1], shape=(5,), index_mask=index_mask, missing_index='nan').data[1],
257
pd.Series(
258
[
259
1.3745401188473625,
260
1.9507143064099162,
261
1.7319939418114051,
262
1.5986584841970366,
263
np.nan
264
],
265
index=index
266
)
267
)
268
pd.testing.assert_series_equal(
269
MyData.download([0, 1], shape=(5,), index_mask=index_mask, missing_index='drop').data[0],
270
pd.Series(
271
[
272
0.9507143064099162,
273
0.7319939418114051,
274
0.5986584841970366
275
],
276
index=index[1:4]
277
)
278
)
279
pd.testing.assert_series_equal(
280
MyData.download([0, 1], shape=(5,), index_mask=index_mask, missing_index='drop').data[1],
281
pd.Series(
282
[
283
1.9507143064099162,
284
1.7319939418114051,
285
1.5986584841970366
286
],
287
index=index[1:4]
288
)
289
)
290
column_mask = vbt.symbol_dict({
291
0: [False, True, True],
292
1: [True, True, False]
293
})
294
pd.testing.assert_frame_equal(
295
MyData.download([0, 1], shape=(5, 3), index_mask=index_mask, column_mask=column_mask,
296
missing_index='nan', missing_columns='nan').data[0],
297
pd.DataFrame(
298
[
299
[np.nan, np.nan, np.nan],
300
[np.nan, 0.15601864044243652, 0.15599452033620265],
301
[np.nan, 0.8661761457749352, 0.6011150117432088],
302
[np.nan, 0.020584494295802447, 0.9699098521619943],
303
[np.nan, 0.21233911067827616, 0.18182496720710062]
304
],
305
index=index
306
)
307
)
308
pd.testing.assert_frame_equal(
309
MyData.download([0, 1], shape=(5, 3), index_mask=index_mask, column_mask=column_mask,
310
missing_index='nan', missing_columns='nan').data[1],
311
pd.DataFrame(
312
[
313
[1.3745401188473625, 1.9507143064099162, np.nan],
314
[1.5986584841970366, 1.15601864044243652, np.nan],
315
[1.05808361216819946, 1.8661761457749352, np.nan],
316
[1.7080725777960455, 1.020584494295802447, np.nan],
317
[np.nan, np.nan, np.nan]
318
],
319
index=index
320
)
321
)
322
pd.testing.assert_frame_equal(
323
MyData.download([0, 1], shape=(5, 3), index_mask=index_mask, column_mask=column_mask,
324
missing_index='drop', missing_columns='drop').data[0],
325
pd.DataFrame(
326
[
327
[0.15601864044243652],
328
[0.8661761457749352],
329
[0.020584494295802447]
330
],
331
index=index[1:4],
332
columns=pd.Index([1], dtype='int64')
333
)
334
)
335
pd.testing.assert_frame_equal(
336
MyData.download([0, 1], shape=(5, 3), index_mask=index_mask, column_mask=column_mask,
337
missing_index='drop', missing_columns='drop').data[1],
338
pd.DataFrame(
339
[
340
[1.15601864044243652],
341
[1.8661761457749352],
342
[1.020584494295802447]
343
],
344
index=index[1:4],
345
columns=pd.Index([1], dtype='int64')
346
)
347
)
348
with pytest.raises(Exception):
349
MyData.download([0, 1], shape=(5, 3), index_mask=index_mask, column_mask=column_mask,
350
missing_index='raise', missing_columns='nan')
351
with pytest.raises(Exception):
352
MyData.download([0, 1], shape=(5, 3), index_mask=index_mask, column_mask=column_mask,
353
missing_index='nan', missing_columns='raise')
354
with pytest.raises(Exception):
355
MyData.download([0, 1], shape=(5, 3), index_mask=index_mask, column_mask=column_mask,
356
missing_index='test', missing_columns='nan')
357
with pytest.raises(Exception):
358
MyData.download([0, 1], shape=(5, 3), index_mask=index_mask, column_mask=column_mask,
359
missing_index='nan', missing_columns='test')
360
361
def test_update(self):
362
pd.testing.assert_series_equal(
363
MyData.download(0, shape=(5,), return_arr=True).update().data[0],
364
pd.Series(
365
[
366
0.3745401188473625,
367
0.9507143064099162,
368
0.7319939418114051,
369
0.5986584841970366,
370
0.11505456638977896
371
]
372
)
373
)
374
pd.testing.assert_series_equal(
375
MyData.download(0, shape=(5,), return_arr=True).update(n=2).data[0],
376
pd.Series(
377
[
378
0.3745401188473625,
379
0.9507143064099162,
380
0.7319939418114051,
381
0.5986584841970366,
382
0.11505456638977896,
383
0.6090665392794814
384
]
385
)
386
)
387
pd.testing.assert_frame_equal(
388
MyData.download(0, shape=(5, 3), return_arr=True).update().data[0],
389
pd.DataFrame(
390
[
391
[0.3745401188473625, 0.9507143064099162, 0.7319939418114051],
392
[0.5986584841970366, 0.15601864044243652, 0.15599452033620265],
393
[0.05808361216819946, 0.8661761457749352, 0.6011150117432088],
394
[0.7080725777960455, 0.020584494295802447, 0.9699098521619943],
395
[0.11505456638977896, 0.6090665392794814, 0.13339096418598828]
396
]
397
)
398
)
399
pd.testing.assert_frame_equal(
400
MyData.download(0, shape=(5, 3), return_arr=True).update(n=2).data[0],
401
pd.DataFrame(
402
[
403
[0.3745401188473625, 0.9507143064099162, 0.7319939418114051],
404
[0.5986584841970366, 0.15601864044243652, 0.15599452033620265],
405
[0.05808361216819946, 0.8661761457749352, 0.6011150117432088],
406
[0.7080725777960455, 0.020584494295802447, 0.9699098521619943],
407
[0.11505456638977896, 0.6090665392794814, 0.13339096418598828],
408
[0.24058961996534878, 0.3271390558111398, 0.8591374909485977]
409
]
410
)
411
)
412
index = pd.DatetimeIndex(
413
[
414
'2020-01-01 00:00:00',
415
'2020-01-02 00:00:00',
416
'2020-01-03 00:00:00',
417
'2020-01-04 00:00:00',
418
'2020-01-05 00:00:00'
419
],
420
freq='D',
421
tz=timezone.utc
422
)
423
pd.testing.assert_series_equal(
424
MyData.download(0, shape=(5,)).update().data[0],
425
pd.Series(
426
[
427
0.3745401188473625,
428
0.9507143064099162,
429
0.7319939418114051,
430
0.5986584841970366,
431
0.11505456638977896
432
],
433
index=index
434
)
435
)
436
updated_index = pd.DatetimeIndex(
437
[
438
'2020-01-01 00:00:00',
439
'2020-01-02 00:00:00',
440
'2020-01-03 00:00:00',
441
'2020-01-04 00:00:00',
442
'2020-01-05 00:00:00',
443
'2020-01-06 00:00:00'
444
],
445
freq='D',
446
tz=timezone.utc
447
)
448
pd.testing.assert_series_equal(
449
MyData.download(0, shape=(5,)).update(n=2).data[0],
450
pd.Series(
451
[
452
0.3745401188473625,
453
0.9507143064099162,
454
0.7319939418114051,
455
0.5986584841970366,
456
0.11505456638977896,
457
0.6090665392794814
458
],
459
index=updated_index
460
)
461
)
462
index2 = pd.DatetimeIndex(
463
[
464
'2020-01-01 00:00:00',
465
'2020-01-02 00:00:00',
466
'2020-01-03 00:00:00',
467
'2020-01-04 00:00:00',
468
'2020-01-05 00:00:00'
469
],
470
freq='D',
471
tz=pytz.utc
472
).tz_convert(to_timezone('Europe/Berlin'))
473
pd.testing.assert_series_equal(
474
MyData.download(0, shape=(5,), tz_localize='UTC', tz_convert='Europe/Berlin')
475
.update(tz_localize=None).data[0],
476
pd.Series(
477
[
478
0.3745401188473625,
479
0.9507143064099162,
480
0.7319939418114051,
481
0.5986584841970366,
482
0.11505456638977896
483
],
484
index=index2
485
)
486
)
487
index_mask = vbt.symbol_dict({
488
0: [False, True, True, True, True],
489
1: [True, True, True, True, False]
490
})
491
update_index_mask = vbt.symbol_dict({
492
0: [True],
493
1: [False]
494
})
495
pd.testing.assert_series_equal(
496
MyData.download([0, 1], shape=(5,), index_mask=index_mask, missing_index='nan')
497
.update(index_mask=update_index_mask).data[0],
498
pd.Series(
499
[
500
np.nan,
501
0.9507143064099162,
502
0.7319939418114051,
503
0.5986584841970366,
504
0.11505456638977896
505
],
506
index=index
507
)
508
)
509
pd.testing.assert_series_equal(
510
MyData.download([0, 1], shape=(5,), index_mask=index_mask, missing_index='nan')
511
.update(index_mask=update_index_mask).data[1],
512
pd.Series(
513
[
514
1.3745401188473625,
515
1.9507143064099162,
516
1.7319939418114051,
517
1.5986584841970366,
518
np.nan
519
],
520
index=index
521
)
522
)
523
update_index_mask2 = vbt.symbol_dict({
524
0: [True, False],
525
1: [False, True]
526
})
527
pd.testing.assert_series_equal(
528
MyData.download([0, 1], shape=(5,), index_mask=index_mask, missing_index='nan')
529
.update(n=2, index_mask=update_index_mask2).data[0],
530
pd.Series(
531
[
532
np.nan,
533
0.9507143064099162,
534
0.7319939418114051,
535
0.5986584841970366,
536
0.11505456638977896,
537
np.nan
538
],
539
index=updated_index
540
)
541
)
542
pd.testing.assert_series_equal(
543
MyData.download([0, 1], shape=(5,), index_mask=index_mask, missing_index='nan')
544
.update(n=2, index_mask=update_index_mask2).data[1],
545
pd.Series(
546
[
547
1.3745401188473625,
548
1.9507143064099162,
549
1.7319939418114051,
550
1.5986584841970366,
551
np.nan,
552
1.6090665392794814
553
],
554
index=updated_index
555
)
556
)
557
pd.testing.assert_series_equal(
558
MyData.download([0, 1], shape=(5,), index_mask=index_mask, missing_index='drop')
559
.update(index_mask=update_index_mask).data[0],
560
pd.Series(
561
[
562
0.9507143064099162,
563
0.7319939418114051,
564
0.5986584841970366
565
],
566
index=index[1:4]
567
)
568
)
569
pd.testing.assert_series_equal(
570
MyData.download([0, 1], shape=(5,), index_mask=index_mask, missing_index='drop')
571
.update(index_mask=update_index_mask).data[1],
572
pd.Series(
573
[
574
1.9507143064099162,
575
1.7319939418114051,
576
1.5986584841970366
577
],
578
index=index[1:4]
579
)
580
)
581
pd.testing.assert_series_equal(
582
MyData.download([0, 1], shape=(5,), index_mask=index_mask, missing_index='drop')
583
.update(n=2, index_mask=update_index_mask2).data[0],
584
pd.Series(
585
[
586
0.9507143064099162,
587
0.7319939418114051,
588
0.5986584841970366
589
],
590
index=index[1:4]
591
)
592
)
593
pd.testing.assert_series_equal(
594
MyData.download([0, 1], shape=(5,), index_mask=index_mask, missing_index='drop')
595
.update(n=2, index_mask=update_index_mask2).data[1],
596
pd.Series(
597
[
598
1.9507143064099162,
599
1.7319939418114051,
600
1.5986584841970366
601
],
602
index=index[1:4]
603
)
604
)
605
column_mask = vbt.symbol_dict({
606
0: [False, True, True],
607
1: [True, True, False]
608
})
609
pd.testing.assert_frame_equal(
610
MyData.download([0, 1], shape=(5, 3), index_mask=index_mask, column_mask=column_mask,
611
missing_index='nan', missing_columns='nan')
612
.update(index_mask=update_index_mask).data[0],
613
pd.DataFrame(
614
[
615
[np.nan, np.nan, np.nan],
616
[np.nan, 0.15601864044243652, 0.15599452033620265],
617
[np.nan, 0.8661761457749352, 0.6011150117432088],
618
[np.nan, 0.020584494295802447, 0.9699098521619943],
619
[np.nan, 0.6090665392794814, 0.13339096418598828]
620
],
621
index=index
622
)
623
)
624
pd.testing.assert_frame_equal(
625
MyData.download([0, 1], shape=(5, 3), index_mask=index_mask, column_mask=column_mask,
626
missing_index='nan', missing_columns='nan')
627
.update(index_mask=update_index_mask).data[1],
628
pd.DataFrame(
629
[
630
[1.3745401188473625, 1.9507143064099162, np.nan],
631
[1.5986584841970366, 1.15601864044243652, np.nan],
632
[1.05808361216819946, 1.8661761457749352, np.nan],
633
[1.7080725777960455, 1.020584494295802447, np.nan],
634
[np.nan, np.nan, np.nan]
635
],
636
index=index
637
)
638
)
639
pd.testing.assert_frame_equal(
640
MyData.download([0, 1], shape=(5, 3), index_mask=index_mask, column_mask=column_mask,
641
missing_index='nan', missing_columns='nan')
642
.update(n=2, index_mask=update_index_mask2).data[0],
643
pd.DataFrame(
644
[
645
[np.nan, np.nan, np.nan],
646
[np.nan, 0.15601864044243652, 0.15599452033620265],
647
[np.nan, 0.8661761457749352, 0.6011150117432088],
648
[np.nan, 0.020584494295802447, 0.9699098521619943],
649
[np.nan, 0.6090665392794814, 0.13339096418598828],
650
[np.nan, np.nan, np.nan]
651
],
652
index=updated_index
653
)
654
)
655
pd.testing.assert_frame_equal(
656
MyData.download([0, 1], shape=(5, 3), index_mask=index_mask, column_mask=column_mask,
657
missing_index='nan', missing_columns='nan')
658
.update(n=2, index_mask=update_index_mask2).data[1],
659
pd.DataFrame(
660
[
661
[1.3745401188473625, 1.9507143064099162, np.nan],
662
[1.5986584841970366, 1.15601864044243652, np.nan],
663
[1.05808361216819946, 1.8661761457749352, np.nan],
664
[1.7080725777960455, 1.020584494295802447, np.nan],
665
[np.nan, np.nan, np.nan],
666
[1.2405896199653488, 1.3271390558111398, np.nan]
667
],
668
index=updated_index
669
)
670
)
671
pd.testing.assert_frame_equal(
672
MyData.download([0, 1], shape=(5, 3), index_mask=index_mask, column_mask=column_mask,
673
missing_index='drop', missing_columns='drop')
674
.update(index_mask=update_index_mask).data[0],
675
pd.DataFrame(
676
[
677
[0.15601864044243652],
678
[0.8661761457749352],
679
[0.020584494295802447]
680
],
681
index=index[1:4],
682
columns=pd.Index([1], dtype='int64')
683
)
684
)
685
pd.testing.assert_frame_equal(
686
MyData.download([0, 1], shape=(5, 3), index_mask=index_mask, column_mask=column_mask,
687
missing_index='drop', missing_columns='drop')
688
.update(index_mask=update_index_mask).data[1],
689
pd.DataFrame(
690
[
691
[1.15601864044243652],
692
[1.8661761457749352],
693
[1.020584494295802447]
694
],
695
index=index[1:4],
696
columns=pd.Index([1], dtype='int64')
697
)
698
)
699
pd.testing.assert_frame_equal(
700
MyData.download([0, 1], shape=(5, 3), index_mask=index_mask, column_mask=column_mask,
701
missing_index='drop', missing_columns='drop')
702
.update(n=2, index_mask=update_index_mask2).data[0],
703
pd.DataFrame(
704
[
705
[0.15601864044243652],
706
[0.8661761457749352],
707
[0.020584494295802447]
708
],
709
index=index[1:4],
710
columns=pd.Index([1], dtype='int64')
711
)
712
)
713
pd.testing.assert_frame_equal(
714
MyData.download([0, 1], shape=(5, 3), index_mask=index_mask, column_mask=column_mask,
715
missing_index='drop', missing_columns='drop')
716
.update(n=2, index_mask=update_index_mask2).data[1],
717
pd.DataFrame(
718
[
719
[1.15601864044243652],
720
[1.8661761457749352],
721
[1.020584494295802447]
722
],
723
index=index[1:4],
724
columns=pd.Index([1], dtype='int64')
725
)
726
)
727
728
def test_concat(self):
729
index = pd.DatetimeIndex(
730
[
731
'2020-01-01 00:00:00',
732
'2020-01-02 00:00:00',
733
'2020-01-03 00:00:00',
734
'2020-01-04 00:00:00',
735
'2020-01-05 00:00:00'
736
],
737
freq='D',
738
tz=timezone.utc
739
)
740
pd.testing.assert_series_equal(
741
MyData.download(0, shape=(5,), columns='feat0').concat()['feat0'],
742
pd.Series(
743
[
744
0.3745401188473625,
745
0.9507143064099162,
746
0.7319939418114051,
747
0.5986584841970366,
748
0.15601864044243652
749
],
750
index=index,
751
name=0
752
)
753
)
754
pd.testing.assert_frame_equal(
755
MyData.download([0, 1], shape=(5,), columns='feat0').concat()['feat0'],
756
pd.DataFrame(
757
[
758
[0.3745401188473625, 1.3745401188473625],
759
[0.9507143064099162, 1.9507143064099162],
760
[0.7319939418114051, 1.7319939418114051],
761
[0.5986584841970366, 1.5986584841970366],
762
[0.15601864044243652, 1.15601864044243652]
763
],
764
index=index,
765
columns=pd.Index([0, 1], dtype='int64', name='symbol')
766
)
767
)
768
pd.testing.assert_series_equal(
769
MyData.download(0, shape=(5, 3), columns=['feat0', 'feat1', 'feat2']).concat()['feat0'],
770
pd.Series(
771
[
772
0.3745401188473625,
773
0.5986584841970366,
774
0.05808361216819946,
775
0.7080725777960455,
776
0.8324426408004217
777
],
778
index=index,
779
name=0
780
)
781
)
782
pd.testing.assert_series_equal(
783
MyData.download(0, shape=(5, 3), columns=['feat0', 'feat1', 'feat2']).concat()['feat1'],
784
pd.Series(
785
[
786
0.9507143064099162,
787
0.15601864044243652,
788
0.8661761457749352,
789
0.020584494295802447,
790
0.21233911067827616
791
],
792
index=index,
793
name=0
794
)
795
)
796
pd.testing.assert_series_equal(
797
MyData.download(0, shape=(5, 3), columns=['feat0', 'feat1', 'feat2']).concat()['feat2'],
798
pd.Series(
799
[
800
0.7319939418114051,
801
0.15599452033620265,
802
0.6011150117432088,
803
0.9699098521619943,
804
0.18182496720710062
805
],
806
index=index,
807
name=0
808
)
809
)
810
pd.testing.assert_frame_equal(
811
MyData.download([0, 1], shape=(5, 3), columns=['feat0', 'feat1', 'feat2']).concat()['feat0'],
812
pd.DataFrame(
813
[
814
[0.3745401188473625, 1.3745401188473625],
815
[0.5986584841970366, 1.5986584841970366],
816
[0.05808361216819946, 1.05808361216819946],
817
[0.7080725777960455, 1.7080725777960455],
818
[0.8324426408004217, 1.8324426408004217]
819
],
820
index=index,
821
columns=pd.Index([0, 1], dtype='int64', name='symbol')
822
)
823
)
824
pd.testing.assert_frame_equal(
825
MyData.download([0, 1], shape=(5, 3), columns=['feat0', 'feat1', 'feat2']).concat()['feat1'],
826
pd.DataFrame(
827
[
828
[0.9507143064099162, 1.9507143064099162],
829
[0.15601864044243652, 1.15601864044243652],
830
[0.8661761457749352, 1.8661761457749352],
831
[0.020584494295802447, 1.020584494295802447],
832
[0.21233911067827616, 1.21233911067827616]
833
],
834
index=index,
835
columns=pd.Index([0, 1], dtype='int64', name='symbol')
836
)
837
)
838
pd.testing.assert_frame_equal(
839
MyData.download([0, 1], shape=(5, 3), columns=['feat0', 'feat1', 'feat2']).concat()['feat2'],
840
pd.DataFrame(
841
[
842
[0.7319939418114051, 1.7319939418114051],
843
[0.15599452033620265, 1.15599452033620265],
844
[0.6011150117432088, 1.6011150117432088],
845
[0.9699098521619943, 1.9699098521619943],
846
[0.18182496720710062, 1.18182496720710062]
847
],
848
index=index,
849
columns=pd.Index([0, 1], dtype='int64', name='symbol')
850
)
851
)
852
853
def test_get(self):
854
index = pd.DatetimeIndex(
855
[
856
'2020-01-01 00:00:00',
857
'2020-01-02 00:00:00',
858
'2020-01-03 00:00:00',
859
'2020-01-04 00:00:00',
860
'2020-01-05 00:00:00'
861
],
862
freq='D',
863
tz=timezone.utc
864
)
865
pd.testing.assert_series_equal(
866
MyData.download(0, shape=(5,), columns='feat0').get(),
867
pd.Series(
868
[
869
0.3745401188473625,
870
0.9507143064099162,
871
0.7319939418114051,
872
0.5986584841970366,
873
0.15601864044243652
874
],
875
index=index,
876
name='feat0'
877
)
878
)
879
pd.testing.assert_frame_equal(
880
MyData.download(0, shape=(5, 3), columns=['feat0', 'feat1', 'feat2']).get(),
881
pd.DataFrame(
882
[
883
[0.3745401188473625, 0.9507143064099162, 0.7319939418114051],
884
[0.5986584841970366, 0.15601864044243652, 0.15599452033620265],
885
[0.05808361216819946, 0.8661761457749352, 0.6011150117432088],
886
[0.7080725777960455, 0.020584494295802447, 0.9699098521619943],
887
[0.8324426408004217, 0.21233911067827616, 0.18182496720710062]
888
],
889
index=index,
890
columns=pd.Index(['feat0', 'feat1', 'feat2'], dtype='object')
891
)
892
)
893
pd.testing.assert_series_equal(
894
MyData.download(0, shape=(5, 3), columns=['feat0', 'feat1', 'feat2']).get('feat0'),
895
pd.Series(
896
[
897
0.3745401188473625,
898
0.5986584841970366,
899
0.05808361216819946,
900
0.7080725777960455,
901
0.8324426408004217
902
],
903
index=index,
904
name='feat0'
905
)
906
)
907
pd.testing.assert_frame_equal(
908
MyData.download([0, 1], shape=(5,), columns='feat0').get(),
909
pd.DataFrame(
910
[
911
[0.3745401188473625, 1.3745401188473625],
912
[0.9507143064099162, 1.9507143064099162],
913
[0.7319939418114051, 1.7319939418114051],
914
[0.5986584841970366, 1.5986584841970366],
915
[0.15601864044243652, 1.15601864044243652]
916
],
917
index=index,
918
columns=pd.Index([0, 1], dtype='int64', name='symbol')
919
)
920
)
921
pd.testing.assert_frame_equal(
922
MyData.download([0, 1], shape=(5, 3), columns=['feat0', 'feat1', 'feat2']).get('feat0'),
923
pd.DataFrame(
924
[
925
[0.3745401188473625, 1.3745401188473625],
926
[0.5986584841970366, 1.5986584841970366],
927
[0.05808361216819946, 1.05808361216819946],
928
[0.7080725777960455, 1.7080725777960455],
929
[0.8324426408004217, 1.8324426408004217]
930
],
931
index=index,
932
columns=pd.Index([0, 1], dtype='int64', name='symbol')
933
)
934
)
935
pd.testing.assert_frame_equal(
936
MyData.download([0, 1], shape=(5, 3), columns=['feat0', 'feat1', 'feat2']).get(['feat0', 'feat1'])[0],
937
pd.DataFrame(
938
[
939
[0.3745401188473625, 1.3745401188473625],
940
[0.5986584841970366, 1.5986584841970366],
941
[0.05808361216819946, 1.05808361216819946],
942
[0.7080725777960455, 1.7080725777960455],
943
[0.8324426408004217, 1.8324426408004217]
944
],
945
index=index,
946
columns=pd.Index([0, 1], dtype='int64', name='symbol')
947
)
948
)
949
pd.testing.assert_frame_equal(
950
MyData.download([0, 1], shape=(5, 3), columns=['feat0', 'feat1', 'feat2']).get()[0],
951
pd.DataFrame(
952
[
953
[0.3745401188473625, 1.3745401188473625],
954
[0.5986584841970366, 1.5986584841970366],
955
[0.05808361216819946, 1.05808361216819946],
956
[0.7080725777960455, 1.7080725777960455],
957
[0.8324426408004217, 1.8324426408004217]
958
],
959
index=index,
960
columns=pd.Index([0, 1], dtype='int64', name='symbol')
961
)
962
)
963
964
def test_indexing(self):
965
assert MyData.download([0, 1], shape=(5,), columns='feat0').iloc[:3].wrapper == \
966
MyData.download([0, 1], shape=(3,), columns='feat0').wrapper
967
assert MyData.download([0, 1], shape=(5, 3), columns=['feat0', 'feat1', 'feat2']).iloc[:3].wrapper == \
968
MyData.download([0, 1], shape=(3, 3), columns=['feat0', 'feat1', 'feat2']).wrapper
969
assert MyData.download([0, 1], shape=(5, 3), columns=['feat0', 'feat1', 'feat2'])['feat0'].wrapper == \
970
MyData.download([0, 1], shape=(5,), columns='feat0').wrapper
971
assert MyData.download([0, 1], shape=(5, 3), columns=['feat0', 'feat1', 'feat2'])[['feat0']].wrapper == \
972
MyData.download([0, 1], shape=(5, 1), columns=['feat0']).wrapper
973
974
def test_stats(self):
975
index_mask = vbt.symbol_dict({
976
0: [False, True, True, True, True],
977
1: [True, True, True, True, False]
978
})
979
column_mask = vbt.symbol_dict({
980
0: [False, True, True],
981
1: [True, True, False]
982
})
983
data = MyData.download(
984
[0, 1], shape=(5, 3), index_mask=index_mask, column_mask=column_mask,
985
missing_index='nan', missing_columns='nan', columns=['feat0', 'feat1', 'feat2'])
986
987
stats_index = pd.Index([
988
'Start', 'End', 'Period', 'Total Symbols', 'Null Counts: 0', 'Null Counts: 1'
989
], dtype='object')
990
pd.testing.assert_series_equal(
991
data.stats(),
992
pd.Series([
993
pd.Timestamp('2020-01-01 00:00:00+0000', tz='UTC'),
994
pd.Timestamp('2020-01-05 00:00:00+0000', tz='UTC'),
995
pd.Timedelta('5 days 00:00:00'),
996
2, 2.3333333333333335, 2.3333333333333335
997
],
998
index=stats_index,
999
name='agg_func_mean'
1000
)
1001
)
1002
pd.testing.assert_series_equal(
1003
data.stats(column='feat0'),
1004
pd.Series([
1005
pd.Timestamp('2020-01-01 00:00:00+0000', tz='UTC'),
1006
pd.Timestamp('2020-01-05 00:00:00+0000', tz='UTC'),
1007
pd.Timedelta('5 days 00:00:00'),
1008
2, 5, 1
1009
],
1010
index=stats_index,
1011
name='feat0'
1012
)
1013
)
1014
pd.testing.assert_series_equal(
1015
data.stats(group_by=True),
1016
pd.Series([
1017
pd.Timestamp('2020-01-01 00:00:00+0000', tz='UTC'),
1018
pd.Timestamp('2020-01-05 00:00:00+0000', tz='UTC'),
1019
pd.Timedelta('5 days 00:00:00'),
1020
2, 7, 7
1021
],
1022
index=stats_index,
1023
name='group'
1024
)
1025
)
1026
pd.testing.assert_series_equal(
1027
data['feat0'].stats(),
1028
data.stats(column='feat0')
1029
)
1030
pd.testing.assert_series_equal(
1031
data.replace(wrapper=data.wrapper.replace(group_by=True)).stats(),
1032
data.stats(group_by=True)
1033
)
1034
stats_df = data.stats(agg_func=None)
1035
assert stats_df.shape == (3, 6)
1036
pd.testing.assert_index_equal(stats_df.index, data.wrapper.columns)
1037
pd.testing.assert_index_equal(stats_df.columns, stats_index)
1038
1039
1040
# ############# updater.py ############# #
1041
1042
class TestDataUpdater:
1043
def test_update(self):
1044
data = MyData.download(0, shape=(5,), return_arr=True)
1045
updater = vbt.DataUpdater(data)
1046
updater.update()
1047
assert updater.data == data.update()
1048
assert updater.config['data'] == data.update()
1049
1050
def test_update_every(self):
1051
data = MyData.download(0, shape=(5,), return_arr=True)
1052
kwargs = dict(call_count=0)
1053
1054
class DataUpdater(vbt.DataUpdater):
1055
def update(self, kwargs):
1056
super().update()
1057
kwargs['call_count'] += 1
1058
if kwargs['call_count'] == 5:
1059
raise vbt.CancelledError
1060
1061
updater = DataUpdater(data)
1062
updater.update_every(kwargs=kwargs)
1063
for i in range(5):
1064
data = data.update()
1065
assert updater.data == data
1066
assert updater.config['data'] == data
1067
1068