Path: blob/main/crates/polars-expr/src/expressions/mod.rs
6940 views
mod aggregation;1mod alias;2mod apply;3mod binary;4mod cast;5mod column;6mod count;7mod eval;8mod filter;9mod gather;10mod group_iter;11mod literal;12#[cfg(feature = "dynamic_group_by")]13mod rolling;14mod slice;15mod sort;16mod sortby;17mod ternary;18mod window;1920use std::borrow::Cow;21use std::fmt::{Display, Formatter};2223pub(crate) use aggregation::*;24pub(crate) use alias::*;25pub(crate) use apply::*;26use arrow::array::ArrayRef;27use arrow::legacy::utils::CustomIterTools;28pub(crate) use binary::*;29pub(crate) use cast::*;30pub(crate) use column::*;31pub(crate) use count::*;32pub(crate) use eval::*;33pub(crate) use filter::*;34pub(crate) use gather::*;35pub(crate) use literal::*;36use polars_core::prelude::*;37use polars_io::predicates::PhysicalIoExpr;38use polars_plan::prelude::*;39#[cfg(feature = "dynamic_group_by")]40pub(crate) use rolling::RollingExpr;41pub(crate) use slice::*;42pub(crate) use sort::*;43pub(crate) use sortby::*;44pub(crate) use ternary::*;45pub use window::window_function_format_order_by;46pub(crate) use window::*;4748use crate::state::ExecutionState;4950#[derive(Clone, Debug)]51pub enum AggState {52/// Already aggregated: `.agg_list(group_tuples)` is called53/// and produced a `Series` of dtype `List`54AggregatedList(Column),55/// Already aggregated: `.agg` is called on an aggregation56/// that produces a scalar.57/// think of `sum`, `mean`, `variance` like aggregations.58AggregatedScalar(Column),59/// Not yet aggregated: `agg_list` still has to be called.60NotAggregated(Column),61/// A literal scalar value.62LiteralScalar(Column),63}6465impl AggState {66fn try_map<F>(&self, func: F) -> PolarsResult<Self>67where68F: FnOnce(&Column) -> PolarsResult<Column>,69{70Ok(match self {71AggState::AggregatedList(c) => AggState::AggregatedList(func(c)?),72AggState::AggregatedScalar(c) => AggState::AggregatedScalar(func(c)?),73AggState::LiteralScalar(c) => AggState::LiteralScalar(func(c)?),74AggState::NotAggregated(c) => AggState::NotAggregated(func(c)?),75})76}7778fn is_scalar(&self) -> bool {79matches!(self, Self::AggregatedScalar(_))80}81}8283// lazy update strategy84#[cfg_attr(debug_assertions, derive(Debug))]85#[derive(PartialEq, Clone, Copy)]86pub(crate) enum UpdateGroups {87/// don't update groups88No,89/// use the length of the current groups to determine new sorted indexes, preferred90/// for performance91WithGroupsLen,92/// use the series list offsets to determine the new group lengths93/// this one should be used when the length has changed. Note that94/// the series should be aggregated state or else it will panic.95WithSeriesLen,96}9798#[cfg_attr(debug_assertions, derive(Debug))]99pub struct AggregationContext<'a> {100/// Can be in one of two states101/// 1. already aggregated as list102/// 2. flat (still needs the grouptuples to aggregate)103state: AggState,104/// group tuples for AggState105groups: Cow<'a, GroupPositions>,106/// This is used to determined if we need to update the groups107/// into a sorted groups. We do this lazily, so that this work only is108/// done when the groups are needed109update_groups: UpdateGroups,110/// This is true when the Series and Groups still have all111/// their original values. Not the case when filtered112original_len: bool,113}114115impl<'a> AggregationContext<'a> {116pub(crate) fn groups(&mut self) -> &Cow<'a, GroupPositions> {117match self.update_groups {118UpdateGroups::No => {},119UpdateGroups::WithGroupsLen => {120// the groups are unordered121// and the series is aggregated with this groups122// so we need to recreate new grouptuples that123// match the exploded Series124let mut offset = 0 as IdxSize;125126match self.groups.as_ref().as_ref() {127GroupsType::Idx(groups) => {128let groups = groups129.iter()130.map(|g| {131let len = g.1.len() as IdxSize;132let new_offset = offset + len;133let out = [offset, len];134offset = new_offset;135out136})137.collect();138self.groups = Cow::Owned(139GroupsType::Slice {140groups,141rolling: false,142}143.into_sliceable(),144)145},146// sliced groups are already in correct order147GroupsType::Slice { .. } => {},148}149self.update_groups = UpdateGroups::No;150},151UpdateGroups::WithSeriesLen => {152let s = self.get_values().clone();153self.det_groups_from_list(s.as_materialized_series());154},155}156&self.groups157}158159pub(crate) fn get_values(&self) -> &Column {160match &self.state {161AggState::NotAggregated(s)162| AggState::AggregatedScalar(s)163| AggState::AggregatedList(s) => s,164AggState::LiteralScalar(s) => s,165}166}167168pub fn agg_state(&self) -> &AggState {169&self.state170}171172pub(crate) fn is_not_aggregated(&self) -> bool {173matches!(174&self.state,175AggState::NotAggregated(_) | AggState::LiteralScalar(_)176)177}178179pub(crate) fn is_aggregated(&self) -> bool {180!self.is_not_aggregated()181}182183pub(crate) fn is_literal(&self) -> bool {184matches!(self.state, AggState::LiteralScalar(_))185}186187/// # Arguments188/// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its189/// the columns dtype)190fn new(191column: Column,192groups: Cow<'a, GroupPositions>,193aggregated: bool,194) -> AggregationContext<'a> {195let series = if aggregated {196assert_eq!(column.len(), groups.len());197AggState::AggregatedScalar(column)198} else {199AggState::NotAggregated(column)200};201202Self {203state: series,204groups,205update_groups: UpdateGroups::No,206original_len: true,207}208}209210fn with_agg_state(&mut self, agg_state: AggState) {211self.state = agg_state;212}213214fn from_agg_state(215agg_state: AggState,216groups: Cow<'a, GroupPositions>,217) -> AggregationContext<'a> {218Self {219state: agg_state,220groups,221update_groups: UpdateGroups::No,222original_len: true,223}224}225226pub(crate) fn set_original_len(&mut self, original_len: bool) -> &mut Self {227self.original_len = original_len;228self229}230231pub(crate) fn with_update_groups(&mut self, update: UpdateGroups) -> &mut Self {232self.update_groups = update;233self234}235236fn det_groups_from_list(&mut self, s: &Series) {237let mut offset = 0 as IdxSize;238let list = s239.list()240.expect("impl error, should be a list at this point");241242match list.chunks().len() {2431 => {244let arr = list.downcast_iter().next().unwrap();245let offsets = arr.offsets().as_slice();246247let mut previous = 0i64;248let groups = offsets[1..]249.iter()250.map(|&o| {251let len = (o - previous) as IdxSize;252let new_offset = offset + len;253254previous = o;255let out = [offset, len];256offset = new_offset;257out258})259.collect_trusted();260self.groups = Cow::Owned(261GroupsType::Slice {262groups,263rolling: false,264}265.into_sliceable(),266);267},268_ => {269let groups = {270self.get_values()271.list()272.expect("impl error, should be a list at this point")273.amortized_iter()274.map(|s| {275if let Some(s) = s {276let len = s.as_ref().len() as IdxSize;277let new_offset = offset + len;278let out = [offset, len];279offset = new_offset;280out281} else {282[offset, 0]283}284})285.collect_trusted()286};287self.groups = Cow::Owned(288GroupsType::Slice {289groups,290rolling: false,291}292.into_sliceable(),293);294},295}296self.update_groups = UpdateGroups::No;297}298299/// # Arguments300/// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its301/// the columns dtype)302pub(crate) fn with_values(303&mut self,304column: Column,305aggregated: bool,306expr: Option<&Expr>,307) -> PolarsResult<&mut Self> {308self.with_values_and_args(309column,310aggregated,311expr,312false,313self.agg_state().is_scalar(),314)315}316317pub(crate) fn with_values_and_args(318&mut self,319column: Column,320aggregated: bool,321expr: Option<&Expr>,322// if the applied function was a `map` instead of an `apply`323// this will keep functions applied over literals as literals: F(lit) = lit324mapped: bool,325returns_scalar: bool,326) -> PolarsResult<&mut Self> {327self.state = match (aggregated, column.dtype()) {328(true, &DataType::List(_)) if !returns_scalar => {329if column.len() != self.groups.len() {330let fmt_expr = if let Some(e) = expr {331format!("'{e:?}' ")332} else {333String::new()334};335polars_bail!(336ComputeError:337"aggregation expression '{}' produced a different number of elements: {} \338than the number of groups: {} (this is likely invalid)",339fmt_expr, column.len(), self.groups.len(),340);341}342AggState::AggregatedList(column)343},344(true, _) => AggState::AggregatedScalar(column),345_ => {346match self.state {347// already aggregated to sum, min even this series was flattened it never could348// retrieve the length before grouping, so it stays in this state.349AggState::AggregatedScalar(_) => AggState::AggregatedScalar(column),350// applying a function on a literal, keeps the literal state351AggState::LiteralScalar(_) if column.len() == 1 && mapped => {352AggState::LiteralScalar(column)353},354_ => AggState::NotAggregated(column.into_column()),355}356},357};358Ok(self)359}360361pub(crate) fn with_literal(&mut self, column: Column) -> &mut Self {362self.state = AggState::LiteralScalar(column);363self364}365366/// Update the group tuples367pub(crate) fn with_groups(&mut self, groups: GroupPositions) -> &mut Self {368if let AggState::AggregatedList(_) = self.agg_state() {369// In case of new groups, a series always needs to be flattened370self.with_values(self.flat_naive().into_owned(), false, None)371.unwrap();372}373self.groups = Cow::Owned(groups);374// make sure that previous setting is not used375self.update_groups = UpdateGroups::No;376self377}378379pub(crate) fn _implode_no_agg(&mut self) {380match self.state.clone() {381AggState::NotAggregated(_) => {382let _ = self.aggregated();383let AggState::AggregatedList(s) = self.state.clone() else {384unreachable!()385};386self.state = AggState::AggregatedScalar(s);387},388AggState::AggregatedList(s) => {389self.state = AggState::AggregatedScalar(s);390},391_ => unreachable!("should only be called in non-agg/list-agg state by aggregation.rs"),392}393}394395/// Aggregate into `ListChunked`.396pub fn aggregated_as_list<'b>(&'b mut self) -> Cow<'b, ListChunked> {397self.aggregated();398let out = self.get_values();399match self.agg_state() {400AggState::AggregatedScalar(_) => Cow::Owned(out.as_list()),401_ => Cow::Borrowed(out.list().unwrap()),402}403}404405/// Get the aggregated version of the series.406pub fn aggregated(&mut self) -> Column {407// we clone, because we only want to call `self.groups()` if needed.408// self groups may instantiate new groups and thus can be expensive.409match self.state.clone() {410AggState::NotAggregated(s) => {411// The groups are determined lazily and in case of a flat/non-aggregated412// series we use the groups to aggregate the list413// because this is lazy, we first must to update the groups414// by calling .groups()415self.groups();416#[cfg(debug_assertions)]417{418if self.groups.len() > s.len() {419polars_warn!(420"groups may be out of bounds; more groups than elements in a series is only possible in dynamic group_by"421)422}423}424425// SAFETY:426// groups are in bounds427let out = unsafe { s.agg_list(&self.groups) };428self.state = AggState::AggregatedList(out.clone());429430self.update_groups = UpdateGroups::WithGroupsLen;431out432},433AggState::AggregatedList(s) | AggState::AggregatedScalar(s) => s.into_column(),434AggState::LiteralScalar(s) => {435self.groups();436let rows = self.groups.len();437let s = s.implode().unwrap();438let s = s.new_from_index(0, rows);439let s = s.into_column();440self.state = AggState::AggregatedList(s.clone());441self.with_update_groups(UpdateGroups::WithSeriesLen);442s.clone()443},444}445}446447/// Get the final aggregated version of the series.448pub fn finalize(&mut self) -> Column {449// we clone, because we only want to call `self.groups()` if needed.450// self groups may instantiate new groups and thus can be expensive.451match &self.state {452AggState::LiteralScalar(c) => {453let c = c.clone();454self.groups();455let rows = self.groups.len();456c.new_from_index(0, rows)457},458_ => self.aggregated(),459}460}461462// If a binary or ternary function has both of these branches true, it should463// flatten the list464fn arity_should_explode(&self) -> bool {465use AggState::*;466match self.agg_state() {467LiteralScalar(s) => s.len() == 1,468AggregatedScalar(_) => true,469_ => false,470}471}472473pub fn get_final_aggregation(mut self) -> (Column, Cow<'a, GroupPositions>) {474let _ = self.groups();475let groups = self.groups;476match self.state {477AggState::NotAggregated(c) => (c, groups),478AggState::AggregatedScalar(c) => (c, groups),479AggState::LiteralScalar(c) => (c, groups),480AggState::AggregatedList(c) => {481let flattened = c.explode(true).unwrap();482let groups = groups.into_owned();483// unroll the possible flattened state484// say we have groups with overlapping windows:485//486// offset, len487// 0, 1488// 0, 2489// 0, 4490//491// gets aggregation492//493// [0]494// [0, 1],495// [0, 1, 2, 3]496//497// before aggregation the column was498// [0, 1, 2, 3]499// but explode on this list yields500// [0, 0, 1, 0, 1, 2, 3]501//502// so we unroll the groups as503//504// [0, 1]505// [1, 2]506// [3, 4]507let groups = groups.unroll();508(flattened, Cow::Owned(groups))509},510}511}512513/// Get the not-aggregated version of the series.514/// Note that we call it naive, because if a previous expr515/// has filtered or sorted this, this information is in the516/// group tuples not the flattened series.517pub(crate) fn flat_naive(&self) -> Cow<'_, Column> {518match &self.state {519AggState::NotAggregated(c) => Cow::Borrowed(c),520AggState::AggregatedList(c) => {521#[cfg(debug_assertions)]522{523// panic so we find cases where we accidentally explode overlapping groups524// we don't want this as this can create a lot of data525if let GroupsType::Slice { rolling: true, .. } = self.groups.as_ref().as_ref() {526panic!(527"implementation error, polars should not hit this branch for overlapping groups"528)529}530}531532// We should not insert nulls, otherwise the offsets in the groups will not be correct.533Cow::Owned(c.explode(true).unwrap())534},535AggState::AggregatedScalar(c) => Cow::Borrowed(c),536AggState::LiteralScalar(c) => Cow::Borrowed(c),537}538}539540/// Take the series.541pub(crate) fn take(&mut self) -> Column {542let c = match &mut self.state {543AggState::NotAggregated(c)544| AggState::AggregatedScalar(c)545| AggState::AggregatedList(c) => c,546AggState::LiteralScalar(c) => c,547};548std::mem::take(c)549}550}551552/// Take a DataFrame and evaluate the expressions.553/// Implement this for Column, lt, eq, etc554pub trait PhysicalExpr: Send + Sync {555fn as_expression(&self) -> Option<&Expr> {556None557}558559/// Take a DataFrame and evaluate the expression.560fn evaluate(&self, df: &DataFrame, _state: &ExecutionState) -> PolarsResult<Column>;561562/// Some expression that are not aggregations can be done per group563/// Think of sort, slice, filter, shift, etc.564/// defaults to ignoring the group565///566/// This method is called by an aggregation function.567///568/// In case of a simple expr, like 'column', the groups are ignored and the column is returned.569/// In case of an expr where group behavior makes sense, this method is called.570/// For a filter operation for instance, a Series is created per groups and filtered.571///572/// An implementation of this method may apply an aggregation on the groups only. For instance573/// on a shift, the groups are first aggregated to a `ListChunked` and the shift is applied per574/// group. The implementation then has to return the `Series` exploded (because a later aggregation575/// will use the group tuples to aggregate). The group tuples also have to be updated, because576/// aggregation to a list sorts the exploded `Series` by group.577///578/// This has some gotcha's. An implementation may also change the group tuples instead of579/// the `Series`.580///581// we allow this because we pass the vec to the Cow582// Note to self: Don't be smart and dispatch to evaluate as default implementation583// this means filters will be incorrect and lead to invalid results down the line584#[allow(clippy::ptr_arg)]585fn evaluate_on_groups<'a>(586&self,587df: &DataFrame,588groups: &'a GroupPositions,589state: &ExecutionState,590) -> PolarsResult<AggregationContext<'a>>;591592/// Get the output field of this expr593fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field>;594595/// Convert to a partitioned aggregator.596fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {597None598}599600fn is_literal(&self) -> bool {601false602}603fn is_scalar(&self) -> bool;604}605606impl Display for &dyn PhysicalExpr {607fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {608match self.as_expression() {609None => Ok(()),610Some(e) => write!(f, "{e:?}"),611}612}613}614615/// Wrapper struct that allow us to use a PhysicalExpr in polars-io.616///617/// This is used to filter rows during the scan of file.618pub struct PhysicalIoHelper {619pub expr: Arc<dyn PhysicalExpr>,620pub has_window_function: bool,621}622623impl PhysicalIoExpr for PhysicalIoHelper {624fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {625let mut state: ExecutionState = Default::default();626if self.has_window_function {627state.insert_has_window_function_flag();628}629self.expr.evaluate(df, &state).map(|c| {630// IO expression result should be boolean-typed.631debug_assert_eq!(c.dtype(), &DataType::Boolean);632(if c.len() == 1 && df.height() != 1 {633// filter(lit(True)) will hit here.634c.new_from_index(0, df.height())635} else {636c637})638.take_materialized_series()639})640}641}642643pub fn phys_expr_to_io_expr(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalIoExpr> {644let has_window_function = if let Some(expr) = expr.as_expression() {645expr.into_iter()646.any(|expr| matches!(expr, Expr::Window { .. }))647} else {648false649};650Arc::new(PhysicalIoHelper {651expr,652has_window_function,653}) as Arc<dyn PhysicalIoExpr>654}655656pub trait PartitionedAggregation: Send + Sync + PhysicalExpr {657/// This is called in partitioned aggregation.658/// Partitioned results may differ from aggregation results.659/// For instance, for a `mean` operation a partitioned result660/// needs to return the `sum` and the `valid_count` (length - null count).661///662/// A final aggregation can then take the sum of sums and sum of valid_counts663/// to produce a final mean.664#[allow(clippy::ptr_arg)]665fn evaluate_partitioned(666&self,667df: &DataFrame,668groups: &GroupPositions,669state: &ExecutionState,670) -> PolarsResult<Column>;671672/// Called to merge all the partitioned results in a final aggregate.673#[allow(clippy::ptr_arg)]674fn finalize(675&self,676partitioned: Column,677groups: &GroupPositions,678state: &ExecutionState,679) -> PolarsResult<Column>;680}681682683