Path: blob/main/crates/polars-expr/src/expressions/gather.rs
6940 views
use polars_core::chunked_array::cast::CastOptions;1use polars_core::prelude::arity::unary_elementwise_values;2use polars_core::prelude::*;3use polars_ops::prelude::lst_get;4use polars_ops::series::convert_to_unsigned_index;5use polars_utils::index::ToIdx;67use super::*;8use crate::expressions::{AggState, AggregationContext, PhysicalExpr, UpdateGroups};910pub struct GatherExpr {11pub(crate) phys_expr: Arc<dyn PhysicalExpr>,12pub(crate) idx: Arc<dyn PhysicalExpr>,13pub(crate) expr: Expr,14pub(crate) returns_scalar: bool,15}1617impl PhysicalExpr for GatherExpr {18fn as_expression(&self) -> Option<&Expr> {19Some(&self.expr)20}2122fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {23let series = self.phys_expr.evaluate(df, state)?;24let idx = self.idx.evaluate(df, state)?;25let idx = convert_to_unsigned_index(idx.as_materialized_series(), series.len())?;26series.take(&idx)27}2829#[allow(clippy::ptr_arg)]30fn evaluate_on_groups<'a>(31&self,32df: &DataFrame,33groups: &'a GroupPositions,34state: &ExecutionState,35) -> PolarsResult<AggregationContext<'a>> {36let mut ac = self.phys_expr.evaluate_on_groups(df, groups, state)?;37let mut idx = self.idx.evaluate_on_groups(df, groups, state)?;3839let ac_list = ac.aggregated_as_list();4041if self.returns_scalar {42polars_ensure!(43!matches!(idx.agg_state(), AggState::AggregatedList(_) | AggState::NotAggregated(_)),44ComputeError: "expected single index"45);4647// For returns_scalar=true, we can dispatch to `list.get`.48let idx = idx.flat_naive();49let idx = idx.cast(&DataType::Int64)?;50let idx = idx.i64().unwrap();51let taken = lst_get(ac_list.as_ref(), idx, true)?;5253ac.with_values_and_args(taken, true, Some(&self.expr), false, true)?;54ac.with_update_groups(UpdateGroups::No);55return Ok(ac);56}5758// Cast the indices to59// - IdxSize, if the idx only contains positive integers.60// - Int64, if the idx contains negative numbers.61// This may give false positives if there are masked out elements.62let idx = idx.aggregated_as_list();63let idx = idx.apply_to_inner(&|s| match s.dtype() {64dtype if dtype == &IDX_DTYPE => Ok(s),65dtype if dtype.is_unsigned_integer() => {66s.cast_with_options(&IDX_DTYPE, CastOptions::Strict)67},6869dtype if dtype.is_signed_integer() => {70let has_negative_integers = s.lt(0)?.any();71if has_negative_integers && dtype == &DataType::Int64 {72Ok(s)73} else if has_negative_integers {74s.cast_with_options(&DataType::Int64, CastOptions::Strict)75} else {76s.cast_with_options(&IDX_DTYPE, CastOptions::Overflowing)77}78},79_ => polars_bail!(80op = "gather/get",81got = s.dtype(),82expected = "integer type"83),84})?;8586let taken = if idx.inner_dtype() == &IDX_DTYPE {87// Fast path: all indices are positive.8889ac_list90.amortized_iter()91.zip(idx.amortized_iter())92.map(|(s, idx)| Some(s?.as_ref().take(idx?.as_ref().idx().unwrap())))93.map(|opt_res| opt_res.transpose())94.collect::<PolarsResult<ListChunked>>()?95.with_name(ac.get_values().name().clone())96} else {97// Slower path: some indices may be negative.98assert!(idx.inner_dtype() == &DataType::Int64);99100ac_list101.amortized_iter()102.zip(idx.amortized_iter())103.map(|(s, idx)| {104let s = s?;105let idx = idx?;106let idx = idx.as_ref().i64().unwrap();107let target_len = s.as_ref().len() as u64;108let idx = unary_elementwise_values(idx, |v| v.to_idx(target_len));109Some(s.as_ref().take(&idx))110})111.map(|opt_res| opt_res.transpose())112.collect::<PolarsResult<ListChunked>>()?113.with_name(ac.get_values().name().clone())114};115116ac.with_values(taken.into_column(), true, Some(&self.expr))?;117ac.with_update_groups(UpdateGroups::WithSeriesLen);118Ok(ac)119}120121fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {122self.phys_expr.to_field(input_schema)123}124125fn is_scalar(&self) -> bool {126self.returns_scalar127}128}129130131