Path: blob/main/crates/polars-expr/src/expressions/structeval.rs
8406 views
use std::sync::Arc;12use polars_core::POOL;3use polars_core::error::{PolarsResult, polars_ensure};4use polars_core::frame::DataFrame;5use polars_core::prelude::*;6use polars_core::schema::Schema;7use polars_plan::dsl::Expr;8use rayon::prelude::*;910use super::PhysicalExpr;11#[cfg(feature = "dtype-struct")]12use crate::dispatch::struct_::with_fields;13use crate::prelude::{AggState, AggregationContext, UpdateGroups};14use crate::state::ExecutionState;1516#[derive(Clone)]17pub struct StructEvalExpr {18input: Arc<dyn PhysicalExpr>,19evaluation: Vec<Arc<dyn PhysicalExpr>>,20expr: Expr,21output_field: Field,22operates_on_scalar: bool,23allow_threading: bool,24}2526impl StructEvalExpr {27pub(crate) fn new(28input: Arc<dyn PhysicalExpr>,29evaluation: Vec<Arc<dyn PhysicalExpr>>,30expr: Expr,31output_field: Field,32operates_on_scalar: bool,33allow_threading: bool,34) -> Self {35Self {36input,37evaluation,38expr,39output_field,40operates_on_scalar,41allow_threading,42}43}44}4546impl StructEvalExpr {47fn apply_all_literal_elementwise<'a>(48&self,49mut acs: Vec<AggregationContext<'a>>,50) -> PolarsResult<AggregationContext<'a>> {51let cols = acs52.iter()53.map(|ac| ac.get_values().clone())54.collect::<Vec<_>>();55let out = with_fields(&cols)?;56polars_ensure!(57out.len() == 1,58ComputeError: "elementwise expression {:?} must return exactly 1 value on literals, got {}",59&self.expr, out.len()60);61let mut ac = acs.pop().unwrap();62ac.with_literal(out);63Ok(ac)64}6566fn apply_elementwise<'a>(67&self,68mut acs: Vec<AggregationContext<'a>>,69must_aggregate: bool,70) -> PolarsResult<AggregationContext<'a>> {71// At this stage, we either have (with or without LiteralScalars):72// - one or more AggregatedList or NotAggregated ACs73// - one or more AggregatedScalar ACs7475let mut previous = None;76for ac in acs.iter_mut() {77if matches!(78ac.state,79AggState::LiteralScalar(_) | AggState::AggregatedScalar(_)80) {81continue;82}8384if must_aggregate {85ac.aggregated();86}8788if matches!(ac.state, AggState::AggregatedList(_)) {89if let Some(p) = previous {90ac.groups().check_lengths(p)?;91}92previous = Some(ac.groups());93}94}9596// At this stage, we do not have both AggregatedList and NotAggregated ACs9798// The first AC represents the `input` and will be used as the base AC.99let base_ac_idx = 0;100101match acs[base_ac_idx].agg_state() {102AggState::AggregatedList(s) => {103let aggregated = acs.iter().any(|ac| ac.is_aggregated());104let ca = s.list().unwrap();105let input_len = s.len();106107let out = ca.apply_to_inner(&|_| {108let cols = acs109.iter()110.map(|ac| ac.flat_naive().into_owned())111.collect::<Vec<_>>();112Ok(with_fields(&cols)?.as_materialized_series().clone())113})?;114115let out = out.into_column();116assert!(input_len == out.len());117118let mut ac = acs.swap_remove(base_ac_idx);119ac.with_values_and_args(120out,121aggregated,122Some(&self.expr),123false,124self.is_scalar(),125)?;126Ok(ac)127},128_ => {129let aggregated = acs.iter().any(|ac| ac.is_aggregated());130assert!(aggregated == self.is_scalar());131132let cols = acs133.iter()134.map(|ac| ac.flat_naive().into_owned())135.collect::<Vec<_>>();136137let input_len = cols[base_ac_idx].len();138let out = with_fields(&cols)?;139assert!(input_len == out.len());140141let mut ac = acs.swap_remove(base_ac_idx);142ac.with_values_and_args(143out,144aggregated,145Some(&self.expr),146false,147self.is_scalar(),148)?;149Ok(ac)150},151}152}153154fn apply_group_aware<'a>(155&self,156mut acs: Vec<AggregationContext<'a>>,157) -> PolarsResult<AggregationContext<'a>> {158let len = acs[0].groups.len();159let mut iters = acs160.iter_mut()161.map(|ac| ac.iter_groups(true))162.collect::<Vec<_>>();163let ca = (0..len)164.map(|_| {165let mut cols = Vec::with_capacity(iters.len());166for i in &mut iters {167match i.next().unwrap() {168None => return Ok(None),169Some(s) => cols.push(s.as_ref().clone().into_column()),170}171}172let out = with_fields(&cols)?;173Ok(Some(out))174})175.collect::<PolarsResult<ListChunked>>()?;176drop(iters);177178// Finish apply groups; see also ApplyExpr for the reference solution.179let ac = acs.swap_remove(0);180self.finish_apply_groups(ac, ca)181}182183fn finish_apply_groups<'a>(184&self,185mut ac: AggregationContext<'a>,186ca: ListChunked,187) -> PolarsResult<AggregationContext<'a>> {188let col = if self.is_scalar() {189let out = ca190.explode(ExplodeOptions {191empty_as_null: true,192keep_nulls: true,193})194.unwrap();195// if the explode doesn't return the same len, it wasn't scalar.196polars_ensure!(out.len() == ca.len(), InvalidOperation: "expected scalar for expr: {}, got {}", self.expr, &out);197ac.update_groups = UpdateGroups::No;198out.into_column()199} else {200ac.with_update_groups(UpdateGroups::WithSeriesLen);201ca.into_series().into()202};203204ac.with_values_and_args(col, true, self.as_expression(), false, self.is_scalar())?;205206Ok(ac)207}208}209210impl PhysicalExpr for StructEvalExpr {211fn as_expression(&self) -> Option<&Expr> {212Some(&self.expr)213}214215fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {216let input = self.input.evaluate(df, state)?;217218// Set ExecutionState.219let mut state = state.clone();220let mut eval = Vec::with_capacity(self.evaluation.len() + 1);221let input_len = input.len();222223state.with_fields = Some(Arc::new(input.struct_()?.clone()));224225// Collect evaluation fields; input goes first.226eval.push(input);227228let f = |e: &Arc<dyn PhysicalExpr>| {229let result = e.evaluate(df, &state)?;230polars_ensure!(231result.len() == input_len || result.len() == 1,232ShapeMismatch: "struct.with_fields expressions must have matching or unit length"233);234Ok(result)235};236let cols = if self.allow_threading {237POOL.install(|| {238self.evaluation239.par_iter()240.map(f)241.collect::<PolarsResult<Vec<_>>>()242})243} else {244self.evaluation245.iter()246.map(f)247.collect::<PolarsResult<Vec<_>>>()248};249for col in cols? {250eval.push(col);251}252253// Apply with_fields.254with_fields(&eval)255}256257fn evaluate_on_groups<'a>(258&self,259df: &DataFrame,260groups: &'a GroupPositions,261state: &ExecutionState,262) -> PolarsResult<AggregationContext<'a>> {263// The evaluation is similar to a regular Function, with the modification that the input264// is evaluated first, and retained for future use in the ExecutionState.265266// Evaluate input.267let mut ac = self.input.evaluate_on_groups(df, groups, state)?;268269ac.groups();270ac.set_groups_for_undefined_agg_states();271272// Snap the AC into the ExecutionState for re-use when Field is evaluated.273let mut state = state.clone();274state.with_fields_ac = Some(Arc::new(ac.into_static()));275276// Collect evaluation fields.277let mut acs = Vec::with_capacity(self.evaluation.len() + 1);278acs.push(ac);279280let f = |e: &Arc<dyn PhysicalExpr>| e.evaluate_on_groups(df, groups, &state);281let acs_eval = if self.allow_threading {282POOL.install(|| {283self.evaluation284.par_iter()285.map(f)286.collect::<PolarsResult<Vec<_>>>()287})288} else {289self.evaluation290.iter()291.map(f)292.collect::<PolarsResult<Vec<_>>>()293};294for ac in acs_eval? {295acs.push(ac)296}297298// Revert ExecutionState.299state.with_fields_ac = None;300301// Merge the `evaluation` back into the `input` struct.302// @NOTE. From this point on, we are dealing with a regular Function `with_fields`, which is303// elementwise top-level and not fallible. We leverage the reference dispatch for ApplyExpr,304// but simplified.305306// Collect statistics on input aggstates307let mut has_agg_list = false;308let mut has_agg_scalar = false;309let mut has_not_agg = false;310let mut has_not_agg_with_overlapping_groups = false;311let mut not_agg_groups_may_diverge = false;312313let mut previous: Option<&AggregationContext<'_>> = None;314for ac in &acs {315match ac.state {316AggState::AggregatedList(_) => {317has_agg_list = true;318},319AggState::AggregatedScalar(_) => has_agg_scalar = true,320AggState::NotAggregated(_) => {321has_not_agg = true;322if let Some(p) = previous {323not_agg_groups_may_diverge |= !p.groups.is_same(&ac.groups)324}325previous = Some(ac);326if ac.groups.is_overlapping() {327has_not_agg_with_overlapping_groups = true;328}329},330AggState::LiteralScalar(_) => {},331}332}333334let all_literal = !(has_agg_list || has_agg_scalar || has_not_agg);335let elementwise_must_aggregate =336has_not_agg && (has_agg_list || not_agg_groups_may_diverge);337338if all_literal {339// Fast path340self.apply_all_literal_elementwise(acs)341} else if has_agg_scalar && (has_agg_list || has_not_agg) {342// Not compatible343self.apply_group_aware(acs)344} else if elementwise_must_aggregate && has_not_agg_with_overlapping_groups {345// Compatible but calling aggregated() is too expensive346self.apply_group_aware(acs)347} else {348// Broadcast in NotAgg or AggList requires group_aware349acs.iter_mut().filter(|ac| !ac.is_literal()).for_each(|ac| {350ac.groups();351});352let has_broadcast =353if let Some(base_ac_idx) = acs.iter().position(|ac| !ac.is_literal()) {354acs.iter()355.enumerate()356.filter(|(i, ac)| *i != base_ac_idx && !ac.is_literal())357.any(|(_, ac)| {358acs[base_ac_idx]359.groups360.iter()361.zip(ac.groups.iter())362.any(|(l, r)| l.len() != r.len() && (l.len() == 1 || r.len() == 1))363})364} else {365false366};367if has_broadcast {368// Broadcast fall-back.369self.apply_group_aware(acs)370} else {371self.apply_elementwise(acs, elementwise_must_aggregate)372}373}374}375376fn to_field(&self, _input_schema: &Schema) -> PolarsResult<Field> {377Ok(self.output_field.clone())378}379380fn is_scalar(&self) -> bool {381self.operates_on_scalar382}383}384385386