Path: blob/main/crates/polars-plan/src/plans/aexpr/properties/general.rs
6940 views
use polars_utils::idx_vec::UnitVec;1use polars_utils::unitvec;23use super::super::*;45impl AExpr {6pub(crate) fn is_leaf(&self) -> bool {7matches!(self, AExpr::Column(_) | AExpr::Literal(_) | AExpr::Len)8}910pub(crate) fn is_col(&self) -> bool {11matches!(self, AExpr::Column(_))12}1314/// Checks whether this expression is elementwise. This only checks the top level expression.15pub(crate) fn is_elementwise_top_level(&self) -> bool {16use AExpr::*;1718match self {19AnonymousFunction { options, .. } => options.is_elementwise(),2021Function { options, .. } => options.is_elementwise(),2223Literal(v) => v.is_scalar(),2425Eval { variant, .. } => match variant {26EvalVariant::List => true,27EvalVariant::Cumulative { min_samples: _ } => false,28},2930BinaryExpr { .. } | Column(_) | Ternary { .. } | Cast { .. } => true,3132Agg { .. }33| Explode { .. }34| Filter { .. }35| Gather { .. }36| Len37| Slice { .. }38| Sort { .. }39| SortBy { .. }40| Window { .. } => false,41}42}4344/// Checks whether this expression is row-separable. This only checks the top level expression.45pub(crate) fn is_row_separable_top_level(&self) -> bool {46use AExpr::*;4748match self {49AnonymousFunction { options, .. } => options.is_row_separable(),50Function { options, .. } => options.is_row_separable(),51Literal(v) => v.is_scalar(),52Explode { .. } | Filter { .. } => true,53_ => self.is_elementwise_top_level(),54}55}5657pub(crate) fn does_not_modify_top_level(&self) -> bool {58match self {59AExpr::Column(_) => true,60AExpr::Function { function, .. } => {61matches!(function, IRFunctionExpr::SetSortedFlag(_))62},63_ => false,64}65}66}6768// Traversal utilities69fn property_and_traverse<F>(stack: &mut UnitVec<Node>, ae: &AExpr, property: F) -> bool70where71F: Fn(&AExpr) -> bool,72{73if !property(ae) {74return false;75}76ae.inputs_rev(stack);77true78}7980fn property_rec<F>(node: Node, expr_arena: &Arena<AExpr>, property: F) -> bool81where82F: Fn(&mut UnitVec<Node>, &AExpr, &Arena<AExpr>) -> bool,83{84let mut stack = unitvec![];85let mut ae = expr_arena.get(node);8687loop {88if !property(&mut stack, ae, expr_arena) {89return false;90}9192let Some(node) = stack.pop() else {93break;94};9596ae = expr_arena.get(node);97}9899true100}101102/// Checks if the top-level expression node does not modify. If this is the case, then `stack` will103/// be extended further with any nested expression nodes.104fn does_not_modify(stack: &mut UnitVec<Node>, ae: &AExpr, _expr_arena: &Arena<AExpr>) -> bool {105property_and_traverse(stack, ae, |ae| ae.does_not_modify_top_level())106}107108pub fn does_not_modify_rec(node: Node, expr_arena: &Arena<AExpr>) -> bool {109property_rec(node, expr_arena, does_not_modify)110}111112pub fn is_prop<P: Fn(&AExpr) -> bool>(113stack: &mut UnitVec<Node>,114ae: &AExpr,115expr_arena: &Arena<AExpr>,116prop_top_level: P,117) -> bool {118use AExpr::*;119120if !prop_top_level(ae) {121return false;122}123124match ae {125// Literals that aren't being projected are allowed to be non-scalar, so we don't add them126// for inspection. (e.g. `is_in(<literal>)`).127#[cfg(feature = "is_in")]128Function {129function: IRFunctionExpr::Boolean(IRBooleanFunction::IsIn { .. }),130input,131..132} => (|| {133if let Some(rhs) = input.get(1) {134assert_eq!(input.len(), 2); // A.is_in(B)135let rhs = rhs.node();136137if matches!(expr_arena.get(rhs), AExpr::Literal { .. }) {138stack.extend([input[0].node()]);139return;140}141};142143ae.inputs_rev(stack);144})(),145_ => ae.inputs_rev(stack),146}147148true149}150151/// Checks if the top-level expression node is elementwise. If this is the case, then `stack` will152/// be extended further with any nested expression nodes.153pub fn is_elementwise(stack: &mut UnitVec<Node>, ae: &AExpr, expr_arena: &Arena<AExpr>) -> bool {154is_prop(stack, ae, expr_arena, |ae| ae.is_elementwise_top_level())155}156157pub fn all_elementwise<'a, N>(nodes: &'a [N], expr_arena: &Arena<AExpr>) -> bool158where159Node: From<&'a N>,160{161nodes162.iter()163.all(|n| is_elementwise_rec(n.into(), expr_arena))164}165166/// Recursive variant of `is_elementwise`167pub fn is_elementwise_rec(node: Node, expr_arena: &Arena<AExpr>) -> bool {168property_rec(node, expr_arena, is_elementwise)169}170171/// Checks if the top-level expression node is row-separable. If this is the case, then `stack` will172/// be extended further with any nested expression nodes.173pub fn is_row_separable(stack: &mut UnitVec<Node>, ae: &AExpr, expr_arena: &Arena<AExpr>) -> bool {174is_prop(stack, ae, expr_arena, |ae| ae.is_row_separable_top_level())175}176177pub fn all_row_separable<'a, N>(nodes: &'a [N], expr_arena: &Arena<AExpr>) -> bool178where179Node: From<&'a N>,180{181nodes182.iter()183.all(|n| is_row_separable_rec(n.into(), expr_arena))184}185186/// Recursive variant of `is_row_separable`187pub fn is_row_separable_rec(node: Node, expr_arena: &Arena<AExpr>) -> bool {188property_rec(node, expr_arena, is_row_separable)189}190191#[derive(Debug, Clone)]192pub enum ExprPushdownGroup {193/// Can be pushed. (elementwise, infallible)194///195/// e.g. non-strict cast196Pushable,197/// Cannot be pushed, but doesn't block pushables. (elementwise, fallible)198///199/// Fallible expressions are categorized into this group rather than the Barrier group. The200/// effect of this means we push more predicates, but the expression may no longer error201/// if the problematic rows are filtered out.202///203/// e.g. strict-cast, list.get(null_on_oob=False), to_datetime(strict=True)204Fallible,205/// Cannot be pushed, and blocks all expressions at the current level. (non-elementwise)206///207/// e.g. sort()208Barrier,209}210211impl ExprPushdownGroup {212/// Note:213/// * `stack` is not extended with any nodes if a barrier expression is seen.214/// * This function is not recursive - the caller should repeatedly215/// call this function with the `stack` to perform a recursive check.216pub fn update_with_expr(217&mut self,218stack: &mut UnitVec<Node>,219ae: &AExpr,220expr_arena: &Arena<AExpr>,221) -> &mut Self {222match self {223ExprPushdownGroup::Pushable | ExprPushdownGroup::Fallible => {224// Downgrade to unpushable if fallible225if match ae {226// Rows that go OOB on get/gather may be filtered out in earlier operations,227// so we don't push these down.228AExpr::Function {229function: IRFunctionExpr::ListExpr(IRListFunction::Get(false)),230..231} => true,232233#[cfg(feature = "list_gather")]234AExpr::Function {235function: IRFunctionExpr::ListExpr(IRListFunction::Gather(false)),236..237} => true,238239#[cfg(feature = "dtype-array")]240AExpr::Function {241function: IRFunctionExpr::ArrayExpr(IRArrayFunction::Get(false)),242..243} => true,244245#[cfg(all(feature = "strings", feature = "temporal"))]246AExpr::Function {247input,248function:249IRFunctionExpr::StringExpr(IRStringFunction::Strptime(_, strptime_options)),250..251} => {252debug_assert!(input.len() <= 2);253254let ambiguous_arg_is_infallible_scalar = input255.get(1)256.map(|x| expr_arena.get(x.node()))257.is_some_and(|ae| match ae {258AExpr::Literal(lv) => {259lv.extract_str().is_some_and(|ambiguous| match ambiguous {260"earliest" | "latest" | "null" => true,261"raise" => false,262v => {263if cfg!(debug_assertions) {264panic!("unhandled parameter to ambiguous: {v}")265}266false267},268})269},270_ => false,271});272273let ambiguous_is_fallible = !ambiguous_arg_is_infallible_scalar;274275strptime_options.strict || ambiguous_is_fallible276},277AExpr::Cast {278expr,279dtype: _,280options: CastOptions::Strict,281} => !matches!(expr_arena.get(*expr), AExpr::Literal(_)),282283_ => false,284} {285*self = ExprPushdownGroup::Fallible;286}287288// Downgrade to barrier if non-elementwise289if !is_elementwise(stack, ae, expr_arena) {290*self = ExprPushdownGroup::Barrier291}292},293294ExprPushdownGroup::Barrier => {},295}296297self298}299300pub fn update_with_expr_rec<'a>(301&mut self,302mut ae: &'a AExpr,303expr_arena: &'a Arena<AExpr>,304scratch: Option<&mut UnitVec<Node>>,305) -> &mut Self {306let mut local_scratch = unitvec![];307let stack = scratch.unwrap_or(&mut local_scratch);308309loop {310self.update_with_expr(stack, ae, expr_arena);311312if let ExprPushdownGroup::Barrier = self {313return self;314}315316let Some(node) = stack.pop() else {317break;318};319320ae = expr_arena.get(node);321}322323self324}325326pub fn blocks_pushdown(&self, maintain_errors: bool) -> bool {327match self {328ExprPushdownGroup::Barrier => true,329ExprPushdownGroup::Fallible => maintain_errors,330ExprPushdownGroup::Pushable => false,331}332}333}334335pub fn can_pre_agg_exprs(336exprs: &[ExprIR],337expr_arena: &Arena<AExpr>,338_input_schema: &Schema,339) -> bool {340exprs341.iter()342.all(|e| can_pre_agg(e.node(), expr_arena, _input_schema))343}344345/// Checks whether an expression can be pre-aggregated in a group-by. Note that this also must be346/// implemented physically, so this isn't a complete list.347pub fn can_pre_agg(agg: Node, expr_arena: &Arena<AExpr>, _input_schema: &Schema) -> bool {348let aexpr = expr_arena.get(agg);349350match aexpr {351AExpr::Len => true,352AExpr::Column(_) | AExpr::Literal(_) => false,353// We only allow expressions that end with an aggregation.354AExpr::Agg(_) => {355let has_aggregation =356|node: Node| has_aexpr(node, expr_arena, |ae| matches!(ae, AExpr::Agg(_)));357358// check if the aggregation type is partitionable359// only simple aggregation like col().sum360// that can be divided in to the aggregation of their partitions are allowed361let can_partition = (expr_arena).iter(agg).all(|(_, ae)| {362use AExpr::*;363match ae {364// struct is needed to keep both states365#[cfg(feature = "dtype-struct")]366Agg(IRAggExpr::Mean(_)) => {367// only numeric means for now.368// logical types seem to break because of casts to float.369matches!(370expr_arena371.get(agg)372.get_dtype(_input_schema, expr_arena)373.map(|dt| { dt.is_primitive_numeric() }),374Ok(true)375)376},377// only allowed expressions378Agg(agg_e) => {379matches!(380agg_e,381IRAggExpr::Min { .. }382| IRAggExpr::Max { .. }383| IRAggExpr::Sum(_)384| IRAggExpr::Last(_)385| IRAggExpr::First(_)386| IRAggExpr::Count {387input: _,388include_nulls: true389}390)391},392Function { input, options, .. } => {393options.is_elementwise()394&& input.len() == 1395&& !has_aggregation(input[0].node())396},397BinaryExpr { left, right, .. } => {398!has_aggregation(*left) && !has_aggregation(*right)399},400Ternary {401truthy,402falsy,403predicate,404..405} => {406!has_aggregation(*truthy)407&& !has_aggregation(*falsy)408&& !has_aggregation(*predicate)409},410Literal(lv) => lv.is_scalar(),411Column(_) | Len | Cast { .. } => true,412_ => false,413}414});415416#[cfg(feature = "object")]417{418for name in aexpr_to_leaf_names(agg, expr_arena) {419let dtype = _input_schema.get(&name).unwrap();420421if let DataType::Object(_) = dtype {422return false;423}424}425}426can_partition427},428_ => false,429}430}431432/// Identifies columns that are guaranteed to be non-NULL after applying this filter.433///434/// This is conservative in that it will not give false positives, but may not identify all columns.435///436/// Note, this must be called with the root node of filter expressions (the root nodes after splitting437/// with MintermIter is also allowed).438pub(crate) fn predicate_non_null_column_outputs(439predicate_node: Node,440expr_arena: &Arena<AExpr>,441non_null_column_callback: &mut dyn FnMut(&PlSmallStr),442) {443let mut minterm_iter = MintermIter::new(predicate_node, expr_arena);444let stack: &mut UnitVec<Node> = &mut unitvec![];445446/// Only traverse the first input, e.g. `A.is_in(B)` we don't consider B.447macro_rules! traverse_first_input {448// &[ExprIR]449($inputs:expr) => {{450if let Some(expr_ir) = $inputs.first() {451stack.push(expr_ir.node())452}453454false455}};456}457458loop {459use AExpr::*;460461let node = if let Some(node) = stack.pop() {462node463} else if let Some(minterm_node) = minterm_iter.next() {464// Some additional leaf exprs can be pruned.465match expr_arena.get(minterm_node) {466Function {467input,468function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNotNull),469options: _,470} if !input.is_empty() => input.first().unwrap().node(),471472Function {473input,474function: IRFunctionExpr::Boolean(IRBooleanFunction::Not),475options: _,476} if !input.is_empty() => match expr_arena.get(input.first().unwrap().node()) {477Function {478input,479function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNull),480options: _,481} if !input.is_empty() => input.first().unwrap().node(),482483_ => minterm_node,484},485486_ => minterm_node,487}488} else {489break;490};491492let ae = expr_arena.get(node);493494// This match we traverse a subset of the operations that are guaranteed to maintain NULLs.495//496// This must not catch any operations that materialize NULLs, as otherwise e.g.497// `e.fill_null(False) >= False` will include NULLs498let traverse_all_inputs = match ae {499BinaryExpr {500left: _,501op,502right: _,503} => {504use Operator::*;505506match op {507Eq | NotEq | Lt | LtEq | Gt | GtEq | Plus | Minus | Multiply | Divide508| TrueDivide | FloorDivide | Modulus | Xor => true,509510// These can turn NULLs into true/false. E.g.:511// * (L & False) >= False becomes True512// * L | True becomes True513EqValidity | NotEqValidity | Or | LogicalOr | And | LogicalAnd => false,514}515},516517Cast { dtype, .. } => {518// Forbid nested types, it's currently buggy:519// >>> pl.select(a=pl.lit(None), b=pl.lit(None).cast(pl.Struct({})))520// | a | b |521// | --- | --- |522// | null | struct[0] |523// |------|-----------|524// | null | {} |525//526// (issue at https://github.com/pola-rs/polars/issues/23276)527!dtype.is_nested()528},529530Function {531input,532function: _,533options,534} => {535if options536.flags537.contains(FunctionFlags::PRESERVES_NULL_FIRST_INPUT)538{539traverse_first_input!(input)540} else {541options542.flags543.contains(FunctionFlags::PRESERVES_NULL_ALL_INPUTS)544}545},546547Column(name) => {548non_null_column_callback(name);549false550},551552_ => false,553};554555if traverse_all_inputs {556ae.inputs_rev(stack);557}558}559}560561562