Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/chunked_array/list/sets.rs
6939 views
1
use std::fmt::{Display, Formatter};
2
use std::hash::Hash;
3
4
use arrow::array::{
5
Array, BinaryViewArray, ListArray, MutableArray, MutablePlBinary, MutablePrimitiveArray,
6
PrimitiveArray, Utf8ViewArray,
7
};
8
use arrow::bitmap::Bitmap;
9
use arrow::compute::utils::combine_validities_and;
10
use arrow::offset::OffsetsBuffer;
11
use arrow::types::NativeType;
12
use polars_core::prelude::*;
13
use polars_core::with_match_physical_numeric_type;
14
use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash, TotalOrdWrap};
15
#[cfg(feature = "serde")]
16
use serde::{Deserialize, Serialize};
17
18
trait MaterializeValues<K> {
19
// extends the iterator to the values and returns the current offset
20
fn extend_buf<I: Iterator<Item = K>>(&mut self, values: I) -> usize;
21
}
22
23
impl<T> MaterializeValues<Option<T>> for MutablePrimitiveArray<T>
24
where
25
T: NativeType,
26
{
27
fn extend_buf<I: Iterator<Item = Option<T>>>(&mut self, values: I) -> usize {
28
self.extend(values);
29
self.len()
30
}
31
}
32
33
impl<T> MaterializeValues<TotalOrdWrap<Option<T>>> for MutablePrimitiveArray<T>
34
where
35
T: NativeType,
36
{
37
fn extend_buf<I: Iterator<Item = TotalOrdWrap<Option<T>>>>(&mut self, values: I) -> usize {
38
self.extend(values.map(|x| x.0));
39
self.len()
40
}
41
}
42
43
impl<'a> MaterializeValues<Option<&'a [u8]>> for MutablePlBinary {
44
fn extend_buf<I: Iterator<Item = Option<&'a [u8]>>>(&mut self, values: I) -> usize {
45
self.extend(values);
46
self.len()
47
}
48
}
49
50
fn set_operation<K, I, J, R>(
51
set: &mut PlIndexSet<K>,
52
set2: &mut PlIndexSet<K>,
53
a: I,
54
b: J,
55
out: &mut R,
56
set_op: SetOperation,
57
broadcast_rhs: bool,
58
) -> usize
59
where
60
K: Eq + Hash + Copy,
61
I: IntoIterator<Item = K>,
62
J: IntoIterator<Item = K>,
63
R: MaterializeValues<K>,
64
{
65
set.clear();
66
let a = a.into_iter();
67
let b = b.into_iter();
68
69
match set_op {
70
SetOperation::Intersection => {
71
set.extend(a);
72
// If broadcast `set2` should already be filled.
73
if !broadcast_rhs {
74
set2.clear();
75
set2.extend(b);
76
}
77
out.extend_buf(set.intersection(set2).copied())
78
},
79
SetOperation::Union => {
80
set.extend(a);
81
set.extend(b);
82
out.extend_buf(set.drain(..))
83
},
84
SetOperation::Difference => {
85
set.extend(a);
86
for v in b {
87
set.swap_remove(&v);
88
}
89
out.extend_buf(set.drain(..))
90
},
91
SetOperation::SymmetricDifference => {
92
// If broadcast `set2` should already be filled.
93
if !broadcast_rhs {
94
set2.clear();
95
set2.extend(b);
96
}
97
// We could speed this up, but implementing ourselves, but we need to have a cloneable
98
// iterator as we need 2 passes
99
set.extend(a);
100
out.extend_buf(set.symmetric_difference(set2).copied())
101
},
102
}
103
}
104
105
fn copied_wrapper_opt<T: Copy + TotalEq + TotalHash>(
106
v: Option<&T>,
107
) -> <Option<T> as ToTotalOrd>::TotalOrdItem {
108
v.copied().to_total_ord()
109
}
110
111
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
112
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
113
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
114
pub enum SetOperation {
115
Intersection,
116
Union,
117
Difference,
118
SymmetricDifference,
119
}
120
121
impl Display for SetOperation {
122
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
123
let s = match self {
124
SetOperation::Intersection => "intersection",
125
SetOperation::Union => "union",
126
SetOperation::Difference => "difference",
127
SetOperation::SymmetricDifference => "symmetric_difference",
128
};
129
write!(f, "{s}")
130
}
131
}
132
133
fn primitive<T>(
134
a: &PrimitiveArray<T>,
135
b: &PrimitiveArray<T>,
136
offsets_a: &[i64],
137
offsets_b: &[i64],
138
set_op: SetOperation,
139
validity: Option<Bitmap>,
140
) -> PolarsResult<ListArray<i64>>
141
where
142
T: NativeType + TotalHash + TotalEq + Copy + ToTotalOrd,
143
<Option<T> as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy,
144
{
145
let broadcast_lhs = offsets_a.len() == 2;
146
let broadcast_rhs = offsets_b.len() == 2;
147
148
let mut set = Default::default();
149
let mut set2: PlIndexSet<<Option<T> as ToTotalOrd>::TotalOrdItem> = Default::default();
150
151
let mut values_out = MutablePrimitiveArray::with_capacity(std::cmp::max(
152
*offsets_a.last().unwrap(),
153
*offsets_b.last().unwrap(),
154
) as usize);
155
let mut offsets = Vec::with_capacity(std::cmp::max(offsets_a.len(), offsets_b.len()));
156
offsets.push(0i64);
157
158
let offsets_slice = if offsets_a.len() > offsets_b.len() {
159
offsets_a
160
} else {
161
offsets_b
162
};
163
let first_a = offsets_a[0];
164
let second_a = offsets_a[1];
165
let first_b = offsets_b[0];
166
let second_b = offsets_b[1];
167
if broadcast_rhs {
168
set2.extend(
169
b.into_iter()
170
.skip(first_b as usize)
171
.take(second_b as usize - first_b as usize)
172
.map(copied_wrapper_opt),
173
);
174
}
175
for i in 1..offsets_slice.len() {
176
// If we go OOB we take the first element as we are then broadcasting.
177
let start_a = *offsets_a.get(i - 1).unwrap_or(&first_a) as usize;
178
let end_a = *offsets_a.get(i).unwrap_or(&second_a) as usize;
179
180
let start_b = *offsets_b.get(i - 1).unwrap_or(&first_b) as usize;
181
let end_b = *offsets_b.get(i).unwrap_or(&second_b) as usize;
182
183
// The branches are the same every loop.
184
// We rely on branch prediction here.
185
let offset = if broadcast_rhs {
186
// going via skip iterator instead of slice doesn't heap alloc nor trigger a bitcount
187
let a_iter = a
188
.into_iter()
189
.skip(start_a)
190
.take(end_a - start_a)
191
.map(copied_wrapper_opt);
192
let b_iter = b
193
.into_iter()
194
.skip(first_b as usize)
195
.take(second_b as usize - first_b as usize)
196
.map(copied_wrapper_opt);
197
set_operation(
198
&mut set,
199
&mut set2,
200
a_iter,
201
b_iter,
202
&mut values_out,
203
set_op,
204
true,
205
)
206
} else if broadcast_lhs {
207
let a_iter = a
208
.into_iter()
209
.skip(first_a as usize)
210
.take(second_a as usize - first_a as usize)
211
.map(copied_wrapper_opt);
212
213
let b_iter = b
214
.into_iter()
215
.skip(start_b)
216
.take(end_b - start_b)
217
.map(copied_wrapper_opt);
218
219
set_operation(
220
&mut set,
221
&mut set2,
222
a_iter,
223
b_iter,
224
&mut values_out,
225
set_op,
226
false,
227
)
228
} else {
229
// going via skip iterator instead of slice doesn't heap alloc nor trigger a bitcount
230
let a_iter = a
231
.into_iter()
232
.skip(start_a)
233
.take(end_a - start_a)
234
.map(copied_wrapper_opt);
235
236
let b_iter = b
237
.into_iter()
238
.skip(start_b)
239
.take(end_b - start_b)
240
.map(copied_wrapper_opt);
241
set_operation(
242
&mut set,
243
&mut set2,
244
a_iter,
245
b_iter,
246
&mut values_out,
247
set_op,
248
false,
249
)
250
};
251
252
offsets.push(offset as i64);
253
}
254
let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) };
255
let dtype = ListArray::<i64>::default_datatype(values_out.dtype().clone());
256
257
let values: PrimitiveArray<T> = values_out.into();
258
Ok(ListArray::new(dtype, offsets, values.boxed(), validity))
259
}
260
261
fn binary(
262
a: &BinaryViewArray,
263
b: &BinaryViewArray,
264
offsets_a: &[i64],
265
offsets_b: &[i64],
266
set_op: SetOperation,
267
validity: Option<Bitmap>,
268
as_utf8: bool,
269
) -> PolarsResult<ListArray<i64>> {
270
let broadcast_lhs = offsets_a.len() == 2;
271
let broadcast_rhs = offsets_b.len() == 2;
272
let mut set = Default::default();
273
let mut set2: PlIndexSet<Option<&[u8]>> = Default::default();
274
275
let mut values_out = MutablePlBinary::with_capacity(std::cmp::max(
276
*offsets_a.last().unwrap(),
277
*offsets_b.last().unwrap(),
278
) as usize);
279
let mut offsets = Vec::with_capacity(std::cmp::max(offsets_a.len(), offsets_b.len()));
280
offsets.push(0i64);
281
282
let offsets_slice = if offsets_a.len() > offsets_b.len() {
283
offsets_a
284
} else {
285
offsets_b
286
};
287
let first_a = offsets_a[0];
288
let second_a = offsets_a[1];
289
let first_b = offsets_b[0];
290
let second_b = offsets_b[1];
291
292
if broadcast_rhs {
293
// set2.extend(b_iter)
294
set2.extend(
295
b.into_iter()
296
.skip(first_b as usize)
297
.take(second_b as usize - first_b as usize),
298
);
299
}
300
301
for i in 1..offsets_slice.len() {
302
// If we go OOB we take the first element as we are then broadcasting.
303
let start_a = *offsets_a.get(i - 1).unwrap_or(&first_a) as usize;
304
let end_a = *offsets_a.get(i).unwrap_or(&second_a) as usize;
305
306
let start_b = *offsets_b.get(i - 1).unwrap_or(&first_b) as usize;
307
let end_b = *offsets_b.get(i).unwrap_or(&second_b) as usize;
308
309
// The branches are the same every loop.
310
// We rely on branch prediction here.
311
let offset = if broadcast_rhs {
312
// going via skip iterator instead of slice doesn't heap alloc nor trigger a bitcount
313
let a_iter = a.into_iter().skip(start_a).take(end_a - start_a);
314
let b_iter = b
315
.into_iter()
316
.skip(first_b as usize)
317
.take(second_b as usize - first_b as usize);
318
set_operation(
319
&mut set,
320
&mut set2,
321
a_iter,
322
b_iter,
323
&mut values_out,
324
set_op,
325
true,
326
)
327
} else if broadcast_lhs {
328
let a_iter = a
329
.into_iter()
330
.skip(first_a as usize)
331
.take(second_a as usize - first_a as usize);
332
let b_iter = b.into_iter().skip(start_b).take(end_b - start_b);
333
set_operation(
334
&mut set,
335
&mut set2,
336
a_iter,
337
b_iter,
338
&mut values_out,
339
set_op,
340
false,
341
)
342
} else {
343
// going via skip iterator instead of slice doesn't heap alloc nor trigger a bitcount
344
let a_iter = a.into_iter().skip(start_a).take(end_a - start_a);
345
let b_iter = b.into_iter().skip(start_b).take(end_b - start_b);
346
set_operation(
347
&mut set,
348
&mut set2,
349
a_iter,
350
b_iter,
351
&mut values_out,
352
set_op,
353
false,
354
)
355
};
356
offsets.push(offset as i64);
357
}
358
let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) };
359
let values = values_out.freeze();
360
361
if as_utf8 {
362
let values = unsafe { values.to_utf8view_unchecked() };
363
let dtype = ListArray::<i64>::default_datatype(values.dtype().clone());
364
Ok(ListArray::new(dtype, offsets, values.boxed(), validity))
365
} else {
366
let dtype = ListArray::<i64>::default_datatype(values.dtype().clone());
367
Ok(ListArray::new(dtype, offsets, values.boxed(), validity))
368
}
369
}
370
371
fn array_set_operation(
372
a: &ListArray<i64>,
373
b: &ListArray<i64>,
374
set_op: SetOperation,
375
) -> PolarsResult<ListArray<i64>> {
376
let offsets_a = a.offsets().as_slice();
377
let offsets_b = b.offsets().as_slice();
378
379
let values_a = a.values();
380
let values_b = b.values();
381
assert_eq!(values_a.dtype(), values_b.dtype());
382
383
let dtype = values_b.dtype();
384
let validity = combine_validities_and(a.validity(), b.validity());
385
386
match dtype {
387
ArrowDataType::Utf8View => {
388
let a = values_a
389
.as_any()
390
.downcast_ref::<Utf8ViewArray>()
391
.unwrap()
392
.to_binview();
393
let b = values_b
394
.as_any()
395
.downcast_ref::<Utf8ViewArray>()
396
.unwrap()
397
.to_binview();
398
399
binary(&a, &b, offsets_a, offsets_b, set_op, validity, true)
400
},
401
ArrowDataType::BinaryView => {
402
let a = values_a.as_any().downcast_ref::<BinaryViewArray>().unwrap();
403
let b = values_b.as_any().downcast_ref::<BinaryViewArray>().unwrap();
404
binary(a, b, offsets_a, offsets_b, set_op, validity, false)
405
},
406
ArrowDataType::Boolean => {
407
polars_bail!(InvalidOperation: "boolean type not yet supported in list 'set' operations")
408
},
409
_ => {
410
with_match_physical_numeric_type!(DataType::from_arrow_dtype(dtype), |$T| {
411
let a = values_a.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
412
let b = values_b.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
413
414
primitive(&a, &b, offsets_a, offsets_b, set_op, validity)
415
})
416
},
417
}
418
}
419
420
pub fn list_set_operation(
421
a: &ListChunked,
422
b: &ListChunked,
423
set_op: SetOperation,
424
) -> PolarsResult<ListChunked> {
425
polars_ensure!(a.len() == b.len() || b.len() == 1 || a.len() == 1, ShapeMismatch: "column lengths don't match");
426
polars_ensure!(a.dtype() == b.dtype(), InvalidOperation: "cannot do 'set' operation on dtypes: {} and {}", a.dtype(), b.dtype());
427
let mut a = a.clone();
428
let mut b = b.clone();
429
if a.len() != b.len() {
430
a.rechunk_mut();
431
b.rechunk_mut();
432
}
433
434
// We will OOB in the kernel otherwise.
435
a.prune_empty_chunks();
436
b.prune_empty_chunks();
437
438
// we use the unsafe variant because we want to keep the nested logical types type.
439
unsafe {
440
arity::try_binary_unchecked_same_type(
441
&a,
442
&b,
443
|a, b| array_set_operation(a, b, set_op).map(|arr| arr.boxed()),
444
false,
445
false,
446
)
447
}
448
}
449
450