Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/min_max/simd.rs
6939 views
1
use std::simd::prelude::*;
2
use std::simd::{LaneCount, SimdElement, SupportedLaneCount};
3
4
use arrow::array::PrimitiveArray;
5
use arrow::bitmap::Bitmap;
6
use arrow::bitmap::bitmask::BitMask;
7
use arrow::types::NativeType;
8
use polars_utils::min_max::MinMax;
9
10
use super::MinMaxKernel;
11
12
fn scalar_reduce_min_propagate_nan<T: MinMax + Copy, const N: usize>(arr: &[T; N]) -> T {
13
let it = arr.iter().copied();
14
it.reduce(MinMax::min_propagate_nan).unwrap()
15
}
16
17
fn scalar_reduce_max_propagate_nan<T: MinMax + Copy, const N: usize>(arr: &[T; N]) -> T {
18
let it = arr.iter().copied();
19
it.reduce(MinMax::max_propagate_nan).unwrap()
20
}
21
22
fn fold_agg_kernel<const N: usize, T, F>(
23
arr: &[T],
24
validity: Option<&Bitmap>,
25
scalar_identity: T,
26
mut simd_f: F,
27
) -> Option<Simd<T, N>>
28
where
29
T: SimdElement + NativeType,
30
F: FnMut(Simd<T, N>, Simd<T, N>) -> Simd<T, N>,
31
LaneCount<N>: SupportedLaneCount,
32
{
33
if arr.is_empty() {
34
return None;
35
}
36
37
let mut arr_chunks = arr.chunks_exact(N);
38
39
let identity = Simd::splat(scalar_identity);
40
let mut state = identity;
41
if let Some(valid) = validity {
42
if valid.unset_bits() == arr.len() {
43
return None;
44
}
45
46
let mask = BitMask::from_bitmap(valid);
47
let mut offset = 0;
48
for c in arr_chunks.by_ref() {
49
let m: Mask<_, N> = mask.get_simd(offset);
50
state = simd_f(state, m.select(Simd::from_slice(c), identity));
51
offset += N;
52
}
53
if !arr.len().is_multiple_of(N) {
54
let mut rest: [T; N] = identity.to_array();
55
let arr_rest = arr_chunks.remainder();
56
rest[..arr_rest.len()].copy_from_slice(arr_rest);
57
let m: Mask<_, N> = mask.get_simd(offset);
58
state = simd_f(state, m.select(Simd::from_array(rest), identity));
59
}
60
} else {
61
for c in arr_chunks.by_ref() {
62
state = simd_f(state, Simd::from_slice(c));
63
}
64
if !arr.len().is_multiple_of(N) {
65
let mut rest: [T; N] = identity.to_array();
66
let arr_rest = arr_chunks.remainder();
67
rest[..arr_rest.len()].copy_from_slice(arr_rest);
68
state = simd_f(state, Simd::from_array(rest));
69
}
70
}
71
72
Some(state)
73
}
74
75
fn fold_agg_min_max_kernel<const N: usize, T, F>(
76
arr: &[T],
77
validity: Option<&Bitmap>,
78
min_scalar_identity: T,
79
max_scalar_identity: T,
80
mut simd_f: F,
81
) -> Option<(Simd<T, N>, Simd<T, N>)>
82
where
83
T: SimdElement + NativeType,
84
F: FnMut((Simd<T, N>, Simd<T, N>), (Simd<T, N>, Simd<T, N>)) -> (Simd<T, N>, Simd<T, N>),
85
LaneCount<N>: SupportedLaneCount,
86
{
87
if arr.is_empty() {
88
return None;
89
}
90
91
let mut arr_chunks = arr.chunks_exact(N);
92
93
let min_identity = Simd::splat(min_scalar_identity);
94
let max_identity = Simd::splat(max_scalar_identity);
95
let mut state = (min_identity, max_identity);
96
if let Some(valid) = validity {
97
if valid.unset_bits() == arr.len() {
98
return None;
99
}
100
101
let mask = BitMask::from_bitmap(valid);
102
let mut offset = 0;
103
for c in arr_chunks.by_ref() {
104
let m: Mask<_, N> = mask.get_simd(offset);
105
let slice = Simd::from_slice(c);
106
state = simd_f(
107
state,
108
(m.select(slice, min_identity), m.select(slice, max_identity)),
109
);
110
offset += N;
111
}
112
if !arr.len().is_multiple_of(N) {
113
let mut min_rest: [T; N] = min_identity.to_array();
114
let mut max_rest: [T; N] = max_identity.to_array();
115
116
let arr_rest = arr_chunks.remainder();
117
min_rest[..arr_rest.len()].copy_from_slice(arr_rest);
118
max_rest[..arr_rest.len()].copy_from_slice(arr_rest);
119
120
let m: Mask<_, N> = mask.get_simd(offset);
121
122
let min_rest = Simd::from_array(min_rest);
123
let max_rest = Simd::from_array(max_rest);
124
125
state = simd_f(
126
state,
127
(
128
m.select(min_rest, min_identity),
129
m.select(max_rest, max_identity),
130
),
131
);
132
}
133
} else {
134
for c in arr_chunks.by_ref() {
135
let slice = Simd::from_slice(c);
136
state = simd_f(state, (slice, slice));
137
}
138
if !arr.len().is_multiple_of(N) {
139
let mut min_rest: [T; N] = min_identity.to_array();
140
let mut max_rest: [T; N] = max_identity.to_array();
141
142
let arr_rest = arr_chunks.remainder();
143
min_rest[..arr_rest.len()].copy_from_slice(arr_rest);
144
max_rest[..arr_rest.len()].copy_from_slice(arr_rest);
145
146
let min_rest = Simd::from_array(min_rest);
147
let max_rest = Simd::from_array(max_rest);
148
149
state = simd_f(state, (min_rest, max_rest));
150
}
151
}
152
153
Some(state)
154
}
155
156
macro_rules! impl_min_max_kernel_int {
157
($T:ty, $N:literal) => {
158
impl MinMaxKernel for PrimitiveArray<$T> {
159
type Scalar<'a> = $T;
160
161
fn min_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
162
fold_agg_kernel::<$N, $T, _>(self.values(), self.validity(), <$T>::MAX, |a, b| {
163
a.simd_min(b)
164
})
165
.map(|s| s.reduce_min())
166
}
167
168
fn max_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
169
fold_agg_kernel::<$N, $T, _>(self.values(), self.validity(), <$T>::MIN, |a, b| {
170
a.simd_max(b)
171
})
172
.map(|s| s.reduce_max())
173
}
174
175
fn min_max_ignore_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {
176
fold_agg_min_max_kernel::<$N, $T, _>(
177
self.values(),
178
self.validity(),
179
<$T>::MAX,
180
<$T>::MIN,
181
|(cmin, cmax), (min, max)| (cmin.simd_min(min), cmax.simd_max(max)),
182
)
183
.map(|(min, max)| (min.reduce_min(), max.reduce_max()))
184
}
185
186
fn min_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
187
self.min_ignore_nan_kernel()
188
}
189
190
fn max_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
191
self.max_ignore_nan_kernel()
192
}
193
194
fn min_max_propagate_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {
195
self.min_max_ignore_nan_kernel()
196
}
197
}
198
199
impl MinMaxKernel for [$T] {
200
type Scalar<'a> = $T;
201
202
fn min_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
203
fold_agg_kernel::<$N, $T, _>(self, None, <$T>::MAX, |a, b| a.simd_min(b))
204
.map(|s| s.reduce_min())
205
}
206
207
fn max_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
208
fold_agg_kernel::<$N, $T, _>(self, None, <$T>::MIN, |a, b| a.simd_max(b))
209
.map(|s| s.reduce_max())
210
}
211
212
fn min_max_ignore_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {
213
fold_agg_min_max_kernel::<$N, $T, _>(
214
self,
215
None,
216
<$T>::MAX,
217
<$T>::MIN,
218
|(cmin, cmax), (min, max)| (cmin.simd_min(min), cmax.simd_max(max)),
219
)
220
.map(|(min, max)| (min.reduce_min(), max.reduce_max()))
221
}
222
223
fn min_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
224
self.min_ignore_nan_kernel()
225
}
226
227
fn max_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
228
self.max_ignore_nan_kernel()
229
}
230
231
fn min_max_propagate_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {
232
self.min_max_ignore_nan_kernel()
233
}
234
}
235
};
236
}
237
238
impl_min_max_kernel_int!(u8, 32);
239
impl_min_max_kernel_int!(u16, 16);
240
impl_min_max_kernel_int!(u32, 16);
241
impl_min_max_kernel_int!(u64, 8);
242
impl_min_max_kernel_int!(i8, 32);
243
impl_min_max_kernel_int!(i16, 16);
244
impl_min_max_kernel_int!(i32, 16);
245
impl_min_max_kernel_int!(i64, 8);
246
247
macro_rules! impl_min_max_kernel_float {
248
($T:ty, $N:literal) => {
249
impl MinMaxKernel for PrimitiveArray<$T> {
250
type Scalar<'a> = $T;
251
252
fn min_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
253
fold_agg_kernel::<$N, $T, _>(self.values(), self.validity(), <$T>::NAN, |a, b| {
254
a.simd_min(b)
255
})
256
.map(|s| s.reduce_min())
257
}
258
259
fn max_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
260
fold_agg_kernel::<$N, $T, _>(self.values(), self.validity(), <$T>::NAN, |a, b| {
261
a.simd_max(b)
262
})
263
.map(|s| s.reduce_max())
264
}
265
266
fn min_max_ignore_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {
267
fold_agg_min_max_kernel::<$N, $T, _>(
268
self.values(),
269
self.validity(),
270
<$T>::NAN,
271
<$T>::NAN,
272
|(cmin, cmax), (min, max)| (cmin.simd_min(min), cmax.simd_max(max)),
273
)
274
.map(|(min, max)| (min.reduce_min(), max.reduce_max()))
275
}
276
277
fn min_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
278
fold_agg_kernel::<$N, $T, _>(
279
self.values(),
280
self.validity(),
281
<$T>::INFINITY,
282
|a, b| (a.simd_lt(b) | a.simd_ne(a)).select(a, b),
283
)
284
.map(|s| scalar_reduce_min_propagate_nan(s.as_array()))
285
}
286
287
fn max_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
288
fold_agg_kernel::<$N, $T, _>(
289
self.values(),
290
self.validity(),
291
<$T>::NEG_INFINITY,
292
|a, b| (a.simd_gt(b) | a.simd_ne(a)).select(a, b),
293
)
294
.map(|s| scalar_reduce_max_propagate_nan(s.as_array()))
295
}
296
297
fn min_max_propagate_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {
298
fold_agg_min_max_kernel::<$N, $T, _>(
299
self.values(),
300
self.validity(),
301
<$T>::INFINITY,
302
<$T>::NEG_INFINITY,
303
|(cmin, cmax), (min, max)| {
304
(
305
(cmin.simd_lt(min) | cmin.simd_ne(cmin)).select(cmin, min),
306
(cmax.simd_gt(max) | cmax.simd_ne(cmax)).select(cmax, max),
307
)
308
},
309
)
310
.map(|(min, max)| {
311
(
312
scalar_reduce_min_propagate_nan(min.as_array()),
313
scalar_reduce_max_propagate_nan(max.as_array()),
314
)
315
})
316
}
317
}
318
319
impl MinMaxKernel for [$T] {
320
type Scalar<'a> = $T;
321
322
fn min_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
323
fold_agg_kernel::<$N, $T, _>(self, None, <$T>::NAN, |a, b| a.simd_min(b))
324
.map(|s| s.reduce_min())
325
}
326
327
fn max_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
328
fold_agg_kernel::<$N, $T, _>(self, None, <$T>::NAN, |a, b| a.simd_max(b))
329
.map(|s| s.reduce_max())
330
}
331
332
fn min_max_ignore_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {
333
fold_agg_min_max_kernel::<$N, $T, _>(
334
self,
335
None,
336
<$T>::NAN,
337
<$T>::NAN,
338
|(cmin, cmax), (min, max)| (cmin.simd_min(min), cmax.simd_max(max)),
339
)
340
.map(|(min, max)| (min.reduce_min(), max.reduce_max()))
341
}
342
343
fn min_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
344
fold_agg_kernel::<$N, $T, _>(self, None, <$T>::INFINITY, |a, b| {
345
(a.simd_lt(b) | a.simd_ne(a)).select(a, b)
346
})
347
.map(|s| scalar_reduce_min_propagate_nan(s.as_array()))
348
}
349
350
fn max_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {
351
fold_agg_kernel::<$N, $T, _>(self, None, <$T>::NEG_INFINITY, |a, b| {
352
(a.simd_gt(b) | a.simd_ne(a)).select(a, b)
353
})
354
.map(|s| scalar_reduce_max_propagate_nan(s.as_array()))
355
}
356
357
fn min_max_propagate_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {
358
fold_agg_min_max_kernel::<$N, $T, _>(
359
self,
360
None,
361
<$T>::INFINITY,
362
<$T>::NEG_INFINITY,
363
|(cmin, cmax), (min, max)| {
364
(
365
(cmin.simd_lt(min) | cmin.simd_ne(cmin)).select(cmin, min),
366
(cmax.simd_gt(max) | cmax.simd_ne(cmax)).select(cmax, max),
367
)
368
},
369
)
370
.map(|(min, max)| {
371
(
372
scalar_reduce_min_propagate_nan(min.as_array()),
373
scalar_reduce_max_propagate_nan(max.as_array()),
374
)
375
})
376
}
377
}
378
};
379
}
380
381
impl_min_max_kernel_float!(f32, 16);
382
impl_min_max_kernel_float!(f64, 8);
383
384