Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/if_then_else/view.rs
6939 views
1
use std::mem::MaybeUninit;
2
use std::ops::Deref;
3
use std::sync::Arc;
4
5
use arrow::array::{Array, BinaryViewArray, MutablePlBinary, Utf8ViewArray, View};
6
use arrow::bitmap::Bitmap;
7
use arrow::buffer::Buffer;
8
use arrow::datatypes::ArrowDataType;
9
use polars_utils::aliases::{InitHashMaps, PlHashSet};
10
11
use super::IfThenElseKernel;
12
use crate::if_then_else::scalar::if_then_else_broadcast_both_scalar_64;
13
14
// Makes a buffer and a set of views into that buffer from a set of strings.
15
// Does not allocate a buffer if not necessary.
16
fn make_buffer_and_views<const N: usize>(
17
strings: [&[u8]; N],
18
buffer_idx: u32,
19
) -> ([View; N], Option<Buffer<u8>>) {
20
let mut buf_data = Vec::new();
21
let views = strings.map(|s| {
22
let offset = buf_data.len().try_into().unwrap();
23
if s.len() > 12 {
24
buf_data.extend(s);
25
}
26
View::new_from_bytes(s, buffer_idx, offset)
27
});
28
let buf = (!buf_data.is_empty()).then(|| buf_data.into());
29
(views, buf)
30
}
31
32
fn has_duplicate_buffers(bufs: &[Buffer<u8>]) -> bool {
33
let mut has_duplicate_buffers = false;
34
let mut bufset = PlHashSet::new();
35
for buf in bufs {
36
if !bufset.insert(buf.as_ptr()) {
37
has_duplicate_buffers = true;
38
break;
39
}
40
}
41
has_duplicate_buffers
42
}
43
44
impl IfThenElseKernel for BinaryViewArray {
45
type Scalar<'a> = &'a [u8];
46
47
fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self {
48
let combined_buffers: Arc<_>;
49
let false_buffer_idx_offset: u32;
50
let mut has_duplicate_bufs = false;
51
if Arc::ptr_eq(if_true.data_buffers(), if_false.data_buffers()) {
52
// Share exact same buffers, no need to combine.
53
combined_buffers = if_true.data_buffers().clone();
54
false_buffer_idx_offset = 0;
55
} else {
56
// Put false buffers after true buffers.
57
let true_buffers = if_true.data_buffers().iter().cloned();
58
let false_buffers = if_false.data_buffers().iter().cloned();
59
60
combined_buffers = true_buffers.chain(false_buffers).collect();
61
has_duplicate_bufs = has_duplicate_buffers(&combined_buffers);
62
false_buffer_idx_offset = if_true.data_buffers().len() as u32;
63
}
64
65
let views = super::if_then_else_loop(
66
mask,
67
if_true.views(),
68
if_false.views(),
69
|m, t, f, o| if_then_else_view_rest(m, t, f, o, false_buffer_idx_offset),
70
|m, t, f, o| if_then_else_view_64(m, t, f, o, false_buffer_idx_offset),
71
);
72
73
let validity = super::if_then_else_validity(mask, if_true.validity(), if_false.validity());
74
75
let mut builder = MutablePlBinary::with_capacity(views.len());
76
77
if has_duplicate_bufs {
78
unsafe {
79
builder.extend_non_null_views_unchecked_dedupe(
80
views.into_iter(),
81
combined_buffers.deref(),
82
)
83
};
84
} else {
85
unsafe {
86
builder.extend_non_null_views_unchecked(views.into_iter(), combined_buffers.deref())
87
};
88
}
89
builder
90
.freeze_with_dtype(if_true.dtype().clone())
91
.with_validity(validity)
92
}
93
94
fn if_then_else_broadcast_true(
95
mask: &Bitmap,
96
if_true: Self::Scalar<'_>,
97
if_false: &Self,
98
) -> Self {
99
// It's cheaper if we put the false buffers first, that way we don't need to modify any views in the loop.
100
let false_buffers = if_false.data_buffers().iter().cloned();
101
let true_buffer_idx_offset: u32 = if_false.data_buffers().len() as u32;
102
let ([true_view], true_buffer) = make_buffer_and_views([if_true], true_buffer_idx_offset);
103
let combined_buffers: Arc<_> = false_buffers.chain(true_buffer).collect();
104
105
let views = super::if_then_else_loop_broadcast_false(
106
true, // Invert the mask so we effectively broadcast true.
107
mask,
108
if_false.views(),
109
true_view,
110
if_then_else_broadcast_false_view_64,
111
);
112
113
let validity = super::if_then_else_validity(mask, None, if_false.validity());
114
115
let mut builder = MutablePlBinary::with_capacity(views.len());
116
117
unsafe {
118
if has_duplicate_buffers(&combined_buffers) {
119
builder.extend_non_null_views_unchecked_dedupe(
120
views.into_iter(),
121
combined_buffers.deref(),
122
)
123
} else {
124
builder.extend_non_null_views_unchecked(views.into_iter(), combined_buffers.deref())
125
}
126
}
127
builder
128
.freeze_with_dtype(if_false.dtype().clone())
129
.with_validity(validity)
130
}
131
132
fn if_then_else_broadcast_false(
133
mask: &Bitmap,
134
if_true: &Self,
135
if_false: Self::Scalar<'_>,
136
) -> Self {
137
// It's cheaper if we put the true buffers first, that way we don't need to modify any views in the loop.
138
let true_buffers = if_true.data_buffers().iter().cloned();
139
let false_buffer_idx_offset: u32 = if_true.data_buffers().len() as u32;
140
let ([false_view], false_buffer) =
141
make_buffer_and_views([if_false], false_buffer_idx_offset);
142
let combined_buffers: Arc<_> = true_buffers.chain(false_buffer).collect();
143
144
let views = super::if_then_else_loop_broadcast_false(
145
false,
146
mask,
147
if_true.views(),
148
false_view,
149
if_then_else_broadcast_false_view_64,
150
);
151
152
let validity = super::if_then_else_validity(mask, if_true.validity(), None);
153
154
let mut builder = MutablePlBinary::with_capacity(views.len());
155
unsafe {
156
if has_duplicate_buffers(&combined_buffers) {
157
builder.extend_non_null_views_unchecked_dedupe(
158
views.into_iter(),
159
combined_buffers.deref(),
160
)
161
} else {
162
builder.extend_non_null_views_unchecked(views.into_iter(), combined_buffers.deref())
163
}
164
};
165
builder
166
.freeze_with_dtype(if_true.dtype().clone())
167
.with_validity(validity)
168
}
169
170
fn if_then_else_broadcast_both(
171
dtype: ArrowDataType,
172
mask: &Bitmap,
173
if_true: Self::Scalar<'_>,
174
if_false: Self::Scalar<'_>,
175
) -> Self {
176
let ([true_view, false_view], buffer) = make_buffer_and_views([if_true, if_false], 0);
177
let buffers: Arc<_> = buffer.into_iter().collect();
178
let views = super::if_then_else_loop_broadcast_both(
179
mask,
180
true_view,
181
false_view,
182
if_then_else_broadcast_both_scalar_64,
183
);
184
185
let mut builder = MutablePlBinary::with_capacity(views.len());
186
unsafe {
187
if has_duplicate_buffers(&buffers) {
188
builder.extend_non_null_views_unchecked_dedupe(views.into_iter(), buffers.deref())
189
} else {
190
builder.extend_non_null_views_unchecked(views.into_iter(), buffers.deref())
191
}
192
};
193
builder.freeze_with_dtype(dtype)
194
}
195
}
196
197
impl IfThenElseKernel for Utf8ViewArray {
198
type Scalar<'a> = &'a str;
199
200
fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self {
201
let ret =
202
IfThenElseKernel::if_then_else(mask, &if_true.to_binview(), &if_false.to_binview());
203
unsafe { ret.to_utf8view_unchecked() }
204
}
205
206
fn if_then_else_broadcast_true(
207
mask: &Bitmap,
208
if_true: Self::Scalar<'_>,
209
if_false: &Self,
210
) -> Self {
211
let ret = IfThenElseKernel::if_then_else_broadcast_true(
212
mask,
213
if_true.as_bytes(),
214
&if_false.to_binview(),
215
);
216
unsafe { ret.to_utf8view_unchecked() }
217
}
218
219
fn if_then_else_broadcast_false(
220
mask: &Bitmap,
221
if_true: &Self,
222
if_false: Self::Scalar<'_>,
223
) -> Self {
224
let ret = IfThenElseKernel::if_then_else_broadcast_false(
225
mask,
226
&if_true.to_binview(),
227
if_false.as_bytes(),
228
);
229
unsafe { ret.to_utf8view_unchecked() }
230
}
231
232
fn if_then_else_broadcast_both(
233
dtype: ArrowDataType,
234
mask: &Bitmap,
235
if_true: Self::Scalar<'_>,
236
if_false: Self::Scalar<'_>,
237
) -> Self {
238
let ret: BinaryViewArray = IfThenElseKernel::if_then_else_broadcast_both(
239
dtype,
240
mask,
241
if_true.as_bytes(),
242
if_false.as_bytes(),
243
);
244
unsafe { ret.to_utf8view_unchecked() }
245
}
246
}
247
248
pub fn if_then_else_view_rest(
249
mask: u64,
250
if_true: &[View],
251
if_false: &[View],
252
out: &mut [MaybeUninit<View>],
253
false_buffer_idx_offset: u32,
254
) {
255
assert!(if_true.len() <= out.len()); // Removes bounds checks in inner loop.
256
let true_it = if_true.iter();
257
let false_it = if_false.iter();
258
for (i, (t, f)) in true_it.zip(false_it).enumerate() {
259
// Written like this, this loop *should* be branchless.
260
// Unfortunately we're still dependent on the compiler.
261
let m = (mask >> i) & 1 != 0;
262
let src = if m { t } else { f };
263
let mut v = *src;
264
let offset = if m | (v.length <= 12) {
265
// Yes, | instead of || is intentional.
266
0
267
} else {
268
false_buffer_idx_offset
269
};
270
v.buffer_idx += offset;
271
out[i] = MaybeUninit::new(v);
272
}
273
}
274
275
pub fn if_then_else_view_64(
276
mask: u64,
277
if_true: &[View; 64],
278
if_false: &[View; 64],
279
out: &mut [MaybeUninit<View>; 64],
280
false_buffer_idx_offset: u32,
281
) {
282
if_then_else_view_rest(mask, if_true, if_false, out, false_buffer_idx_offset)
283
}
284
285
// Using the scalar variant of this works, but was slower, we want to select a source pointer and
286
// then copy it. Using this version for the integers results in branches.
287
pub fn if_then_else_broadcast_false_view_64(
288
mask: u64,
289
if_true: &[View; 64],
290
if_false: View,
291
out: &mut [MaybeUninit<View>; 64],
292
) {
293
assert!(if_true.len() == out.len()); // Removes bounds checks in inner loop.
294
for (i, t) in if_true.iter().enumerate() {
295
let src = if (mask >> i) & 1 != 0 { t } else { &if_false };
296
out[i] = MaybeUninit::new(*src);
297
}
298
}
299
300