Path: blob/main/crates/polars-plan/src/plans/aexpr/properties/general.rs
8446 views
use polars_utils::idx_vec::UnitVec;1use polars_utils::unitvec;23use super::super::*;45impl AExpr {6pub(crate) fn is_leaf(&self) -> bool {7matches!(self, AExpr::Column(_) | AExpr::Literal(_) | AExpr::Len)8}910pub(crate) fn is_col(&self) -> bool {11matches!(self, AExpr::Column(_))12}1314/// Checks whether this expression is elementwise. This only checks the top level expression.15pub(crate) fn is_elementwise_top_level(&self) -> bool {16use AExpr::*;1718match self {19AnonymousFunction { options, .. } => options.is_elementwise(),2021Function { options, .. } => options.is_elementwise(),2223Literal(v) => v.is_scalar(),2425Eval { variant, .. } => variant.is_elementwise(),2627Element | BinaryExpr { .. } | Column(_) | Ternary { .. } | Cast { .. } => true,2829#[cfg(feature = "dtype-struct")]30StructEval { .. } | StructField(_) => true,3132#[cfg(feature = "dynamic_group_by")]33Rolling { .. } => false,3435Agg { .. }36| AnonymousAgg { .. }37| Explode { .. }38| Filter { .. }39| Gather { .. }40| Len41| Slice { .. }42| Sort { .. }43| SortBy { .. }44| Over { .. } => false,45}46}4748/// Checks whether this expression is row-separable. This only checks the top level expression.49pub(crate) fn is_row_separable_top_level(&self) -> bool {50use AExpr::*;5152match self {53AnonymousFunction { options, .. } => options.is_row_separable(),54Function { options, .. } => options.is_row_separable(),55Literal(v) => v.is_scalar(),56Explode { .. } | Filter { .. } => true,57_ => self.is_elementwise_top_level(),58}59}6061pub(crate) fn does_not_modify_top_level(&self) -> bool {62match self {63AExpr::Column(_) => true,64AExpr::Function { function, .. } => {65matches!(function, IRFunctionExpr::SetSortedFlag(_))66},67_ => false,68}69}70}7172// Traversal utilities73fn property_and_traverse<F>(stack: &mut UnitVec<Node>, ae: &AExpr, property: F) -> bool74where75F: Fn(&AExpr) -> bool,76{77if !property(ae) {78return false;79}80ae.inputs_rev(stack);81true82}8384fn property_rec<F>(node: Node, expr_arena: &Arena<AExpr>, property: F) -> bool85where86F: Fn(&mut UnitVec<Node>, &AExpr, &Arena<AExpr>) -> bool,87{88let mut stack = unitvec![];89let mut ae = expr_arena.get(node);9091loop {92if !property(&mut stack, ae, expr_arena) {93return false;94}9596let Some(node) = stack.pop() else {97break;98};99100ae = expr_arena.get(node);101}102103true104}105106/// Checks if the top-level expression node does not modify. If this is the case, then `stack` will107/// be extended further with any nested expression nodes.108fn does_not_modify(stack: &mut UnitVec<Node>, ae: &AExpr, _expr_arena: &Arena<AExpr>) -> bool {109property_and_traverse(stack, ae, |ae| ae.does_not_modify_top_level())110}111112pub fn does_not_modify_rec(node: Node, expr_arena: &Arena<AExpr>) -> bool {113property_rec(node, expr_arena, does_not_modify)114}115116pub fn is_prop<P: Fn(&AExpr) -> bool>(117stack: &mut UnitVec<Node>,118ae: &AExpr,119expr_arena: &Arena<AExpr>,120prop_top_level: P,121) -> bool {122use AExpr::*;123124if !prop_top_level(ae) {125return false;126}127128match ae {129// Literals that aren't being projected are allowed to be non-scalar, so we don't add them130// for inspection. (e.g. `is_in(<literal>)`).131#[cfg(feature = "is_in")]132Function {133function: IRFunctionExpr::Boolean(IRBooleanFunction::IsIn { .. }),134input,135..136} => (|| {137if let Some(rhs) = input.get(1) {138assert_eq!(input.len(), 2); // A.is_in(B)139let rhs = rhs.node();140141if matches!(expr_arena.get(rhs), AExpr::Literal { .. }) {142stack.extend([input[0].node()]);143return;144}145};146ae.inputs_rev(stack);147})(),148_ => {149ae.inputs_rev(stack);150},151}152153true154}155156/// Checks if the top-level expression node is elementwise. If this is the case, then `stack` will157/// be extended further with any nested expression nodes.158pub fn is_elementwise(stack: &mut UnitVec<Node>, ae: &AExpr, expr_arena: &Arena<AExpr>) -> bool {159is_prop(stack, ae, expr_arena, |ae| ae.is_elementwise_top_level())160}161162pub fn all_elementwise<'a, N>(nodes: &'a [N], expr_arena: &Arena<AExpr>) -> bool163where164Node: From<&'a N>,165{166nodes167.iter()168.all(|n| is_elementwise_rec(n.into(), expr_arena))169}170171/// Recursive variant of `is_elementwise`172pub fn is_elementwise_rec(node: Node, expr_arena: &Arena<AExpr>) -> bool {173property_rec(node, expr_arena, is_elementwise)174}175176/// Checks if the top-level expression node is row-separable. If this is the case, then `stack` will177/// be extended further with any nested expression nodes.178pub fn is_row_separable(stack: &mut UnitVec<Node>, ae: &AExpr, expr_arena: &Arena<AExpr>) -> bool {179is_prop(stack, ae, expr_arena, |ae| ae.is_row_separable_top_level())180}181182pub fn all_row_separable<'a, N>(nodes: &'a [N], expr_arena: &Arena<AExpr>) -> bool183where184Node: From<&'a N>,185{186nodes187.iter()188.all(|n| is_row_separable_rec(n.into(), expr_arena))189}190191/// Recursive variant of `is_row_separable`192pub fn is_row_separable_rec(node: Node, expr_arena: &Arena<AExpr>) -> bool {193property_rec(node, expr_arena, is_row_separable)194}195196#[derive(Debug, Clone)]197pub enum ExprPushdownGroup {198/// Can be pushed. (elementwise, infallible)199///200/// e.g. non-strict cast201Pushable,202/// Cannot be pushed, but doesn't block pushables. (elementwise, fallible)203///204/// Fallible expressions are categorized into this group rather than the Barrier group. The205/// effect of this means we push more predicates, but the expression may no longer error206/// if the problematic rows are filtered out.207///208/// e.g. strict-cast, list.get(null_on_oob=False), to_datetime(strict=True)209Fallible,210/// Cannot be pushed, and blocks all expressions at the current level. (non-elementwise)211///212/// e.g. sort()213Barrier,214}215216impl ExprPushdownGroup {217/// Note:218/// * `stack` is not extended with any nodes if a barrier expression is seen.219/// * This function is not recursive - the caller should repeatedly220/// call this function with the `stack` to perform a recursive check.221pub fn update_with_expr(222&mut self,223stack: &mut UnitVec<Node>,224ae: &AExpr,225expr_arena: &Arena<AExpr>,226) -> &mut Self {227match self {228ExprPushdownGroup::Pushable | ExprPushdownGroup::Fallible => {229// Downgrade to unpushable if fallible230if ae.is_fallible_top_level(expr_arena) {231*self = ExprPushdownGroup::Fallible;232}233234// Downgrade to barrier if non-elementwise235if !is_elementwise(stack, ae, expr_arena) {236*self = ExprPushdownGroup::Barrier237}238},239240ExprPushdownGroup::Barrier => {},241}242243self244}245246pub fn update_with_expr_rec<'a>(247&mut self,248mut ae: &'a AExpr,249expr_arena: &'a Arena<AExpr>,250scratch: Option<&mut UnitVec<Node>>,251) -> &mut Self {252let mut local_scratch = unitvec![];253let stack = scratch.unwrap_or(&mut local_scratch);254255loop {256self.update_with_expr(stack, ae, expr_arena);257258if let ExprPushdownGroup::Barrier = self {259return self;260}261262let Some(node) = stack.pop() else {263break;264};265266ae = expr_arena.get(node);267}268269self270}271272pub fn blocks_pushdown(&self, maintain_errors: bool) -> bool {273match self {274ExprPushdownGroup::Barrier => true,275ExprPushdownGroup::Fallible => maintain_errors,276ExprPushdownGroup::Pushable => false,277}278}279}280281pub fn can_pre_agg_exprs(282exprs: &[ExprIR],283expr_arena: &Arena<AExpr>,284_input_schema: &Schema,285) -> bool {286exprs287.iter()288.all(|e| can_pre_agg(e.node(), expr_arena, _input_schema))289}290291/// Checks whether an expression can be pre-aggregated in a group-by. Note that this also must be292/// implemented physically, so this isn't a complete list.293pub fn can_pre_agg(agg: Node, expr_arena: &Arena<AExpr>, _input_schema: &Schema) -> bool {294let aexpr = expr_arena.get(agg);295296match aexpr {297AExpr::Len => true,298AExpr::Column(_) | AExpr::Literal(_) => false,299// We only allow expressions that end with an aggregation.300AExpr::Agg(_) => {301let has_aggregation =302|node: Node| has_aexpr(node, expr_arena, |ae| matches!(ae, AExpr::Agg(_)));303304// check if the aggregation type is partitionable305// only simple aggregation like col().sum306// that can be divided in to the aggregation of their partitions are allowed307let can_partition = (expr_arena).iter(agg).all(|(_, ae)| {308use AExpr::*;309match ae {310// struct is needed to keep both states311#[cfg(feature = "dtype-struct")]312Agg(IRAggExpr::Mean(_)) => {313// only numeric means for now.314// logical types seem to break because of casts to float.315matches!(316expr_arena317.get(agg)318.to_dtype(&ToFieldContext::new(expr_arena, _input_schema))319.map(|dt| { dt.is_primitive_numeric() }),320Ok(true)321)322},323// only allowed expressions324Agg(agg_e) => {325matches!(326agg_e,327IRAggExpr::Min { .. }328| IRAggExpr::Max { .. }329| IRAggExpr::Sum(_)330| IRAggExpr::Last(_)331| IRAggExpr::First(_)332| IRAggExpr::Count {333input: _,334include_nulls: true335}336)337},338Function { input, options, .. } => {339options.is_elementwise()340&& input.len() == 1341&& !has_aggregation(input[0].node())342},343BinaryExpr { left, right, .. } => {344!has_aggregation(*left) && !has_aggregation(*right)345},346Ternary {347truthy,348falsy,349predicate,350..351} => {352!has_aggregation(*truthy)353&& !has_aggregation(*falsy)354&& !has_aggregation(*predicate)355},356Literal(lv) => lv.is_scalar(),357Column(_) | Len | Cast { .. } => true,358_ => false,359}360});361362#[cfg(feature = "object")]363{364for name in aexpr_to_leaf_names(agg, expr_arena) {365let dtype = _input_schema.get(&name).unwrap();366367if let DataType::Object(_) = dtype {368return false;369}370}371}372can_partition373},374_ => false,375}376}377378/// Identifies columns that are guaranteed to be non-NULL after applying this filter.379///380/// This is conservative in that it will not give false positives, but may not identify all columns.381///382/// Note, this must be called with the root node of filter expressions (the root nodes after splitting383/// with MintermIter is also allowed).384pub(crate) fn predicate_non_null_column_outputs(385predicate_node: Node,386expr_arena: &Arena<AExpr>,387non_null_column_callback: &mut dyn FnMut(&PlSmallStr),388) {389let mut minterm_iter = MintermIter::new(predicate_node, expr_arena);390let stack: &mut UnitVec<Node> = &mut unitvec![];391392/// Only traverse the first input, e.g. `A.is_in(B)` we don't consider B.393macro_rules! traverse_first_input {394// &[ExprIR]395($inputs:expr) => {{396if let Some(expr_ir) = $inputs.first() {397stack.push(expr_ir.node())398}399400false401}};402}403404loop {405use AExpr::*;406407let node = if let Some(node) = stack.pop() {408node409} else if let Some(minterm_node) = minterm_iter.next() {410// Some additional leaf exprs can be pruned.411match expr_arena.get(minterm_node) {412Function {413input,414function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNotNull),415options: _,416} if !input.is_empty() => input.first().unwrap().node(),417418Function {419input,420function: IRFunctionExpr::Boolean(IRBooleanFunction::Not),421options: _,422} if !input.is_empty() => match expr_arena.get(input.first().unwrap().node()) {423Function {424input,425function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNull),426options: _,427} if !input.is_empty() => input.first().unwrap().node(),428429_ => minterm_node,430},431432_ => minterm_node,433}434} else {435break;436};437438let ae = expr_arena.get(node);439440// This match we traverse a subset of the operations that are guaranteed to maintain NULLs.441//442// This must not catch any operations that materialize NULLs, as otherwise e.g.443// `e.fill_null(False) >= False` will include NULLs444let traverse_all_inputs = match ae {445BinaryExpr {446left: _,447op,448right: _,449} => {450use Operator::*;451452match op {453Eq | NotEq | Lt | LtEq | Gt | GtEq | Plus | Minus | Multiply | RustDivide454| TrueDivide | FloorDivide | Modulus | Xor => true,455456// These can turn NULLs into true/false. E.g.:457// * (L & False) >= False becomes True458// * L | True becomes True459EqValidity | NotEqValidity | Or | LogicalOr | And | LogicalAnd => false,460}461},462463Cast { dtype, .. } => {464// Forbid nested types, it's currently buggy:465// >>> pl.select(a=pl.lit(None), b=pl.lit(None).cast(pl.Struct({})))466// | a | b |467// | --- | --- |468// | null | struct[0] |469// |------|-----------|470// | null | {} |471//472// (issue at https://github.com/pola-rs/polars/issues/23276)473!dtype.is_nested()474},475476Function {477input,478function: _,479options,480} => {481if options482.flags483.contains(FunctionFlags::PRESERVES_NULL_FIRST_INPUT)484{485traverse_first_input!(input)486} else {487options488.flags489.contains(FunctionFlags::PRESERVES_NULL_ALL_INPUTS)490}491},492493Column(name) => {494non_null_column_callback(name);495false496},497498_ => false,499};500501if traverse_all_inputs {502ae.inputs_rev(stack);503}504}505}506507508