Path: blob/main/crates/polars-plan/src/plans/aexpr/mod.rs
8424 views
mod builder;1mod equality;2mod evaluate;3mod function_expr;4#[cfg(feature = "cse")]5mod hash;6mod minterm_iter;7pub mod predicates;8mod scalar;9mod schema;10mod traverse;1112use std::hash::{Hash, Hasher};1314pub use function_expr::*;15#[cfg(feature = "cse")]16pub(super) use hash::traverse_and_hash_aexpr;17pub use minterm_iter::MintermIter;18use polars_compute::rolling::QuantileMethod;19use polars_core::chunked_array::cast::CastOptions;20use polars_core::prelude::*;21use polars_core::utils::{get_time_units, try_get_supertype};22use polars_utils::arena::{Arena, Node};23pub use scalar::{is_length_preserving_ae, is_scalar_ae};24use strum_macros::IntoStaticStr;25pub use traverse::*;26mod properties;27pub use aexpr::function_expr::schema::FieldsMapper;28pub use builder::AExprBuilder;29pub use evaluate::{constant_evaluate, into_column};30pub use properties::*;31pub use schema::ToFieldContext;3233use crate::constants::LEN;34use crate::prelude::*;3536#[derive(Clone, Debug, IntoStaticStr)]37#[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))]38pub enum IRAggExpr {39Min {40input: Node,41propagate_nans: bool,42},43Max {44input: Node,45propagate_nans: bool,46},47Median(Node),48NUnique(Node),49Item {50input: Node,51/// Return a missing value if there are no values.52allow_empty: bool,53},54First(Node),55FirstNonNull(Node),56Last(Node),57LastNonNull(Node),58Mean(Node),59Implode(Node),60Quantile {61expr: Node,62quantile: Node,63method: QuantileMethod,64},65Sum(Node),66Count {67input: Node,68include_nulls: bool,69},70Std(Node, u8),71Var(Node, u8),72AggGroups(Node),73}7475impl Hash for IRAggExpr {76fn hash<H: Hasher>(&self, state: &mut H) {77std::mem::discriminant(self).hash(state);78match self {79Self::Min {80input: _,81propagate_nans,82}83| Self::Max {84input: _,85propagate_nans,86} => propagate_nans.hash(state),87Self::Quantile {88method: interpol, ..89} => interpol.hash(state),90Self::Std(_, v) | Self::Var(_, v) => v.hash(state),91Self::Count {92input: _,93include_nulls,94} => include_nulls.hash(state),95_ => {},96}97}98}99100impl IRAggExpr {101pub(super) fn equal_nodes(&self, other: &IRAggExpr) -> bool {102use IRAggExpr::*;103match (self, other) {104(105Min {106propagate_nans: l, ..107},108Min {109propagate_nans: r, ..110},111) => l == r,112(113Max {114propagate_nans: l, ..115},116Max {117propagate_nans: r, ..118},119) => l == r,120(Quantile { method: l, .. }, Quantile { method: r, .. }) => l == r,121(Std(_, l), Std(_, r)) => l == r,122(Var(_, l), Var(_, r)) => l == r,123_ => std::mem::discriminant(self) == std::mem::discriminant(other),124}125}126}127128impl From<IRAggExpr> for GroupByMethod {129fn from(value: IRAggExpr) -> Self {130use IRAggExpr::*;131match value {132Min {133input: _,134propagate_nans,135} => {136if propagate_nans {137GroupByMethod::NanMin138} else {139GroupByMethod::Min140}141},142Max {143input: _,144propagate_nans,145} => {146if propagate_nans {147GroupByMethod::NanMax148} else {149GroupByMethod::Max150}151},152Median(_) => GroupByMethod::Median,153NUnique(_) => GroupByMethod::NUnique,154First(_) => GroupByMethod::First,155FirstNonNull(_) => GroupByMethod::FirstNonNull,156Last(_) => GroupByMethod::Last,157LastNonNull(_) => GroupByMethod::LastNonNull,158Item { allow_empty, .. } => GroupByMethod::Item { allow_empty },159Mean(_) => GroupByMethod::Mean,160Implode(_) => GroupByMethod::Implode,161Sum(_) => GroupByMethod::Sum,162Count {163input: _,164include_nulls,165} => GroupByMethod::Count { include_nulls },166Std(_, ddof) => GroupByMethod::Std(ddof),167Var(_, ddof) => GroupByMethod::Var(ddof),168AggGroups(_) => GroupByMethod::Groups,169// Multi-input aggregations.170Quantile { .. } => unreachable!(),171}172}173}174175/// IR expression node that is allocated in an [`Arena`][polars_utils::arena::Arena].176#[derive(Clone, Debug, Default)]177#[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))]178pub enum AExpr {179/// Values in a `eval` context.180///181/// Equivalent of `pl.element()`.182Element,183Explode {184expr: Node,185options: ExplodeOptions,186},187Column(PlSmallStr),188/// Struct field value in a `struct.with_fields` context.189///190/// Equivalent of `pl.field(name)`.191#[cfg(feature = "dtype-struct")]192StructField(PlSmallStr),193Literal(LiteralValue),194BinaryExpr {195left: Node,196op: Operator,197right: Node,198},199Cast {200expr: Node,201dtype: DataType,202options: CastOptions,203},204Sort {205expr: Node,206options: SortOptions,207},208Gather {209expr: Node,210idx: Node,211returns_scalar: bool,212null_on_oob: bool,213},214SortBy {215expr: Node,216by: Vec<Node>,217sort_options: SortMultipleOptions,218},219Filter {220input: Node,221by: Node,222},223Agg(IRAggExpr),224Ternary {225predicate: Node,226truthy: Node,227falsy: Node,228},229AnonymousAgg {230input: Vec<ExprIR>,231fmt_str: Box<PlSmallStr>,232function: OpaqueStreamingAgg,233},234AnonymousFunction {235input: Vec<ExprIR>,236function: OpaqueColumnUdf,237options: FunctionOptions,238fmt_str: Box<PlSmallStr>,239},240/// Evaluates the `evaluation` expression on the output of the `expr`.241///242/// Consequently, `expr` is an input and `evaluation` is not and needs a different schema.243Eval {244expr: Node,245246/// An expression that is guaranteed to not contain any column reference beyond247/// `pl.element()` which refers to `pl.col("")`.248evaluation: Node,249250variant: EvalVariant,251},252#[cfg(feature = "dtype-struct")]253StructEval {254expr: Node,255evaluation: Vec<ExprIR>,256},257Function {258/// Function arguments259/// Some functions rely on aliases,260/// for instance assignment of struct fields.261/// Therefor we need [`ExprIr`].262input: Vec<ExprIR>,263/// function to apply264function: IRFunctionExpr,265options: FunctionOptions,266},267Over {268function: Node,269partition_by: Vec<Node>,270order_by: Option<(Node, SortOptions)>,271mapping: WindowMapping,272},273#[cfg(feature = "dynamic_group_by")]274Rolling {275function: Node,276index_column: Node,277period: Duration,278offset: Duration,279closed_window: ClosedWindow,280},281Slice {282input: Node,283offset: Node,284length: Node,285},286#[default]287Len,288}289290impl AExpr {291#[cfg(feature = "cse")]292pub(crate) fn col(name: PlSmallStr) -> Self {293AExpr::Column(name)294}295296#[recursive::recursive]297pub fn is_scalar(&self, arena: &Arena<AExpr>) -> bool {298match self {299AExpr::Element => false,300AExpr::Literal(lv) => lv.is_scalar(),301AExpr::Function { options, input, .. }302| AExpr::AnonymousFunction { options, input, .. } => {303if options.flags.contains(FunctionFlags::RETURNS_SCALAR) {304true305} else if options.is_elementwise()306|| options.flags.contains(FunctionFlags::LENGTH_PRESERVING)307{308input.iter().all(|e| e.is_scalar(arena))309} else {310false311}312},313AExpr::BinaryExpr { left, right, .. } => {314is_scalar_ae(*left, arena) && is_scalar_ae(*right, arena)315},316AExpr::Ternary {317predicate,318truthy,319falsy,320} => {321is_scalar_ae(*predicate, arena)322&& is_scalar_ae(*truthy, arena)323&& is_scalar_ae(*falsy, arena)324},325AExpr::Agg(_) | AExpr::AnonymousAgg { .. } | AExpr::Len => true,326AExpr::Cast { expr, .. } => is_scalar_ae(*expr, arena),327AExpr::Eval { expr, variant, .. } => {328variant.is_length_preserving() && is_scalar_ae(*expr, arena)329},330#[cfg(feature = "dtype-struct")]331AExpr::StructEval { expr, .. } => is_scalar_ae(*expr, arena),332AExpr::Sort { expr, .. } => is_scalar_ae(*expr, arena),333AExpr::Gather { returns_scalar, .. } => *returns_scalar,334AExpr::SortBy { expr, .. } => is_scalar_ae(*expr, arena),335336// Over and Rolling implicitly zip with the context and thus are never scalars337AExpr::Over { .. } => false,338#[cfg(feature = "dynamic_group_by")]339AExpr::Rolling { .. } => false,340341AExpr::Explode { .. }342| AExpr::Column(_)343| AExpr::Filter { .. }344| AExpr::Slice { .. } => false,345#[cfg(feature = "dtype-struct")]346AExpr::StructField(_) => false,347}348}349350#[recursive::recursive]351pub fn is_length_preserving(&self, arena: &Arena<AExpr>) -> bool {352fn broadcasting_input_length_preserving(353n: impl IntoIterator<Item = Node>,354arena: &Arena<AExpr>,355) -> bool {356let mut num_items = 0;357let mut num_length_preserving = 0;358let mut num_scalar_or_length_preserving = 0;359360for n in n {361num_items += 1;362363if is_length_preserving_ae(n, arena) {364num_length_preserving += 1;365num_scalar_or_length_preserving += 1;366} else if is_scalar_ae(n, arena) {367num_scalar_or_length_preserving += 1;368}369}370371num_length_preserving > 0 && num_scalar_or_length_preserving == num_items372}373374match self {375AExpr::Element => true,376AExpr::Column(_) => true,377#[cfg(feature = "dtype-struct")]378AExpr::StructField(_) => true,379380// Over and Rolling implicitly zip with the context and thus should always be length381// preserving382AExpr::Over { mapping, .. } => !matches!(mapping, WindowMapping::Explode),383#[cfg(feature = "dynamic_group_by")]384AExpr::Rolling { .. } => true,385386AExpr::AnonymousAgg { .. } | AExpr::Literal(_) | AExpr::Agg(_) | AExpr::Len => false,387AExpr::Function { options, input, .. }388| AExpr::AnonymousFunction { options, input, .. } => {389if options.flags.is_elementwise() {390broadcasting_input_length_preserving(input.iter().map(|e| e.node()), arena)391} else if options.flags.is_length_preserving() {392input.iter().all(|e| e.is_length_preserving(arena))393} else {394false395}396},397AExpr::BinaryExpr { left, right, .. } => {398broadcasting_input_length_preserving([*left, *right], arena)399},400AExpr::Ternary {401predicate,402truthy,403falsy,404} => broadcasting_input_length_preserving([*predicate, *truthy, *falsy], arena),405AExpr::Cast { expr, .. } => is_length_preserving_ae(*expr, arena),406AExpr::Eval { expr, variant, .. } => {407variant.is_length_preserving() && is_length_preserving_ae(*expr, arena)408},409#[cfg(feature = "dtype-struct")]410AExpr::StructEval { expr, .. } => is_length_preserving_ae(*expr, arena),411AExpr::Sort { expr, .. } => is_length_preserving_ae(*expr, arena),412AExpr::Gather {413expr: _,414idx,415returns_scalar,416null_on_oob: _,417} => !returns_scalar && is_length_preserving_ae(*idx, arena),418AExpr::SortBy { expr, by, .. } => broadcasting_input_length_preserving(419std::iter::once(*expr).chain(by.iter().copied()),420arena,421),422423AExpr::Explode { .. } | AExpr::Filter { .. } | AExpr::Slice { .. } => false,424}425}426427/// Is the top-level expression fallible based on the data values.428pub fn is_fallible_top_level(&self, arena: &Arena<AExpr>) -> bool {429#[allow(clippy::collapsible_match, clippy::match_like_matches_macro)]430match self {431AExpr::Function {432input, function, ..433} => match function {434IRFunctionExpr::ListExpr(f) => match f {435IRListFunction::Get(false) => true,436#[cfg(feature = "list_gather")]437IRListFunction::Gather(false) => true,438_ => false,439},440#[cfg(feature = "dtype-array")]441IRFunctionExpr::ArrayExpr(f) => match f {442IRArrayFunction::Get(false) => true,443_ => false,444},445#[cfg(feature = "replace")]446IRFunctionExpr::ReplaceStrict { .. } => true,447#[cfg(all(feature = "strings", feature = "temporal"))]448IRFunctionExpr::StringExpr(f) => match f {449IRStringFunction::Strptime(_, strptime_options) => {450debug_assert!(input.len() <= 2);451452let ambiguous_arg_is_infallible_scalar = input453.get(1)454.map(|x| arena.get(x.node()))455.is_some_and(|ae| match ae {456AExpr::Literal(lv) => {457lv.extract_str().is_some_and(|ambiguous| match ambiguous {458"earliest" | "latest" | "null" => true,459"raise" => false,460v => {461if cfg!(debug_assertions) {462panic!("unhandled parameter to ambiguous: {v}")463}464false465},466})467},468_ => false,469});470471let ambiguous_is_fallible = !ambiguous_arg_is_infallible_scalar;472473!matches!(arena.get(input[0].node()), AExpr::Literal(_))474&& (strptime_options.strict || ambiguous_is_fallible)475},476_ => false,477},478_ => false,479},480AExpr::Cast {481expr,482dtype: _,483options: CastOptions::Strict,484} => !matches!(arena.get(*expr), AExpr::Literal(_)),485_ => false,486}487}488}489490#[recursive::recursive]491pub fn deep_clone_ae(ae: Node, arena: &mut Arena<AExpr>) -> Node {492let slf = arena.get(ae).clone();493494let mut children = vec![];495slf.children_rev(&mut children);496for child in &mut children {497*child = deep_clone_ae(*child, arena);498}499children.reverse();500501arena.add(slf.replace_children(&children))502}503504505