Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/comparisons/simd.rs
6939 views
1
use std::ptr;
2
use std::simd::prelude::{Simd, SimdPartialEq, SimdPartialOrd};
3
4
use arrow::array::PrimitiveArray;
5
use arrow::bitmap::Bitmap;
6
use arrow::types::NativeType;
7
use bytemuck::Pod;
8
9
use super::{TotalEqKernel, TotalOrdKernel};
10
11
fn apply_binary_kernel<const N: usize, M: Pod, T, F>(
12
lhs: &PrimitiveArray<T>,
13
rhs: &PrimitiveArray<T>,
14
mut f: F,
15
) -> Bitmap
16
where
17
T: NativeType,
18
F: FnMut(&[T; N], &[T; N]) -> M,
19
{
20
assert_eq!(N, size_of::<M>() * 8);
21
assert!(lhs.len() == rhs.len());
22
let n = lhs.len();
23
24
let lhs_buf = lhs.values().as_slice();
25
let rhs_buf = rhs.values().as_slice();
26
let lhs_chunks = lhs_buf.chunks_exact(N);
27
let rhs_chunks = rhs_buf.chunks_exact(N);
28
let lhs_rest = lhs_chunks.remainder();
29
let rhs_rest = rhs_chunks.remainder();
30
31
let num_masks = n.div_ceil(N);
32
let mut v: Vec<u8> = Vec::with_capacity(num_masks * size_of::<M>());
33
let mut p = v.as_mut_ptr() as *mut M;
34
for (l, r) in lhs_chunks.zip(rhs_chunks) {
35
unsafe {
36
let mask = f(
37
l.try_into().unwrap_unchecked(),
38
r.try_into().unwrap_unchecked(),
39
);
40
p.write_unaligned(mask);
41
p = p.wrapping_add(1);
42
}
43
}
44
45
if !n.is_multiple_of(N) {
46
let mut l: [T; N] = [T::zeroed(); N];
47
let mut r: [T; N] = [T::zeroed(); N];
48
unsafe {
49
ptr::copy_nonoverlapping(lhs_rest.as_ptr(), l.as_mut_ptr(), n % N);
50
ptr::copy_nonoverlapping(rhs_rest.as_ptr(), r.as_mut_ptr(), n % N);
51
p.write_unaligned(f(&l, &r));
52
}
53
}
54
55
unsafe {
56
v.set_len(num_masks * size_of::<M>());
57
}
58
59
Bitmap::from_u8_vec(v, n)
60
}
61
62
fn apply_unary_kernel<const N: usize, M: Pod, T, F>(arg: &PrimitiveArray<T>, mut f: F) -> Bitmap
63
where
64
T: NativeType,
65
F: FnMut(&[T; N]) -> M,
66
{
67
assert_eq!(N, size_of::<M>() * 8);
68
let n = arg.len();
69
70
let arg_buf = arg.values().as_slice();
71
let arg_chunks = arg_buf.chunks_exact(N);
72
let arg_rest = arg_chunks.remainder();
73
74
let num_masks = n.div_ceil(N);
75
let mut v: Vec<u8> = Vec::with_capacity(num_masks * size_of::<M>());
76
let mut p = v.as_mut_ptr() as *mut M;
77
for a in arg_chunks {
78
unsafe {
79
let mask = f(a.try_into().unwrap_unchecked());
80
p.write_unaligned(mask);
81
p = p.wrapping_add(1);
82
}
83
}
84
85
if !n.is_multiple_of(N) {
86
let mut a: [T; N] = [T::zeroed(); N];
87
unsafe {
88
ptr::copy_nonoverlapping(arg_rest.as_ptr(), a.as_mut_ptr(), n % N);
89
p.write_unaligned(f(&a));
90
}
91
}
92
93
unsafe {
94
v.set_len(num_masks * size_of::<M>());
95
}
96
97
Bitmap::from_u8_vec(v, n)
98
}
99
100
macro_rules! impl_int_total_ord_kernel {
101
($T: ty, $width: literal, $mask: ty) => {
102
impl TotalEqKernel for PrimitiveArray<$T> {
103
type Scalar = $T;
104
105
fn tot_eq_kernel(&self, other: &Self) -> Bitmap {
106
apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| {
107
Simd::from(*l).simd_eq(Simd::from(*r)).to_bitmask() as $mask
108
})
109
}
110
111
fn tot_ne_kernel(&self, other: &Self) -> Bitmap {
112
apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| {
113
Simd::from(*l).simd_ne(Simd::from(*r)).to_bitmask() as $mask
114
})
115
}
116
117
fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
118
let r = Simd::splat(*other);
119
apply_unary_kernel::<$width, $mask, _, _>(self, |l| {
120
Simd::from(*l).simd_eq(r).to_bitmask() as $mask
121
})
122
}
123
124
fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
125
let r = Simd::splat(*other);
126
apply_unary_kernel::<$width, $mask, _, _>(self, |l| {
127
Simd::from(*l).simd_ne(r).to_bitmask() as $mask
128
})
129
}
130
}
131
132
impl TotalOrdKernel for PrimitiveArray<$T> {
133
type Scalar = $T;
134
135
fn tot_lt_kernel(&self, other: &Self) -> Bitmap {
136
apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| {
137
Simd::from(*l).simd_lt(Simd::from(*r)).to_bitmask() as $mask
138
})
139
}
140
141
fn tot_le_kernel(&self, other: &Self) -> Bitmap {
142
apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| {
143
Simd::from(*l).simd_le(Simd::from(*r)).to_bitmask() as $mask
144
})
145
}
146
147
fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
148
let r = Simd::splat(*other);
149
apply_unary_kernel::<$width, $mask, _, _>(self, |l| {
150
Simd::from(*l).simd_lt(r).to_bitmask() as $mask
151
})
152
}
153
154
fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
155
let r = Simd::splat(*other);
156
apply_unary_kernel::<$width, $mask, _, _>(self, |l| {
157
Simd::from(*l).simd_le(r).to_bitmask() as $mask
158
})
159
}
160
161
fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
162
let r = Simd::splat(*other);
163
apply_unary_kernel::<$width, $mask, _, _>(self, |l| {
164
Simd::from(*l).simd_gt(r).to_bitmask() as $mask
165
})
166
}
167
168
fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
169
let r = Simd::splat(*other);
170
apply_unary_kernel::<$width, $mask, _, _>(self, |l| {
171
Simd::from(*l).simd_ge(r).to_bitmask() as $mask
172
})
173
}
174
}
175
};
176
}
177
178
macro_rules! impl_float_total_ord_kernel {
179
($T: ty, $width: literal, $mask: ty) => {
180
impl TotalEqKernel for PrimitiveArray<$T> {
181
type Scalar = $T;
182
183
fn tot_eq_kernel(&self, other: &Self) -> Bitmap {
184
apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| {
185
let ls = Simd::from(*l);
186
let rs = Simd::from(*r);
187
let lhs_is_nan = ls.simd_ne(ls);
188
let rhs_is_nan = rs.simd_ne(rs);
189
((lhs_is_nan & rhs_is_nan) | ls.simd_eq(rs)).to_bitmask() as $mask
190
})
191
}
192
193
fn tot_ne_kernel(&self, other: &Self) -> Bitmap {
194
apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| {
195
let ls = Simd::from(*l);
196
let rs = Simd::from(*r);
197
let lhs_is_nan = ls.simd_ne(ls);
198
let rhs_is_nan = rs.simd_ne(rs);
199
(!((lhs_is_nan & rhs_is_nan) | ls.simd_eq(rs))).to_bitmask() as $mask
200
})
201
}
202
203
fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
204
let rs = Simd::splat(*other);
205
apply_unary_kernel::<$width, $mask, _, _>(self, |l| {
206
let ls = Simd::from(*l);
207
let lhs_is_nan = ls.simd_ne(ls);
208
let rhs_is_nan = rs.simd_ne(rs);
209
((lhs_is_nan & rhs_is_nan) | ls.simd_eq(rs)).to_bitmask() as $mask
210
})
211
}
212
213
fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
214
let rs = Simd::splat(*other);
215
apply_unary_kernel::<$width, $mask, _, _>(self, |l| {
216
let ls = Simd::from(*l);
217
let lhs_is_nan = ls.simd_ne(ls);
218
let rhs_is_nan = rs.simd_ne(rs);
219
(!((lhs_is_nan & rhs_is_nan) | ls.simd_eq(rs))).to_bitmask() as $mask
220
})
221
}
222
}
223
224
impl TotalOrdKernel for PrimitiveArray<$T> {
225
type Scalar = $T;
226
227
fn tot_lt_kernel(&self, other: &Self) -> Bitmap {
228
apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| {
229
let ls = Simd::from(*l);
230
let rs = Simd::from(*r);
231
let lhs_is_nan = ls.simd_ne(ls);
232
(!(lhs_is_nan | ls.simd_ge(rs))).to_bitmask() as $mask
233
})
234
}
235
236
fn tot_le_kernel(&self, other: &Self) -> Bitmap {
237
apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| {
238
let ls = Simd::from(*l);
239
let rs = Simd::from(*r);
240
let rhs_is_nan = rs.simd_ne(rs);
241
(rhs_is_nan | ls.simd_le(rs)).to_bitmask() as $mask
242
})
243
}
244
245
fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
246
let rs = Simd::splat(*other);
247
apply_unary_kernel::<$width, $mask, _, _>(self, |l| {
248
let ls = Simd::from(*l);
249
let lhs_is_nan = ls.simd_ne(ls);
250
(!(lhs_is_nan | ls.simd_ge(rs))).to_bitmask() as $mask
251
})
252
}
253
254
fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
255
let rs = Simd::splat(*other);
256
apply_unary_kernel::<$width, $mask, _, _>(self, |l| {
257
let ls = Simd::from(*l);
258
let rhs_is_nan = rs.simd_ne(rs);
259
(rhs_is_nan | ls.simd_le(rs)).to_bitmask() as $mask
260
})
261
}
262
263
fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
264
let rs = Simd::splat(*other);
265
apply_unary_kernel::<$width, $mask, _, _>(self, |l| {
266
let ls = Simd::from(*l);
267
let rhs_is_nan = rs.simd_ne(rs);
268
(!(rhs_is_nan | rs.simd_ge(ls))).to_bitmask() as $mask
269
})
270
}
271
272
fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
273
let rs = Simd::splat(*other);
274
apply_unary_kernel::<$width, $mask, _, _>(self, |l| {
275
let ls = Simd::from(*l);
276
let lhs_is_nan = ls.simd_ne(ls);
277
(lhs_is_nan | rs.simd_le(ls)).to_bitmask() as $mask
278
})
279
}
280
}
281
};
282
}
283
284
impl_int_total_ord_kernel!(u8, 32, u32);
285
impl_int_total_ord_kernel!(u16, 16, u16);
286
impl_int_total_ord_kernel!(u32, 8, u8);
287
impl_int_total_ord_kernel!(u64, 8, u8);
288
impl_int_total_ord_kernel!(i8, 32, u32);
289
impl_int_total_ord_kernel!(i16, 16, u16);
290
impl_int_total_ord_kernel!(i32, 8, u8);
291
impl_int_total_ord_kernel!(i64, 8, u8);
292
impl_float_total_ord_kernel!(f32, 8, u8);
293
impl_float_total_ord_kernel!(f64, 8, u8);
294
295