Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/if_then_else/simd.rs
6939 views
1
#[cfg(target_arch = "x86_64")]
2
use std::mem::MaybeUninit;
3
#[cfg(target_arch = "x86_64")]
4
use std::simd::{Mask, Simd, SimdElement};
5
6
use arrow::array::PrimitiveArray;
7
use arrow::bitmap::Bitmap;
8
use arrow::datatypes::ArrowDataType;
9
10
use super::{
11
IfThenElseKernel, if_then_else_loop, if_then_else_loop_broadcast_both,
12
if_then_else_loop_broadcast_false, if_then_else_validity, scalar,
13
};
14
15
#[cfg(target_arch = "x86_64")]
16
fn select_simd_64<T: Copy + SimdElement>(
17
mask: u64,
18
if_true: Simd<T, 64>,
19
if_false: Simd<T, 64>,
20
out: &mut [MaybeUninit<T>; 64],
21
) {
22
let mv = Mask::<<T as SimdElement>::Mask, 64>::from_bitmask(mask);
23
let ret = mv.select(if_true, if_false);
24
unsafe {
25
let src = ret.as_array().as_ptr() as *const MaybeUninit<T>;
26
core::ptr::copy_nonoverlapping(src, out.as_mut_ptr(), 64);
27
}
28
}
29
30
#[cfg(target_arch = "x86_64")]
31
fn if_then_else_simd_64<T: Copy + SimdElement>(
32
mask: u64,
33
if_true: &[T; 64],
34
if_false: &[T; 64],
35
out: &mut [MaybeUninit<T>; 64],
36
) {
37
select_simd_64(
38
mask,
39
Simd::from_slice(if_true),
40
Simd::from_slice(if_false),
41
out,
42
)
43
}
44
45
#[cfg(target_arch = "x86_64")]
46
fn if_then_else_broadcast_false_simd_64<T: Copy + SimdElement>(
47
mask: u64,
48
if_true: &[T; 64],
49
if_false: T,
50
out: &mut [MaybeUninit<T>; 64],
51
) {
52
select_simd_64(mask, Simd::from_slice(if_true), Simd::splat(if_false), out)
53
}
54
55
#[cfg(target_arch = "x86_64")]
56
fn if_then_else_broadcast_both_simd_64<T: Copy + SimdElement>(
57
mask: u64,
58
if_true: T,
59
if_false: T,
60
out: &mut [MaybeUninit<T>; 64],
61
) {
62
select_simd_64(mask, Simd::splat(if_true), Simd::splat(if_false), out)
63
}
64
65
macro_rules! impl_if_then_else {
66
($T: ty) => {
67
impl IfThenElseKernel for PrimitiveArray<$T> {
68
type Scalar<'a> = $T;
69
70
fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self {
71
let values = if_then_else_loop(
72
mask,
73
if_true.values(),
74
if_false.values(),
75
scalar::if_then_else_scalar_rest,
76
// Auto-generated SIMD was slower on ARM.
77
#[cfg(target_arch = "x86_64")]
78
if_then_else_simd_64,
79
#[cfg(not(target_arch = "x86_64"))]
80
scalar::if_then_else_scalar_64,
81
);
82
let validity = if_then_else_validity(mask, if_true.validity(), if_false.validity());
83
PrimitiveArray::from_vec(values).with_validity(validity)
84
}
85
86
fn if_then_else_broadcast_true(
87
mask: &Bitmap,
88
if_true: Self::Scalar<'_>,
89
if_false: &Self,
90
) -> Self {
91
let values = if_then_else_loop_broadcast_false(
92
true,
93
mask,
94
if_false.values(),
95
if_true,
96
// Auto-generated SIMD was slower on ARM.
97
#[cfg(target_arch = "x86_64")]
98
if_then_else_broadcast_false_simd_64,
99
#[cfg(not(target_arch = "x86_64"))]
100
scalar::if_then_else_broadcast_false_scalar_64,
101
);
102
let validity = if_then_else_validity(mask, None, if_false.validity());
103
PrimitiveArray::from_vec(values).with_validity(validity)
104
}
105
106
fn if_then_else_broadcast_false(
107
mask: &Bitmap,
108
if_true: &Self,
109
if_false: Self::Scalar<'_>,
110
) -> Self {
111
let values = if_then_else_loop_broadcast_false(
112
false,
113
mask,
114
if_true.values(),
115
if_false,
116
// Auto-generated SIMD was slower on ARM.
117
#[cfg(target_arch = "x86_64")]
118
if_then_else_broadcast_false_simd_64,
119
#[cfg(not(target_arch = "x86_64"))]
120
scalar::if_then_else_broadcast_false_scalar_64,
121
);
122
let validity = if_then_else_validity(mask, if_true.validity(), None);
123
PrimitiveArray::from_vec(values).with_validity(validity)
124
}
125
126
fn if_then_else_broadcast_both(
127
_dtype: ArrowDataType,
128
mask: &Bitmap,
129
if_true: Self::Scalar<'_>,
130
if_false: Self::Scalar<'_>,
131
) -> Self {
132
let values = if_then_else_loop_broadcast_both(
133
mask,
134
if_true,
135
if_false,
136
// Auto-generated SIMD was slower on ARM.
137
#[cfg(target_arch = "x86_64")]
138
if_then_else_broadcast_both_simd_64,
139
#[cfg(not(target_arch = "x86_64"))]
140
scalar::if_then_else_broadcast_both_scalar_64,
141
);
142
PrimitiveArray::from_vec(values)
143
}
144
}
145
};
146
}
147
148
impl_if_then_else!(i8);
149
impl_if_then_else!(i16);
150
impl_if_then_else!(i32);
151
impl_if_then_else!(i64);
152
impl_if_then_else!(u8);
153
impl_if_then_else!(u16);
154
impl_if_then_else!(u32);
155
impl_if_then_else!(u64);
156
impl_if_then_else!(f32);
157
impl_if_then_else!(f64);
158
159