Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/expressions/slice.rs
8416 views
1
use AnyValue::Null;
2
use polars_core::POOL;
3
use polars_core::prelude::*;
4
use polars_core::utils::{CustomIterTools, slice_offsets};
5
use polars_utils::idx_vec::IdxVec;
6
use rayon::prelude::*;
7
8
use super::*;
9
use crate::expressions::{AggregationContext, PhysicalExpr};
10
11
pub struct SliceExpr {
12
pub(crate) input: Arc<dyn PhysicalExpr>,
13
pub(crate) offset: Arc<dyn PhysicalExpr>,
14
pub(crate) length: Arc<dyn PhysicalExpr>,
15
pub(crate) expr: Expr,
16
}
17
18
fn extract_offset(offset: &Column, expr: &Expr) -> PolarsResult<i64> {
19
polars_ensure!(
20
offset.len() <= 1, expr = expr, ComputeError:
21
"invalid argument to slice; expected an offset literal, got series of length {}",
22
offset.len()
23
);
24
offset.get(0).unwrap().extract().ok_or_else(
25
|| polars_err!(expr = expr, ComputeError: "unable to extract offset from {:?}", offset),
26
)
27
}
28
29
fn extract_length(length: &Column, expr: &Expr) -> PolarsResult<usize> {
30
polars_ensure!(
31
length.len() <= 1, expr = expr, ComputeError:
32
"invalid argument to slice; expected a length literal, got series of length {}",
33
length.len()
34
);
35
match length.get(0).unwrap() {
36
Null => Ok(usize::MAX),
37
v => v.extract().ok_or_else(
38
|| polars_err!(expr = expr, ComputeError: "unable to extract length from {:?}", length),
39
),
40
}
41
}
42
43
fn extract_args(offset: &Column, length: &Column, expr: &Expr) -> PolarsResult<(i64, usize)> {
44
Ok((extract_offset(offset, expr)?, extract_length(length, expr)?))
45
}
46
47
fn check_argument(arg: &Column, groups: &GroupsType, name: &str, expr: &Expr) -> PolarsResult<()> {
48
polars_ensure!(
49
!matches!(arg.dtype(), DataType::List(_)), expr = expr, ComputeError:
50
"invalid slice argument: cannot use an array as {} argument", name,
51
);
52
polars_ensure!(
53
arg.len() == groups.len(), expr = expr, ComputeError:
54
"invalid slice argument: the evaluated length expression was \
55
of different {} than the number of groups", name
56
);
57
polars_ensure!(
58
arg.null_count() == 0, expr = expr, ComputeError:
59
"invalid slice argument: the {} expression has nulls", name
60
);
61
Ok(())
62
}
63
64
fn slice_groups_idx(offset: i64, length: usize, mut first: IdxSize, idx: &[IdxSize]) -> IdxItem {
65
let (offset, len) = slice_offsets(offset, length, idx.len());
66
67
// If slice isn't out of bounds, we replace first.
68
// If slice is oob, the `idx` vec will be empty and `first` will be ignored
69
if let Some(f) = idx.get(offset) {
70
first = *f;
71
}
72
// This is a clone of the vec, which is unfortunate. Maybe we have a `sliceable` unitvec one day.
73
(first, IdxVec::from_slice(&idx[offset..offset + len]))
74
}
75
76
fn slice_groups_slice(offset: i64, length: usize, first: IdxSize, len: IdxSize) -> [IdxSize; 2] {
77
let (offset, len) = slice_offsets(offset, length, len as usize);
78
[first + offset as IdxSize, len as IdxSize]
79
}
80
81
impl PhysicalExpr for SliceExpr {
82
fn as_expression(&self) -> Option<&Expr> {
83
Some(&self.expr)
84
}
85
86
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
87
let results = POOL.install(|| {
88
[&self.offset, &self.length, &self.input]
89
.par_iter()
90
.map(|e| e.evaluate(df, state))
91
.collect::<PolarsResult<Vec<_>>>()
92
})?;
93
let offset = &results[0];
94
let length = &results[1];
95
let series = &results[2];
96
let (offset, length) = extract_args(offset, length, &self.expr)?;
97
98
Ok(series.slice(offset, length))
99
}
100
101
fn evaluate_on_groups<'a>(
102
&self,
103
df: &DataFrame,
104
groups: &'a GroupPositions,
105
state: &ExecutionState,
106
) -> PolarsResult<AggregationContext<'a>> {
107
let mut results = POOL.install(|| {
108
[&self.offset, &self.length, &self.input]
109
.par_iter()
110
.map(|e| e.evaluate_on_groups(df, groups, state))
111
.collect::<PolarsResult<Vec<_>>>()
112
})?;
113
let mut ac = results.pop().unwrap();
114
115
let mut ac_length = results.pop().unwrap();
116
let mut ac_offset = results.pop().unwrap();
117
118
// Fast path:
119
// When `input` (ac) is a LiteralValue, and both `offset` and `length` are LiteralScalar,
120
// we slice the LiteralValue and avoid calling groups().
121
// TODO: When `input` (ac) is a LiteralValue, and `offset` or `length` is not a LiteralScalar,
122
// we can simplify the groups calculation since we have a List containing one scalar for
123
// each group.
124
125
use AggState::*;
126
let groups = match (&ac_offset.state, &ac_length.state) {
127
(LiteralScalar(offset), LiteralScalar(length)) => {
128
let (offset, length) = extract_args(offset, length, &self.expr)?;
129
130
if let LiteralScalar(s) = ac.agg_state() {
131
let s1 = s.slice(offset, length);
132
ac.with_literal(s1);
133
ac.aggregated();
134
return Ok(ac);
135
}
136
if let AggregatedScalar(c) = ac.state {
137
ac.state = AggregatedList(c.as_list().into_column());
138
ac.update_groups = UpdateGroups::WithSeriesLen;
139
}
140
let groups = ac.groups();
141
142
match groups.as_ref().as_ref() {
143
GroupsType::Idx(groups) => {
144
let groups = groups
145
.iter()
146
.map(|(first, idx)| slice_groups_idx(offset, length, first, idx))
147
.collect();
148
GroupsType::Idx(groups)
149
},
150
GroupsType::Slice {
151
groups,
152
overlapping,
153
monotonic,
154
} => {
155
let groups = groups
156
.iter()
157
.map(|&[first, len]| slice_groups_slice(offset, length, first, len))
158
.collect_trusted();
159
GroupsType::new_slice(groups, *overlapping, *monotonic)
160
},
161
}
162
},
163
(LiteralScalar(offset), _) => {
164
if matches!(ac.state, LiteralScalar(_)) {
165
ac.aggregated();
166
} else if let AggregatedScalar(c) = ac.state {
167
ac.state = AggregatedList(c.as_list().into_column());
168
ac.update_groups = UpdateGroups::WithSeriesLen;
169
}
170
let groups = ac.groups();
171
let offset = extract_offset(offset, &self.expr)?;
172
let length = ac_length.aggregated();
173
check_argument(&length, groups, "length", &self.expr)?;
174
175
let length = length.cast(&IDX_DTYPE)?;
176
let length = length.idx().unwrap();
177
178
match groups.as_ref().as_ref() {
179
GroupsType::Idx(groups) => {
180
let groups = groups
181
.iter()
182
.zip(length.into_no_null_iter())
183
.map(|((first, idx), length)| {
184
slice_groups_idx(offset, length as usize, first, idx)
185
})
186
.collect();
187
GroupsType::Idx(groups)
188
},
189
GroupsType::Slice {
190
groups,
191
overlapping,
192
monotonic: _,
193
} => {
194
let groups = groups
195
.iter()
196
.zip(length.into_no_null_iter())
197
.map(|(&[first, len], length)| {
198
slice_groups_slice(offset, length as usize, first, len)
199
})
200
.collect_trusted();
201
GroupsType::new_slice(groups, *overlapping, false)
202
},
203
}
204
},
205
(_, LiteralScalar(length)) => {
206
if matches!(ac.state, LiteralScalar(_)) {
207
ac.aggregated();
208
} else if let AggregatedScalar(c) = ac.state {
209
ac.state = AggregatedList(c.as_list().into_column());
210
ac.update_groups = UpdateGroups::WithSeriesLen;
211
}
212
let groups = ac.groups();
213
let length = extract_length(length, &self.expr)?;
214
let offset = ac_offset.aggregated();
215
check_argument(&offset, groups, "offset", &self.expr)?;
216
217
let offset = offset.cast(&DataType::Int64)?;
218
let offset = offset.i64().unwrap();
219
220
match groups.as_ref().as_ref() {
221
GroupsType::Idx(groups) => {
222
let groups = groups
223
.iter()
224
.zip(offset.into_no_null_iter())
225
.map(|((first, idx), offset)| {
226
slice_groups_idx(offset, length, first, idx)
227
})
228
.collect();
229
GroupsType::Idx(groups)
230
},
231
GroupsType::Slice {
232
groups,
233
overlapping,
234
monotonic: _,
235
} => {
236
let groups = groups
237
.iter()
238
.zip(offset.into_no_null_iter())
239
.map(|(&[first, len], offset)| {
240
slice_groups_slice(offset, length, first, len)
241
})
242
.collect_trusted();
243
GroupsType::new_slice(groups, *overlapping, false)
244
},
245
}
246
},
247
_ => {
248
if matches!(ac.state, LiteralScalar(_)) {
249
ac.aggregated();
250
} else if let AggregatedScalar(c) = ac.state {
251
ac.state = AggregatedList(c.as_list().into_column());
252
ac.update_groups = UpdateGroups::WithSeriesLen;
253
}
254
255
let groups = ac.groups();
256
let length = ac_length.aggregated();
257
let offset = ac_offset.aggregated();
258
check_argument(&length, groups, "length", &self.expr)?;
259
check_argument(&offset, groups, "offset", &self.expr)?;
260
261
let offset = offset.cast(&DataType::Int64)?;
262
let offset = offset.i64().unwrap();
263
264
let length = length.cast(&IDX_DTYPE)?;
265
let length = length.idx().unwrap();
266
267
match groups.as_ref().as_ref() {
268
GroupsType::Idx(groups) => {
269
let groups = groups
270
.iter()
271
.zip(offset.into_no_null_iter())
272
.zip(length.into_no_null_iter())
273
.map(|(((first, idx), offset), length)| {
274
slice_groups_idx(offset, length as usize, first, idx)
275
})
276
.collect();
277
GroupsType::Idx(groups)
278
},
279
GroupsType::Slice {
280
groups,
281
overlapping,
282
monotonic: _,
283
} => {
284
let groups = groups
285
.iter()
286
.zip(offset.into_no_null_iter())
287
.zip(length.into_no_null_iter())
288
.map(|((&[first, len], offset), length)| {
289
slice_groups_slice(offset, length as usize, first, len)
290
})
291
.collect_trusted();
292
GroupsType::new_slice(groups, *overlapping, false)
293
},
294
}
295
},
296
};
297
298
ac.with_groups(groups.into_sliceable())
299
.set_original_len(false);
300
301
Ok(ac)
302
}
303
304
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
305
self.input.to_field(input_schema)
306
}
307
308
fn is_scalar(&self) -> bool {
309
false
310
}
311
}
312
313