Path: blob/main/crates/polars-plan/src/plans/optimizer/simplify_expr/mod.rs
7889 views
mod simplify_functions;12use polars_utils::float16::pf16;3use polars_utils::floor_divmod::FloorDivMod;4use polars_utils::total_ord::ToTotalOrd;5use simplify_functions::optimize_functions;6mod arity;78use crate::plans::*;910fn new_null_count(input: &[ExprIR]) -> AExpr {11let function = IRFunctionExpr::NullCount;12let options = function.function_options();13AExpr::Function {14input: input.to_vec(),15function,16options,17}18}1920macro_rules! eval_binary_same_type {21($lhs:expr, $rhs:expr, |$l: ident, $r: ident| $ret: expr) => {{22if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = ($lhs, $rhs) {23match (lit_left, lit_right) {24(LiteralValue::Scalar(l), LiteralValue::Scalar(r)) => {25match (l.as_any_value(), r.as_any_value()) {26(AnyValue::Float16($l), AnyValue::Float16($r)) => {27Some(AExpr::Literal(Scalar::from($ret).into()))28},29(AnyValue::Float32($l), AnyValue::Float32($r)) => {30Some(AExpr::Literal(<Scalar as From<f32>>::from($ret).into()))31},32(AnyValue::Float64($l), AnyValue::Float64($r)) => {33Some(AExpr::Literal(<Scalar as From<f64>>::from($ret).into()))34},3536(AnyValue::Int8($l), AnyValue::Int8($r)) => {37Some(AExpr::Literal(<Scalar as From<i8>>::from($ret).into()))38},39(AnyValue::Int16($l), AnyValue::Int16($r)) => {40Some(AExpr::Literal(<Scalar as From<i16>>::from($ret).into()))41},42(AnyValue::Int32($l), AnyValue::Int32($r)) => {43Some(AExpr::Literal(<Scalar as From<i32>>::from($ret).into()))44},45(AnyValue::Int64($l), AnyValue::Int64($r)) => {46Some(AExpr::Literal(<Scalar as From<i64>>::from($ret).into()))47},48(AnyValue::Int128($l), AnyValue::Int128($r)) => {49Some(AExpr::Literal(<Scalar as From<i128>>::from($ret).into()))50},5152(AnyValue::UInt8($l), AnyValue::UInt8($r)) => {53Some(AExpr::Literal(<Scalar as From<u8>>::from($ret).into()))54},55(AnyValue::UInt16($l), AnyValue::UInt16($r)) => {56Some(AExpr::Literal(<Scalar as From<u16>>::from($ret).into()))57},58(AnyValue::UInt32($l), AnyValue::UInt32($r)) => {59Some(AExpr::Literal(<Scalar as From<u32>>::from($ret).into()))60},61(AnyValue::UInt64($l), AnyValue::UInt64($r)) => {62Some(AExpr::Literal(<Scalar as From<u64>>::from($ret).into()))63},64(AnyValue::UInt128($l), AnyValue::UInt128($r)) => {65Some(AExpr::Literal(<Scalar as From<u128>>::from($ret).into()))66},6768_ => None,69}70.into()71},72(73LiteralValue::Dyn(DynLiteralValue::Float($l)),74LiteralValue::Dyn(DynLiteralValue::Float($r)),75) => {76let $l = *$l;77let $r = *$r;78Some(AExpr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(79$ret,80))))81},82(83LiteralValue::Dyn(DynLiteralValue::Int($l)),84LiteralValue::Dyn(DynLiteralValue::Int($r)),85) => {86let $l = *$l;87let $r = *$r;88Some(AExpr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(89$ret,90))))91},92_ => None,93}94} else {95None96}97}};98}99100macro_rules! eval_binary_cmp_same_type {101($lhs:expr, $operand: tt, $rhs:expr) => {{102if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = ($lhs, $rhs) {103match (lit_left, lit_right) {104(LiteralValue::Scalar(l), LiteralValue::Scalar(r)) => match (l.as_any_value(), r.as_any_value()) {105(AnyValue::Float16(l), AnyValue::Float16(r)) => Some(AExpr::Literal({ let x: bool = l.to_total_ord() $operand r.to_total_ord(); Scalar::from(x) }.into())),106(AnyValue::Float32(l), AnyValue::Float32(r)) => Some(AExpr::Literal({ let x: bool = l.to_total_ord() $operand r.to_total_ord(); Scalar::from(x) }.into())),107(AnyValue::Float64(l), AnyValue::Float64(r)) => Some(AExpr::Literal({ let x: bool = l.to_total_ord() $operand r.to_total_ord(); Scalar::from(x) }.into())),108109(AnyValue::Boolean(l), AnyValue::Boolean(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),110111(AnyValue::Int8(l), AnyValue::Int8(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),112(AnyValue::Int16(l), AnyValue::Int16(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),113(AnyValue::Int32(l), AnyValue::Int32(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),114(AnyValue::Int64(l), AnyValue::Int64(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),115(AnyValue::Int128(l), AnyValue::Int128(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),116117(AnyValue::UInt8(l), AnyValue::UInt8(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),118(AnyValue::UInt16(l), AnyValue::UInt16(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),119(AnyValue::UInt32(l), AnyValue::UInt32(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),120(AnyValue::UInt64(l), AnyValue::UInt64(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),121(AnyValue::UInt128(l), AnyValue::UInt128(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),122123_ => None,124}.into(),125(LiteralValue::Dyn(DynLiteralValue::Float(l)), LiteralValue::Dyn(DynLiteralValue::Float(r))) => {126let x: bool = l.to_total_ord() $operand r.to_total_ord();127Some(AExpr::Literal(Scalar::from(x).into()))128},129(LiteralValue::Dyn(DynLiteralValue::Int(l)), LiteralValue::Dyn(DynLiteralValue::Int(r))) => {130let x: bool = l $operand r;131Some(AExpr::Literal(Scalar::from(x).into()))132},133_ => None,134}135} else {136None137}138139}}140}141142pub struct SimplifyBooleanRule {}143144impl OptimizationRule for SimplifyBooleanRule {145fn optimize_expr(146&mut self,147expr_arena: &mut Arena<AExpr>,148expr_node: Node,149_schema: &Schema,150ctx: OptimizeExprContext,151) -> PolarsResult<Option<AExpr>> {152let expr = expr_arena.get(expr_node);153154let out = match expr {155// true AND x => x156AExpr::BinaryExpr { left, op, right } => {157return Ok(arity::simplify_binary(*left, *op, *right, ctx, expr_arena));158},159AExpr::Ternary {160predicate,161truthy,162falsy,163} => {164return Ok(arity::simplify_ternary(165*predicate, *truthy, *falsy, expr_arena,166));167},168AExpr::Function {169input,170function: IRFunctionExpr::Negate,171..172} if input.len() == 1 => {173let input = &input[0];174let ae = expr_arena.get(input.node());175eval_negate(ae)176},177_ => None,178};179Ok(out)180}181}182183fn eval_negate(ae: &AExpr) -> Option<AExpr> {184use std::ops::Neg;185let out = match ae {186AExpr::Literal(lv) => match lv {187LiteralValue::Scalar(sc) => match sc.as_any_value() {188AnyValue::Int8(v) => Scalar::from(v.checked_neg()?),189AnyValue::Int16(v) => Scalar::from(v.checked_neg()?),190AnyValue::Int32(v) => Scalar::from(v.checked_neg()?),191AnyValue::Int64(v) => Scalar::from(v.checked_neg()?),192AnyValue::Int128(v) => Scalar::from(v.checked_neg()?),193AnyValue::Float16(v) => Scalar::from(v.neg()),194AnyValue::Float32(v) => Scalar::from(v.neg()),195AnyValue::Float64(v) => Scalar::from(v.neg()),196_ => return None,197}198.into(),199LiteralValue::Dyn(d) => LiteralValue::Dyn(match d {200DynLiteralValue::Int(v) => DynLiteralValue::Int(v.checked_neg()?),201DynLiteralValue::Float(v) => DynLiteralValue::Float(v.neg()),202_ => return None,203}),204_ => return None,205},206_ => return None,207};208Some(AExpr::Literal(out))209}210211fn eval_bitwise<F>(left: &AExpr, right: &AExpr, operation: F) -> Option<AExpr>212where213F: Fn(bool, bool) -> bool,214{215if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = (left, right) {216return match (lit_left.bool(), lit_right.bool()) {217(Some(x), Some(y)) => Some(AExpr::Literal(Scalar::from(operation(x, y)).into())),218_ => None,219};220}221None222}223224#[cfg(all(feature = "strings", feature = "concat_str"))]225fn string_addition_to_linear_concat(226expr_arena: &Arena<AExpr>,227left_node: Node,228right_node: Node,229left_aexpr: &AExpr,230right_aexpr: &AExpr,231input_schema: &Schema,232) -> Option<AExpr> {233{234let left_e = ExprIR::from_node(left_node, expr_arena);235let right_e = ExprIR::from_node(right_node, expr_arena);236237let get_type = |ae: &AExpr| {238ae.to_dtype(&ToFieldContext::new(expr_arena, input_schema))239.ok()240};241let type_a = get_type(left_aexpr).or_else(|| get_type(right_aexpr))?;242let type_b = get_type(right_aexpr).or_else(|| get_type(right_aexpr))?;243244if type_a != type_b {245return None;246}247248if type_a.is_string() {249match (left_aexpr, right_aexpr) {250// concat + concat251(252AExpr::Function {253input: input_left,254function:255fun_l @ IRFunctionExpr::StringExpr(IRStringFunction::ConcatHorizontal {256delimiter: sep_l,257ignore_nulls: ignore_nulls_l,258}),259options,260},261AExpr::Function {262input: input_right,263function:264IRFunctionExpr::StringExpr(IRStringFunction::ConcatHorizontal {265delimiter: sep_r,266ignore_nulls: ignore_nulls_r,267}),268..269},270) => {271if sep_l.is_empty() && sep_r.is_empty() && ignore_nulls_l == ignore_nulls_r {272let mut input = Vec::with_capacity(input_left.len() + input_right.len());273input.extend_from_slice(input_left);274input.extend_from_slice(input_right);275Some(AExpr::Function {276input,277function: fun_l.clone(),278options: *options,279})280} else {281None282}283},284// concat + str285(286AExpr::Function {287input,288function:289fun @ IRFunctionExpr::StringExpr(IRStringFunction::ConcatHorizontal {290delimiter: sep,291ignore_nulls,292}),293options,294},295_,296) => {297if sep.is_empty() && !ignore_nulls {298let mut input = input.clone();299input.push(right_e);300Some(AExpr::Function {301input,302function: fun.clone(),303options: *options,304})305} else {306None307}308},309// str + concat310(311_,312AExpr::Function {313input: input_right,314function:315fun @ IRFunctionExpr::StringExpr(IRStringFunction::ConcatHorizontal {316delimiter: sep,317ignore_nulls,318}),319options,320},321) => {322if sep.is_empty() && !ignore_nulls {323let mut input = Vec::with_capacity(1 + input_right.len());324input.push(left_e);325input.extend_from_slice(input_right);326Some(AExpr::Function {327input,328function: fun.clone(),329options: *options,330})331} else {332None333}334},335_ => {336let function = IRStringFunction::ConcatHorizontal {337delimiter: "".into(),338ignore_nulls: false,339};340let options = function.function_options();341Some(AExpr::Function {342input: vec![left_e, right_e],343function: function.into(),344options,345})346},347}348} else {349None350}351}352}353354pub struct SimplifyExprRule {}355356impl OptimizationRule for SimplifyExprRule {357#[allow(clippy::float_cmp)]358fn optimize_expr(359&mut self,360expr_arena: &mut Arena<AExpr>,361expr_node: Node,362schema: &Schema,363_ctx: OptimizeExprContext,364) -> PolarsResult<Option<AExpr>> {365let expr = expr_arena.get(expr_node);366367let out = match &expr {368AExpr::SortBy { expr, by, .. } if by.is_empty() => Some(expr_arena.get(*expr).clone()),369// drop_nulls().len() -> len() - null_count()370// drop_nulls().count() -> len() - null_count()371AExpr::Agg(IRAggExpr::Count {372input,373include_nulls: _,374}) => {375let input_expr = expr_arena.get(*input);376match input_expr {377AExpr::Function {378input,379function: IRFunctionExpr::DropNulls,380options: _,381} => {382// we should perform optimization only if the original expression is a column383// so in case of disabled CSE, we will not suffer from performance regression384if input.len() == 1 {385let drop_nulls_input_node = input[0].node();386match expr_arena.get(drop_nulls_input_node) {387AExpr::Column(_) => Some(AExpr::BinaryExpr {388op: Operator::Minus,389right: expr_arena.add(new_null_count(input)),390left: expr_arena.add(AExpr::Agg(IRAggExpr::Count {391input: drop_nulls_input_node,392include_nulls: true,393})),394}),395_ => None,396}397} else {398None399}400},401_ => None,402}403},404// is_null().sum() -> null_count()405// is_not_null().sum() -> len() - null_count()406AExpr::Agg(IRAggExpr::Sum(input)) => {407let input_expr = expr_arena.get(*input);408match input_expr {409AExpr::Function {410input,411function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNull),412options: _,413} => Some(new_null_count(input)),414AExpr::Function {415input,416function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNotNull),417options: _,418} => {419// we should perform optimization only if the original expression is a column420// so in case of disabled CSE, we will not suffer from performance regression421if input.len() == 1 {422let is_not_null_input_node = input[0].node();423match expr_arena.get(is_not_null_input_node) {424AExpr::Column(_) => Some(AExpr::BinaryExpr {425op: Operator::Minus,426right: expr_arena.add(new_null_count(input)),427left: expr_arena.add(AExpr::Agg(IRAggExpr::Count {428input: is_not_null_input_node,429include_nulls: true,430})),431}),432_ => None,433}434} else {435None436}437},438_ => None,439}440},441// lit(left) + lit(right) => lit(left + right)442// and null propagation443AExpr::BinaryExpr { left, op, right } => {444let left_aexpr = expr_arena.get(*left);445let right_aexpr = expr_arena.get(*right);446447// lit(left) + lit(right) => lit(left + right)448use Operator::*;449#[allow(clippy::manual_map)]450let out = match op {451Plus => {452match eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l + r) {453Some(new) => Some(new),454None => {455// try to replace addition of string columns with `concat_str`456#[cfg(all(feature = "strings", feature = "concat_str"))]457{458string_addition_to_linear_concat(459expr_arena,460*left,461*right,462left_aexpr,463right_aexpr,464schema,465)466}467#[cfg(not(all(feature = "strings", feature = "concat_str")))]468{469None470}471},472}473},474Minus => eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l - r),475Multiply => eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l * r),476Divide => {477if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) =478(left_aexpr, right_aexpr)479{480match (lit_left, lit_right) {481(LiteralValue::Scalar(l), LiteralValue::Scalar(r)) => {482match (l.as_any_value(), r.as_any_value()) {483(AnyValue::Float16(x), AnyValue::Float16(y)) => {484Some(AExpr::Literal(485<Scalar as From<pf16>>::from(x / y).into(),486))487},488(AnyValue::Float32(x), AnyValue::Float32(y)) => {489Some(AExpr::Literal(490<Scalar as From<f32>>::from(x / y).into(),491))492},493(AnyValue::Float64(x), AnyValue::Float64(y)) => {494Some(AExpr::Literal(495<Scalar as From<f64>>::from(x / y).into(),496))497},498499(AnyValue::Int8(x), AnyValue::Int8(y)) => {500Some(AExpr::Literal(501<Scalar as From<i8>>::from(502x.wrapping_floor_div_mod(y).0,503)504.into(),505))506},507(AnyValue::Int16(x), AnyValue::Int16(y)) => {508Some(AExpr::Literal(509<Scalar as From<i16>>::from(510x.wrapping_floor_div_mod(y).0,511)512.into(),513))514},515(AnyValue::Int32(x), AnyValue::Int32(y)) => {516Some(AExpr::Literal(517<Scalar as From<i32>>::from(518x.wrapping_floor_div_mod(y).0,519)520.into(),521))522},523(AnyValue::Int64(x), AnyValue::Int64(y)) => {524Some(AExpr::Literal(525<Scalar as From<i64>>::from(526x.wrapping_floor_div_mod(y).0,527)528.into(),529))530},531(AnyValue::Int128(x), AnyValue::Int128(y)) => {532Some(AExpr::Literal(533<Scalar as From<i128>>::from(534x.wrapping_floor_div_mod(y).0,535)536.into(),537))538},539540(AnyValue::UInt8(x), AnyValue::UInt8(y)) => {541Some(AExpr::Literal(542<Scalar as From<u8>>::from(x / y).into(),543))544},545(AnyValue::UInt16(x), AnyValue::UInt16(y)) => {546Some(AExpr::Literal(547<Scalar as From<u16>>::from(x / y).into(),548))549},550(AnyValue::UInt32(x), AnyValue::UInt32(y)) => {551Some(AExpr::Literal(552<Scalar as From<u32>>::from(x / y).into(),553))554},555(AnyValue::UInt64(x), AnyValue::UInt64(y)) => {556Some(AExpr::Literal(557<Scalar as From<u64>>::from(x / y).into(),558))559},560(AnyValue::UInt128(x), AnyValue::UInt128(y)) => {561Some(AExpr::Literal(562<Scalar as From<u128>>::from(x / y).into(),563))564},565566_ => None,567}568},569570(571LiteralValue::Dyn(DynLiteralValue::Float(x)),572LiteralValue::Dyn(DynLiteralValue::Float(y)),573) => {574Some(AExpr::Literal(<Scalar as From<f64>>::from(x / y).into()))575},576(577LiteralValue::Dyn(DynLiteralValue::Int(x)),578LiteralValue::Dyn(DynLiteralValue::Int(y)),579) => Some(AExpr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(580x.wrapping_floor_div_mod(*y).0,581)))),582_ => None,583}584} else {585None586}587},588TrueDivide => {589if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) =590(left_aexpr, right_aexpr)591{592match (lit_left, lit_right) {593(LiteralValue::Scalar(l), LiteralValue::Scalar(r)) => {594match (l.as_any_value(), r.as_any_value()) {595#[cfg(feature = "dtype-f16")]596(AnyValue::Float16(x), AnyValue::Float16(y)) => {597Some(AExpr::Literal(Scalar::from(x / y).into()))598},599(AnyValue::Float32(x), AnyValue::Float32(y)) => {600Some(AExpr::Literal(Scalar::from(x / y).into()))601},602(AnyValue::Float64(x), AnyValue::Float64(y)) => {603Some(AExpr::Literal(Scalar::from(x / y).into()))604},605606(AnyValue::Int8(x), AnyValue::Int8(y)) => {607Some(AExpr::Literal(608Scalar::from(x as f64 / y as f64).into(),609))610},611(AnyValue::Int16(x), AnyValue::Int16(y)) => {612Some(AExpr::Literal(613Scalar::from(x as f64 / y as f64).into(),614))615},616(AnyValue::Int32(x), AnyValue::Int32(y)) => {617Some(AExpr::Literal(618Scalar::from(x as f64 / y as f64).into(),619))620},621(AnyValue::Int64(x), AnyValue::Int64(y)) => {622Some(AExpr::Literal(623Scalar::from(x as f64 / y as f64).into(),624))625},626(AnyValue::Int128(x), AnyValue::Int128(y)) => {627Some(AExpr::Literal(628Scalar::from(x as f64 / y as f64).into(),629))630},631632(AnyValue::UInt8(x), AnyValue::UInt8(y)) => {633Some(AExpr::Literal(634Scalar::from(x as f64 / y as f64).into(),635))636},637(AnyValue::UInt16(x), AnyValue::UInt16(y)) => {638Some(AExpr::Literal(639Scalar::from(x as f64 / y as f64).into(),640))641},642(AnyValue::UInt32(x), AnyValue::UInt32(y)) => {643Some(AExpr::Literal(644Scalar::from(x as f64 / y as f64).into(),645))646},647(AnyValue::UInt64(x), AnyValue::UInt64(y)) => {648Some(AExpr::Literal(649Scalar::from(x as f64 / y as f64).into(),650))651},652653_ => None,654}655},656657(658LiteralValue::Dyn(DynLiteralValue::Float(x)),659LiteralValue::Dyn(DynLiteralValue::Float(y)),660) => Some(AExpr::Literal(Scalar::from(*x / *y).into())),661(662LiteralValue::Dyn(DynLiteralValue::Int(x)),663LiteralValue::Dyn(DynLiteralValue::Int(y)),664) => {665Some(AExpr::Literal(Scalar::from(*x as f64 / *y as f64).into()))666},667_ => None,668}669} else {670None671}672},673Modulus => eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l674.wrapping_floor_div_mod(r)675.1),676Lt => eval_binary_cmp_same_type!(left_aexpr, <, right_aexpr),677Gt => eval_binary_cmp_same_type!(left_aexpr, >, right_aexpr),678Eq | EqValidity => eval_binary_cmp_same_type!(left_aexpr, ==, right_aexpr),679NotEq | NotEqValidity => {680eval_binary_cmp_same_type!(left_aexpr, !=, right_aexpr)681},682GtEq => eval_binary_cmp_same_type!(left_aexpr, >=, right_aexpr),683LtEq => eval_binary_cmp_same_type!(left_aexpr, <=, right_aexpr),684And | LogicalAnd => eval_bitwise(left_aexpr, right_aexpr, |l, r| l & r),685Or | LogicalOr => eval_bitwise(left_aexpr, right_aexpr, |l, r| l | r),686Xor => eval_bitwise(left_aexpr, right_aexpr, |l, r| l ^ r),687FloorDivide => eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l688.wrapping_floor_div_mod(r)689.0),690};691if out.is_some() {692return Ok(out);693}694695None696},697AExpr::Function {698input,699function,700options,701..702} => {703return optimize_functions(input.clone(), function.clone(), *options, expr_arena);704},705_ => None,706};707Ok(out)708}709}710711712