Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/dispatch/list.rs
7884 views
1
use std::sync::Arc;
2
3
use polars_core::error::{PolarsResult, polars_bail, polars_ensure};
4
use polars_core::prelude::{
5
ChunkExpandAtIndex, Column, DataType, IDX_DTYPE, IntoColumn, ListChunked, SortOptions,
6
};
7
use polars_core::utils::CustomIterTools;
8
use polars_ops::prelude::ListNameSpaceImpl;
9
use polars_plan::dsl::{ColumnsUdf, ReshapeDimension, SpecialEq};
10
use polars_plan::plans::IRListFunction;
11
use polars_utils::pl_str::PlSmallStr;
12
13
pub fn function_expr_to_udf(func: IRListFunction) -> SpecialEq<Arc<dyn ColumnsUdf>> {
14
use IRListFunction::*;
15
match func {
16
Concat => wrap!(concat),
17
#[cfg(feature = "is_in")]
18
Contains { nulls_equal } => map_as_slice!(contains, nulls_equal),
19
#[cfg(feature = "list_drop_nulls")]
20
DropNulls => map!(drop_nulls),
21
#[cfg(feature = "list_sample")]
22
Sample {
23
is_fraction,
24
with_replacement,
25
shuffle,
26
seed,
27
} => {
28
if is_fraction {
29
map_as_slice!(sample_fraction, with_replacement, shuffle, seed)
30
} else {
31
map_as_slice!(sample_n, with_replacement, shuffle, seed)
32
}
33
},
34
Slice => wrap!(slice),
35
Shift => map_as_slice!(shift),
36
Get(null_on_oob) => wrap!(get, null_on_oob),
37
#[cfg(feature = "list_gather")]
38
Gather(null_on_oob) => map_as_slice!(gather, null_on_oob),
39
#[cfg(feature = "list_gather")]
40
GatherEvery => map_as_slice!(gather_every),
41
#[cfg(feature = "list_count")]
42
CountMatches => map_as_slice!(count_matches),
43
Sum => map!(sum),
44
Length => map!(length),
45
Max => map!(max),
46
Min => map!(min),
47
Mean => map!(mean),
48
Median => map!(median),
49
Std(ddof) => map!(std, ddof),
50
Var(ddof) => map!(var, ddof),
51
ArgMin => map!(arg_min),
52
ArgMax => map!(arg_max),
53
#[cfg(feature = "diff")]
54
Diff { n, null_behavior } => map!(diff, n, null_behavior),
55
Sort(options) => map!(sort, options),
56
Reverse => map!(reverse),
57
Unique(is_stable) => map!(unique, is_stable),
58
#[cfg(feature = "list_sets")]
59
SetOperation(s) => map_as_slice!(set_operation, s),
60
#[cfg(feature = "list_any_all")]
61
Any => map!(lst_any),
62
#[cfg(feature = "list_any_all")]
63
All => map!(lst_all),
64
Join(ignore_nulls) => map_as_slice!(join, ignore_nulls),
65
#[cfg(feature = "dtype-array")]
66
ToArray(width) => map!(to_array, width),
67
NUnique => map!(n_unique),
68
#[cfg(feature = "list_to_struct")]
69
ToStruct(names) => map!(to_struct, &names),
70
}
71
}
72
73
#[cfg(feature = "is_in")]
74
pub(super) fn contains(args: &mut [Column], nulls_equal: bool) -> PolarsResult<Column> {
75
let list = &args[0];
76
let item = &args[1];
77
polars_ensure!(matches!(list.dtype(), DataType::List(_)),
78
SchemaMismatch: "invalid series dtype: expected `List`, got `{}`", list.dtype(),
79
);
80
let mut ca = polars_ops::prelude::is_in(
81
item.as_materialized_series(),
82
list.as_materialized_series(),
83
nulls_equal,
84
)?;
85
ca.rename(list.name().clone());
86
Ok(ca.into_column())
87
}
88
89
#[cfg(feature = "list_drop_nulls")]
90
pub(super) fn drop_nulls(s: &Column) -> PolarsResult<Column> {
91
let list = s.list()?;
92
Ok(list.lst_drop_nulls().into_column())
93
}
94
95
#[cfg(feature = "list_sample")]
96
pub(super) fn sample_n(
97
s: &[Column],
98
with_replacement: bool,
99
shuffle: bool,
100
seed: Option<u64>,
101
) -> PolarsResult<Column> {
102
let list = s[0].list()?;
103
let n = &s[1];
104
list.lst_sample_n(n.as_materialized_series(), with_replacement, shuffle, seed)
105
.map(|ok| ok.into_column())
106
}
107
108
#[cfg(feature = "list_sample")]
109
pub(super) fn sample_fraction(
110
s: &[Column],
111
with_replacement: bool,
112
shuffle: bool,
113
seed: Option<u64>,
114
) -> PolarsResult<Column> {
115
let list = s[0].list()?;
116
let fraction = &s[1];
117
list.lst_sample_fraction(
118
fraction.as_materialized_series(),
119
with_replacement,
120
shuffle,
121
seed,
122
)
123
.map(|ok| ok.into_column())
124
}
125
126
fn check_slice_arg_shape(slice_len: usize, ca_len: usize, name: &str) -> PolarsResult<()> {
127
polars_ensure!(
128
slice_len == ca_len,
129
ComputeError:
130
"shape of the slice '{}' argument: {} does not match that of the list column: {}",
131
name, slice_len, ca_len
132
);
133
Ok(())
134
}
135
136
pub(super) fn shift(s: &[Column]) -> PolarsResult<Column> {
137
let list = s[0].list()?;
138
let periods = &s[1];
139
140
list.lst_shift(periods).map(|ok| ok.into_column())
141
}
142
143
pub(super) fn slice(args: &mut [Column]) -> PolarsResult<Column> {
144
let s = &args[0];
145
let list_ca = s.list()?;
146
let offset_s = &args[1];
147
let length_s = &args[2];
148
149
let mut out: ListChunked = match (offset_s.len(), length_s.len()) {
150
(1, 1) => {
151
let offset = offset_s.get(0).unwrap().try_extract::<i64>()?;
152
let slice_len = length_s
153
.get(0)
154
.unwrap()
155
.extract::<usize>()
156
.unwrap_or(usize::MAX);
157
return Ok(list_ca.lst_slice(offset, slice_len).into_column());
158
},
159
(1, length_slice_len) => {
160
check_slice_arg_shape(length_slice_len, list_ca.len(), "length")?;
161
let offset = offset_s.get(0).unwrap().try_extract::<i64>()?;
162
// cast to i64 as it is more likely that it is that dtype
163
// instead of usize/u64 (we never need that max length)
164
let length_ca = length_s.cast(&DataType::Int64)?;
165
let length_ca = length_ca.i64().unwrap();
166
167
list_ca
168
.amortized_iter()
169
.zip(length_ca)
170
.map(|(opt_s, opt_length)| match (opt_s, opt_length) {
171
(Some(s), Some(length)) => Some(s.as_ref().slice(offset, length as usize)),
172
_ => None,
173
})
174
.collect_trusted()
175
},
176
(offset_len, 1) => {
177
check_slice_arg_shape(offset_len, list_ca.len(), "offset")?;
178
let length_slice = length_s
179
.get(0)
180
.unwrap()
181
.extract::<usize>()
182
.unwrap_or(usize::MAX);
183
let offset_ca = offset_s.cast(&DataType::Int64)?;
184
let offset_ca = offset_ca.i64().unwrap();
185
list_ca
186
.amortized_iter()
187
.zip(offset_ca)
188
.map(|(opt_s, opt_offset)| match (opt_s, opt_offset) {
189
(Some(s), Some(offset)) => Some(s.as_ref().slice(offset, length_slice)),
190
_ => None,
191
})
192
.collect_trusted()
193
},
194
_ => {
195
check_slice_arg_shape(offset_s.len(), list_ca.len(), "offset")?;
196
check_slice_arg_shape(length_s.len(), list_ca.len(), "length")?;
197
let offset_ca = offset_s.cast(&DataType::Int64)?;
198
let offset_ca = offset_ca.i64()?;
199
// cast to i64 as it is more likely that it is that dtype
200
// instead of usize/u64 (we never need that max length)
201
let length_ca = length_s.cast(&DataType::Int64)?;
202
let length_ca = length_ca.i64().unwrap();
203
204
list_ca
205
.amortized_iter()
206
.zip(offset_ca)
207
.zip(length_ca)
208
.map(
209
|((opt_s, opt_offset), opt_length)| match (opt_s, opt_offset, opt_length) {
210
(Some(s), Some(offset), Some(length)) => {
211
Some(s.as_ref().slice(offset, length as usize))
212
},
213
_ => None,
214
},
215
)
216
.collect_trusted()
217
},
218
};
219
out.rename(s.name().clone());
220
Ok(out.into_column())
221
}
222
223
pub(super) fn concat(s: &mut [Column]) -> PolarsResult<Column> {
224
let mut first = std::mem::take(&mut s[0]);
225
let other = &s[1..];
226
227
// TODO! don't auto cast here, but implode beforehand.
228
let mut first_ca = match first.try_list() {
229
Some(ca) => ca,
230
None => {
231
first = first
232
.reshape_list(&[ReshapeDimension::Infer, ReshapeDimension::new_dimension(1)])
233
.unwrap();
234
first.list().unwrap()
235
},
236
}
237
.clone();
238
239
if first_ca.len() == 1 && !other.is_empty() {
240
let max_len = other.iter().map(|s| s.len()).max().unwrap();
241
if max_len > 1 {
242
first_ca = first_ca.new_from_index(0, max_len)
243
}
244
}
245
246
first_ca.lst_concat(other).map(IntoColumn::into_column)
247
}
248
249
pub(super) fn get(s: &mut [Column], null_on_oob: bool) -> PolarsResult<Column> {
250
let ca = s[0].list()?;
251
let index = s[1].cast(&DataType::Int64)?;
252
let index = index.i64().unwrap();
253
254
polars_ops::prelude::lst_get(ca, index, null_on_oob)
255
}
256
257
#[cfg(feature = "list_gather")]
258
pub(super) fn gather(args: &[Column], null_on_oob: bool) -> PolarsResult<Column> {
259
let ca = &args[0];
260
let idx = &args[1];
261
let ca = ca.list()?;
262
263
if idx.len() == 1 && idx.dtype().is_primitive_numeric() && null_on_oob {
264
// fast path
265
let idx = idx.get(0)?.try_extract::<i64>()?;
266
let out = ca.lst_get(idx, null_on_oob).map(Column::from)?;
267
// make sure we return a list
268
out.reshape_list(&[ReshapeDimension::Infer, ReshapeDimension::new_dimension(1)])
269
} else {
270
ca.lst_gather(idx.as_materialized_series(), null_on_oob)
271
.map(Column::from)
272
}
273
}
274
275
#[cfg(feature = "list_gather")]
276
pub(super) fn gather_every(args: &[Column]) -> PolarsResult<Column> {
277
let ca = &args[0];
278
let n = &args[1].strict_cast(&IDX_DTYPE)?;
279
let offset = &args[2].strict_cast(&IDX_DTYPE)?;
280
281
ca.list()?
282
.lst_gather_every(n.idx()?, offset.idx()?)
283
.map(Column::from)
284
}
285
286
#[cfg(feature = "list_count")]
287
pub(super) fn count_matches(args: &[Column]) -> PolarsResult<Column> {
288
let s = &args[0];
289
let element = &args[1];
290
polars_ensure!(
291
element.len() == 1,
292
ComputeError: "argument expression in `list.count_matches` must produce exactly one element, got {}",
293
element.len()
294
);
295
let ca = s.list()?;
296
polars_ops::prelude::list_count_matches(ca, element.get(0).unwrap()).map(Column::from)
297
}
298
299
pub(super) fn sum(s: &Column) -> PolarsResult<Column> {
300
s.list()?.lst_sum().map(Column::from)
301
}
302
303
pub(super) fn length(s: &Column) -> PolarsResult<Column> {
304
Ok(s.list()?.lst_lengths().into_column())
305
}
306
307
pub(super) fn max(s: &Column) -> PolarsResult<Column> {
308
s.list()?.lst_max().map(Column::from)
309
}
310
311
pub(super) fn min(s: &Column) -> PolarsResult<Column> {
312
s.list()?.lst_min().map(Column::from)
313
}
314
315
pub(super) fn mean(s: &Column) -> PolarsResult<Column> {
316
Ok(s.list()?.lst_mean().into())
317
}
318
319
pub(super) fn median(s: &Column) -> PolarsResult<Column> {
320
Ok(s.list()?.lst_median().into())
321
}
322
323
pub(super) fn std(s: &Column, ddof: u8) -> PolarsResult<Column> {
324
Ok(s.list()?.lst_std(ddof).into())
325
}
326
327
pub(super) fn var(s: &Column, ddof: u8) -> PolarsResult<Column> {
328
Ok(s.list()?.lst_var(ddof)?.into())
329
}
330
331
pub(super) fn arg_min(s: &Column) -> PolarsResult<Column> {
332
Ok(s.list()?.lst_arg_min().into_column())
333
}
334
335
pub(super) fn arg_max(s: &Column) -> PolarsResult<Column> {
336
Ok(s.list()?.lst_arg_max().into_column())
337
}
338
339
#[cfg(feature = "diff")]
340
pub(super) fn diff(
341
s: &Column,
342
n: i64,
343
null_behavior: polars_core::series::ops::NullBehavior,
344
) -> PolarsResult<Column> {
345
Ok(s.list()?.lst_diff(n, null_behavior)?.into_column())
346
}
347
348
pub(super) fn sort(s: &Column, options: SortOptions) -> PolarsResult<Column> {
349
Ok(s.list()?.lst_sort(options)?.into_column())
350
}
351
352
pub(super) fn reverse(s: &Column) -> PolarsResult<Column> {
353
Ok(s.list()?.lst_reverse().into_column())
354
}
355
356
pub(super) fn unique(s: &Column, is_stable: bool) -> PolarsResult<Column> {
357
if is_stable {
358
Ok(s.list()?.lst_unique_stable()?.into_column())
359
} else {
360
Ok(s.list()?.lst_unique()?.into_column())
361
}
362
}
363
364
#[cfg(feature = "list_sets")]
365
pub(super) fn set_operation(
366
s: &[Column],
367
set_type: polars_ops::prelude::SetOperation,
368
) -> PolarsResult<Column> {
369
let s0 = &s[0];
370
let s1 = &s[1];
371
372
if s0.is_empty() || s1.is_empty() {
373
use polars_ops::prelude::SetOperation;
374
375
return match set_type {
376
SetOperation::Intersection => {
377
if s0.is_empty() {
378
Ok(s0.clone())
379
} else {
380
Ok(s1.clone().with_name(s0.name().clone()))
381
}
382
},
383
SetOperation::Difference => Ok(s0.clone()),
384
SetOperation::Union | SetOperation::SymmetricDifference => {
385
if s0.is_empty() {
386
Ok(s1.clone().with_name(s0.name().clone()))
387
} else {
388
Ok(s0.clone())
389
}
390
},
391
};
392
}
393
394
polars_ops::prelude::list_set_operation(s0.list()?, s1.list()?, set_type)
395
.map(|ca| ca.into_column())
396
}
397
398
#[cfg(feature = "list_any_all")]
399
pub(super) fn lst_any(s: &Column) -> PolarsResult<Column> {
400
s.list()?.lst_any().map(Column::from)
401
}
402
403
#[cfg(feature = "list_any_all")]
404
pub(super) fn lst_all(s: &Column) -> PolarsResult<Column> {
405
s.list()?.lst_all().map(Column::from)
406
}
407
408
pub(super) fn join(s: &[Column], ignore_nulls: bool) -> PolarsResult<Column> {
409
let ca = s[0].list()?;
410
let separator = s[1].str()?;
411
Ok(ca.lst_join(separator, ignore_nulls)?.into_column())
412
}
413
414
#[cfg(feature = "dtype-array")]
415
pub(super) fn to_array(s: &Column, width: usize) -> PolarsResult<Column> {
416
if let DataType::List(inner) = s.dtype() {
417
s.cast(&DataType::Array(inner.clone(), width))
418
} else {
419
polars_bail!(ComputeError: "expected List dtype")
420
}
421
}
422
423
#[cfg(feature = "list_to_struct")]
424
pub(super) fn to_struct(s: &Column, names: &Arc<[PlSmallStr]>) -> PolarsResult<Column> {
425
use polars_ops::prelude::ToStruct;
426
427
let args = polars_ops::prelude::ListToStructArgs::FixedWidth(names.clone());
428
Ok(s.list()?.to_struct(&args)?.into_column())
429
}
430
431
pub(super) fn n_unique(s: &Column) -> PolarsResult<Column> {
432
Ok(s.list()?.lst_n_unique()?.into_column())
433
}
434
435