Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/aexpr/function_expr/array.rs
7889 views
1
use polars_core::utils::slice_offsets;
2
use polars_ops::chunked_array::array::*;
3
4
use super::*;
5
6
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
7
#[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))]
8
pub enum IRArrayFunction {
9
Length,
10
Min,
11
Max,
12
Sum,
13
ToList,
14
Unique(bool),
15
NUnique,
16
Std(u8),
17
Var(u8),
18
Mean,
19
Median,
20
#[cfg(feature = "array_any_all")]
21
Any,
22
#[cfg(feature = "array_any_all")]
23
All,
24
Sort(SortOptions),
25
Reverse,
26
ArgMin,
27
ArgMax,
28
Get(bool),
29
Join(bool),
30
#[cfg(feature = "is_in")]
31
Contains {
32
nulls_equal: bool,
33
},
34
#[cfg(feature = "array_count")]
35
CountMatches,
36
Shift,
37
Explode(ExplodeOptions),
38
Concat,
39
Slice(i64, i64),
40
#[cfg(feature = "array_to_struct")]
41
ToStruct(Option<DslNameGenerator>),
42
}
43
44
impl<'a> FieldsMapper<'a> {
45
/// Validate that the dtype is an array.
46
pub fn ensure_is_array(self) -> PolarsResult<Self> {
47
let dt = self.args()[0].dtype();
48
polars_ensure!(
49
dt.is_array(),
50
InvalidOperation: format!("expected Array datatype for array operation, got: {:?}", dt)
51
);
52
Ok(self)
53
}
54
}
55
56
impl IRArrayFunction {
57
pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
58
use IRArrayFunction::*;
59
60
match self {
61
Concat => Ok(Field::new(
62
mapper
63
.args()
64
.first()
65
.map_or(PlSmallStr::EMPTY, |x| x.name.clone()),
66
concat_arr_output_dtype(
67
&mut mapper.args().iter().map(|x| (x.name.as_str(), &x.dtype)),
68
)?,
69
)),
70
Length => mapper.ensure_is_array()?.with_dtype(IDX_DTYPE),
71
Min | Max => mapper
72
.ensure_is_array()?
73
.map_to_list_and_array_inner_dtype(),
74
Sum => mapper.ensure_is_array()?.nested_sum_type(),
75
ToList => mapper
76
.ensure_is_array()?
77
.try_map_dtype(map_array_dtype_to_list_dtype),
78
Unique(_) => mapper
79
.ensure_is_array()?
80
.try_map_dtype(map_array_dtype_to_list_dtype),
81
NUnique => mapper.ensure_is_array()?.with_dtype(IDX_DTYPE),
82
Std(_) => mapper.ensure_is_array()?.moment_dtype(),
83
Var(_) => mapper.ensure_is_array()?.var_dtype(),
84
Mean => mapper.ensure_is_array()?.moment_dtype(),
85
Median => mapper.ensure_is_array()?.moment_dtype(),
86
#[cfg(feature = "array_any_all")]
87
Any | All => mapper.ensure_is_array()?.with_dtype(DataType::Boolean),
88
Sort(_) => mapper.ensure_is_array()?.with_same_dtype(),
89
Reverse => mapper.ensure_is_array()?.with_same_dtype(),
90
ArgMin | ArgMax => mapper.ensure_is_array()?.with_dtype(IDX_DTYPE),
91
Get(_) => mapper
92
.ensure_is_array()?
93
.map_to_list_and_array_inner_dtype(),
94
Join(_) => mapper.ensure_is_array()?.with_dtype(DataType::String),
95
#[cfg(feature = "is_in")]
96
Contains { nulls_equal: _ } => mapper.ensure_is_array()?.with_dtype(DataType::Boolean),
97
#[cfg(feature = "array_count")]
98
CountMatches => mapper.ensure_is_array()?.with_dtype(IDX_DTYPE),
99
Shift => mapper.ensure_is_array()?.with_same_dtype(),
100
Explode { .. } => mapper.ensure_is_array()?.try_map_to_array_inner_dtype(),
101
Slice(offset, length) => mapper
102
.ensure_is_array()?
103
.try_map_dtype(map_to_array_fixed_length(offset, length)),
104
#[cfg(feature = "array_to_struct")]
105
ToStruct(name_generator) => mapper.ensure_is_array()?.try_map_dtype(|dtype| {
106
let DataType::Array(inner, width) = dtype else {
107
polars_bail!(InvalidOperation: "expected Array type, got: {dtype}")
108
};
109
110
(0..*width)
111
.map(|i| {
112
let name = match name_generator {
113
None => arr_default_struct_name_gen(i),
114
Some(ng) => PlSmallStr::from_string(ng.call(i)?),
115
};
116
Ok(Field::new(name, inner.as_ref().clone()))
117
})
118
.collect::<PolarsResult<Vec<Field>>>()
119
.map(DataType::Struct)
120
}),
121
}
122
}
123
124
pub fn function_options(&self) -> FunctionOptions {
125
use IRArrayFunction as A;
126
match self {
127
#[cfg(feature = "array_any_all")]
128
A::Any | A::All => FunctionOptions::elementwise(),
129
#[cfg(feature = "is_in")]
130
A::Contains { nulls_equal: _ } => FunctionOptions::elementwise(),
131
#[cfg(feature = "array_count")]
132
A::CountMatches => FunctionOptions::elementwise(),
133
A::Concat => FunctionOptions::elementwise()
134
.with_flags(|f| f | FunctionFlags::INPUT_WILDCARD_EXPANSION),
135
A::Length
136
| A::Min
137
| A::Max
138
| A::Sum
139
| A::ToList
140
| A::Unique(_)
141
| A::NUnique
142
| A::Std(_)
143
| A::Var(_)
144
| A::Mean
145
| A::Median
146
| A::Sort(_)
147
| A::Reverse
148
| A::ArgMin
149
| A::ArgMax
150
| A::Get(_)
151
| A::Join(_)
152
| A::Shift
153
| A::Slice(_, _) => FunctionOptions::elementwise(),
154
A::Explode { .. } => FunctionOptions::row_separable(),
155
#[cfg(feature = "array_to_struct")]
156
A::ToStruct(_) => FunctionOptions::elementwise(),
157
}
158
}
159
}
160
161
fn map_array_dtype_to_list_dtype(datatype: &DataType) -> PolarsResult<DataType> {
162
if let DataType::Array(inner, _) = datatype {
163
Ok(DataType::List(inner.clone()))
164
} else {
165
polars_bail!(ComputeError: "expected array dtype")
166
}
167
}
168
169
fn map_to_array_fixed_length(
170
offset: &i64,
171
length: &i64,
172
) -> impl FnOnce(&DataType) -> PolarsResult<DataType> {
173
move |datatype: &DataType| {
174
if let DataType::Array(inner, array_len) = datatype {
175
let length: usize = if *length < 0 {
176
(*array_len as i64 + *length).max(0)
177
} else {
178
*length
179
}.try_into().map_err(|_| {
180
polars_err!(OutOfBounds: "length must be a non-negative integer, got: {}", length)
181
})?;
182
let (_, slice_offset) = slice_offsets(*offset, length, *array_len);
183
Ok(DataType::Array(inner.clone(), slice_offset))
184
} else {
185
polars_bail!(ComputeError: "expected array dtype, got {}", datatype);
186
}
187
}
188
}
189
190
impl Display for IRArrayFunction {
191
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
192
use IRArrayFunction::*;
193
let name = match self {
194
Concat => "concat",
195
Length => "length",
196
Min => "min",
197
Max => "max",
198
Sum => "sum",
199
ToList => "to_list",
200
Unique(_) => "unique",
201
NUnique => "n_unique",
202
Std(_) => "std",
203
Var(_) => "var",
204
Mean => "mean",
205
Median => "median",
206
#[cfg(feature = "array_any_all")]
207
Any => "any",
208
#[cfg(feature = "array_any_all")]
209
All => "all",
210
Sort(_) => "sort",
211
Reverse => "reverse",
212
ArgMin => "arg_min",
213
ArgMax => "arg_max",
214
Get(_) => "get",
215
Join(_) => "join",
216
#[cfg(feature = "is_in")]
217
Contains { nulls_equal: _ } => "contains",
218
#[cfg(feature = "array_count")]
219
CountMatches => "count_matches",
220
Shift => "shift",
221
Slice(_, _) => "slice",
222
Explode { .. } => "explode",
223
#[cfg(feature = "array_to_struct")]
224
ToStruct(_) => "to_struct",
225
};
226
write!(f, "arr.{name}")
227
}
228
}
229
230
/// Determine the output dtype of a `concat_arr` operation. Also performs validation to ensure input
231
/// dtypes are compatible.
232
fn concat_arr_output_dtype(
233
inputs: &mut dyn ExactSizeIterator<Item = (&str, &DataType)>,
234
) -> PolarsResult<DataType> {
235
#[allow(clippy::len_zero)]
236
if inputs.len() == 0 {
237
// should not be reachable - we did not set ALLOW_EMPTY_INPUTS
238
panic!();
239
}
240
241
let mut inputs = inputs.map(|(name, dtype)| {
242
let (inner_dtype, width) = match dtype {
243
DataType::Array(inner, width) => (inner.as_ref(), *width),
244
dt => (dt, 1),
245
};
246
(name, dtype, inner_dtype, width)
247
});
248
let (first_name, first_dtype, first_inner_dtype, mut out_width) = inputs.next().unwrap();
249
250
for (col_name, dtype, inner_dtype, width) in inputs {
251
out_width += width;
252
253
if inner_dtype != first_inner_dtype {
254
polars_bail!(
255
SchemaMismatch:
256
"concat_arr dtype mismatch: expected {} or array[{}] dtype to match dtype of first \
257
input column (name: {}, dtype: {}), got {} instead for column {}",
258
first_inner_dtype, first_inner_dtype, first_name, first_dtype, dtype, col_name,
259
)
260
}
261
}
262
263
Ok(DataType::Array(
264
Box::new(first_inner_dtype.clone()),
265
out_width,
266
))
267
}
268
269