Path: blob/main/crates/polars-expr/src/expressions/gather.rs
8424 views
use polars_core::chunked_array::cast::CastOptions;1use polars_core::prelude::arity::unary_elementwise_values;2use polars_core::prelude::*;3use polars_ops::prelude::lst_get;4use polars_ops::series::convert_and_bound_index;5use polars_utils::index::ToIdx;67use super::*;8use crate::expressions::{AggState, AggregationContext, PhysicalExpr, UpdateGroups};910pub struct GatherExpr {11pub(crate) phys_expr: Arc<dyn PhysicalExpr>,12pub(crate) idx: Arc<dyn PhysicalExpr>,13pub(crate) expr: Expr,14pub(crate) returns_scalar: bool,15pub(crate) null_on_oob: bool,16}1718impl PhysicalExpr for GatherExpr {19fn as_expression(&self) -> Option<&Expr> {20Some(&self.expr)21}2223fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {24let series = self.phys_expr.evaluate(df, state)?;25let idx = self.idx.evaluate(df, state)?;26let idx =27convert_and_bound_index(idx.as_materialized_series(), series.len(), self.null_on_oob)?;28series.take(&idx)29}3031#[allow(clippy::ptr_arg)]32fn evaluate_on_groups<'a>(33&self,34df: &DataFrame,35groups: &'a GroupPositions,36state: &ExecutionState,37) -> PolarsResult<AggregationContext<'a>> {38let mut ac = self.phys_expr.evaluate_on_groups(df, groups, state)?;39let mut idx = self.idx.evaluate_on_groups(df, groups, state)?;4041let ac_list = ac.aggregated_as_list();4243if self.returns_scalar {44polars_ensure!(45!matches!(idx.agg_state(), AggState::AggregatedList(_) | AggState::NotAggregated(_)),46ComputeError: "expected single index"47);4849// For returns_scalar=true, we can dispatch to `list.get`.50let idx = idx.flat_naive();51let idx = idx.cast(&DataType::Int64)?;52let idx = idx.i64().unwrap();53let taken = lst_get(ac_list.as_ref(), idx, true)?;5455ac.with_values_and_args(taken, true, Some(&self.expr), false, true)?;56ac.with_update_groups(UpdateGroups::No);57return Ok(ac);58}5960// Cast the indices to61// - IdxSize, if the idx only contains positive integers.62// - Int64, if the idx contains negative numbers.63// This may give false positives if there are masked out elements.64let idx = idx.aggregated_as_list();65let idx = idx.apply_to_inner(&|s| match s.dtype() {66dtype if dtype == &IDX_DTYPE => Ok(s),67dtype if dtype.is_unsigned_integer() => {68s.cast_with_options(&IDX_DTYPE, CastOptions::Strict)69},7071dtype if dtype.is_signed_integer() => {72let has_negative_integers = s.lt(0)?.any();73if has_negative_integers && dtype == &DataType::Int64 {74Ok(s)75} else if has_negative_integers {76s.cast_with_options(&DataType::Int64, CastOptions::Strict)77} else {78s.cast_with_options(&IDX_DTYPE, CastOptions::Overflowing)79}80},81_ => polars_bail!(82op = "gather/get",83got = s.dtype(),84expected = "integer type"85),86})?;8788let taken = if idx.inner_dtype() == &IDX_DTYPE {89// Fast path: all indices are positive.9091ac_list92.amortized_iter()93.zip(idx.amortized_iter())94.map(|(s, idx)| Some(s?.as_ref().take(idx?.as_ref().idx().unwrap())))95.map(|opt_res| opt_res.transpose())96.collect::<PolarsResult<ListChunked>>()?97.with_name(ac.get_values().name().clone())98} else {99// Slower path: some indices may be negative.100assert!(idx.inner_dtype() == &DataType::Int64);101102ac_list103.amortized_iter()104.zip(idx.amortized_iter())105.map(|(s, idx)| {106let s = s?;107let idx = idx?;108let idx = idx.as_ref().i64().unwrap();109let target_len = s.as_ref().len() as u64;110let idx = unary_elementwise_values(idx, |v| v.to_idx(target_len));111Some(s.as_ref().take(&idx))112})113.map(|opt_res| opt_res.transpose())114.collect::<PolarsResult<ListChunked>>()?115.with_name(ac.get_values().name().clone())116};117118ac.with_agg_state(AggState::AggregatedList(taken.into_column()));119ac.with_update_groups(UpdateGroups::WithSeriesLen);120Ok(ac)121}122123fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {124self.phys_expr.to_field(input_schema)125}126127fn is_scalar(&self) -> bool {128self.returns_scalar129}130}131132133