Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/if_then_else/mod.rs
6939 views
1
use std::mem::MaybeUninit;
2
3
use arrow::array::{Array, PrimitiveArray};
4
use arrow::bitmap::utils::SlicesIterator;
5
use arrow::bitmap::{self, Bitmap};
6
use arrow::datatypes::ArrowDataType;
7
8
use crate::NotSimdPrimitive;
9
10
mod array;
11
mod boolean;
12
mod list;
13
mod scalar;
14
#[cfg(feature = "simd")]
15
mod simd;
16
mod view;
17
18
pub trait IfThenElseKernel: Sized + Array {
19
type Scalar<'a>;
20
21
fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self;
22
fn if_then_else_broadcast_true(
23
mask: &Bitmap,
24
if_true: Self::Scalar<'_>,
25
if_false: &Self,
26
) -> Self;
27
fn if_then_else_broadcast_false(
28
mask: &Bitmap,
29
if_true: &Self,
30
if_false: Self::Scalar<'_>,
31
) -> Self;
32
fn if_then_else_broadcast_both(
33
dtype: ArrowDataType,
34
mask: &Bitmap,
35
if_true: Self::Scalar<'_>,
36
if_false: Self::Scalar<'_>,
37
) -> Self;
38
}
39
40
impl<T: NotSimdPrimitive> IfThenElseKernel for PrimitiveArray<T> {
41
type Scalar<'a> = T;
42
43
fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self {
44
let values = if_then_else_loop(
45
mask,
46
if_true.values(),
47
if_false.values(),
48
scalar::if_then_else_scalar_rest,
49
scalar::if_then_else_scalar_64,
50
);
51
let validity = if_then_else_validity(mask, if_true.validity(), if_false.validity());
52
PrimitiveArray::from_vec(values).with_validity(validity)
53
}
54
55
fn if_then_else_broadcast_true(
56
mask: &Bitmap,
57
if_true: Self::Scalar<'_>,
58
if_false: &Self,
59
) -> Self {
60
let values = if_then_else_loop_broadcast_false(
61
true,
62
mask,
63
if_false.values(),
64
if_true,
65
scalar::if_then_else_broadcast_false_scalar_64,
66
);
67
let validity = if_then_else_validity(mask, None, if_false.validity());
68
PrimitiveArray::from_vec(values).with_validity(validity)
69
}
70
71
fn if_then_else_broadcast_false(
72
mask: &Bitmap,
73
if_true: &Self,
74
if_false: Self::Scalar<'_>,
75
) -> Self {
76
let values = if_then_else_loop_broadcast_false(
77
false,
78
mask,
79
if_true.values(),
80
if_false,
81
scalar::if_then_else_broadcast_false_scalar_64,
82
);
83
let validity = if_then_else_validity(mask, if_true.validity(), None);
84
PrimitiveArray::from_vec(values).with_validity(validity)
85
}
86
87
fn if_then_else_broadcast_both(
88
_dtype: ArrowDataType,
89
mask: &Bitmap,
90
if_true: Self::Scalar<'_>,
91
if_false: Self::Scalar<'_>,
92
) -> Self {
93
let values = if_then_else_loop_broadcast_both(
94
mask,
95
if_true,
96
if_false,
97
scalar::if_then_else_broadcast_both_scalar_64,
98
);
99
PrimitiveArray::from_vec(values)
100
}
101
}
102
103
pub fn if_then_else_validity(
104
mask: &Bitmap,
105
if_true: Option<&Bitmap>,
106
if_false: Option<&Bitmap>,
107
) -> Option<Bitmap> {
108
match (if_true, if_false) {
109
(None, None) => None,
110
(None, Some(f)) => Some(mask | f),
111
(Some(t), None) => Some(bitmap::binary(mask, t, |m, t| !m | t)),
112
(Some(t), Some(f)) => Some(bitmap::ternary(mask, t, f, |m, t, f| (m & t) | (!m & f))),
113
}
114
}
115
116
fn if_then_else_extend<B, ET: Fn(&mut B, usize, usize), EF: Fn(&mut B, usize, usize)>(
117
builder: &mut B,
118
mask: &Bitmap,
119
extend_true: ET,
120
extend_false: EF,
121
) {
122
let mut last_true_end = 0;
123
for (start, len) in SlicesIterator::new(mask) {
124
if start != last_true_end {
125
extend_false(builder, last_true_end, start - last_true_end);
126
};
127
extend_true(builder, start, len);
128
last_true_end = start + len;
129
}
130
if last_true_end != mask.len() {
131
extend_false(builder, last_true_end, mask.len() - last_true_end)
132
}
133
}
134
135
fn if_then_else_loop<T, F, F64>(
136
mask: &Bitmap,
137
if_true: &[T],
138
if_false: &[T],
139
process_var: F,
140
process_chunk: F64,
141
) -> Vec<T>
142
where
143
T: Copy,
144
F: Fn(u64, &[T], &[T], &mut [MaybeUninit<T>]),
145
F64: Fn(u64, &[T; 64], &[T; 64], &mut [MaybeUninit<T>; 64]),
146
{
147
assert_eq!(mask.len(), if_true.len());
148
assert_eq!(mask.len(), if_false.len());
149
150
let mut ret = Vec::with_capacity(mask.len());
151
let out = &mut ret.spare_capacity_mut()[..mask.len()];
152
153
// Handle prefix.
154
let aligned = mask.aligned::<u64>();
155
let (start_true, rest_true) = if_true.split_at(aligned.prefix_bitlen());
156
let (start_false, rest_false) = if_false.split_at(aligned.prefix_bitlen());
157
let (start_out, rest_out) = out.split_at_mut(aligned.prefix_bitlen());
158
if aligned.prefix_bitlen() > 0 {
159
process_var(aligned.prefix(), start_true, start_false, start_out);
160
}
161
162
// Handle bulk.
163
let mut true_chunks = rest_true.chunks_exact(64);
164
let mut false_chunks = rest_false.chunks_exact(64);
165
let mut out_chunks = rest_out.chunks_exact_mut(64);
166
let combined = true_chunks
167
.by_ref()
168
.zip(false_chunks.by_ref())
169
.zip(out_chunks.by_ref());
170
for (i, ((tc, fc), oc)) in combined.enumerate() {
171
let m = unsafe { *aligned.bulk().get_unchecked(i) };
172
process_chunk(
173
m,
174
tc.try_into().unwrap(),
175
fc.try_into().unwrap(),
176
oc.try_into().unwrap(),
177
);
178
}
179
180
// Handle suffix.
181
if aligned.suffix_bitlen() > 0 {
182
process_var(
183
aligned.suffix(),
184
true_chunks.remainder(),
185
false_chunks.remainder(),
186
out_chunks.into_remainder(),
187
);
188
}
189
190
unsafe {
191
ret.set_len(mask.len());
192
}
193
ret
194
}
195
196
fn if_then_else_loop_broadcast_false<T, F64>(
197
invert_mask: bool, // Allows code reuse for both false and true broadcasts.
198
mask: &Bitmap,
199
if_true: &[T],
200
if_false: T,
201
process_chunk: F64,
202
) -> Vec<T>
203
where
204
T: Copy,
205
F64: Fn(u64, &[T; 64], T, &mut [MaybeUninit<T>; 64]),
206
{
207
assert_eq!(mask.len(), if_true.len());
208
209
let mut ret = Vec::with_capacity(mask.len());
210
let out = &mut ret.spare_capacity_mut()[..mask.len()];
211
212
// XOR with all 1's inverts the mask.
213
let xor_inverter = if invert_mask { u64::MAX } else { 0 };
214
215
// Handle prefix.
216
let aligned = mask.aligned::<u64>();
217
let (start_true, rest_true) = if_true.split_at(aligned.prefix_bitlen());
218
let (start_out, rest_out) = out.split_at_mut(aligned.prefix_bitlen());
219
if aligned.prefix_bitlen() > 0 {
220
scalar::if_then_else_broadcast_false_scalar_rest(
221
aligned.prefix() ^ xor_inverter,
222
start_true,
223
if_false,
224
start_out,
225
);
226
}
227
228
// Handle bulk.
229
let mut true_chunks = rest_true.chunks_exact(64);
230
let mut out_chunks = rest_out.chunks_exact_mut(64);
231
let combined = true_chunks.by_ref().zip(out_chunks.by_ref());
232
for (i, (tc, oc)) in combined.enumerate() {
233
let m = unsafe { *aligned.bulk().get_unchecked(i) } ^ xor_inverter;
234
process_chunk(m, tc.try_into().unwrap(), if_false, oc.try_into().unwrap());
235
}
236
237
// Handle suffix.
238
if aligned.suffix_bitlen() > 0 {
239
scalar::if_then_else_broadcast_false_scalar_rest(
240
aligned.suffix() ^ xor_inverter,
241
true_chunks.remainder(),
242
if_false,
243
out_chunks.into_remainder(),
244
);
245
}
246
247
unsafe {
248
ret.set_len(mask.len());
249
}
250
ret
251
}
252
253
fn if_then_else_loop_broadcast_both<T, F64>(
254
mask: &Bitmap,
255
if_true: T,
256
if_false: T,
257
generate_chunk: F64,
258
) -> Vec<T>
259
where
260
T: Copy,
261
F64: Fn(u64, T, T, &mut [MaybeUninit<T>; 64]),
262
{
263
let mut ret = Vec::with_capacity(mask.len());
264
let out = &mut ret.spare_capacity_mut()[..mask.len()];
265
266
// Handle prefix.
267
let aligned = mask.aligned::<u64>();
268
let (start_out, rest_out) = out.split_at_mut(aligned.prefix_bitlen());
269
scalar::if_then_else_broadcast_both_scalar_rest(aligned.prefix(), if_true, if_false, start_out);
270
271
// Handle bulk.
272
let mut out_chunks = rest_out.chunks_exact_mut(64);
273
for (i, oc) in out_chunks.by_ref().enumerate() {
274
let m = unsafe { *aligned.bulk().get_unchecked(i) };
275
generate_chunk(m, if_true, if_false, oc.try_into().unwrap());
276
}
277
278
// Handle suffix.
279
if aligned.suffix_bitlen() > 0 {
280
scalar::if_then_else_broadcast_both_scalar_rest(
281
aligned.suffix(),
282
if_true,
283
if_false,
284
out_chunks.into_remainder(),
285
);
286
}
287
288
unsafe {
289
ret.set_len(mask.len());
290
}
291
ret
292
}
293
294