Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/chunked_array/list/sets.rs
8411 views
1
use std::fmt::{Display, Formatter};
2
use std::hash::Hash;
3
4
use arrow::array::{
5
Array, BinaryViewArray, ListArray, MutableArray, MutablePlBinary, MutablePrimitiveArray,
6
PrimitiveArray, Utf8ViewArray,
7
};
8
use arrow::bitmap::Bitmap;
9
use arrow::compute::utils::combine_validities_and;
10
use arrow::offset::OffsetsBuffer;
11
use arrow::types::NativeType;
12
use polars_core::prelude::*;
13
use polars_core::with_match_physical_numeric_type;
14
use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash, TotalOrdWrap};
15
#[cfg(feature = "serde")]
16
use serde::{Deserialize, Serialize};
17
18
trait MaterializeValues<K> {
19
// extends the iterator to the values and returns the current offset
20
fn extend_buf<I: Iterator<Item = K>>(&mut self, values: I) -> usize;
21
}
22
23
impl<T> MaterializeValues<Option<T>> for MutablePrimitiveArray<T>
24
where
25
T: NativeType,
26
{
27
fn extend_buf<I: Iterator<Item = Option<T>>>(&mut self, values: I) -> usize {
28
self.extend(values);
29
self.len()
30
}
31
}
32
33
impl<T> MaterializeValues<TotalOrdWrap<Option<T>>> for MutablePrimitiveArray<T>
34
where
35
T: NativeType,
36
{
37
fn extend_buf<I: Iterator<Item = TotalOrdWrap<Option<T>>>>(&mut self, values: I) -> usize {
38
self.extend(values.map(|x| x.0));
39
self.len()
40
}
41
}
42
43
impl<'a> MaterializeValues<Option<&'a [u8]>> for MutablePlBinary {
44
fn extend_buf<I: Iterator<Item = Option<&'a [u8]>>>(&mut self, values: I) -> usize {
45
self.extend(values);
46
self.len()
47
}
48
}
49
50
#[allow(clippy::too_many_arguments)]
51
fn set_operation<I, J, K, R>(
52
set: &mut PlIndexSet<K>,
53
set2: &mut PlIndexSet<K>,
54
a: &mut I,
55
b: &mut J,
56
out: &mut R,
57
set_op: SetOperation,
58
broadcast_rhs: bool,
59
) -> usize
60
where
61
K: Eq + Hash + Copy,
62
I: Iterator<Item = K>,
63
J: Iterator<Item = K>,
64
R: MaterializeValues<K>,
65
{
66
set.clear();
67
68
match set_op {
69
SetOperation::Intersection => {
70
set.extend(a);
71
// If broadcast `set2` should already be filled.
72
if !broadcast_rhs {
73
set2.clear();
74
set2.extend(b);
75
}
76
out.extend_buf(set.intersection(set2).copied())
77
},
78
SetOperation::Union => {
79
set.extend(a);
80
set.extend(b);
81
out.extend_buf(set.drain(..))
82
},
83
SetOperation::Difference => {
84
set.extend(a);
85
for v in b {
86
set.swap_remove(&v);
87
}
88
out.extend_buf(set.drain(..))
89
},
90
SetOperation::SymmetricDifference => {
91
// If broadcast `set2` should already be filled.
92
if !broadcast_rhs {
93
set2.clear();
94
set2.extend(b);
95
}
96
// We could speed this up, but implementing ourselves, but we need to have a cloneable
97
// iterator as we need 2 passes
98
set.extend(a);
99
out.extend_buf(set.symmetric_difference(set2).copied())
100
},
101
}
102
}
103
104
fn copied_wrapper_opt<T: Copy + TotalEq + TotalHash>(
105
v: Option<&T>,
106
) -> <Option<T> as ToTotalOrd>::TotalOrdItem {
107
v.copied().to_total_ord()
108
}
109
110
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
111
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
112
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
113
pub enum SetOperation {
114
Intersection,
115
Union,
116
Difference,
117
SymmetricDifference,
118
}
119
120
impl Display for SetOperation {
121
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
122
let s = match self {
123
SetOperation::Intersection => "intersection",
124
SetOperation::Union => "union",
125
SetOperation::Difference => "difference",
126
SetOperation::SymmetricDifference => "symmetric_difference",
127
};
128
write!(f, "{s}")
129
}
130
}
131
132
fn primitive<T>(
133
a: &PrimitiveArray<T>,
134
b: &PrimitiveArray<T>,
135
offsets_a: &[i64],
136
offsets_b: &[i64],
137
set_op: SetOperation,
138
validity: Option<Bitmap>,
139
) -> PolarsResult<ListArray<i64>>
140
where
141
T: NativeType + TotalHash + TotalEq + Copy + ToTotalOrd,
142
<Option<T> as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy,
143
{
144
let broadcast_lhs = offsets_a.len() == 2;
145
let broadcast_rhs = offsets_b.len() == 2;
146
147
let mut set = Default::default();
148
let mut set2: PlIndexSet<<Option<T> as ToTotalOrd>::TotalOrdItem> = Default::default();
149
150
let mut values_out = MutablePrimitiveArray::with_capacity(std::cmp::max(
151
*offsets_a.last().unwrap(),
152
*offsets_b.last().unwrap(),
153
) as usize);
154
let mut offsets = Vec::with_capacity(std::cmp::max(offsets_a.len(), offsets_b.len()));
155
offsets.push(0i64);
156
157
let offsets_slice = if offsets_a.len() > offsets_b.len() {
158
offsets_a
159
} else {
160
offsets_b
161
};
162
let first_a = offsets_a[0];
163
let second_a = offsets_a[1];
164
let first_b = offsets_b[0];
165
let second_b = offsets_b[1];
166
if broadcast_rhs {
167
set2.extend(
168
b.into_iter()
169
.skip(first_b as usize)
170
.take(second_b as usize - first_b as usize)
171
.map(copied_wrapper_opt),
172
);
173
}
174
175
let mut iter_a = a.into_iter().skip(first_a as usize);
176
let mut iter_b = b.into_iter().skip(first_b as usize);
177
178
for i in 1..offsets_slice.len() {
179
// If we go OOB we take the first element as we are then broadcasting.
180
let start_a = *offsets_a.get(i - 1).unwrap_or(&first_a) as usize;
181
let end_a = *offsets_a.get(i).unwrap_or(&second_a) as usize;
182
183
let start_b = *offsets_b.get(i - 1).unwrap_or(&first_b) as usize;
184
let end_b = *offsets_b.get(i).unwrap_or(&second_b) as usize;
185
186
let mut iter_a_broadcast = iter_a.clone();
187
let mut iter_b_broadcast = iter_b.clone();
188
189
// The branches are the same every loop.
190
// We rely on branch prediction here.
191
let mut iter_a = if broadcast_lhs {
192
iter_a_broadcast
193
.by_ref()
194
.take(second_a as usize - first_a as usize)
195
.map(copied_wrapper_opt)
196
} else {
197
iter_a
198
.by_ref()
199
.take(end_a - start_a)
200
.map(copied_wrapper_opt)
201
};
202
let mut iter_b = if broadcast_rhs {
203
iter_b_broadcast
204
.by_ref()
205
.take(second_b as usize - first_b as usize)
206
.map(copied_wrapper_opt)
207
} else {
208
iter_b
209
.by_ref()
210
.take(end_b - start_b)
211
.map(copied_wrapper_opt)
212
};
213
214
let offset = set_operation(
215
&mut set,
216
&mut set2,
217
&mut iter_a,
218
&mut iter_b,
219
&mut values_out,
220
set_op,
221
broadcast_rhs,
222
);
223
224
assert!(iter_a.next().is_none());
225
if !broadcast_rhs || matches!(set_op, SetOperation::Union | SetOperation::Difference) {
226
assert!(iter_b.next().is_none());
227
};
228
229
offsets.push(offset as i64);
230
}
231
let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) };
232
let dtype = ListArray::<i64>::default_datatype(values_out.dtype().clone());
233
234
let values: PrimitiveArray<T> = values_out.into();
235
Ok(ListArray::new(dtype, offsets, values.boxed(), validity))
236
}
237
238
fn binary(
239
a: &BinaryViewArray,
240
b: &BinaryViewArray,
241
offsets_a: &[i64],
242
offsets_b: &[i64],
243
set_op: SetOperation,
244
validity: Option<Bitmap>,
245
as_utf8: bool,
246
) -> PolarsResult<ListArray<i64>> {
247
let broadcast_lhs = offsets_a.len() == 2;
248
let broadcast_rhs = offsets_b.len() == 2;
249
let mut set: PlIndexSet<Option<&[u8]>> = Default::default();
250
let mut set2: PlIndexSet<Option<&[u8]>> = Default::default();
251
252
let mut values_out = MutablePlBinary::with_capacity(std::cmp::max(
253
*offsets_a.last().unwrap(),
254
*offsets_b.last().unwrap(),
255
) as usize);
256
let mut offsets = Vec::with_capacity(std::cmp::max(offsets_a.len(), offsets_b.len()));
257
offsets.push(0i64);
258
259
let offsets_slice = if offsets_a.len() > offsets_b.len() {
260
offsets_a
261
} else {
262
offsets_b
263
};
264
let first_a = offsets_a[0];
265
let second_a = offsets_a[1];
266
let first_b = offsets_b[0];
267
let second_b = offsets_b[1];
268
269
if broadcast_rhs {
270
// set2.extend(b_iter)
271
set2.extend(
272
b.into_iter()
273
.skip(first_b as usize)
274
.take(second_b as usize - first_b as usize),
275
);
276
}
277
278
let mut iter_a = a.into_iter().skip(first_a as usize);
279
let mut iter_b = b.into_iter().skip(first_b as usize);
280
281
for i in 1..offsets_slice.len() {
282
// If we go OOB we take the first element as we are then broadcasting.
283
let start_a = *offsets_a.get(i - 1).unwrap_or(&first_a) as usize;
284
let end_a = *offsets_a.get(i).unwrap_or(&second_a) as usize;
285
286
let start_b = *offsets_b.get(i - 1).unwrap_or(&first_b) as usize;
287
let end_b = *offsets_b.get(i).unwrap_or(&second_b) as usize;
288
289
let mut iter_a_broadcast = iter_a.clone();
290
let mut iter_b_broadcast = iter_b.clone();
291
292
// The branches are the same every loop.
293
// We rely on branch prediction here.
294
let mut iter_a = if broadcast_lhs {
295
iter_a_broadcast
296
.by_ref()
297
.take(second_a as usize - first_a as usize)
298
} else {
299
iter_a.by_ref().take(end_a - start_a)
300
};
301
let mut iter_b = if broadcast_rhs {
302
iter_b_broadcast
303
.by_ref()
304
.take(second_b as usize - first_b as usize)
305
} else {
306
iter_b.by_ref().take(end_b - start_b)
307
};
308
309
let offset = set_operation(
310
&mut set,
311
&mut set2,
312
&mut iter_a,
313
&mut iter_b,
314
&mut values_out,
315
set_op,
316
broadcast_rhs,
317
);
318
319
assert!(iter_a.next().is_none());
320
if !broadcast_rhs || matches!(set_op, SetOperation::Union | SetOperation::Difference) {
321
assert!(iter_b.next().is_none());
322
};
323
324
offsets.push(offset as i64);
325
}
326
let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) };
327
let values = values_out.freeze();
328
329
if as_utf8 {
330
let values = unsafe { values.to_utf8view_unchecked() };
331
let dtype = ListArray::<i64>::default_datatype(values.dtype().clone());
332
Ok(ListArray::new(dtype, offsets, values.boxed(), validity))
333
} else {
334
let dtype = ListArray::<i64>::default_datatype(values.dtype().clone());
335
Ok(ListArray::new(dtype, offsets, values.boxed(), validity))
336
}
337
}
338
339
fn array_set_operation(
340
a: &ListArray<i64>,
341
b: &ListArray<i64>,
342
set_op: SetOperation,
343
) -> PolarsResult<ListArray<i64>> {
344
let offsets_a = a.offsets().as_slice();
345
let offsets_b = b.offsets().as_slice();
346
347
let values_a = a.values();
348
let values_b = b.values();
349
assert_eq!(values_a.dtype(), values_b.dtype());
350
351
let dtype = values_b.dtype();
352
let validity = combine_validities_and(a.validity(), b.validity());
353
354
match dtype {
355
ArrowDataType::Utf8View => {
356
let a = values_a
357
.as_any()
358
.downcast_ref::<Utf8ViewArray>()
359
.unwrap()
360
.to_binview();
361
let b = values_b
362
.as_any()
363
.downcast_ref::<Utf8ViewArray>()
364
.unwrap()
365
.to_binview();
366
367
binary(&a, &b, offsets_a, offsets_b, set_op, validity, true)
368
},
369
ArrowDataType::BinaryView => {
370
let a = values_a.as_any().downcast_ref::<BinaryViewArray>().unwrap();
371
let b = values_b.as_any().downcast_ref::<BinaryViewArray>().unwrap();
372
binary(a, b, offsets_a, offsets_b, set_op, validity, false)
373
},
374
ArrowDataType::Boolean => {
375
polars_bail!(InvalidOperation: "boolean type not yet supported in list 'set' operations")
376
},
377
_ => {
378
with_match_physical_numeric_type!(DataType::from_arrow_dtype(dtype), |$T| {
379
let a = values_a.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
380
let b = values_b.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
381
382
primitive(&a, &b, offsets_a, offsets_b, set_op, validity)
383
})
384
},
385
}
386
}
387
388
pub fn list_set_operation(
389
a: &ListChunked,
390
b: &ListChunked,
391
set_op: SetOperation,
392
) -> PolarsResult<ListChunked> {
393
polars_ensure!(a.len() == b.len() || b.len() == 1 || a.len() == 1, ShapeMismatch: "column lengths don't match");
394
polars_ensure!(a.dtype() == b.dtype(), InvalidOperation: "cannot do 'set' operation on dtypes: {} and {}", a.dtype(), b.dtype());
395
let mut a = a.clone();
396
let mut b = b.clone();
397
if a.len() != b.len() {
398
a.rechunk_mut();
399
b.rechunk_mut();
400
}
401
402
// We will OOB in the kernel otherwise.
403
a.prune_empty_chunks();
404
b.prune_empty_chunks();
405
406
// we use the unsafe variant because we want to keep the nested logical types type.
407
unsafe {
408
arity::try_binary_unchecked_same_type(
409
&a,
410
&b,
411
|a, b| array_set_operation(a, b, set_op).map(|arr| arr.boxed()),
412
false,
413
false,
414
)
415
}
416
}
417
418