Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/dispatch/array.rs
7884 views
1
use polars_core::error::{PolarsResult, polars_bail, polars_ensure, polars_err};
2
use polars_core::prelude::{Column, DataType, ExplodeOptions, IntoColumn, SortOptions};
3
use polars_ops::prelude::array::ArrayNameSpace;
4
#[cfg(feature = "array_to_struct")]
5
use polars_plan::dsl::DslNameGenerator;
6
use polars_plan::dsl::{ColumnsUdf, SpecialEq};
7
use polars_plan::plans::IRArrayFunction;
8
use polars_utils::pl_str::PlSmallStr;
9
10
use super::*;
11
12
pub fn function_expr_to_udf(func: IRArrayFunction) -> SpecialEq<Arc<dyn ColumnsUdf>> {
13
use IRArrayFunction::*;
14
match func {
15
Concat => map_as_slice!(concat_arr),
16
Length => map!(length),
17
Min => map!(min),
18
Max => map!(max),
19
Sum => map!(sum),
20
ToList => map!(to_list),
21
Unique(stable) => map!(unique, stable),
22
NUnique => map!(n_unique),
23
Std(ddof) => map!(std, ddof),
24
Var(ddof) => map!(var, ddof),
25
Mean => map!(mean),
26
Median => map!(median),
27
#[cfg(feature = "array_any_all")]
28
Any => map!(any),
29
#[cfg(feature = "array_any_all")]
30
All => map!(all),
31
Sort(options) => map!(sort, options),
32
Reverse => map!(reverse),
33
ArgMin => map!(arg_min),
34
ArgMax => map!(arg_max),
35
Get(null_on_oob) => map_as_slice!(get, null_on_oob),
36
Join(ignore_nulls) => map_as_slice!(join, ignore_nulls),
37
#[cfg(feature = "is_in")]
38
Contains { nulls_equal } => map_as_slice!(contains, nulls_equal),
39
#[cfg(feature = "array_count")]
40
CountMatches => map_as_slice!(count_matches),
41
Shift => map_as_slice!(shift),
42
Explode(options) => map_as_slice!(explode, options),
43
Slice(offset, length) => map!(slice, offset, length),
44
#[cfg(feature = "array_to_struct")]
45
ToStruct(ng) => map!(arr_to_struct, ng.clone()),
46
}
47
}
48
49
pub(super) fn length(s: &Column) -> PolarsResult<Column> {
50
let array = s.array()?;
51
let width = array.width();
52
let width = IdxSize::try_from(width)
53
.map_err(|_| polars_err!(bigidx, ctx = "array length", size = width))?;
54
55
let mut c = Column::new_scalar(array.name().clone(), width.into(), array.len());
56
if let Some(validity) = array.rechunk_validity() {
57
let mut series = c.into_materialized_series().clone();
58
59
// SAFETY: We keep datatypes intact and call compute_len afterwards.
60
let chunks = unsafe { series.chunks_mut() };
61
assert_eq!(chunks.len(), 1);
62
63
chunks[0] = chunks[0].with_validity(Some(validity));
64
65
series.compute_len();
66
c = series.into_column();
67
}
68
69
Ok(c)
70
}
71
72
pub(super) fn max(s: &Column) -> PolarsResult<Column> {
73
Ok(s.array()?.array_max().into())
74
}
75
76
pub(super) fn min(s: &Column) -> PolarsResult<Column> {
77
Ok(s.array()?.array_min().into())
78
}
79
80
pub(super) fn sum(s: &Column) -> PolarsResult<Column> {
81
s.array()?.array_sum().map(Column::from)
82
}
83
84
pub(super) fn std(s: &Column, ddof: u8) -> PolarsResult<Column> {
85
s.array()?.array_std(ddof).map(Column::from)
86
}
87
88
pub(super) fn var(s: &Column, ddof: u8) -> PolarsResult<Column> {
89
s.array()?.array_var(ddof).map(Column::from)
90
}
91
92
pub(super) fn mean(s: &Column) -> PolarsResult<Column> {
93
s.array()?.array_mean().map(Column::from)
94
}
95
96
pub(super) fn median(s: &Column) -> PolarsResult<Column> {
97
s.array()?.array_median().map(Column::from)
98
}
99
100
pub(super) fn unique(s: &Column, stable: bool) -> PolarsResult<Column> {
101
let ca = s.array()?;
102
let out = if stable {
103
ca.array_unique_stable()
104
} else {
105
ca.array_unique()
106
};
107
out.map(|ca| ca.into_column())
108
}
109
110
pub(super) fn n_unique(s: &Column) -> PolarsResult<Column> {
111
Ok(s.array()?.array_n_unique()?.into_column())
112
}
113
114
pub(super) fn to_list(s: &Column) -> PolarsResult<Column> {
115
if let DataType::Array(inner, _) = s.dtype() {
116
s.cast(&DataType::List(inner.clone()))
117
} else {
118
polars_bail!(ComputeError: "expected array dtype")
119
}
120
}
121
122
#[cfg(feature = "array_any_all")]
123
pub(super) fn any(s: &Column) -> PolarsResult<Column> {
124
s.array()?.array_any().map(Column::from)
125
}
126
127
#[cfg(feature = "array_any_all")]
128
pub(super) fn all(s: &Column) -> PolarsResult<Column> {
129
s.array()?.array_all().map(Column::from)
130
}
131
132
pub(super) fn sort(s: &Column, options: SortOptions) -> PolarsResult<Column> {
133
Ok(s.array()?.array_sort(options)?.into_column())
134
}
135
136
pub(super) fn reverse(s: &Column) -> PolarsResult<Column> {
137
Ok(s.array()?.array_reverse().into_column())
138
}
139
140
pub(super) fn arg_min(s: &Column) -> PolarsResult<Column> {
141
Ok(s.array()?.array_arg_min().into_column())
142
}
143
144
pub(super) fn arg_max(s: &Column) -> PolarsResult<Column> {
145
Ok(s.array()?.array_arg_max().into_column())
146
}
147
148
pub(super) fn get(s: &[Column], null_on_oob: bool) -> PolarsResult<Column> {
149
let ca = s[0].array()?;
150
let index = s[1].cast(&DataType::Int64)?;
151
let index = index.i64().unwrap();
152
ca.array_get(index, null_on_oob).map(Column::from)
153
}
154
155
pub(super) fn join(s: &[Column], ignore_nulls: bool) -> PolarsResult<Column> {
156
let ca = s[0].array()?;
157
let separator = s[1].str()?;
158
ca.array_join(separator, ignore_nulls).map(Column::from)
159
}
160
161
#[cfg(feature = "is_in")]
162
pub(super) fn contains(s: &[Column], nulls_equal: bool) -> PolarsResult<Column> {
163
let array = &s[0];
164
let item = &s[1];
165
polars_ensure!(matches!(array.dtype(), DataType::Array(_, _)),
166
SchemaMismatch: "invalid series dtype: expected `Array`, got `{}`", array.dtype(),
167
);
168
let mut ca = polars_ops::series::is_in(
169
item.as_materialized_series(),
170
array.as_materialized_series(),
171
nulls_equal,
172
)?;
173
ca.rename(array.name().clone());
174
Ok(ca.into_column())
175
}
176
177
#[cfg(feature = "array_count")]
178
pub(super) fn count_matches(args: &[Column]) -> PolarsResult<Column> {
179
let s = &args[0];
180
let element = &args[1];
181
polars_ensure!(
182
element.len() == 1,
183
ComputeError: "argument expression in `arr.count_matches` must produce exactly one element, got {}",
184
element.len()
185
);
186
let ca = s.array()?;
187
ca.array_count_matches(element.get(0).unwrap())
188
.map(Column::from)
189
}
190
191
pub(super) fn shift(s: &[Column]) -> PolarsResult<Column> {
192
let ca = s[0].array()?;
193
let n = &s[1];
194
195
ca.array_shift(n.as_materialized_series()).map(Column::from)
196
}
197
198
pub(super) fn slice(s: &Column, offset: i64, length: i64) -> PolarsResult<Column> {
199
let ca = s.array()?;
200
ca.array_slice(offset, length).map(Column::from)
201
}
202
203
fn explode(c: &[Column], options: ExplodeOptions) -> PolarsResult<Column> {
204
c[0].explode(options)
205
}
206
207
fn concat_arr(args: &[Column]) -> PolarsResult<Column> {
208
let dtype = concat_arr_output_dtype(&mut args.iter().map(|c| (c.name().as_str(), c.dtype())))?;
209
210
polars_ops::series::concat_arr::concat_arr(args, &dtype)
211
}
212
213
/// Determine the output dtype of a `concat_arr` operation. Also performs validation to ensure input
214
/// dtypes are compatible.
215
fn concat_arr_output_dtype(
216
inputs: &mut dyn ExactSizeIterator<Item = (&str, &DataType)>,
217
) -> PolarsResult<DataType> {
218
#[allow(clippy::len_zero)]
219
if inputs.len() == 0 {
220
// should not be reachable - we did not set ALLOW_EMPTY_INPUTS
221
panic!();
222
}
223
224
let mut inputs = inputs.map(|(name, dtype)| {
225
let (inner_dtype, width) = match dtype {
226
DataType::Array(inner, width) => (inner.as_ref(), *width),
227
dt => (dt, 1),
228
};
229
(name, dtype, inner_dtype, width)
230
});
231
let (first_name, first_dtype, first_inner_dtype, mut out_width) = inputs.next().unwrap();
232
233
for (col_name, dtype, inner_dtype, width) in inputs {
234
out_width += width;
235
236
if inner_dtype != first_inner_dtype {
237
polars_bail!(
238
SchemaMismatch:
239
"concat_arr dtype mismatch: expected {} or array[{}] dtype to match dtype of first \
240
input column (name: {}, dtype: {}), got {} instead for column {}",
241
first_inner_dtype, first_inner_dtype, first_name, first_dtype, dtype, col_name,
242
)
243
}
244
}
245
246
Ok(DataType::Array(
247
Box::new(first_inner_dtype.clone()),
248
out_width,
249
))
250
}
251
252
#[cfg(feature = "array_to_struct")]
253
fn arr_to_struct(s: &Column, name_generator: Option<DslNameGenerator>) -> PolarsResult<Column> {
254
use polars_ops::prelude::array::ToStruct;
255
256
let name_generator =
257
name_generator.map(|f| Arc::new(move |i| f.call(i).map(PlSmallStr::from)) as Arc<_>);
258
s.array()?
259
.to_struct(name_generator)
260
.map(IntoColumn::into_column)
261
}
262
263