Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/float_sum.rs
8406 views
1
use std::ops::{Add, IndexMut};
2
#[cfg(feature = "simd")]
3
use std::simd::{prelude::*, *};
4
5
use arrow::array::{Array, PrimitiveArray};
6
use arrow::bitmap::Bitmap;
7
use arrow::bitmap::bitmask::BitMask;
8
use arrow::types::NativeType;
9
use num_traits::{AsPrimitive, Float};
10
#[cfg(feature = "simd")]
11
use polars_utils::float16::pf16;
12
13
const STRIPE: usize = 16;
14
const PAIRWISE_RECURSION_LIMIT: usize = 128;
15
16
// We want to be generic over both integers and floats, requiring this helper trait.
17
#[cfg(feature = "simd")]
18
pub trait SimdCastGeneric<const N: usize>
19
where
20
LaneCount<N>: SupportedLaneCount,
21
{
22
fn cast_generic<U: SimdCast>(self) -> Simd<U, N>;
23
}
24
25
macro_rules! impl_cast_custom {
26
($_type:ty) => {
27
#[cfg(feature = "simd")]
28
impl<const N: usize> SimdCastGeneric<N> for Simd<$_type, N>
29
where
30
LaneCount<N>: SupportedLaneCount,
31
{
32
fn cast_generic<U: SimdCast>(self) -> Simd<U, N> {
33
self.cast::<U>()
34
}
35
}
36
};
37
}
38
39
impl_cast_custom!(u8);
40
impl_cast_custom!(u16);
41
impl_cast_custom!(u32);
42
impl_cast_custom!(u64);
43
impl_cast_custom!(i8);
44
impl_cast_custom!(i16);
45
impl_cast_custom!(i32);
46
impl_cast_custom!(i64);
47
impl_cast_custom!(f32);
48
impl_cast_custom!(f64);
49
50
fn vector_horizontal_sum<V, T>(mut v: V) -> T
51
where
52
V: IndexMut<usize, Output = T>,
53
T: Add<T, Output = T> + Sized + Copy,
54
{
55
// We have to be careful about this reduction, floating
56
// point math is NOT associative so we have to write this
57
// in a form that maps to good shuffle instructions.
58
// We fold the vector onto itself, halved, until we are down to
59
// four elements which we add in a shuffle-friendly way.
60
let mut width = STRIPE;
61
while width > 4 {
62
for j in 0..width / 2 {
63
v[j] = v[j] + v[width / 2 + j];
64
}
65
width /= 2;
66
}
67
68
(v[0] + v[2]) + (v[1] + v[3])
69
}
70
71
// As a trait to not proliferate SIMD bounds.
72
pub trait SumBlock<F> {
73
fn sum_block_vectorized(&self) -> F;
74
fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F;
75
}
76
77
#[cfg(feature = "simd")]
78
impl<T, F> SumBlock<F> for [T; PAIRWISE_RECURSION_LIMIT]
79
where
80
T: SimdElement,
81
F: SimdElement + SimdCast + Add<Output = F> + Default,
82
Simd<T, STRIPE>: SimdCastGeneric<STRIPE>,
83
Simd<F, STRIPE>: std::iter::Sum,
84
{
85
fn sum_block_vectorized(&self) -> F {
86
let vsum = self
87
.chunks_exact(STRIPE)
88
.map(|a| Simd::<T, STRIPE>::from_slice(a).cast_generic::<F>())
89
.sum::<Simd<F, STRIPE>>();
90
vector_horizontal_sum(vsum)
91
}
92
93
fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F {
94
let zero = Simd::default();
95
let vsum = self
96
.chunks_exact(STRIPE)
97
.enumerate()
98
.map(|(i, a)| {
99
let m: Mask<_, STRIPE> = mask.get_simd(i * STRIPE);
100
m.select(Simd::from_slice(a).cast_generic::<F>(), zero)
101
})
102
.sum::<Simd<F, STRIPE>>();
103
vector_horizontal_sum(vsum)
104
}
105
}
106
107
#[cfg(feature = "simd")]
108
impl<F> SumBlock<F> for [i128; PAIRWISE_RECURSION_LIMIT]
109
where
110
i128: AsPrimitive<F>,
111
F: Float + std::iter::Sum + 'static,
112
{
113
fn sum_block_vectorized(&self) -> F {
114
self.iter().map(|x| x.as_()).sum()
115
}
116
117
fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F {
118
self.iter()
119
.enumerate()
120
.map(|(idx, x)| if mask.get(idx) { x.as_() } else { F::zero() })
121
.sum()
122
}
123
}
124
125
#[cfg(feature = "simd")]
126
impl<F> SumBlock<F> for [u128; PAIRWISE_RECURSION_LIMIT]
127
where
128
u128: AsPrimitive<F>,
129
F: Float + std::iter::Sum + 'static,
130
{
131
fn sum_block_vectorized(&self) -> F {
132
self.iter().map(|x| x.as_()).sum()
133
}
134
135
fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F {
136
self.iter()
137
.enumerate()
138
.map(|(idx, x)| if mask.get(idx) { x.as_() } else { F::zero() })
139
.sum()
140
}
141
}
142
143
#[cfg(feature = "simd")]
144
impl<F> SumBlock<F> for [pf16; PAIRWISE_RECURSION_LIMIT]
145
where
146
pf16: AsPrimitive<F>,
147
F: Float + std::iter::Sum + 'static,
148
{
149
fn sum_block_vectorized(&self) -> F {
150
self.iter().map(|x| x.as_()).sum()
151
}
152
153
fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F {
154
self.iter()
155
.enumerate()
156
.map(|(idx, x)| if mask.get(idx) { x.as_() } else { F::zero() })
157
.sum()
158
}
159
}
160
161
#[cfg(not(feature = "simd"))]
162
impl<T, F> SumBlock<F> for [T; PAIRWISE_RECURSION_LIMIT]
163
where
164
T: AsPrimitive<F> + 'static,
165
F: Default + Add<Output = F> + Copy + 'static,
166
{
167
fn sum_block_vectorized(&self) -> F {
168
let mut vsum = [F::default(); STRIPE];
169
for chunk in self.chunks_exact(STRIPE) {
170
for j in 0..STRIPE {
171
vsum[j] = vsum[j] + chunk[j].as_();
172
}
173
}
174
vector_horizontal_sum(vsum)
175
}
176
177
fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F {
178
let mut vsum = [F::default(); STRIPE];
179
for (i, chunk) in self.chunks_exact(STRIPE).enumerate() {
180
for j in 0..STRIPE {
181
// Unconditional add with select for better branch-free opts.
182
let addend = if mask.get(i * STRIPE + j) {
183
chunk[j].as_()
184
} else {
185
F::default()
186
};
187
vsum[j] = vsum[j] + addend;
188
}
189
}
190
vector_horizontal_sum(vsum)
191
}
192
}
193
194
/// Invariant: f.len() % PAIRWISE_RECURSION_LIMIT == 0 and f.len() > 0.
195
unsafe fn pairwise_sum<F, T>(f: &[T]) -> F
196
where
197
[T; PAIRWISE_RECURSION_LIMIT]: SumBlock<F>,
198
F: Add<Output = F>,
199
{
200
debug_assert!(!f.is_empty() && f.len().is_multiple_of(PAIRWISE_RECURSION_LIMIT));
201
202
let block: Option<&[T; PAIRWISE_RECURSION_LIMIT]> = f.try_into().ok();
203
if let Some(block) = block {
204
return block.sum_block_vectorized();
205
}
206
207
// SAFETY: we maintain the invariant. `try_into` array of len PAIRWISE_RECURSION_LIMIT
208
// failed so we know f.len() >= 2*PAIRWISE_RECURSION_LIMIT, and thus blocks >= 2.
209
// This means 0 < left_len < f.len() and left_len is divisible by PAIRWISE_RECURSION_LIMIT,
210
// maintaining the invariant for both recursive calls.
211
unsafe {
212
let blocks = f.len() / PAIRWISE_RECURSION_LIMIT;
213
let left_len = (blocks / 2) * PAIRWISE_RECURSION_LIMIT;
214
let (left, right) = (f.get_unchecked(..left_len), f.get_unchecked(left_len..));
215
pairwise_sum(left) + pairwise_sum(right)
216
}
217
}
218
219
/// Invariant: f.len() % PAIRWISE_RECURSION_LIMIT == 0 and f.len() > 0.
220
/// Also, f.len() == mask.len().
221
unsafe fn pairwise_sum_with_mask<F, T>(f: &[T], mask: BitMask<'_>) -> F
222
where
223
[T; PAIRWISE_RECURSION_LIMIT]: SumBlock<F>,
224
F: Add<Output = F>,
225
{
226
debug_assert!(!f.is_empty() && f.len().is_multiple_of(PAIRWISE_RECURSION_LIMIT));
227
debug_assert!(f.len() == mask.len());
228
229
let block: Option<&[T; PAIRWISE_RECURSION_LIMIT]> = f.try_into().ok();
230
if let Some(block) = block {
231
return block.sum_block_vectorized_with_mask(mask);
232
}
233
234
// SAFETY: see pairwise_sum.
235
unsafe {
236
let blocks = f.len() / PAIRWISE_RECURSION_LIMIT;
237
let left_len = (blocks / 2) * PAIRWISE_RECURSION_LIMIT;
238
let (left, right) = (f.get_unchecked(..left_len), f.get_unchecked(left_len..));
239
let (left_mask, right_mask) = mask.split_at_unchecked(left_len);
240
pairwise_sum_with_mask(left, left_mask) + pairwise_sum_with_mask(right, right_mask)
241
}
242
}
243
244
pub trait FloatSum<F>: Sized {
245
fn sum(f: &[Self]) -> F;
246
fn sum_with_validity(f: &[Self], validity: &Bitmap) -> F;
247
}
248
249
impl<T, F> FloatSum<F> for T
250
where
251
F: Float + std::iter::Sum + 'static,
252
T: AsPrimitive<F>,
253
[T; PAIRWISE_RECURSION_LIMIT]: SumBlock<F>,
254
{
255
fn sum(f: &[Self]) -> F {
256
let remainder = f.len() % PAIRWISE_RECURSION_LIMIT;
257
let (rest, main) = f.split_at(remainder);
258
let mainsum = if f.len() > remainder {
259
unsafe { pairwise_sum(main) }
260
} else {
261
F::zero()
262
};
263
// TODO: faster remainder.
264
let restsum: F = rest.iter().map(|x| x.as_()).sum();
265
mainsum + restsum
266
}
267
268
fn sum_with_validity(f: &[Self], validity: &Bitmap) -> F {
269
let mask = BitMask::from_bitmap(validity);
270
assert!(f.len() == mask.len());
271
272
let remainder = f.len() % PAIRWISE_RECURSION_LIMIT;
273
let (rest, main) = f.split_at(remainder);
274
let (rest_mask, main_mask) = mask.split_at(remainder);
275
let mainsum = if f.len() > remainder {
276
unsafe { pairwise_sum_with_mask(main, main_mask) }
277
} else {
278
F::zero()
279
};
280
// TODO: faster remainder.
281
let restsum: F = rest
282
.iter()
283
.enumerate()
284
.map(|(i, x)| {
285
// No filter but rather select of 0.0 for cmov opt.
286
if rest_mask.get(i) { x.as_() } else { F::zero() }
287
})
288
.sum();
289
mainsum + restsum
290
}
291
}
292
293
pub fn sum_arr_as_f32<T>(arr: &PrimitiveArray<T>) -> f32
294
where
295
T: NativeType + FloatSum<f32>,
296
{
297
let validity = arr.validity().filter(|_| arr.null_count() > 0);
298
if let Some(mask) = validity {
299
FloatSum::sum_with_validity(arr.values(), mask)
300
} else {
301
FloatSum::sum(arr.values())
302
}
303
}
304
305
pub fn sum_arr_as_f64<T>(arr: &PrimitiveArray<T>) -> f64
306
where
307
T: NativeType + FloatSum<f64>,
308
{
309
let validity = arr.validity().filter(|_| arr.null_count() > 0);
310
if let Some(mask) = validity {
311
FloatSum::sum_with_validity(arr.values(), mask)
312
} else {
313
FloatSum::sum(arr.values())
314
}
315
}
316
317