Path: blob/main/crates/polars-plan/src/plans/aexpr/mod.rs
6940 views
mod builder;1mod equality;2mod evaluate;3mod function_expr;4#[cfg(feature = "cse")]5mod hash;6mod minterm_iter;7pub mod predicates;8mod scalar;9mod schema;10mod traverse;1112use std::hash::{Hash, Hasher};1314pub use function_expr::*;15#[cfg(feature = "cse")]16pub(super) use hash::traverse_and_hash_aexpr;17pub use minterm_iter::MintermIter;18use polars_compute::rolling::QuantileMethod;19use polars_core::chunked_array::cast::CastOptions;20use polars_core::prelude::*;21use polars_core::utils::{get_time_units, try_get_supertype};22use polars_utils::arena::{Arena, Node};23pub use scalar::is_scalar_ae;24use strum_macros::IntoStaticStr;25pub use traverse::*;26mod properties;27pub use aexpr::function_expr::schema::FieldsMapper;28pub use builder::AExprBuilder;29pub use properties::*;3031use crate::constants::LEN;32use crate::plans::Context;33use crate::prelude::*;3435#[derive(Clone, Debug, IntoStaticStr)]36#[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))]37pub enum IRAggExpr {38Min {39input: Node,40propagate_nans: bool,41},42Max {43input: Node,44propagate_nans: bool,45},46Median(Node),47NUnique(Node),48First(Node),49Last(Node),50Mean(Node),51Implode(Node),52Quantile {53expr: Node,54quantile: Node,55method: QuantileMethod,56},57Sum(Node),58Count {59input: Node,60include_nulls: bool,61},62Std(Node, u8),63Var(Node, u8),64AggGroups(Node),65}6667impl Hash for IRAggExpr {68fn hash<H: Hasher>(&self, state: &mut H) {69std::mem::discriminant(self).hash(state);70match self {71Self::Min {72input: _,73propagate_nans,74}75| Self::Max {76input: _,77propagate_nans,78} => propagate_nans.hash(state),79Self::Quantile {80method: interpol, ..81} => interpol.hash(state),82Self::Std(_, v) | Self::Var(_, v) => v.hash(state),83Self::Count {84input: _,85include_nulls,86} => include_nulls.hash(state),87_ => {},88}89}90}9192impl IRAggExpr {93pub(super) fn equal_nodes(&self, other: &IRAggExpr) -> bool {94use IRAggExpr::*;95match (self, other) {96(97Min {98propagate_nans: l, ..99},100Min {101propagate_nans: r, ..102},103) => l == r,104(105Max {106propagate_nans: l, ..107},108Max {109propagate_nans: r, ..110},111) => l == r,112(Quantile { method: l, .. }, Quantile { method: r, .. }) => l == r,113(Std(_, l), Std(_, r)) => l == r,114(Var(_, l), Var(_, r)) => l == r,115_ => std::mem::discriminant(self) == std::mem::discriminant(other),116}117}118}119120impl From<IRAggExpr> for GroupByMethod {121fn from(value: IRAggExpr) -> Self {122use IRAggExpr::*;123match value {124Min {125input: _,126propagate_nans,127} => {128if propagate_nans {129GroupByMethod::NanMin130} else {131GroupByMethod::Min132}133},134Max {135input: _,136propagate_nans,137} => {138if propagate_nans {139GroupByMethod::NanMax140} else {141GroupByMethod::Max142}143},144Median(_) => GroupByMethod::Median,145NUnique(_) => GroupByMethod::NUnique,146First(_) => GroupByMethod::First,147Last(_) => GroupByMethod::Last,148Mean(_) => GroupByMethod::Mean,149Implode(_) => GroupByMethod::Implode,150Sum(_) => GroupByMethod::Sum,151Count {152input: _,153include_nulls,154} => GroupByMethod::Count { include_nulls },155Std(_, ddof) => GroupByMethod::Std(ddof),156Var(_, ddof) => GroupByMethod::Var(ddof),157AggGroups(_) => GroupByMethod::Groups,158Quantile { .. } => unreachable!(),159}160}161}162163/// IR expression node that is allocated in an [`Arena`][polars_utils::arena::Arena].164#[derive(Clone, Debug, Default)]165#[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))]166pub enum AExpr {167Explode {168expr: Node,169skip_empty: bool,170},171Column(PlSmallStr),172Literal(LiteralValue),173BinaryExpr {174left: Node,175op: Operator,176right: Node,177},178Cast {179expr: Node,180dtype: DataType,181options: CastOptions,182},183Sort {184expr: Node,185options: SortOptions,186},187Gather {188expr: Node,189idx: Node,190returns_scalar: bool,191},192SortBy {193expr: Node,194by: Vec<Node>,195sort_options: SortMultipleOptions,196},197Filter {198input: Node,199by: Node,200},201Agg(IRAggExpr),202Ternary {203predicate: Node,204truthy: Node,205falsy: Node,206},207AnonymousFunction {208input: Vec<ExprIR>,209function: OpaqueColumnUdf,210options: FunctionOptions,211fmt_str: Box<PlSmallStr>,212},213/// Evaluates the `evaluation` expression on the output of the `expr`.214///215/// Consequently, `expr` is an input and `evaluation` is not and needs a different schema.216Eval {217expr: Node,218219/// An expression that is guaranteed to not contain any column reference beyond220/// `pl.element()` which refers to `pl.col("")`.221evaluation: Node,222223variant: EvalVariant,224},225Function {226/// Function arguments227/// Some functions rely on aliases,228/// for instance assignment of struct fields.229/// Therefor we need [`ExprIr`].230input: Vec<ExprIR>,231/// function to apply232function: IRFunctionExpr,233options: FunctionOptions,234},235Window {236function: Node,237partition_by: Vec<Node>,238order_by: Option<(Node, SortOptions)>,239options: WindowType,240},241Slice {242input: Node,243offset: Node,244length: Node,245},246#[default]247Len,248}249250impl AExpr {251#[cfg(feature = "cse")]252pub(crate) fn col(name: PlSmallStr) -> Self {253AExpr::Column(name)254}255256/// This should be a 1 on 1 copy of the get_type method of Expr until Expr is completely phased out.257pub fn get_dtype(&self, schema: &Schema, arena: &Arena<AExpr>) -> PolarsResult<DataType> {258self.to_field(schema, arena).map(|f| f.dtype().clone())259}260261#[recursive::recursive]262pub fn is_scalar(&self, arena: &Arena<AExpr>) -> bool {263match self {264AExpr::Literal(lv) => lv.is_scalar(),265AExpr::Function { options, input, .. }266| AExpr::AnonymousFunction { options, input, .. } => {267if options.flags.contains(FunctionFlags::RETURNS_SCALAR) {268true269} else if options.is_elementwise()270|| options.flags.contains(FunctionFlags::LENGTH_PRESERVING)271{272input.iter().all(|e| e.is_scalar(arena))273} else {274false275}276},277AExpr::BinaryExpr { left, right, .. } => {278is_scalar_ae(*left, arena) && is_scalar_ae(*right, arena)279},280AExpr::Ternary {281predicate,282truthy,283falsy,284} => {285is_scalar_ae(*predicate, arena)286&& is_scalar_ae(*truthy, arena)287&& is_scalar_ae(*falsy, arena)288},289AExpr::Agg(_) | AExpr::Len => true,290AExpr::Cast { expr, .. } => is_scalar_ae(*expr, arena),291AExpr::Eval { expr, variant, .. } => match variant {292EvalVariant::List => is_scalar_ae(*expr, arena),293EvalVariant::Cumulative { .. } => is_scalar_ae(*expr, arena),294},295AExpr::Sort { expr, .. } => is_scalar_ae(*expr, arena),296AExpr::Gather { returns_scalar, .. } => *returns_scalar,297AExpr::SortBy { expr, .. } => is_scalar_ae(*expr, arena),298AExpr::Window { function, .. } => is_scalar_ae(*function, arena),299AExpr::Explode { .. }300| AExpr::Column(_)301| AExpr::Filter { .. }302| AExpr::Slice { .. } => false,303}304}305}306307308