Path: blob/main/crates/polars-plan/src/plans/optimizer/simplify_expr/mod.rs
8416 views
mod simplify_functions;12use num_traits::Zero;3use polars_utils::float16::pf16;4use polars_utils::floor_divmod::FloorDivMod;5use polars_utils::total_ord::ToTotalOrd;6use simplify_functions::optimize_functions;7mod arity;89use crate::plans::*;1011fn new_null_count(input: &[ExprIR]) -> AExpr {12let function = IRFunctionExpr::NullCount;13let options = function.function_options();14AExpr::Function {15input: input.to_vec(),16function,17options,18}19}2021macro_rules! eval_binary_same_type {22($lhs:expr, $rhs:expr, |$l: ident, $r: ident| $ret: expr) => {{23if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = ($lhs, $rhs) {24match (lit_left, lit_right) {25(LiteralValue::Scalar(l), LiteralValue::Scalar(r)) => {26match (l.as_any_value(), r.as_any_value()) {27(AnyValue::Float16($l), AnyValue::Float16($r)) => {28Some(AExpr::Literal(Scalar::from($ret).into()))29},30(AnyValue::Float32($l), AnyValue::Float32($r)) => {31Some(AExpr::Literal(<Scalar as From<f32>>::from($ret).into()))32},33(AnyValue::Float64($l), AnyValue::Float64($r)) => {34Some(AExpr::Literal(<Scalar as From<f64>>::from($ret).into()))35},3637(AnyValue::Int8($l), AnyValue::Int8($r)) => {38Some(AExpr::Literal(<Scalar as From<i8>>::from($ret).into()))39},40(AnyValue::Int16($l), AnyValue::Int16($r)) => {41Some(AExpr::Literal(<Scalar as From<i16>>::from($ret).into()))42},43(AnyValue::Int32($l), AnyValue::Int32($r)) => {44Some(AExpr::Literal(<Scalar as From<i32>>::from($ret).into()))45},46(AnyValue::Int64($l), AnyValue::Int64($r)) => {47Some(AExpr::Literal(<Scalar as From<i64>>::from($ret).into()))48},49(AnyValue::Int128($l), AnyValue::Int128($r)) => {50Some(AExpr::Literal(<Scalar as From<i128>>::from($ret).into()))51},5253(AnyValue::UInt8($l), AnyValue::UInt8($r)) => {54Some(AExpr::Literal(<Scalar as From<u8>>::from($ret).into()))55},56(AnyValue::UInt16($l), AnyValue::UInt16($r)) => {57Some(AExpr::Literal(<Scalar as From<u16>>::from($ret).into()))58},59(AnyValue::UInt32($l), AnyValue::UInt32($r)) => {60Some(AExpr::Literal(<Scalar as From<u32>>::from($ret).into()))61},62(AnyValue::UInt64($l), AnyValue::UInt64($r)) => {63Some(AExpr::Literal(<Scalar as From<u64>>::from($ret).into()))64},65(AnyValue::UInt128($l), AnyValue::UInt128($r)) => {66Some(AExpr::Literal(<Scalar as From<u128>>::from($ret).into()))67},6869_ => None,70}71.into()72},73(74LiteralValue::Dyn(DynLiteralValue::Float($l)),75LiteralValue::Dyn(DynLiteralValue::Float($r)),76) => {77let $l = *$l;78let $r = *$r;79Some(AExpr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(80$ret,81))))82},83(84LiteralValue::Dyn(DynLiteralValue::Int($l)),85LiteralValue::Dyn(DynLiteralValue::Int($r)),86) => {87let $l = *$l;88let $r = *$r;89Some(AExpr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(90$ret,91))))92},93_ => None,94}95} else {96None97}98}};99}100101macro_rules! eval_binary_cmp_same_type {102($lhs:expr, $operand: tt, $rhs:expr) => {{103if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = ($lhs, $rhs) {104match (lit_left, lit_right) {105(LiteralValue::Scalar(l), LiteralValue::Scalar(r)) => match (l.as_any_value(), r.as_any_value()) {106(AnyValue::Float16(l), AnyValue::Float16(r)) => Some(AExpr::Literal({ let x: bool = l.to_total_ord() $operand r.to_total_ord(); Scalar::from(x) }.into())),107(AnyValue::Float32(l), AnyValue::Float32(r)) => Some(AExpr::Literal({ let x: bool = l.to_total_ord() $operand r.to_total_ord(); Scalar::from(x) }.into())),108(AnyValue::Float64(l), AnyValue::Float64(r)) => Some(AExpr::Literal({ let x: bool = l.to_total_ord() $operand r.to_total_ord(); Scalar::from(x) }.into())),109110(AnyValue::Boolean(l), AnyValue::Boolean(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),111112(AnyValue::Int8(l), AnyValue::Int8(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),113(AnyValue::Int16(l), AnyValue::Int16(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),114(AnyValue::Int32(l), AnyValue::Int32(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),115(AnyValue::Int64(l), AnyValue::Int64(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),116(AnyValue::Int128(l), AnyValue::Int128(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),117118(AnyValue::UInt8(l), AnyValue::UInt8(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),119(AnyValue::UInt16(l), AnyValue::UInt16(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),120(AnyValue::UInt32(l), AnyValue::UInt32(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),121(AnyValue::UInt64(l), AnyValue::UInt64(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),122(AnyValue::UInt128(l), AnyValue::UInt128(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())),123124_ => None,125}.into(),126(LiteralValue::Dyn(DynLiteralValue::Float(l)), LiteralValue::Dyn(DynLiteralValue::Float(r))) => {127let x: bool = l.to_total_ord() $operand r.to_total_ord();128Some(AExpr::Literal(Scalar::from(x).into()))129},130(LiteralValue::Dyn(DynLiteralValue::Int(l)), LiteralValue::Dyn(DynLiteralValue::Int(r))) => {131let x: bool = l $operand r;132Some(AExpr::Literal(Scalar::from(x).into()))133},134_ => None,135}136} else {137None138}139140}}141}142143pub struct SimplifyBooleanRule {}144145impl OptimizationRule for SimplifyBooleanRule {146fn optimize_expr(147&mut self,148expr_arena: &mut Arena<AExpr>,149expr_node: Node,150_schema: &Schema,151ctx: OptimizeExprContext,152) -> PolarsResult<Option<AExpr>> {153let expr = expr_arena.get(expr_node);154155let out = match expr {156// true AND x => x157AExpr::BinaryExpr { left, op, right } => {158return Ok(arity::simplify_binary(*left, *op, *right, ctx, expr_arena));159},160AExpr::Ternary {161predicate,162truthy,163falsy,164} => {165return Ok(arity::simplify_ternary(166*predicate, *truthy, *falsy, expr_arena,167));168},169AExpr::Function {170input,171function: IRFunctionExpr::Negate,172..173} if input.len() == 1 => {174let input = &input[0];175let ae = expr_arena.get(input.node());176eval_negate(ae)177},178_ => None,179};180Ok(out)181}182}183184fn eval_negate(ae: &AExpr) -> Option<AExpr> {185use std::ops::Neg;186let out = match ae {187AExpr::Literal(lv) => match lv {188LiteralValue::Scalar(sc) => match sc.as_any_value() {189AnyValue::Int8(v) => Scalar::from(v.checked_neg()?),190AnyValue::Int16(v) => Scalar::from(v.checked_neg()?),191AnyValue::Int32(v) => Scalar::from(v.checked_neg()?),192AnyValue::Int64(v) => Scalar::from(v.checked_neg()?),193AnyValue::Int128(v) => Scalar::from(v.checked_neg()?),194AnyValue::Float16(v) => Scalar::from(v.neg()),195AnyValue::Float32(v) => Scalar::from(v.neg()),196AnyValue::Float64(v) => Scalar::from(v.neg()),197_ => return None,198}199.into(),200LiteralValue::Dyn(d) => LiteralValue::Dyn(match d {201DynLiteralValue::Int(v) => DynLiteralValue::Int(v.checked_neg()?),202DynLiteralValue::Float(v) => DynLiteralValue::Float(v.neg()),203_ => return None,204}),205_ => return None,206},207_ => return None,208};209Some(AExpr::Literal(out))210}211212fn eval_bitwise<F>(left: &AExpr, right: &AExpr, operation: F) -> Option<AExpr>213where214F: Fn(bool, bool) -> bool,215{216if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = (left, right) {217return match (lit_left.bool(), lit_right.bool()) {218(Some(x), Some(y)) => Some(AExpr::Literal(Scalar::from(operation(x, y)).into())),219_ => None,220};221}222None223}224225#[cfg(all(feature = "strings", feature = "concat_str"))]226fn string_addition_to_linear_concat(227expr_arena: &Arena<AExpr>,228left_node: Node,229right_node: Node,230left_aexpr: &AExpr,231right_aexpr: &AExpr,232input_schema: &Schema,233) -> Option<AExpr> {234{235let left_e = ExprIR::from_node(left_node, expr_arena);236let right_e = ExprIR::from_node(right_node, expr_arena);237238let get_type = |ae: &AExpr| {239ae.to_dtype(&ToFieldContext::new(expr_arena, input_schema))240.ok()241};242let type_a = get_type(left_aexpr).or_else(|| get_type(right_aexpr))?;243let type_b = get_type(right_aexpr).or_else(|| get_type(right_aexpr))?;244245if type_a != type_b {246return None;247}248249if type_a.is_string() {250match (left_aexpr, right_aexpr) {251// concat + concat252(253AExpr::Function {254input: input_left,255function:256fun_l @ IRFunctionExpr::StringExpr(IRStringFunction::ConcatHorizontal {257delimiter: sep_l,258ignore_nulls: ignore_nulls_l,259}),260options,261},262AExpr::Function {263input: input_right,264function:265IRFunctionExpr::StringExpr(IRStringFunction::ConcatHorizontal {266delimiter: sep_r,267ignore_nulls: ignore_nulls_r,268}),269..270},271) => {272if sep_l.is_empty() && sep_r.is_empty() && ignore_nulls_l == ignore_nulls_r {273let mut input = Vec::with_capacity(input_left.len() + input_right.len());274input.extend_from_slice(input_left);275input.extend_from_slice(input_right);276Some(AExpr::Function {277input,278function: fun_l.clone(),279options: *options,280})281} else {282None283}284},285// concat + str286(287AExpr::Function {288input,289function:290fun @ IRFunctionExpr::StringExpr(IRStringFunction::ConcatHorizontal {291delimiter: sep,292ignore_nulls,293}),294options,295},296_,297) => {298if sep.is_empty() && !ignore_nulls {299let mut input = input.clone();300input.push(right_e);301Some(AExpr::Function {302input,303function: fun.clone(),304options: *options,305})306} else {307None308}309},310// str + concat311(312_,313AExpr::Function {314input: input_right,315function:316fun @ IRFunctionExpr::StringExpr(IRStringFunction::ConcatHorizontal {317delimiter: sep,318ignore_nulls,319}),320options,321},322) => {323if sep.is_empty() && !ignore_nulls {324let mut input = Vec::with_capacity(1 + input_right.len());325input.push(left_e);326input.extend_from_slice(input_right);327Some(AExpr::Function {328input,329function: fun.clone(),330options: *options,331})332} else {333None334}335},336_ => {337let function = IRStringFunction::ConcatHorizontal {338delimiter: "".into(),339ignore_nulls: false,340};341let options = function.function_options();342Some(AExpr::Function {343input: vec![left_e, right_e],344function: function.into(),345options,346})347},348}349} else {350None351}352}353}354355pub struct SimplifyExprRule {}356357impl OptimizationRule for SimplifyExprRule {358#[allow(clippy::float_cmp)]359fn optimize_expr(360&mut self,361expr_arena: &mut Arena<AExpr>,362expr_node: Node,363schema: &Schema,364_ctx: OptimizeExprContext,365) -> PolarsResult<Option<AExpr>> {366let expr = expr_arena.get(expr_node);367368let out = match &expr {369AExpr::SortBy { expr, by, .. } if by.is_empty() => Some(expr_arena.get(*expr).clone()),370// drop_nulls().len() -> len() - null_count()371// drop_nulls().count() -> len() - null_count()372AExpr::Agg(IRAggExpr::Count {373input,374include_nulls: _,375}) => {376let input_expr = expr_arena.get(*input);377match input_expr {378AExpr::Function {379input,380function: IRFunctionExpr::DropNulls,381options: _,382} => {383// we should perform optimization only if the original expression is a column384// so in case of disabled CSE, we will not suffer from performance regression385if input.len() == 1 {386let drop_nulls_input_node = input[0].node();387match expr_arena.get(drop_nulls_input_node) {388AExpr::Column(_) => Some(AExpr::BinaryExpr {389op: Operator::Minus,390right: expr_arena.add(new_null_count(input)),391left: expr_arena.add(AExpr::Agg(IRAggExpr::Count {392input: drop_nulls_input_node,393include_nulls: true,394})),395}),396_ => None,397}398} else {399None400}401},402_ => None,403}404},405// is_null().sum() -> null_count()406// is_not_null().sum() -> len() - null_count()407AExpr::Agg(IRAggExpr::Sum(input)) => {408let input_expr = expr_arena.get(*input);409match input_expr {410AExpr::Function {411input,412function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNull),413options: _,414} => Some(new_null_count(input)),415AExpr::Function {416input,417function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNotNull),418options: _,419} => {420// we should perform optimization only if the original expression is a column421// so in case of disabled CSE, we will not suffer from performance regression422if input.len() == 1 {423let is_not_null_input_node = input[0].node();424match expr_arena.get(is_not_null_input_node) {425AExpr::Column(_) => Some(AExpr::BinaryExpr {426op: Operator::Minus,427right: expr_arena.add(new_null_count(input)),428left: expr_arena.add(AExpr::Agg(IRAggExpr::Count {429input: is_not_null_input_node,430include_nulls: true,431})),432}),433_ => None,434}435} else {436None437}438},439_ => None,440}441},442// lit(left) + lit(right) => lit(left + right)443// and null propagation444AExpr::BinaryExpr { left, op, right } => {445let left_aexpr = expr_arena.get(*left);446let right_aexpr = expr_arena.get(*right);447448// lit(left) + lit(right) => lit(left + right)449use Operator::*;450#[allow(clippy::manual_map)]451let out = match op {452Plus => {453match eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l + r) {454Some(new) => Some(new),455None => {456// try to replace addition of string columns with `concat_str`457#[cfg(all(feature = "strings", feature = "concat_str"))]458{459string_addition_to_linear_concat(460expr_arena,461*left,462*right,463left_aexpr,464right_aexpr,465schema,466)467}468#[cfg(not(all(feature = "strings", feature = "concat_str")))]469{470None471}472},473}474},475Minus => eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l - r),476Multiply => eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l * r),477RustDivide => {478if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) =479(left_aexpr, right_aexpr)480{481match (lit_left, lit_right) {482(LiteralValue::Scalar(l), LiteralValue::Scalar(r)) => {483match (l.as_any_value(), r.as_any_value()) {484(AnyValue::Float16(x), AnyValue::Float16(y)) => {485Some(AExpr::Literal(486<Scalar as From<pf16>>::from(x / y).into(),487))488},489(AnyValue::Float32(x), AnyValue::Float32(y)) => {490Some(AExpr::Literal(491<Scalar as From<f32>>::from(x / y).into(),492))493},494(AnyValue::Float64(x), AnyValue::Float64(y)) => {495Some(AExpr::Literal(496<Scalar as From<f64>>::from(x / y).into(),497))498},499500(AnyValue::Int8(x), AnyValue::Int8(y)) => {501Some(AExpr::Literal(502<Scalar as From<i8>>::from(503x.wrapping_floor_div_mod(y).0,504)505.into(),506))507},508(AnyValue::Int16(x), AnyValue::Int16(y)) => {509Some(AExpr::Literal(510<Scalar as From<i16>>::from(511x.wrapping_floor_div_mod(y).0,512)513.into(),514))515},516(AnyValue::Int32(x), AnyValue::Int32(y)) => {517Some(AExpr::Literal(518<Scalar as From<i32>>::from(519x.wrapping_floor_div_mod(y).0,520)521.into(),522))523},524(AnyValue::Int64(x), AnyValue::Int64(y)) => {525Some(AExpr::Literal(526<Scalar as From<i64>>::from(527x.wrapping_floor_div_mod(y).0,528)529.into(),530))531},532(AnyValue::Int128(x), AnyValue::Int128(y)) => {533Some(AExpr::Literal(534<Scalar as From<i128>>::from(535x.wrapping_floor_div_mod(y).0,536)537.into(),538))539},540541(AnyValue::UInt8(x), AnyValue::UInt8(y)) => {542Some(AExpr::Literal(543<Scalar as From<u8>>::from(x / y).into(),544))545},546(AnyValue::UInt16(x), AnyValue::UInt16(y)) => {547Some(AExpr::Literal(548<Scalar as From<u16>>::from(x / y).into(),549))550},551(AnyValue::UInt32(x), AnyValue::UInt32(y)) => {552Some(AExpr::Literal(553<Scalar as From<u32>>::from(x / y).into(),554))555},556(AnyValue::UInt64(x), AnyValue::UInt64(y)) => {557Some(AExpr::Literal(558<Scalar as From<u64>>::from(x / y).into(),559))560},561(AnyValue::UInt128(x), AnyValue::UInt128(y)) => {562Some(AExpr::Literal(563<Scalar as From<u128>>::from(x / y).into(),564))565},566567_ => None,568}569},570571(572LiteralValue::Dyn(DynLiteralValue::Float(x)),573LiteralValue::Dyn(DynLiteralValue::Float(y)),574) => {575Some(AExpr::Literal(<Scalar as From<f64>>::from(x / y).into()))576},577(578LiteralValue::Dyn(DynLiteralValue::Int(x)),579LiteralValue::Dyn(DynLiteralValue::Int(y)),580) => Some(AExpr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(581x.wrapping_floor_div_mod(*y).0,582)))),583_ => None,584}585} else {586None587}588},589TrueDivide => {590if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) =591(left_aexpr, right_aexpr)592{593match (lit_left, lit_right) {594(LiteralValue::Scalar(l), LiteralValue::Scalar(r)) => {595match (l.as_any_value(), r.as_any_value()) {596#[cfg(feature = "dtype-f16")]597(AnyValue::Float16(x), AnyValue::Float16(y)) => {598Some(AExpr::Literal(Scalar::from(x / y).into()))599},600(AnyValue::Float32(x), AnyValue::Float32(y)) => {601Some(AExpr::Literal(Scalar::from(x / y).into()))602},603(AnyValue::Float64(x), AnyValue::Float64(y)) => {604Some(AExpr::Literal(Scalar::from(x / y).into()))605},606607(AnyValue::Int8(x), AnyValue::Int8(y)) => {608Some(AExpr::Literal(609Scalar::from(x as f64 / y as f64).into(),610))611},612(AnyValue::Int16(x), AnyValue::Int16(y)) => {613Some(AExpr::Literal(614Scalar::from(x as f64 / y as f64).into(),615))616},617(AnyValue::Int32(x), AnyValue::Int32(y)) => {618Some(AExpr::Literal(619Scalar::from(x as f64 / y as f64).into(),620))621},622(AnyValue::Int64(x), AnyValue::Int64(y)) => {623Some(AExpr::Literal(624Scalar::from(x as f64 / y as f64).into(),625))626},627(AnyValue::Int128(x), AnyValue::Int128(y)) => {628Some(AExpr::Literal(629Scalar::from(x as f64 / y as f64).into(),630))631},632633(AnyValue::UInt8(x), AnyValue::UInt8(y)) => {634Some(AExpr::Literal(635Scalar::from(x as f64 / y as f64).into(),636))637},638(AnyValue::UInt16(x), AnyValue::UInt16(y)) => {639Some(AExpr::Literal(640Scalar::from(x as f64 / y as f64).into(),641))642},643(AnyValue::UInt32(x), AnyValue::UInt32(y)) => {644Some(AExpr::Literal(645Scalar::from(x as f64 / y as f64).into(),646))647},648(AnyValue::UInt64(x), AnyValue::UInt64(y)) => {649Some(AExpr::Literal(650Scalar::from(x as f64 / y as f64).into(),651))652},653654_ => None,655}656},657658(659LiteralValue::Dyn(DynLiteralValue::Float(x)),660LiteralValue::Dyn(DynLiteralValue::Float(y)),661) => Some(AExpr::Literal(Scalar::from(*x / *y).into())),662(663LiteralValue::Dyn(DynLiteralValue::Int(x)),664LiteralValue::Dyn(DynLiteralValue::Int(y)),665) => {666Some(AExpr::Literal(Scalar::from(*x as f64 / *y as f64).into()))667},668_ => None,669}670} else {671None672}673},674Modulus => eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| {675if r.is_zero() {676// TODO: this should optimize to `null` once we can express "is dynamic int but actually contains null"677return Ok(None);678}679680l.wrapping_floor_div_mod(r).1681}),682Lt => eval_binary_cmp_same_type!(left_aexpr, <, right_aexpr),683Gt => eval_binary_cmp_same_type!(left_aexpr, >, right_aexpr),684Eq | EqValidity => eval_binary_cmp_same_type!(left_aexpr, ==, right_aexpr),685NotEq | NotEqValidity => {686eval_binary_cmp_same_type!(left_aexpr, !=, right_aexpr)687},688GtEq => eval_binary_cmp_same_type!(left_aexpr, >=, right_aexpr),689LtEq => eval_binary_cmp_same_type!(left_aexpr, <=, right_aexpr),690And | LogicalAnd => eval_bitwise(left_aexpr, right_aexpr, |l, r| l & r),691Or | LogicalOr => eval_bitwise(left_aexpr, right_aexpr, |l, r| l | r),692Xor => eval_bitwise(left_aexpr, right_aexpr, |l, r| l ^ r),693FloorDivide => eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| {694if r.is_zero() {695// TODO: this should optimize to `null` once we can express "is dynamic int but actually contains null"696return Ok(None);697}698699l.wrapping_floor_div_mod(r).0700}),701};702if out.is_some() {703return Ok(out);704}705706None707},708AExpr::Function {709input,710function,711options,712..713} => {714return optimize_functions(input.clone(), function.clone(), *options, expr_arena);715},716_ => None,717};718Ok(out)719}720}721722723