Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-arrow/src/legacy/kernels/ewm/variance.rs
6940 views
1
use std::ops::{AddAssign, DivAssign, MulAssign};
2
3
use num_traits::Float;
4
5
use crate::array::PrimitiveArray;
6
use crate::legacy::utils::CustomIterTools;
7
use crate::trusted_len::TrustedLen;
8
use crate::types::NativeType;
9
10
#[allow(clippy::too_many_arguments)]
11
fn ewm_cov_internal<I, T>(
12
xs: I,
13
ys: I,
14
alpha: T,
15
adjust: bool,
16
bias: bool,
17
min_periods: usize,
18
ignore_nulls: bool,
19
do_sqrt: bool,
20
) -> PrimitiveArray<T>
21
where
22
I: IntoIterator<Item = Option<T>>,
23
I::IntoIter: TrustedLen,
24
T: Float + NativeType + AddAssign + MulAssign + DivAssign,
25
{
26
let old_wt_factor = T::one() - alpha;
27
let new_wt = if adjust { T::one() } else { alpha };
28
let mut sum_wt = T::one();
29
let mut sum_wt2 = T::one();
30
let mut old_wt = T::one();
31
32
let mut opt_mean_x = None;
33
let mut opt_mean_y = None;
34
let mut cov = T::zero();
35
let mut non_na_cnt = 0usize;
36
let min_periods_fixed = if min_periods == 0 { 1 } else { min_periods };
37
38
let res = xs
39
.into_iter()
40
.zip(ys)
41
.enumerate()
42
.map(|(i, (opt_x, opt_y))| {
43
let is_observation = opt_x.is_some() && opt_y.is_some();
44
if is_observation {
45
non_na_cnt += 1;
46
}
47
match (i, opt_mean_x, opt_mean_y) {
48
(0, _, _) => {
49
if is_observation {
50
opt_mean_x = opt_x;
51
opt_mean_y = opt_y;
52
}
53
},
54
(_, Some(mean_x), Some(mean_y)) => {
55
if is_observation || !ignore_nulls {
56
sum_wt *= old_wt_factor;
57
sum_wt2 *= old_wt_factor * old_wt_factor;
58
old_wt *= old_wt_factor;
59
if is_observation {
60
let x = opt_x.unwrap();
61
let y = opt_y.unwrap();
62
let old_mean_x = mean_x;
63
let old_mean_y = mean_y;
64
65
// avoid numerical errors on constant series
66
if mean_x != x {
67
opt_mean_x =
68
Some((old_wt * old_mean_x + new_wt * x) / (old_wt + new_wt));
69
}
70
71
// avoid numerical errors on constant series
72
if mean_y != y {
73
opt_mean_y =
74
Some((old_wt * old_mean_y + new_wt * y) / (old_wt + new_wt));
75
}
76
77
cov = ((old_wt
78
* (cov
79
+ ((old_mean_x - opt_mean_x.unwrap())
80
* (old_mean_y - opt_mean_y.unwrap()))))
81
+ (new_wt
82
* ((x - opt_mean_x.unwrap()) * (y - opt_mean_y.unwrap()))))
83
/ (old_wt + new_wt);
84
85
sum_wt += new_wt;
86
sum_wt2 += new_wt * new_wt;
87
old_wt += new_wt;
88
if !adjust {
89
sum_wt /= old_wt;
90
sum_wt2 /= old_wt * old_wt;
91
old_wt = T::one();
92
}
93
}
94
}
95
},
96
_ => {
97
if is_observation {
98
opt_mean_x = opt_x;
99
opt_mean_y = opt_y;
100
}
101
},
102
}
103
match (non_na_cnt >= min_periods_fixed, bias, is_observation) {
104
(_, _, false) => None,
105
(false, _, true) => None,
106
(true, false, true) => {
107
if non_na_cnt == 1 {
108
Some(cov)
109
} else {
110
let numerator = sum_wt * sum_wt;
111
let denominator = numerator - sum_wt2;
112
if denominator > T::zero() {
113
Some((numerator / denominator) * cov)
114
} else {
115
None
116
}
117
}
118
},
119
(true, true, true) => Some(cov),
120
}
121
});
122
123
if do_sqrt {
124
res.map(|opt_x| opt_x.map(|x| x.sqrt())).collect_trusted()
125
} else {
126
res.collect_trusted()
127
}
128
}
129
130
pub fn ewm_cov<I, T>(
131
xs: I,
132
ys: I,
133
alpha: T,
134
adjust: bool,
135
bias: bool,
136
min_periods: usize,
137
ignore_nulls: bool,
138
) -> PrimitiveArray<T>
139
where
140
I: IntoIterator<Item = Option<T>>,
141
I::IntoIter: TrustedLen,
142
T: Float + NativeType + AddAssign + MulAssign + DivAssign,
143
{
144
ewm_cov_internal(
145
xs,
146
ys,
147
alpha,
148
adjust,
149
bias,
150
min_periods,
151
ignore_nulls,
152
false,
153
)
154
}
155
156
pub fn ewm_var<I, T>(
157
xs: I,
158
alpha: T,
159
adjust: bool,
160
bias: bool,
161
min_periods: usize,
162
ignore_nulls: bool,
163
) -> PrimitiveArray<T>
164
where
165
I: IntoIterator<Item = Option<T>> + Clone,
166
I::IntoIter: TrustedLen,
167
T: Float + NativeType + AddAssign + MulAssign + DivAssign,
168
{
169
ewm_cov_internal(
170
xs.clone(),
171
xs,
172
alpha,
173
adjust,
174
bias,
175
min_periods,
176
ignore_nulls,
177
false,
178
)
179
}
180
181
pub fn ewm_std<I, T>(
182
xs: I,
183
alpha: T,
184
adjust: bool,
185
bias: bool,
186
min_periods: usize,
187
ignore_nulls: bool,
188
) -> PrimitiveArray<T>
189
where
190
I: IntoIterator<Item = Option<T>> + Clone,
191
I::IntoIter: TrustedLen,
192
T: Float + NativeType + AddAssign + MulAssign + DivAssign,
193
{
194
ewm_cov_internal(
195
xs.clone(),
196
xs,
197
alpha,
198
adjust,
199
bias,
200
min_periods,
201
ignore_nulls,
202
true,
203
)
204
}
205
206
#[cfg(test)]
207
mod test {
208
use super::super::assert_allclose;
209
use super::*;
210
const ALPHA: f64 = 0.5;
211
const EPS: f64 = 1e-15;
212
use std::f64::consts::SQRT_2;
213
214
const XS: [Option<f64>; 7] = [
215
Some(1.0),
216
Some(5.0),
217
Some(7.0),
218
Some(1.0),
219
Some(2.0),
220
Some(1.0),
221
Some(4.0),
222
];
223
const YS: [Option<f64>; 7] = [None, Some(5.0), Some(7.0), None, None, Some(1.0), Some(4.0)];
224
225
#[test]
226
fn test_ewm_var() {
227
assert_allclose!(
228
ewm_var(XS.to_vec(), ALPHA, true, true, 0, true),
229
PrimitiveArray::from([
230
Some(0.0),
231
Some(3.555_555_555_555_556),
232
Some(4.244_897_959_183_674),
233
Some(7.182_222_222_222_221),
234
Some(3.796_045_785_639_958),
235
Some(2.467_120_181_405_896),
236
Some(2.476_036_952_073_904_3),
237
]),
238
EPS
239
);
240
assert_allclose!(
241
ewm_var(XS.to_vec(), ALPHA, true, true, 0, false),
242
PrimitiveArray::from([
243
Some(0.0),
244
Some(3.555_555_555_555_556),
245
Some(4.244_897_959_183_674),
246
Some(7.182_222_222_222_221),
247
Some(3.796_045_785_639_958),
248
Some(2.467_120_181_405_896),
249
Some(2.476_036_952_073_904_3),
250
]),
251
EPS
252
);
253
assert_allclose!(
254
ewm_var(XS.to_vec(), ALPHA, true, false, 0, true),
255
PrimitiveArray::from([
256
Some(0.0),
257
Some(8.0),
258
Some(7.428_571_428_571_429),
259
Some(11.542_857_142_857_143),
260
Some(5.883_870_967_741_934_5),
261
Some(3.760_368_663_594_470_6),
262
Some(3.743_532_058_492_688_6),
263
]),
264
EPS
265
);
266
assert_allclose!(
267
ewm_var(XS.to_vec(), ALPHA, true, false, 0, false),
268
PrimitiveArray::from([
269
Some(0.0),
270
Some(8.0),
271
Some(7.428_571_428_571_429),
272
Some(11.542_857_142_857_143),
273
Some(5.883_870_967_741_934_5),
274
Some(3.760_368_663_594_470_6),
275
Some(3.743_532_058_492_688_6),
276
]),
277
EPS
278
);
279
assert_allclose!(
280
ewm_var(XS.to_vec(), ALPHA, false, true, 0, true),
281
PrimitiveArray::from([
282
Some(0.0),
283
Some(4.0),
284
Some(6.0),
285
Some(7.0),
286
Some(3.75),
287
Some(2.437_5),
288
Some(2.484_375),
289
]),
290
EPS
291
);
292
assert_allclose!(
293
ewm_var(XS.to_vec(), ALPHA, false, true, 0, false),
294
PrimitiveArray::from([
295
Some(0.0),
296
Some(4.0),
297
Some(6.0),
298
Some(7.0),
299
Some(3.75),
300
Some(2.437_5),
301
Some(2.484_375),
302
]),
303
EPS
304
);
305
assert_allclose!(
306
ewm_var(XS.to_vec(), ALPHA, false, true, 0, false),
307
PrimitiveArray::from([
308
Some(0.0),
309
Some(4.0),
310
Some(6.0),
311
Some(7.0),
312
Some(3.75),
313
Some(2.437_5),
314
Some(2.484_375),
315
]),
316
EPS
317
);
318
assert_allclose!(
319
ewm_var(XS.to_vec(), ALPHA, false, false, 0, true),
320
PrimitiveArray::from([
321
Some(0.0),
322
Some(8.0),
323
Some(9.600_000_000_000_001),
324
Some(10.666_666_666_666_666),
325
Some(5.647_058_823_529_411),
326
Some(3.659_824_046_920_821),
327
Some(3.727_472_527_472_527_6),
328
]),
329
EPS
330
);
331
assert_allclose!(
332
ewm_var(XS.to_vec(), ALPHA, false, false, 0, false),
333
PrimitiveArray::from([
334
Some(0.0),
335
Some(8.0),
336
Some(9.600_000_000_000_001),
337
Some(10.666_666_666_666_666),
338
Some(5.647_058_823_529_411),
339
Some(3.659_824_046_920_821),
340
Some(3.727_472_527_472_527_6),
341
]),
342
EPS
343
);
344
assert_allclose!(
345
ewm_var(YS.to_vec(), ALPHA, true, true, 0, true),
346
PrimitiveArray::from([
347
None,
348
Some(0.0),
349
Some(0.888_888_888_888_889),
350
None,
351
None,
352
Some(7.346_938_775_510_203),
353
Some(3.555_555_555_555_555_4),
354
]),
355
EPS
356
);
357
assert_allclose!(
358
ewm_var(YS.to_vec(), ALPHA, true, true, 0, false),
359
PrimitiveArray::from([
360
None,
361
Some(0.0),
362
Some(0.888_888_888_888_889),
363
None,
364
None,
365
Some(3.922_437_673_130_193_3),
366
Some(2.549_788_542_868_127_3),
367
]),
368
EPS
369
);
370
assert_allclose!(
371
ewm_var(YS.to_vec(), ALPHA, true, false, 0, true),
372
PrimitiveArray::from([
373
None,
374
Some(0.0),
375
Some(2.0),
376
None,
377
None,
378
Some(12.857_142_857_142_856),
379
Some(5.714_285_714_285_714),
380
]),
381
EPS
382
);
383
assert_allclose!(
384
ewm_var(YS.to_vec(), ALPHA, true, false, 0, false),
385
PrimitiveArray::from([
386
None,
387
Some(0.0),
388
Some(2.0),
389
None,
390
None,
391
Some(14.159_999_999_999_997),
392
Some(5.039_513_677_811_549_5),
393
]),
394
EPS
395
);
396
assert_allclose!(
397
ewm_var(YS.to_vec(), ALPHA, false, true, 0, true),
398
PrimitiveArray::from([
399
None,
400
Some(0.0),
401
Some(1.0),
402
None,
403
None,
404
Some(6.75),
405
Some(3.437_5),
406
]),
407
EPS
408
);
409
assert_allclose!(
410
ewm_var(YS.to_vec(), ALPHA, false, true, 0, false),
411
PrimitiveArray::from([None, Some(0.0), Some(1.0), None, None, Some(4.2), Some(3.1)]),
412
EPS
413
);
414
assert_allclose!(
415
ewm_var(YS.to_vec(), ALPHA, false, false, 0, true),
416
PrimitiveArray::from([
417
None,
418
Some(0.0),
419
Some(2.0),
420
None,
421
None,
422
Some(10.8),
423
Some(5.238_095_238_095_238),
424
]),
425
EPS
426
);
427
assert_allclose!(
428
ewm_var(YS.to_vec(), ALPHA, false, false, 0, false),
429
PrimitiveArray::from([
430
None,
431
Some(0.0),
432
Some(2.0),
433
None,
434
None,
435
Some(12.352_941_176_470_589),
436
Some(5.299_145_299_145_3),
437
]),
438
EPS
439
);
440
}
441
442
#[test]
443
fn test_ewm_cov() {
444
assert_allclose!(
445
ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, true, true, 0, true),
446
PrimitiveArray::from([
447
None,
448
Some(0.0),
449
Some(0.888_888_888_888_889),
450
None,
451
None,
452
Some(7.346_938_775_510_203),
453
Some(3.555_555_555_555_555_4)
454
]),
455
EPS
456
);
457
assert_allclose!(
458
ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, true, true, 0, false),
459
PrimitiveArray::from([
460
None,
461
Some(0.0),
462
Some(0.888_888_888_888_889),
463
None,
464
None,
465
Some(3.922_437_673_130_193_3),
466
Some(2.549_788_542_868_127_3)
467
]),
468
EPS
469
);
470
assert_allclose!(
471
ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, true, false, 0, true),
472
PrimitiveArray::from([
473
None,
474
Some(0.0),
475
Some(2.0),
476
None,
477
None,
478
Some(12.857_142_857_142_856),
479
Some(5.714_285_714_285_714)
480
]),
481
EPS
482
);
483
assert_allclose!(
484
ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, true, false, 0, false),
485
PrimitiveArray::from([
486
None,
487
Some(0.0),
488
Some(2.0),
489
None,
490
None,
491
Some(14.159_999_999_999_997),
492
Some(5.039_513_677_811_549_5)
493
]),
494
EPS
495
);
496
assert_allclose!(
497
ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, true, 0, true),
498
PrimitiveArray::from([
499
None,
500
Some(0.0),
501
Some(1.0),
502
None,
503
None,
504
Some(6.75),
505
Some(3.437_5)
506
]),
507
EPS
508
);
509
assert_allclose!(
510
ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, true, 0, false),
511
PrimitiveArray::from([None, Some(0.0), Some(1.0), None, None, Some(4.2), Some(3.1)]),
512
EPS
513
);
514
assert_allclose!(
515
ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, false, 0, true),
516
PrimitiveArray::from([
517
None,
518
Some(0.0),
519
Some(2.0),
520
None,
521
None,
522
Some(10.8),
523
Some(5.238_095_238_095_238)
524
]),
525
EPS
526
);
527
assert_allclose!(
528
ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, false, 0, false),
529
PrimitiveArray::from([
530
None,
531
Some(0.0),
532
Some(2.0),
533
None,
534
None,
535
Some(12.352_941_176_470_589),
536
Some(5.299_145_299_145_3)
537
]),
538
EPS
539
);
540
}
541
542
#[test]
543
fn test_ewm_std() {
544
assert_allclose!(
545
ewm_std(XS.to_vec(), ALPHA, true, true, 0, true),
546
PrimitiveArray::from([
547
Some(0.0),
548
Some(1.885_618_083_164_126_7),
549
Some(2.060_315_014_550_851_3),
550
Some(2.679_966_832_298_904),
551
Some(1.948_344_370_392_451_5),
552
Some(1.570_706_904_997_204_2),
553
Some(1.573_542_802_746_053_2),
554
]),
555
EPS
556
);
557
assert_allclose!(
558
ewm_std(XS.to_vec(), ALPHA, true, true, 0, false),
559
PrimitiveArray::from([
560
Some(0.0),
561
Some(1.885_618_083_164_126_7),
562
Some(2.060_315_014_550_851_3),
563
Some(2.679_966_832_298_904),
564
Some(1.948_344_370_392_451_5),
565
Some(1.570_706_904_997_204_2),
566
Some(1.573_542_802_746_053_2),
567
]),
568
EPS
569
);
570
assert_allclose!(
571
ewm_std(XS.to_vec(), ALPHA, true, false, 0, true),
572
PrimitiveArray::from([
573
Some(0.0),
574
Some(2.828_427_124_746_190_3),
575
Some(2.725_540_575_476_987_5),
576
Some(3.397_478_056_273_085_3),
577
Some(2.425_669_179_369_259),
578
Some(1.939_167_002_502_484_5),
579
Some(1.934_820_937_061_796_6),
580
]),
581
EPS
582
);
583
assert_allclose!(
584
ewm_std(XS.to_vec(), ALPHA, true, false, 0, false),
585
PrimitiveArray::from([
586
Some(0.0),
587
Some(2.828_427_124_746_190_3),
588
Some(2.725_540_575_476_987_5),
589
Some(3.397_478_056_273_085_3),
590
Some(2.425_669_179_369_259),
591
Some(1.939_167_002_502_484_5),
592
Some(1.934_820_937_061_796_6),
593
]),
594
EPS
595
);
596
assert_allclose!(
597
ewm_std(XS.to_vec(), ALPHA, false, true, 0, true),
598
PrimitiveArray::from([
599
Some(0.0),
600
Some(2.0),
601
Some(2.449_489_742_783_178),
602
Some(2.645_751_311_064_590_7),
603
Some(1.936_491_673_103_708_5),
604
Some(1.561_249_499_599_599_6),
605
Some(1.576_190_026_614_811_4),
606
]),
607
EPS
608
);
609
assert_allclose!(
610
ewm_std(XS.to_vec(), ALPHA, false, true, 0, false),
611
PrimitiveArray::from([
612
Some(0.0),
613
Some(2.0),
614
Some(2.449_489_742_783_178),
615
Some(2.645_751_311_064_590_7),
616
Some(1.936_491_673_103_708_5),
617
Some(1.561_249_499_599_599_6),
618
Some(1.576_190_026_614_811_4),
619
]),
620
EPS
621
);
622
assert_allclose!(
623
ewm_std(XS.to_vec(), ALPHA, false, false, 0, true),
624
PrimitiveArray::from([
625
Some(0.0),
626
Some(2.828_427_124_746_190_3),
627
Some(3.098_386_676_965_933_6),
628
Some(3.265_986_323_710_904),
629
Some(2.376_354_103_144_018_3),
630
Some(1.913_066_660_344_281_2),
631
Some(1.930_666_342_865_210_7),
632
]),
633
EPS
634
);
635
assert_allclose!(
636
ewm_std(XS.to_vec(), ALPHA, false, false, 0, false),
637
PrimitiveArray::from([
638
Some(0.0),
639
Some(2.828_427_124_746_190_3),
640
Some(3.098_386_676_965_933_6),
641
Some(3.265_986_323_710_904),
642
Some(2.376_354_103_144_018_3),
643
Some(1.913_066_660_344_281_2),
644
Some(1.930_666_342_865_210_7),
645
]),
646
EPS
647
);
648
assert_allclose!(
649
ewm_std(YS.to_vec(), ALPHA, true, true, 0, true),
650
PrimitiveArray::from([
651
None,
652
Some(0.0),
653
Some(0.942_809_041_582_063_4),
654
None,
655
None,
656
Some(2.710_523_708_715_753_4),
657
Some(1.885_618_083_164_126_7),
658
]),
659
EPS
660
);
661
assert_allclose!(
662
ewm_std(YS.to_vec(), ALPHA, true, true, 0, false),
663
PrimitiveArray::from([
664
None,
665
Some(0.0),
666
Some(0.942_809_041_582_063_4),
667
None,
668
None,
669
Some(1.980_514_497_076_503),
670
Some(1.596_805_731_098_222),
671
]),
672
EPS
673
);
674
assert_allclose!(
675
ewm_std(YS.to_vec(), ALPHA, true, false, 0, true),
676
PrimitiveArray::from([
677
None,
678
Some(0.0),
679
Some(SQRT_2),
680
None,
681
None,
682
Some(3.585_685_828_003_181),
683
Some(2.390_457_218_668_787),
684
]),
685
EPS
686
);
687
assert_allclose!(
688
ewm_std(YS.to_vec(), ALPHA, true, false, 0, false),
689
PrimitiveArray::from([
690
None,
691
Some(0.0),
692
Some(SQRT_2),
693
None,
694
None,
695
Some(3.762_977_544_445_355_3),
696
Some(2.244_886_116_891_356),
697
]),
698
EPS
699
);
700
assert_allclose!(
701
ewm_std(YS.to_vec(), ALPHA, false, true, 0, true),
702
PrimitiveArray::from([
703
None,
704
Some(0.0),
705
Some(1.0),
706
None,
707
None,
708
Some(2.598_076_211_353_316),
709
Some(1.854_049_621_773_915_7),
710
]),
711
EPS
712
);
713
assert_allclose!(
714
ewm_std(YS.to_vec(), ALPHA, false, true, 0, false),
715
PrimitiveArray::from([
716
None,
717
Some(0.0),
718
Some(1.0),
719
None,
720
None,
721
Some(2.049_390_153_191_92),
722
Some(1.760_681_686_165_901),
723
]),
724
EPS
725
);
726
assert_allclose!(
727
ewm_std(YS.to_vec(), ALPHA, false, false, 0, true),
728
PrimitiveArray::from([
729
None,
730
Some(0.0),
731
Some(SQRT_2),
732
None,
733
None,
734
Some(3.286_335_345_030_997),
735
Some(2.288_688_541_085_317_5),
736
]),
737
EPS
738
);
739
assert_allclose!(
740
ewm_std(YS.to_vec(), ALPHA, false, false, 0, false),
741
PrimitiveArray::from([
742
None,
743
Some(0.0),
744
Some(SQRT_2),
745
None,
746
None,
747
Some(3.514_675_116_774_036_7),
748
Some(2.301_987_249_996_250_4),
749
]),
750
EPS
751
);
752
}
753
754
#[test]
755
fn test_ewm_min_periods() {
756
assert_allclose!(
757
ewm_var(YS.to_vec(), ALPHA, true, true, 0, false),
758
PrimitiveArray::from([
759
None,
760
Some(0.0),
761
Some(0.888_888_888_888_889),
762
None,
763
None,
764
Some(3.922_437_673_130_193_3),
765
Some(2.549_788_542_868_127_3),
766
]),
767
EPS
768
);
769
assert_allclose!(
770
ewm_var(YS.to_vec(), ALPHA, true, true, 1, false),
771
PrimitiveArray::from([
772
None,
773
Some(0.0),
774
Some(0.888_888_888_888_889),
775
None,
776
None,
777
Some(3.922_437_673_130_193_3),
778
Some(2.549_788_542_868_127_3),
779
]),
780
EPS
781
);
782
assert_allclose!(
783
ewm_var(YS.to_vec(), ALPHA, true, true, 2, false),
784
PrimitiveArray::from([
785
None,
786
None,
787
Some(0.888_888_888_888_889),
788
None,
789
None,
790
Some(3.922_437_673_130_193_3),
791
Some(2.549_788_542_868_127_3),
792
]),
793
EPS
794
);
795
}
796
}
797
798