Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/float_sum.rs
6939 views
1
use std::ops::{Add, IndexMut};
2
#[cfg(feature = "simd")]
3
use std::simd::{prelude::*, *};
4
5
use arrow::array::{Array, PrimitiveArray};
6
use arrow::bitmap::Bitmap;
7
use arrow::bitmap::bitmask::BitMask;
8
use arrow::types::NativeType;
9
use num_traits::{AsPrimitive, Float};
10
11
const STRIPE: usize = 16;
12
const PAIRWISE_RECURSION_LIMIT: usize = 128;
13
14
// We want to be generic over both integers and floats, requiring this helper trait.
15
#[cfg(feature = "simd")]
16
pub trait SimdCastGeneric<const N: usize>
17
where
18
LaneCount<N>: SupportedLaneCount,
19
{
20
fn cast_generic<U: SimdCast>(self) -> Simd<U, N>;
21
}
22
23
macro_rules! impl_cast_custom {
24
($_type:ty) => {
25
#[cfg(feature = "simd")]
26
impl<const N: usize> SimdCastGeneric<N> for Simd<$_type, N>
27
where
28
LaneCount<N>: SupportedLaneCount,
29
{
30
fn cast_generic<U: SimdCast>(self) -> Simd<U, N> {
31
self.cast::<U>()
32
}
33
}
34
};
35
}
36
37
impl_cast_custom!(u8);
38
impl_cast_custom!(u16);
39
impl_cast_custom!(u32);
40
impl_cast_custom!(u64);
41
impl_cast_custom!(i8);
42
impl_cast_custom!(i16);
43
impl_cast_custom!(i32);
44
impl_cast_custom!(i64);
45
impl_cast_custom!(f32);
46
impl_cast_custom!(f64);
47
48
fn vector_horizontal_sum<V, T>(mut v: V) -> T
49
where
50
V: IndexMut<usize, Output = T>,
51
T: Add<T, Output = T> + Sized + Copy,
52
{
53
// We have to be careful about this reduction, floating
54
// point math is NOT associative so we have to write this
55
// in a form that maps to good shuffle instructions.
56
// We fold the vector onto itself, halved, until we are down to
57
// four elements which we add in a shuffle-friendly way.
58
let mut width = STRIPE;
59
while width > 4 {
60
for j in 0..width / 2 {
61
v[j] = v[j] + v[width / 2 + j];
62
}
63
width /= 2;
64
}
65
66
(v[0] + v[2]) + (v[1] + v[3])
67
}
68
69
// As a trait to not proliferate SIMD bounds.
70
pub trait SumBlock<F> {
71
fn sum_block_vectorized(&self) -> F;
72
fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F;
73
}
74
75
#[cfg(feature = "simd")]
76
impl<T, F> SumBlock<F> for [T; PAIRWISE_RECURSION_LIMIT]
77
where
78
T: SimdElement,
79
F: SimdElement + SimdCast + Add<Output = F> + Default,
80
Simd<T, STRIPE>: SimdCastGeneric<STRIPE>,
81
Simd<F, STRIPE>: std::iter::Sum,
82
{
83
fn sum_block_vectorized(&self) -> F {
84
let vsum = self
85
.chunks_exact(STRIPE)
86
.map(|a| Simd::<T, STRIPE>::from_slice(a).cast_generic::<F>())
87
.sum::<Simd<F, STRIPE>>();
88
vector_horizontal_sum(vsum)
89
}
90
91
fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F {
92
let zero = Simd::default();
93
let vsum = self
94
.chunks_exact(STRIPE)
95
.enumerate()
96
.map(|(i, a)| {
97
let m: Mask<_, STRIPE> = mask.get_simd(i * STRIPE);
98
m.select(Simd::from_slice(a).cast_generic::<F>(), zero)
99
})
100
.sum::<Simd<F, STRIPE>>();
101
vector_horizontal_sum(vsum)
102
}
103
}
104
105
#[cfg(feature = "simd")]
106
impl<F> SumBlock<F> for [i128; PAIRWISE_RECURSION_LIMIT]
107
where
108
i128: AsPrimitive<F>,
109
F: Float + std::iter::Sum + 'static,
110
{
111
fn sum_block_vectorized(&self) -> F {
112
self.iter().map(|x| x.as_()).sum()
113
}
114
115
fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F {
116
self.iter()
117
.enumerate()
118
.map(|(idx, x)| if mask.get(idx) { x.as_() } else { F::zero() })
119
.sum()
120
}
121
}
122
123
#[cfg(not(feature = "simd"))]
124
impl<T, F> SumBlock<F> for [T; PAIRWISE_RECURSION_LIMIT]
125
where
126
T: AsPrimitive<F> + 'static,
127
F: Default + Add<Output = F> + Copy + 'static,
128
{
129
fn sum_block_vectorized(&self) -> F {
130
let mut vsum = [F::default(); STRIPE];
131
for chunk in self.chunks_exact(STRIPE) {
132
for j in 0..STRIPE {
133
vsum[j] = vsum[j] + chunk[j].as_();
134
}
135
}
136
vector_horizontal_sum(vsum)
137
}
138
139
fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F {
140
let mut vsum = [F::default(); STRIPE];
141
for (i, chunk) in self.chunks_exact(STRIPE).enumerate() {
142
for j in 0..STRIPE {
143
// Unconditional add with select for better branch-free opts.
144
let addend = if mask.get(i * STRIPE + j) {
145
chunk[j].as_()
146
} else {
147
F::default()
148
};
149
vsum[j] = vsum[j] + addend;
150
}
151
}
152
vector_horizontal_sum(vsum)
153
}
154
}
155
156
/// Invariant: f.len() % PAIRWISE_RECURSION_LIMIT == 0 and f.len() > 0.
157
unsafe fn pairwise_sum<F, T>(f: &[T]) -> F
158
where
159
[T; PAIRWISE_RECURSION_LIMIT]: SumBlock<F>,
160
F: Add<Output = F>,
161
{
162
debug_assert!(!f.is_empty() && f.len().is_multiple_of(PAIRWISE_RECURSION_LIMIT));
163
164
let block: Option<&[T; PAIRWISE_RECURSION_LIMIT]> = f.try_into().ok();
165
if let Some(block) = block {
166
return block.sum_block_vectorized();
167
}
168
169
// SAFETY: we maintain the invariant. `try_into` array of len PAIRWISE_RECURSION_LIMIT
170
// failed so we know f.len() >= 2*PAIRWISE_RECURSION_LIMIT, and thus blocks >= 2.
171
// This means 0 < left_len < f.len() and left_len is divisible by PAIRWISE_RECURSION_LIMIT,
172
// maintaining the invariant for both recursive calls.
173
unsafe {
174
let blocks = f.len() / PAIRWISE_RECURSION_LIMIT;
175
let left_len = (blocks / 2) * PAIRWISE_RECURSION_LIMIT;
176
let (left, right) = (f.get_unchecked(..left_len), f.get_unchecked(left_len..));
177
pairwise_sum(left) + pairwise_sum(right)
178
}
179
}
180
181
/// Invariant: f.len() % PAIRWISE_RECURSION_LIMIT == 0 and f.len() > 0.
182
/// Also, f.len() == mask.len().
183
unsafe fn pairwise_sum_with_mask<F, T>(f: &[T], mask: BitMask<'_>) -> F
184
where
185
[T; PAIRWISE_RECURSION_LIMIT]: SumBlock<F>,
186
F: Add<Output = F>,
187
{
188
debug_assert!(!f.is_empty() && f.len().is_multiple_of(PAIRWISE_RECURSION_LIMIT));
189
debug_assert!(f.len() == mask.len());
190
191
let block: Option<&[T; PAIRWISE_RECURSION_LIMIT]> = f.try_into().ok();
192
if let Some(block) = block {
193
return block.sum_block_vectorized_with_mask(mask);
194
}
195
196
// SAFETY: see pairwise_sum.
197
unsafe {
198
let blocks = f.len() / PAIRWISE_RECURSION_LIMIT;
199
let left_len = (blocks / 2) * PAIRWISE_RECURSION_LIMIT;
200
let (left, right) = (f.get_unchecked(..left_len), f.get_unchecked(left_len..));
201
let (left_mask, right_mask) = mask.split_at_unchecked(left_len);
202
pairwise_sum_with_mask(left, left_mask) + pairwise_sum_with_mask(right, right_mask)
203
}
204
}
205
206
pub trait FloatSum<F>: Sized {
207
fn sum(f: &[Self]) -> F;
208
fn sum_with_validity(f: &[Self], validity: &Bitmap) -> F;
209
}
210
211
impl<T, F> FloatSum<F> for T
212
where
213
F: Float + std::iter::Sum + 'static,
214
T: AsPrimitive<F>,
215
[T; PAIRWISE_RECURSION_LIMIT]: SumBlock<F>,
216
{
217
fn sum(f: &[Self]) -> F {
218
let remainder = f.len() % PAIRWISE_RECURSION_LIMIT;
219
let (rest, main) = f.split_at(remainder);
220
let mainsum = if f.len() > remainder {
221
unsafe { pairwise_sum(main) }
222
} else {
223
F::zero()
224
};
225
// TODO: faster remainder.
226
let restsum: F = rest.iter().map(|x| x.as_()).sum();
227
mainsum + restsum
228
}
229
230
fn sum_with_validity(f: &[Self], validity: &Bitmap) -> F {
231
let mask = BitMask::from_bitmap(validity);
232
assert!(f.len() == mask.len());
233
234
let remainder = f.len() % PAIRWISE_RECURSION_LIMIT;
235
let (rest, main) = f.split_at(remainder);
236
let (rest_mask, main_mask) = mask.split_at(remainder);
237
let mainsum = if f.len() > remainder {
238
unsafe { pairwise_sum_with_mask(main, main_mask) }
239
} else {
240
F::zero()
241
};
242
// TODO: faster remainder.
243
let restsum: F = rest
244
.iter()
245
.enumerate()
246
.map(|(i, x)| {
247
// No filter but rather select of 0.0 for cmov opt.
248
if rest_mask.get(i) { x.as_() } else { F::zero() }
249
})
250
.sum();
251
mainsum + restsum
252
}
253
}
254
255
pub fn sum_arr_as_f32<T>(arr: &PrimitiveArray<T>) -> f32
256
where
257
T: NativeType + FloatSum<f32>,
258
{
259
let validity = arr.validity().filter(|_| arr.null_count() > 0);
260
if let Some(mask) = validity {
261
FloatSum::sum_with_validity(arr.values(), mask)
262
} else {
263
FloatSum::sum(arr.values())
264
}
265
}
266
267
pub fn sum_arr_as_f64<T>(arr: &PrimitiveArray<T>) -> f64
268
where
269
T: NativeType + FloatSum<f64>,
270
{
271
let validity = arr.validity().filter(|_| arr.null_count() > 0);
272
if let Some(mask) = validity {
273
FloatSum::sum_with_validity(arr.values(), mask)
274
} else {
275
FloatSum::sum(arr.values())
276
}
277
}
278
279