Path: blob/main/crates/polars-plan/src/plans/optimizer/simplify_expr/simplify_functions.rs
7889 views
use super::*;12pub(super) fn optimize_functions(3input: Vec<ExprIR>,4function: IRFunctionExpr,5options: FunctionOptions,6expr_arena: &mut Arena<AExpr>,7) -> PolarsResult<Option<AExpr>> {8let out = match function {9// is_null().any() -> null_count() > 010// is_not_null().any() -> null_count() < len()11// CORRECTNESS: we can ignore 'ignore_nulls' since is_null/is_not_null never produces NULLS12IRFunctionExpr::Boolean(IRBooleanFunction::Any { ignore_nulls: _ }) => {13let input_node = expr_arena.get(input[0].node());14match input_node {15AExpr::Function {16input,17function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNull),18options: _,19} => Some(AExpr::BinaryExpr {20left: expr_arena.add(new_null_count(input)),21op: Operator::Gt,22right: expr_arena.add(AExpr::Literal(LiteralValue::new_idxsize(0))),23}),24AExpr::Function {25input,26function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNotNull),27options: _,28} => {29// we should perform optimization only if the original expression is a column30// so in case of disabled CSE, we will not suffer from performance regression31if input.len() == 1 {32let is_not_null_input_node = input[0].node();33match expr_arena.get(is_not_null_input_node) {34AExpr::Column(_) => Some(AExpr::BinaryExpr {35op: Operator::Lt,36left: expr_arena.add(new_null_count(input)),37right: expr_arena.add(AExpr::Agg(IRAggExpr::Count {38input: is_not_null_input_node,39include_nulls: true,40})),41}),42_ => None,43}44} else {45None46}47},48_ => None,49}50},51// is_null().all() -> null_count() == len()52// is_not_null().all() -> null_count() == 053IRFunctionExpr::Boolean(IRBooleanFunction::All { ignore_nulls: _ }) => {54let input_node = expr_arena.get(input[0].node());55match input_node {56AExpr::Function {57input,58function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNull),59options: _,60} => {61// we should perform optimization only if the original expression is a column62// so in case of disabled CSE, we will not suffer from performance regression63if input.len() == 1 {64let is_null_input_node = input[0].node();65match expr_arena.get(is_null_input_node) {66AExpr::Column(_) => Some(AExpr::BinaryExpr {67op: Operator::Eq,68right: expr_arena.add(new_null_count(input)),69left: expr_arena.add(AExpr::Agg(IRAggExpr::Count {70input: is_null_input_node,71include_nulls: true,72})),73}),74_ => None,75}76} else {77None78}79},80AExpr::Function {81input,82function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNotNull),83options: _,84} => Some(AExpr::BinaryExpr {85left: expr_arena.add(new_null_count(input)),86op: Operator::Eq,87right: expr_arena.add(AExpr::Literal(LiteralValue::new_idxsize(0))),88}),89_ => None,90}91},92// sort().reverse() -> sort(reverse)93// sort_by().reverse() -> sort_by(reverse)94IRFunctionExpr::Reverse => {95let input = expr_arena.get(input[0].node());96match input {97AExpr::Sort { expr, options } => {98let mut options = *options;99options.descending = !options.descending;100Some(AExpr::Sort {101expr: *expr,102options,103})104},105AExpr::SortBy {106expr,107by,108sort_options,109} => {110let mut sort_options = sort_options.clone();111let reversed_descending = sort_options.descending.iter().map(|x| !*x).collect();112sort_options.descending = reversed_descending;113Some(AExpr::SortBy {114expr: *expr,115by: by.clone(),116sort_options,117})118},119// TODO: add support for cum_sum and other operation that allow reversing.120_ => None,121}122},123// flatten nested concat_str calls124#[cfg(all(feature = "strings", feature = "concat_str"))]125ref function @ IRFunctionExpr::StringExpr(IRStringFunction::ConcatHorizontal {126delimiter: ref sep,127ignore_nulls,128}) if sep.is_empty() => {129if input130.iter()131.any(|e| is_string_concat(expr_arena.get(e.node()), ignore_nulls))132{133let mut new_inputs = Vec::with_capacity(input.len() * 2);134135for e in input {136match get_string_concat_input(e.node(), expr_arena, ignore_nulls) {137Some(inp) => new_inputs.extend_from_slice(inp),138None => new_inputs.push(e.clone()),139}140}141Some(AExpr::Function {142input: new_inputs,143function: function.clone(),144options,145})146} else {147None148}149},150IRFunctionExpr::Boolean(IRBooleanFunction::Not) => {151let y = expr_arena.get(input[0].node());152153match y {154// not(a and b) => not(a) or not(b)155AExpr::BinaryExpr {156left,157op: Operator::And | Operator::LogicalAnd,158right,159} => {160let left = *left;161let right = *right;162Some(AExpr::BinaryExpr {163left: expr_arena.add(AExpr::Function {164input: vec![ExprIR::from_node(left, expr_arena)],165function: IRFunctionExpr::Boolean(IRBooleanFunction::Not),166options,167}),168op: Operator::Or,169right: expr_arena.add(AExpr::Function {170input: vec![ExprIR::from_node(right, expr_arena)],171function: IRFunctionExpr::Boolean(IRBooleanFunction::Not),172options,173}),174})175},176// not(a or b) => not(a) and not(b)177AExpr::BinaryExpr {178left,179op: Operator::Or | Operator::LogicalOr,180right,181} => {182let left = *left;183let right = *right;184Some(AExpr::BinaryExpr {185left: expr_arena.add(AExpr::Function {186input: vec![ExprIR::from_node(left, expr_arena)],187function: IRFunctionExpr::Boolean(IRBooleanFunction::Not),188options,189}),190op: Operator::And,191right: expr_arena.add(AExpr::Function {192input: vec![ExprIR::from_node(right, expr_arena)],193function: IRFunctionExpr::Boolean(IRBooleanFunction::Not),194options,195}),196})197},198// not(not x) => x199AExpr::Function {200input,201function: IRFunctionExpr::Boolean(IRBooleanFunction::Not),202..203} => Some(expr_arena.get(input[0].node()).clone()),204// not(lit x) => !x205AExpr::Literal(lv) if lv.bool().is_some() => {206Some(AExpr::Literal(Scalar::from(!lv.bool().unwrap()).into()))207},208// not(x.is_null) => x.is_not_null209AExpr::Function {210input,211function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNull),212options,213} => Some(AExpr::Function {214input: input.clone(),215function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNotNull),216options: *options,217}),218// not(x.is_not_null) => x.is_null219AExpr::Function {220input,221function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNotNull),222options,223} => Some(AExpr::Function {224input: input.clone(),225function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNull),226options: *options,227}),228// not(a == b) => a != b229AExpr::BinaryExpr {230left,231op: Operator::Eq,232right,233} => Some(AExpr::BinaryExpr {234left: *left,235op: Operator::NotEq,236right: *right,237}),238// not(a != b) => a == b239AExpr::BinaryExpr {240left,241op: Operator::NotEq,242right,243} => Some(AExpr::BinaryExpr {244left: *left,245op: Operator::Eq,246right: *right,247}),248// not(a < b) => a >= b249AExpr::BinaryExpr {250left,251op: Operator::Lt,252right,253} => Some(AExpr::BinaryExpr {254left: *left,255op: Operator::GtEq,256right: *right,257}),258// not(a <= b) => a > b259AExpr::BinaryExpr {260left,261op: Operator::LtEq,262right,263} => Some(AExpr::BinaryExpr {264left: *left,265op: Operator::Gt,266right: *right,267}),268// not(a > b) => a <= b269AExpr::BinaryExpr {270left,271op: Operator::Gt,272right,273} => Some(AExpr::BinaryExpr {274left: *left,275op: Operator::LtEq,276right: *right,277}),278// not(a >= b) => a < b279AExpr::BinaryExpr {280left,281op: Operator::GtEq,282right,283} => Some(AExpr::BinaryExpr {284left: *left,285op: Operator::Lt,286right: *right,287}),288#[cfg(feature = "is_between")]289// not(col('x').is_between(a,b)) => col('x') < a || col('x') > b290AExpr::Function {291input,292function: IRFunctionExpr::Boolean(IRBooleanFunction::IsBetween { closed }),293..294} => {295if !matches!(expr_arena.get(input[0].node()), AExpr::Column(_)) {296None297} else {298let left_cmp_op = match closed {299ClosedInterval::Both | ClosedInterval::Left => Operator::Lt,300ClosedInterval::None | ClosedInterval::Right => Operator::LtEq,301};302let right_cmp_op = match closed {303ClosedInterval::Both | ClosedInterval::Right => Operator::Gt,304ClosedInterval::None | ClosedInterval::Left => Operator::GtEq,305};306let left_left = input[0].node();307let right_left = input[1].node();308309let left_right = left_left;310let right_right = input[2].node();311312// input[0] is between input[1] and input[2]313Some(AExpr::BinaryExpr {314// input[0] (<,<=) input[1]315left: expr_arena.add(AExpr::BinaryExpr {316left: left_left,317op: left_cmp_op,318right: right_left,319}),320// OR321op: Operator::Or,322// input[0] (>,>=) input[2]323right: expr_arena.add(AExpr::BinaryExpr {324left: left_right,325op: right_cmp_op,326right: right_right,327}),328})329}330},331_ => None,332}333},334IRFunctionExpr::GatherEvery { n: 1, offset: 0 } => {335Some(expr_arena.get(input[0].node()).clone())336},337IRFunctionExpr::GatherEvery { n: 1, offset } => {338let offset_i64: i64 = offset.try_into().unwrap_or(i64::MAX);339let offset_node =340expr_arena.add(AExpr::Literal(LiteralValue::Scalar(offset_i64.into())));341let length_node = expr_arena.add(AExpr::Literal(LiteralValue::Scalar(342(usize::MAX as u64).into(),343)));344Some(AExpr::Slice {345input: input[0].node(),346offset: offset_node,347length: length_node,348})349},350_ => None,351};352Ok(out)353}354355#[cfg(all(feature = "strings", feature = "concat_str"))]356fn is_string_concat(ae: &AExpr, ignore_nulls: bool) -> bool {357matches!(ae, AExpr::Function {358function:IRFunctionExpr::StringExpr(359IRStringFunction::ConcatHorizontal{delimiter: sep, ignore_nulls: func_inore_nulls},360),361..362} if sep.is_empty() && *func_inore_nulls == ignore_nulls)363}364365#[cfg(all(feature = "strings", feature = "concat_str"))]366fn get_string_concat_input(367node: Node,368expr_arena: &Arena<AExpr>,369ignore_nulls: bool,370) -> Option<&[ExprIR]> {371match expr_arena.get(node) {372AExpr::Function {373input,374function:375IRFunctionExpr::StringExpr(IRStringFunction::ConcatHorizontal {376delimiter: sep,377ignore_nulls: func_ignore_nulls,378}),379..380} if sep.is_empty() && *func_ignore_nulls == ignore_nulls => Some(input),381_ => None,382}383}384385386