Path: blob/main/crates/polars-expr/src/expressions/mod.rs
8421 views
mod aggregation;1mod alias;2mod apply;3mod binary;4mod cast;5mod column;6mod count;7mod element;8mod eval;9#[cfg(feature = "dtype-struct")]10mod field;11mod filter;12mod gather;13mod group_iter;14mod literal;15#[cfg(feature = "dynamic_group_by")]16mod rolling;17mod slice;18mod sort;19mod sortby;20#[cfg(feature = "dtype-struct")]21mod structeval;22mod ternary;23mod window;2425use std::borrow::Cow;26use std::fmt::{Display, Formatter};2728pub(crate) use aggregation::*;29pub(crate) use alias::*;30pub(crate) use apply::*;31use arrow::array::ArrayRef;32use arrow::bitmap::MutableBitmap;33use arrow::legacy::utils::CustomIterTools;34pub(crate) use binary::*;35pub(crate) use cast::*;36pub(crate) use column::*;37pub(crate) use count::*;38pub(crate) use element::*;39pub(crate) use eval::*;40#[cfg(feature = "dtype-struct")]41pub(crate) use field::*;42pub(crate) use filter::*;43pub(crate) use gather::*;44pub(crate) use literal::*;45use polars_core::prelude::*;46use polars_io::predicates::PhysicalIoExpr;47use polars_plan::prelude::*;48#[cfg(feature = "dynamic_group_by")]49pub(crate) use rolling::RollingExpr;50pub(crate) use slice::*;51pub(crate) use sort::*;52pub(crate) use sortby::*;53#[cfg(feature = "dtype-struct")]54pub(crate) use structeval::*;55pub(crate) use ternary::*;56pub use window::window_function_format_order_by;57pub(crate) use window::*;5859use crate::state::ExecutionState;6061#[derive(Clone, Debug)]62pub enum AggState {63/// Already aggregated: `.agg_list(group_tuples)` is called64/// and produced a `Series` of dtype `List`65AggregatedList(Column),66/// Already aggregated: `.agg` is called on an aggregation67/// that produces a scalar.68/// think of `sum`, `mean`, `variance` like aggregations.69AggregatedScalar(Column),70/// Not yet aggregated: `agg_list` still has to be called.71NotAggregated(Column),72/// A literal scalar value.73LiteralScalar(Column),74}7576impl AggState {77fn try_map<F>(&self, func: F) -> PolarsResult<Self>78where79F: FnOnce(&Column) -> PolarsResult<Column>,80{81Ok(match self {82AggState::AggregatedList(c) => AggState::AggregatedList(func(c)?),83AggState::AggregatedScalar(c) => AggState::AggregatedScalar(func(c)?),84AggState::LiteralScalar(c) => AggState::LiteralScalar(func(c)?),85AggState::NotAggregated(c) => AggState::NotAggregated(func(c)?),86})87}8889fn is_scalar(&self) -> bool {90matches!(self, Self::AggregatedScalar(_))91}9293pub fn name(&self) -> &PlSmallStr {94match self {95AggState::AggregatedList(s)96| AggState::NotAggregated(s)97| AggState::LiteralScalar(s)98| AggState::AggregatedScalar(s) => s.name(),99}100}101102pub fn flat_dtype(&self) -> &DataType {103match self {104AggState::AggregatedList(s) => s.dtype().inner_dtype().unwrap(),105AggState::NotAggregated(s)106| AggState::LiteralScalar(s)107| AggState::AggregatedScalar(s) => s.dtype(),108}109}110}111112// lazy update strategy113#[derive(Debug, PartialEq, Clone, Copy)]114pub(crate) enum UpdateGroups {115/// don't update groups116No,117/// use the length of the current groups to determine new sorted indexes, preferred118/// for performance119WithGroupsLen,120/// use the series list offsets to determine the new group lengths121/// this one should be used when the length has changed. Note that122/// the series should be aggregated state or else it will panic.123WithSeriesLen,124}125126#[cfg_attr(debug_assertions, derive(Debug))]127pub struct AggregationContext<'a> {128/// Can be in one of two states129/// 1. already aggregated as list130/// 2. flat (still needs the grouptuples to aggregate)131///132/// When aggregation state is LiteralScalar or AggregatedScalar, the group values are not133/// related to the state data anymore. The number of groups is still accurate.134pub(crate) state: AggState,135/// group tuples for AggState136pub(crate) groups: Cow<'a, GroupPositions>,137/// This is used to determined if we need to update the groups138/// into a sorted groups. We do this lazily, so that this work only is139/// done when the groups are needed140pub(crate) update_groups: UpdateGroups,141/// This is true when the Series and Groups still have all142/// their original values. Not the case when filtered143pub(crate) original_len: bool,144}145146impl<'a> AggregationContext<'a> {147pub(crate) fn groups(&mut self) -> &Cow<'a, GroupPositions> {148match self.update_groups {149UpdateGroups::No => {},150UpdateGroups::WithGroupsLen => {151// the groups are unordered152// and the series is aggregated with this groups153// so we need to recreate new grouptuples that154// match the exploded Series155let mut offset = 0 as IdxSize;156157match self.groups.as_ref().as_ref() {158GroupsType::Idx(groups) => {159let groups = groups160.iter()161.map(|g| {162let len = g.1.len() as IdxSize;163let new_offset = offset + len;164let out = [offset, len];165offset = new_offset;166out167})168.collect();169self.groups =170Cow::Owned(GroupsType::new_slice(groups, false, true).into_sliceable())171},172// sliced groups are already in correct order,173// Update offsets in the case of overlapping groups174// e.g. [0,2], [1,3], [2,4] becomes [0,2], [2,3], [5,4]175GroupsType::Slice { groups, .. } => {176// unroll177let groups = groups178.iter()179.map(|g| {180let len = g[1];181let new = [offset, g[1]];182offset += len;183new184})185.collect();186self.groups =187Cow::Owned(GroupsType::new_slice(groups, false, true).into_sliceable())188},189}190self.update_groups = UpdateGroups::No;191},192UpdateGroups::WithSeriesLen => {193let s = self.get_values().clone();194self.det_groups_from_list(s.as_materialized_series());195},196}197&self.groups198}199200pub(crate) fn get_values(&self) -> &Column {201match &self.state {202AggState::NotAggregated(s)203| AggState::AggregatedScalar(s)204| AggState::AggregatedList(s) => s,205AggState::LiteralScalar(s) => s,206}207}208209pub fn agg_state(&self) -> &AggState {210&self.state211}212213pub(crate) fn is_not_aggregated(&self) -> bool {214matches!(215&self.state,216AggState::NotAggregated(_) | AggState::LiteralScalar(_)217)218}219220pub(crate) fn is_aggregated(&self) -> bool {221!self.is_not_aggregated()222}223224pub(crate) fn is_literal(&self) -> bool {225matches!(self.state, AggState::LiteralScalar(_))226}227228/// # Arguments229/// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its230/// the columns dtype)231fn new(232column: Column,233groups: Cow<'a, GroupPositions>,234aggregated: bool,235) -> AggregationContext<'a> {236let series = if aggregated {237assert_eq!(column.len(), groups.len());238AggState::AggregatedScalar(column)239} else {240AggState::NotAggregated(column)241};242243Self {244state: series,245groups,246update_groups: UpdateGroups::No,247original_len: true,248}249}250251fn with_agg_state(&mut self, agg_state: AggState) {252self.state = agg_state;253}254255fn from_agg_state(256agg_state: AggState,257groups: Cow<'a, GroupPositions>,258) -> AggregationContext<'a> {259Self {260state: agg_state,261groups,262update_groups: UpdateGroups::No,263original_len: true,264}265}266267pub(crate) fn set_original_len(&mut self, original_len: bool) -> &mut Self {268self.original_len = original_len;269self270}271272pub(crate) fn with_update_groups(&mut self, update: UpdateGroups) -> &mut Self {273self.update_groups = update;274self275}276277fn det_groups_from_list(&mut self, s: &Series) {278let mut offset = 0 as IdxSize;279let list = s280.list()281.expect("impl error, should be a list at this point");282283match list.chunks().len() {2841 => {285let arr = list.downcast_iter().next().unwrap();286let offsets = arr.offsets().as_slice();287288let mut previous = 0i64;289let groups = offsets[1..]290.iter()291.map(|&o| {292let len = (o - previous) as IdxSize;293let new_offset = offset + len;294295previous = o;296let out = [offset, len];297offset = new_offset;298out299})300.collect_trusted();301self.groups =302Cow::Owned(GroupsType::new_slice(groups, false, true).into_sliceable());303},304_ => {305let groups = {306self.get_values()307.list()308.expect("impl error, should be a list at this point")309.amortized_iter()310.map(|s| {311if let Some(s) = s {312let len = s.as_ref().len() as IdxSize;313let new_offset = offset + len;314let out = [offset, len];315offset = new_offset;316out317} else {318[offset, 0]319}320})321.collect_trusted()322};323self.groups =324Cow::Owned(GroupsType::new_slice(groups, false, true).into_sliceable());325},326}327self.update_groups = UpdateGroups::No;328}329330/// # Arguments331/// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its332/// the columns dtype)333pub(crate) fn with_values(334&mut self,335column: Column,336aggregated: bool,337expr: Option<&Expr>,338) -> PolarsResult<&mut Self> {339self.with_values_and_args(340column,341aggregated,342expr,343false,344self.agg_state().is_scalar(),345)346}347348pub(crate) fn with_values_and_args(349&mut self,350column: Column,351aggregated: bool,352expr: Option<&Expr>,353// if the applied function was a `map` instead of an `apply`354// this will keep functions applied over literals as literals: F(lit) = lit355preserve_literal: bool,356returns_scalar: bool,357) -> PolarsResult<&mut Self> {358self.state = match (aggregated, column.dtype()) {359(true, &DataType::List(_)) if !returns_scalar => {360if column.len() != self.groups.len() {361let fmt_expr = if let Some(e) = expr {362format!("'{e:?}' ")363} else {364String::new()365};366polars_bail!(367ComputeError:368"aggregation expression '{}' produced a different number of elements: {} \369than the number of groups: {} (this is likely invalid)",370fmt_expr, column.len(), self.groups.len(),371);372}373AggState::AggregatedList(column)374},375(true, _) => AggState::AggregatedScalar(column),376_ => {377match self.state {378// already aggregated to sum, min even this series was flattened it never could379// retrieve the length before grouping, so it stays in this state.380AggState::AggregatedScalar(_) => AggState::AggregatedScalar(column),381// applying a function on a literal, keeps the literal state382AggState::LiteralScalar(_) if column.len() == 1 && preserve_literal => {383AggState::LiteralScalar(column)384},385_ => AggState::NotAggregated(column.into_column()),386}387},388};389Ok(self)390}391392pub(crate) fn with_literal(&mut self, column: Column) -> &mut Self {393self.state = AggState::LiteralScalar(column);394self395}396397/// Update the group tuples398pub(crate) fn with_groups(&mut self, groups: GroupPositions) -> &mut Self {399if let AggState::AggregatedList(_) = self.agg_state() {400// In case of new groups, a series always needs to be flattened401self.with_values(self.flat_naive().into_owned(), false, None)402.unwrap();403}404self.groups = Cow::Owned(groups);405// make sure that previous setting is not used406self.update_groups = UpdateGroups::No;407self408}409410/// Ensure that each group is represented by contiguous values in memory.411pub fn normalize_values(&mut self) {412self.set_original_len(false);413self.groups();414let values = self.flat_naive();415let values = unsafe { values.agg_list(&self.groups) };416self.state = AggState::AggregatedList(values);417self.with_update_groups(UpdateGroups::WithGroupsLen);418}419420/// Aggregate into `ListChunked`.421pub fn aggregated_as_list<'b>(&'b mut self) -> Cow<'b, ListChunked> {422self.aggregated();423let out = self.get_values();424match self.agg_state() {425AggState::AggregatedScalar(_) => Cow::Owned(out.as_list()),426_ => Cow::Borrowed(out.list().unwrap()),427}428}429430/// Get the aggregated version of the series.431pub fn aggregated(&mut self) -> Column {432// we clone, because we only want to call `self.groups()` if needed.433// self groups may instantiate new groups and thus can be expensive.434match self.state.clone() {435AggState::NotAggregated(s) => {436// The groups are determined lazily and in case of a flat/non-aggregated437// series we use the groups to aggregate the list438// because this is lazy, we first must to update the groups439// by calling .groups()440self.groups();441#[cfg(debug_assertions)]442{443if self.groups.len() > s.len() {444polars_warn!(445"groups may be out of bounds; more groups than elements in a series is only possible in dynamic group_by"446)447}448}449450// SAFETY:451// groups are in bounds452let out = unsafe { s.agg_list(&self.groups) };453self.state = AggState::AggregatedList(out.clone());454455self.update_groups = UpdateGroups::WithGroupsLen;456out457},458AggState::AggregatedList(s) | AggState::AggregatedScalar(s) => s.into_column(),459AggState::LiteralScalar(s) => {460let rows = self.groups.len();461let s = s.implode().unwrap();462let s = s.new_from_index(0, rows);463let s = s.into_column();464self.state = AggState::AggregatedList(s.clone());465self.with_update_groups(UpdateGroups::WithSeriesLen);466s.clone()467},468}469}470471/// Get the final aggregated version of the series.472pub fn finalize(&mut self) -> Column {473// we clone, because we only want to call `self.groups()` if needed.474// self groups may instantiate new groups and thus can be expensive.475match &self.state {476AggState::LiteralScalar(c) => {477let c = c.clone();478self.groups();479let rows = self.groups.len();480c.new_from_index(0, rows)481},482_ => self.aggregated(),483}484}485486// If a binary or ternary function has both of these branches true, it should487// flatten the list488fn arity_should_explode(&self) -> bool {489use AggState::*;490match self.agg_state() {491LiteralScalar(s) => s.len() == 1,492AggregatedScalar(_) => true,493_ => false,494}495}496497pub fn get_final_aggregation(mut self) -> (Column, Cow<'a, GroupPositions>) {498let _ = self.groups();499let groups = self.groups;500match self.state {501AggState::NotAggregated(c) => (c, groups),502AggState::AggregatedScalar(c) => (c, groups),503AggState::LiteralScalar(c) => (c, groups),504AggState::AggregatedList(c) => {505let flattened = c506.explode(ExplodeOptions {507empty_as_null: false,508keep_nulls: true,509})510.unwrap();511let groups = groups.into_owned();512// unroll the possible flattened state513// say we have groups with overlapping windows:514//515// offset, len516// 0, 1517// 0, 2518// 0, 4519//520// gets aggregation521//522// [0]523// [0, 1],524// [0, 1, 2, 3]525//526// before aggregation the column was527// [0, 1, 2, 3]528// but explode on this list yields529// [0, 0, 1, 0, 1, 2, 3]530//531// so we unroll the groups as532//533// [0, 1]534// [1, 2]535// [3, 4]536let groups = groups.unroll();537(flattened, Cow::Owned(groups))538},539}540}541542/// Get the not-aggregated version of the series.543/// Note that we call it naive, because if a previous expr544/// has filtered or sorted this, this information is in the545/// group tuples not the flattened series.546pub(crate) fn flat_naive(&self) -> Cow<'_, Column> {547match &self.state {548AggState::NotAggregated(c) => Cow::Borrowed(c),549AggState::AggregatedList(c) => {550if cfg!(debug_assertions) {551// Warning, so we find cases where we accidentally explode overlapping groups552// We don't want this as this can create a lot of data553if self.groups.is_overlapping() {554polars_warn!(555"performance - an aggregated list with overlapping groups may consume excessive memory"556)557}558}559560// We should not insert nulls, otherwise the offsets in the groups will not be correct.561Cow::Owned(562c.explode(ExplodeOptions {563empty_as_null: false,564keep_nulls: true,565})566.unwrap(),567)568},569AggState::AggregatedScalar(c) => Cow::Borrowed(c),570AggState::LiteralScalar(c) => Cow::Borrowed(c),571}572}573574fn flat_naive_length(&self) -> usize {575match &self.state {576AggState::NotAggregated(c) => c.len(),577AggState::AggregatedList(c) => c.list().unwrap().inner_length(),578AggState::AggregatedScalar(c) => c.len(),579AggState::LiteralScalar(_) => 1,580}581}582583/// Take the series.584pub(crate) fn take(&mut self) -> Column {585let c = match &mut self.state {586AggState::NotAggregated(c)587| AggState::AggregatedScalar(c)588| AggState::AggregatedList(c) => c,589AggState::LiteralScalar(c) => c,590};591std::mem::take(c)592}593594/// Do the group indices reference all values in the aggregation state.595fn groups_cover_all_values(&mut self) -> bool {596if matches!(597self.state,598AggState::LiteralScalar(_) | AggState::AggregatedScalar(_)599) {600return true;601}602603let num_values = self.flat_naive_length();604match self.groups().as_ref().as_ref() {605GroupsType::Idx(groups) => {606let mut seen = MutableBitmap::from_len_zeroed(num_values);607for (_, g) in groups {608for i in g.iter() {609unsafe { seen.set_unchecked(*i as usize, true) };610}611}612seen.unset_bits() == 0613},614GroupsType::Slice {615groups,616overlapping: true,617monotonic: _,618} => {619// @NOTE: Slice groups are sorted by their `start` value.620let mut offset = 0;621let mut covers_all = true;622for [start, length] in groups {623covers_all &= *start <= offset;624offset = start + length;625}626covers_all && offset == num_values as IdxSize627},628629// If we don't have overlapping data, we can just do a count.630GroupsType::Slice {631groups,632overlapping: false,633monotonic: _,634} => groups.iter().map(|[_, l]| *l as usize).sum::<usize>() == num_values,635}636}637638/// Fixes groups for `AggregatedScalar` and `LiteralScalar` so that they point to valid639/// data elements in the `AggState` values.640fn set_groups_for_undefined_agg_states(&mut self) {641match &self.state {642AggState::AggregatedList(_) | AggState::NotAggregated(_) => {},643AggState::AggregatedScalar(c) => {644assert_eq!(self.update_groups, UpdateGroups::No);645self.groups = Cow::Owned({646let groups = (0..c.len() as IdxSize).map(|i| [i, 1]).collect();647GroupsType::new_slice(groups, false, true).into_sliceable()648});649},650AggState::LiteralScalar(c) => {651assert_eq!(c.len(), 1);652assert_eq!(self.update_groups, UpdateGroups::No);653self.groups = Cow::Owned({654let groups = vec![[0, 1]; self.groups.len()];655GroupsType::new_slice(groups, true, true).into_sliceable()656});657},658}659}660661pub fn into_static(&self) -> AggregationContext<'static> {662let groups: GroupPositions = GroupPositions::to_owned(&self.groups);663let groups: Cow<'static, GroupPositions> = Cow::Owned(groups);664AggregationContext {665state: self.state.clone(),666groups,667update_groups: self.update_groups,668original_len: self.original_len,669}670}671}672673/// Take a DataFrame and evaluate the expressions.674/// Implement this for Column, lt, eq, etc675pub trait PhysicalExpr: Send + Sync {676fn as_expression(&self) -> Option<&Expr> {677None678}679680fn as_column(&self) -> Option<PlSmallStr> {681None682}683684/// Take a DataFrame and evaluate the expression.685fn evaluate(&self, df: &DataFrame, _state: &ExecutionState) -> PolarsResult<Column>;686687/// Some expression that are not aggregations can be done per group688/// Think of sort, slice, filter, shift, etc.689/// defaults to ignoring the group690///691/// This method is called by an aggregation function.692///693/// In case of a simple expr, like 'column', the groups are ignored and the column is returned.694/// In case of an expr where group behavior makes sense, this method is called.695/// For a filter operation for instance, a Series is created per groups and filtered.696///697/// An implementation of this method may apply an aggregation on the groups only. For instance698/// on a shift, the groups are first aggregated to a `ListChunked` and the shift is applied per699/// group. The implementation then has to return the `Series` exploded (because a later aggregation700/// will use the group tuples to aggregate). The group tuples also have to be updated, because701/// aggregation to a list sorts the exploded `Series` by group.702///703/// This has some gotcha's. An implementation may also change the group tuples instead of704/// the `Series`.705///706// we allow this because we pass the vec to the Cow707// Note to self: Don't be smart and dispatch to evaluate as default implementation708// this means filters will be incorrect and lead to invalid results down the line709#[allow(clippy::ptr_arg)]710fn evaluate_on_groups<'a>(711&self,712df: &DataFrame,713groups: &'a GroupPositions,714state: &ExecutionState,715) -> PolarsResult<AggregationContext<'a>>;716717/// Get the output field of this expr718fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field>;719720fn is_literal(&self) -> bool {721false722}723fn is_scalar(&self) -> bool;724}725726impl Display for &dyn PhysicalExpr {727fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {728match self.as_expression() {729None => Ok(()),730Some(e) => write!(f, "{e:?}"),731}732}733}734735/// Wrapper struct that allow us to use a PhysicalExpr in polars-io.736///737/// This is used to filter rows during the scan of file.738pub struct PhysicalIoHelper {739pub expr: Arc<dyn PhysicalExpr>,740pub has_window_function: bool,741}742743impl PhysicalIoExpr for PhysicalIoHelper {744fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {745let mut state: ExecutionState = Default::default();746if self.has_window_function {747state.insert_has_window_function_flag();748}749self.expr.evaluate(df, &state).map(|c| {750// IO expression result should be boolean-typed.751debug_assert_eq!(c.dtype(), &DataType::Boolean);752(if c.len() == 1 && df.height() != 1 {753// filter(lit(True)) will hit here.754c.new_from_index(0, df.height())755} else {756c757})758.take_materialized_series()759})760}761}762763pub fn phys_expr_to_io_expr(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalIoExpr> {764let has_window_function = if let Some(expr) = expr.as_expression() {765expr.into_iter().any(|expr| {766#[cfg(feature = "dynamic_group_by")]767if matches!(expr, Expr::Rolling { .. }) {768return true;769}770771matches!(expr, Expr::Over { .. })772})773} else {774false775};776Arc::new(PhysicalIoHelper {777expr,778has_window_function,779}) as Arc<dyn PhysicalIoExpr>780}781782783