Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/expressions/slice.rs
6940 views
1
use AnyValue::Null;
2
use polars_core::POOL;
3
use polars_core::prelude::*;
4
use polars_core::utils::{CustomIterTools, slice_offsets};
5
use rayon::prelude::*;
6
7
use super::*;
8
use crate::expressions::{AggregationContext, PhysicalExpr};
9
10
pub struct SliceExpr {
11
pub(crate) input: Arc<dyn PhysicalExpr>,
12
pub(crate) offset: Arc<dyn PhysicalExpr>,
13
pub(crate) length: Arc<dyn PhysicalExpr>,
14
pub(crate) expr: Expr,
15
}
16
17
fn extract_offset(offset: &Column, expr: &Expr) -> PolarsResult<i64> {
18
polars_ensure!(
19
offset.len() <= 1, expr = expr, ComputeError:
20
"invalid argument to slice; expected an offset literal, got series of length {}",
21
offset.len()
22
);
23
offset.get(0).unwrap().extract().ok_or_else(
24
|| polars_err!(expr = expr, ComputeError: "unable to extract offset from {:?}", offset),
25
)
26
}
27
28
fn extract_length(length: &Column, expr: &Expr) -> PolarsResult<usize> {
29
polars_ensure!(
30
length.len() <= 1, expr = expr, ComputeError:
31
"invalid argument to slice; expected a length literal, got series of length {}",
32
length.len()
33
);
34
match length.get(0).unwrap() {
35
Null => Ok(usize::MAX),
36
v => v.extract().ok_or_else(
37
|| polars_err!(expr = expr, ComputeError: "unable to extract length from {:?}", length),
38
),
39
}
40
}
41
42
fn extract_args(offset: &Column, length: &Column, expr: &Expr) -> PolarsResult<(i64, usize)> {
43
Ok((extract_offset(offset, expr)?, extract_length(length, expr)?))
44
}
45
46
fn check_argument(arg: &Column, groups: &GroupsType, name: &str, expr: &Expr) -> PolarsResult<()> {
47
polars_ensure!(
48
!matches!(arg.dtype(), DataType::List(_)), expr = expr, ComputeError:
49
"invalid slice argument: cannot use an array as {} argument", name,
50
);
51
polars_ensure!(
52
arg.len() == groups.len(), expr = expr, ComputeError:
53
"invalid slice argument: the evaluated length expression was \
54
of different {} than the number of groups", name
55
);
56
polars_ensure!(
57
arg.null_count() == 0, expr = expr, ComputeError:
58
"invalid slice argument: the {} expression has nulls", name
59
);
60
Ok(())
61
}
62
63
fn slice_groups_idx(offset: i64, length: usize, mut first: IdxSize, idx: &[IdxSize]) -> IdxItem {
64
let (offset, len) = slice_offsets(offset, length, idx.len());
65
66
// If slice isn't out of bounds, we replace first.
67
// If slice is oob, the `idx` vec will be empty and `first` will be ignored
68
if let Some(f) = idx.get(offset) {
69
first = *f;
70
}
71
// This is a clone of the vec, which is unfortunate. Maybe we have a `sliceable` unitvec one day.
72
(first, idx[offset..offset + len].into())
73
}
74
75
fn slice_groups_slice(offset: i64, length: usize, first: IdxSize, len: IdxSize) -> [IdxSize; 2] {
76
let (offset, len) = slice_offsets(offset, length, len as usize);
77
[first + offset as IdxSize, len as IdxSize]
78
}
79
80
impl PhysicalExpr for SliceExpr {
81
fn as_expression(&self) -> Option<&Expr> {
82
Some(&self.expr)
83
}
84
85
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
86
let results = POOL.install(|| {
87
[&self.offset, &self.length, &self.input]
88
.par_iter()
89
.map(|e| e.evaluate(df, state))
90
.collect::<PolarsResult<Vec<_>>>()
91
})?;
92
let offset = &results[0];
93
let length = &results[1];
94
let series = &results[2];
95
let (offset, length) = extract_args(offset, length, &self.expr)?;
96
97
Ok(series.slice(offset, length))
98
}
99
100
fn evaluate_on_groups<'a>(
101
&self,
102
df: &DataFrame,
103
groups: &'a GroupPositions,
104
state: &ExecutionState,
105
) -> PolarsResult<AggregationContext<'a>> {
106
let mut results = POOL.install(|| {
107
[&self.offset, &self.length, &self.input]
108
.par_iter()
109
.map(|e| e.evaluate_on_groups(df, groups, state))
110
.collect::<PolarsResult<Vec<_>>>()
111
})?;
112
let mut ac = results.pop().unwrap();
113
114
if let AggState::AggregatedScalar(_) = ac.agg_state() {
115
polars_bail!(InvalidOperation: "cannot slice() an aggregated scalar value")
116
}
117
118
let mut ac_length = results.pop().unwrap();
119
let mut ac_offset = results.pop().unwrap();
120
121
// Fast path:
122
// When `input` (ac) is a LiteralValue, and both `offset` and `length` are LiteralScalar,
123
// we slice the LiteralValue and avoid calling groups().
124
// TODO: When `input` (ac) is a LiteralValue, and `offset` or `length` is not a LiteralScalar,
125
// we can simplify the groups calculation since we have a List containing one scalar for
126
// each group.
127
128
use AggState::*;
129
let groups = match (&ac_offset.state, &ac_length.state) {
130
(LiteralScalar(offset), LiteralScalar(length)) => {
131
let (offset, length) = extract_args(offset, length, &self.expr)?;
132
133
if let LiteralScalar(s) = ac.agg_state() {
134
let s1 = s.slice(offset, length);
135
ac.with_literal(s1);
136
ac.aggregated();
137
return Ok(ac);
138
}
139
let groups = ac.groups();
140
141
match groups.as_ref().as_ref() {
142
GroupsType::Idx(groups) => {
143
let groups = groups
144
.iter()
145
.map(|(first, idx)| slice_groups_idx(offset, length, first, idx))
146
.collect();
147
GroupsType::Idx(groups)
148
},
149
GroupsType::Slice { groups, .. } => {
150
let groups = groups
151
.iter()
152
.map(|&[first, len]| slice_groups_slice(offset, length, first, len))
153
.collect_trusted();
154
GroupsType::Slice {
155
groups,
156
rolling: false,
157
}
158
},
159
}
160
},
161
(LiteralScalar(offset), _) => {
162
if matches!(ac.state, LiteralScalar(_)) {
163
ac.aggregated();
164
}
165
let groups = ac.groups();
166
let offset = extract_offset(offset, &self.expr)?;
167
let length = ac_length.aggregated();
168
check_argument(&length, groups, "length", &self.expr)?;
169
170
let length = length.cast(&IDX_DTYPE)?;
171
let length = length.idx().unwrap();
172
173
match groups.as_ref().as_ref() {
174
GroupsType::Idx(groups) => {
175
let groups = groups
176
.iter()
177
.zip(length.into_no_null_iter())
178
.map(|((first, idx), length)| {
179
slice_groups_idx(offset, length as usize, first, idx)
180
})
181
.collect();
182
GroupsType::Idx(groups)
183
},
184
GroupsType::Slice { groups, .. } => {
185
let groups = groups
186
.iter()
187
.zip(length.into_no_null_iter())
188
.map(|(&[first, len], length)| {
189
slice_groups_slice(offset, length as usize, first, len)
190
})
191
.collect_trusted();
192
GroupsType::Slice {
193
groups,
194
rolling: false,
195
}
196
},
197
}
198
},
199
(_, LiteralScalar(length)) => {
200
if matches!(ac.state, LiteralScalar(_)) {
201
ac.aggregated();
202
}
203
let groups = ac.groups();
204
let length = extract_length(length, &self.expr)?;
205
let offset = ac_offset.aggregated();
206
check_argument(&offset, groups, "offset", &self.expr)?;
207
208
let offset = offset.cast(&DataType::Int64)?;
209
let offset = offset.i64().unwrap();
210
211
match groups.as_ref().as_ref() {
212
GroupsType::Idx(groups) => {
213
let groups = groups
214
.iter()
215
.zip(offset.into_no_null_iter())
216
.map(|((first, idx), offset)| {
217
slice_groups_idx(offset, length, first, idx)
218
})
219
.collect();
220
GroupsType::Idx(groups)
221
},
222
GroupsType::Slice { groups, .. } => {
223
let groups = groups
224
.iter()
225
.zip(offset.into_no_null_iter())
226
.map(|(&[first, len], offset)| {
227
slice_groups_slice(offset, length, first, len)
228
})
229
.collect_trusted();
230
GroupsType::Slice {
231
groups,
232
rolling: false,
233
}
234
},
235
}
236
},
237
_ => {
238
if matches!(ac.state, LiteralScalar(_)) {
239
ac.aggregated();
240
}
241
242
let groups = ac.groups();
243
let length = ac_length.aggregated();
244
let offset = ac_offset.aggregated();
245
check_argument(&length, groups, "length", &self.expr)?;
246
check_argument(&offset, groups, "offset", &self.expr)?;
247
248
let offset = offset.cast(&DataType::Int64)?;
249
let offset = offset.i64().unwrap();
250
251
let length = length.cast(&IDX_DTYPE)?;
252
let length = length.idx().unwrap();
253
254
match groups.as_ref().as_ref() {
255
GroupsType::Idx(groups) => {
256
let groups = groups
257
.iter()
258
.zip(offset.into_no_null_iter())
259
.zip(length.into_no_null_iter())
260
.map(|(((first, idx), offset), length)| {
261
slice_groups_idx(offset, length as usize, first, idx)
262
})
263
.collect();
264
GroupsType::Idx(groups)
265
},
266
GroupsType::Slice { groups, .. } => {
267
let groups = groups
268
.iter()
269
.zip(offset.into_no_null_iter())
270
.zip(length.into_no_null_iter())
271
.map(|((&[first, len], offset), length)| {
272
slice_groups_slice(offset, length as usize, first, len)
273
})
274
.collect_trusted();
275
GroupsType::Slice {
276
groups,
277
rolling: false,
278
}
279
},
280
}
281
},
282
};
283
284
ac.with_groups(groups.into_sliceable())
285
.set_original_len(false);
286
287
Ok(ac)
288
}
289
290
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
291
self.input.to_field(input_schema)
292
}
293
294
fn is_scalar(&self) -> bool {
295
false
296
}
297
}
298
299