Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/expressions/gather.rs
6940 views
1
use polars_core::chunked_array::cast::CastOptions;
2
use polars_core::prelude::arity::unary_elementwise_values;
3
use polars_core::prelude::*;
4
use polars_ops::prelude::lst_get;
5
use polars_ops::series::convert_to_unsigned_index;
6
use polars_utils::index::ToIdx;
7
8
use super::*;
9
use crate::expressions::{AggState, AggregationContext, PhysicalExpr, UpdateGroups};
10
11
pub struct GatherExpr {
12
pub(crate) phys_expr: Arc<dyn PhysicalExpr>,
13
pub(crate) idx: Arc<dyn PhysicalExpr>,
14
pub(crate) expr: Expr,
15
pub(crate) returns_scalar: bool,
16
}
17
18
impl PhysicalExpr for GatherExpr {
19
fn as_expression(&self) -> Option<&Expr> {
20
Some(&self.expr)
21
}
22
23
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
24
let series = self.phys_expr.evaluate(df, state)?;
25
let idx = self.idx.evaluate(df, state)?;
26
let idx = convert_to_unsigned_index(idx.as_materialized_series(), series.len())?;
27
series.take(&idx)
28
}
29
30
#[allow(clippy::ptr_arg)]
31
fn evaluate_on_groups<'a>(
32
&self,
33
df: &DataFrame,
34
groups: &'a GroupPositions,
35
state: &ExecutionState,
36
) -> PolarsResult<AggregationContext<'a>> {
37
let mut ac = self.phys_expr.evaluate_on_groups(df, groups, state)?;
38
let mut idx = self.idx.evaluate_on_groups(df, groups, state)?;
39
40
let ac_list = ac.aggregated_as_list();
41
42
if self.returns_scalar {
43
polars_ensure!(
44
!matches!(idx.agg_state(), AggState::AggregatedList(_) | AggState::NotAggregated(_)),
45
ComputeError: "expected single index"
46
);
47
48
// For returns_scalar=true, we can dispatch to `list.get`.
49
let idx = idx.flat_naive();
50
let idx = idx.cast(&DataType::Int64)?;
51
let idx = idx.i64().unwrap();
52
let taken = lst_get(ac_list.as_ref(), idx, true)?;
53
54
ac.with_values_and_args(taken, true, Some(&self.expr), false, true)?;
55
ac.with_update_groups(UpdateGroups::No);
56
return Ok(ac);
57
}
58
59
// Cast the indices to
60
// - IdxSize, if the idx only contains positive integers.
61
// - Int64, if the idx contains negative numbers.
62
// This may give false positives if there are masked out elements.
63
let idx = idx.aggregated_as_list();
64
let idx = idx.apply_to_inner(&|s| match s.dtype() {
65
dtype if dtype == &IDX_DTYPE => Ok(s),
66
dtype if dtype.is_unsigned_integer() => {
67
s.cast_with_options(&IDX_DTYPE, CastOptions::Strict)
68
},
69
70
dtype if dtype.is_signed_integer() => {
71
let has_negative_integers = s.lt(0)?.any();
72
if has_negative_integers && dtype == &DataType::Int64 {
73
Ok(s)
74
} else if has_negative_integers {
75
s.cast_with_options(&DataType::Int64, CastOptions::Strict)
76
} else {
77
s.cast_with_options(&IDX_DTYPE, CastOptions::Overflowing)
78
}
79
},
80
_ => polars_bail!(
81
op = "gather/get",
82
got = s.dtype(),
83
expected = "integer type"
84
),
85
})?;
86
87
let taken = if idx.inner_dtype() == &IDX_DTYPE {
88
// Fast path: all indices are positive.
89
90
ac_list
91
.amortized_iter()
92
.zip(idx.amortized_iter())
93
.map(|(s, idx)| Some(s?.as_ref().take(idx?.as_ref().idx().unwrap())))
94
.map(|opt_res| opt_res.transpose())
95
.collect::<PolarsResult<ListChunked>>()?
96
.with_name(ac.get_values().name().clone())
97
} else {
98
// Slower path: some indices may be negative.
99
assert!(idx.inner_dtype() == &DataType::Int64);
100
101
ac_list
102
.amortized_iter()
103
.zip(idx.amortized_iter())
104
.map(|(s, idx)| {
105
let s = s?;
106
let idx = idx?;
107
let idx = idx.as_ref().i64().unwrap();
108
let target_len = s.as_ref().len() as u64;
109
let idx = unary_elementwise_values(idx, |v| v.to_idx(target_len));
110
Some(s.as_ref().take(&idx))
111
})
112
.map(|opt_res| opt_res.transpose())
113
.collect::<PolarsResult<ListChunked>>()?
114
.with_name(ac.get_values().name().clone())
115
};
116
117
ac.with_values(taken.into_column(), true, Some(&self.expr))?;
118
ac.with_update_groups(UpdateGroups::WithSeriesLen);
119
Ok(ac)
120
}
121
122
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
123
self.phys_expr.to_field(input_schema)
124
}
125
126
fn is_scalar(&self) -> bool {
127
self.returns_scalar
128
}
129
}
130
131