Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/series/ops/fused.rs
6939 views
1
use arrow::array::PrimitiveArray;
2
use arrow::compute::utils::combine_validities_and3;
3
use polars_core::prelude::*;
4
use polars_core::utils::align_chunks_ternary;
5
use polars_core::with_match_physical_numeric_polars_type;
6
7
// a + (b * c)
8
fn fma_arr<T: NumericNative>(
9
a: &PrimitiveArray<T>,
10
b: &PrimitiveArray<T>,
11
c: &PrimitiveArray<T>,
12
) -> PrimitiveArray<T> {
13
assert_eq!(a.len(), b.len());
14
let validity = combine_validities_and3(a.validity(), b.validity(), c.validity());
15
let a = a.values().as_slice();
16
let b = b.values().as_slice();
17
let c = c.values().as_slice();
18
19
assert_eq!(a.len(), b.len());
20
assert_eq!(b.len(), c.len());
21
let out = a
22
.iter()
23
.zip(b.iter())
24
.zip(c.iter())
25
.map(|((a, b), c)| *a + (*b * *c))
26
.collect::<Vec<_>>();
27
PrimitiveArray::from_data_default(out.into(), validity)
28
}
29
30
fn fma_ca<T: PolarsNumericType>(
31
a: &ChunkedArray<T>,
32
b: &ChunkedArray<T>,
33
c: &ChunkedArray<T>,
34
) -> ChunkedArray<T> {
35
let (a, b, c) = align_chunks_ternary(a, b, c);
36
let chunks = a
37
.downcast_iter()
38
.zip(b.downcast_iter())
39
.zip(c.downcast_iter())
40
.map(|((a, b), c)| fma_arr(a, b, c));
41
ChunkedArray::from_chunk_iter(a.name().clone(), chunks)
42
}
43
44
pub fn fma_columns(a: &Column, b: &Column, c: &Column) -> Column {
45
if a.len() == b.len() && a.len() == c.len() {
46
with_match_physical_numeric_polars_type!(a.dtype(), |$T| {
47
let a: &ChunkedArray<$T> = a.as_materialized_series().as_ref().as_ref().as_ref();
48
let b: &ChunkedArray<$T> = b.as_materialized_series().as_ref().as_ref().as_ref();
49
let c: &ChunkedArray<$T> = c.as_materialized_series().as_ref().as_ref().as_ref();
50
51
fma_ca(a, b, c).into_column()
52
})
53
} else {
54
(a.as_materialized_series()
55
+ &(b.as_materialized_series() * c.as_materialized_series()).unwrap())
56
.unwrap()
57
.into()
58
}
59
}
60
61
// a - (b * c)
62
fn fsm_arr<T: NumericNative>(
63
a: &PrimitiveArray<T>,
64
b: &PrimitiveArray<T>,
65
c: &PrimitiveArray<T>,
66
) -> PrimitiveArray<T> {
67
assert_eq!(a.len(), b.len());
68
let validity = combine_validities_and3(a.validity(), b.validity(), c.validity());
69
let a = a.values().as_slice();
70
let b = b.values().as_slice();
71
let c = c.values().as_slice();
72
73
assert_eq!(a.len(), b.len());
74
assert_eq!(b.len(), c.len());
75
let out = a
76
.iter()
77
.zip(b.iter())
78
.zip(c.iter())
79
.map(|((a, b), c)| *a - (*b * *c))
80
.collect::<Vec<_>>();
81
PrimitiveArray::from_data_default(out.into(), validity)
82
}
83
84
fn fsm_ca<T: PolarsNumericType>(
85
a: &ChunkedArray<T>,
86
b: &ChunkedArray<T>,
87
c: &ChunkedArray<T>,
88
) -> ChunkedArray<T> {
89
let (a, b, c) = align_chunks_ternary(a, b, c);
90
let chunks = a
91
.downcast_iter()
92
.zip(b.downcast_iter())
93
.zip(c.downcast_iter())
94
.map(|((a, b), c)| fsm_arr(a, b, c));
95
ChunkedArray::from_chunk_iter(a.name().clone(), chunks)
96
}
97
98
pub fn fsm_columns(a: &Column, b: &Column, c: &Column) -> Column {
99
if a.len() == b.len() && a.len() == c.len() {
100
with_match_physical_numeric_polars_type!(a.dtype(), |$T| {
101
let a: &ChunkedArray<$T> = a.as_materialized_series().as_ref().as_ref().as_ref();
102
let b: &ChunkedArray<$T> = b.as_materialized_series().as_ref().as_ref().as_ref();
103
let c: &ChunkedArray<$T> = c.as_materialized_series().as_ref().as_ref().as_ref();
104
105
fsm_ca(a, b, c).into_column()
106
})
107
} else {
108
(a.as_materialized_series()
109
- &(b.as_materialized_series() * c.as_materialized_series()).unwrap())
110
.unwrap()
111
.into()
112
}
113
}
114
115
fn fms_arr<T: NumericNative>(
116
a: &PrimitiveArray<T>,
117
b: &PrimitiveArray<T>,
118
c: &PrimitiveArray<T>,
119
) -> PrimitiveArray<T> {
120
assert_eq!(a.len(), b.len());
121
let validity = combine_validities_and3(a.validity(), b.validity(), c.validity());
122
let a = a.values().as_slice();
123
let b = b.values().as_slice();
124
let c = c.values().as_slice();
125
126
assert_eq!(a.len(), b.len());
127
assert_eq!(b.len(), c.len());
128
let out = a
129
.iter()
130
.zip(b.iter())
131
.zip(c.iter())
132
.map(|((a, b), c)| (*a * *b) - *c)
133
.collect::<Vec<_>>();
134
PrimitiveArray::from_data_default(out.into(), validity)
135
}
136
137
fn fms_ca<T: PolarsNumericType>(
138
a: &ChunkedArray<T>,
139
b: &ChunkedArray<T>,
140
c: &ChunkedArray<T>,
141
) -> ChunkedArray<T> {
142
let (a, b, c) = align_chunks_ternary(a, b, c);
143
let chunks = a
144
.downcast_iter()
145
.zip(b.downcast_iter())
146
.zip(c.downcast_iter())
147
.map(|((a, b), c)| fms_arr(a, b, c));
148
ChunkedArray::from_chunk_iter(a.name().clone(), chunks)
149
}
150
151
pub fn fms_columns(a: &Column, b: &Column, c: &Column) -> Column {
152
if a.len() == b.len() && a.len() == c.len() {
153
with_match_physical_numeric_polars_type!(a.dtype(), |$T| {
154
let a: &ChunkedArray<$T> = a.as_materialized_series().as_ref().as_ref().as_ref();
155
let b: &ChunkedArray<$T> = b.as_materialized_series().as_ref().as_ref().as_ref();
156
let c: &ChunkedArray<$T> = c.as_materialized_series().as_ref().as_ref().as_ref();
157
158
fms_ca(a, b, c).into_column()
159
})
160
} else {
161
(&(a.as_materialized_series() * b.as_materialized_series()).unwrap()
162
- c.as_materialized_series())
163
.unwrap()
164
.into()
165
}
166
}
167
168