Path: blob/main/crates/polars-plan/src/plans/optimizer/fused.rs
6940 views
use super::stack_opt::OptimizeExprContext;1use super::*;23pub struct FusedArithmetic {}45fn get_expr(input: &[Node], op: FusedOperator, expr_arena: &Arena<AExpr>) -> AExpr {6let input = input7.iter()8.copied()9.map(|n| ExprIR::from_node(n, expr_arena))10.collect();11let mut options =12FunctionOptions::elementwise().with_casting_rules(CastingRules::cast_to_supertypes());13// order of operations change because of FMA14// so we must toggle this check off15// it is still safe as it is a trusted operation16unsafe { options.no_check_lengths() }17AExpr::Function {18input,19function: IRFunctionExpr::Fused(op),20options,21}22}2324fn check_eligible(25left: &Node,26right: &Node,27expr_arena: &Arena<AExpr>,28schema: &Schema,29) -> PolarsResult<bool> {30let field_left = expr_arena.get(*left).to_field(schema, expr_arena)?;31let type_right = expr_arena.get(*right).get_dtype(schema, expr_arena)?;32let type_left = &field_left.dtype;33// Exclude literals for now as these will not benefit from fused operations downstream #985734// This optimization would also interfere with the `col -> lit` type-coercion rules35// And it might also interfere with constant folding which is a more suitable optimizations here36if type_left.is_primitive_numeric()37&& type_right.is_primitive_numeric()38&& !has_aexpr_literal(*left, expr_arena)39&& !has_aexpr_literal(*right, expr_arena)40{41Ok(true)42} else {43Ok(false)44}45}4647impl OptimizationRule for FusedArithmetic {48#[allow(clippy::float_cmp)]49fn optimize_expr(50&mut self,51expr_arena: &mut Arena<AExpr>,52expr_node: Node,53schema: &Schema,54ctx: OptimizeExprContext,55) -> PolarsResult<Option<AExpr>> {56// We don't want to fuse arithmetic that we send to pyarrow.57if ctx.in_pyarrow_scan || ctx.in_io_plugin {58return Ok(None);59}6061let expr = expr_arena.get(expr_node);6263use AExpr::*;64match expr {65BinaryExpr {66left,67op: Operator::Plus,68right,69} => {70// FUSED MULTIPLY ADD71// For fma the plus is always the out as the multiply takes prevalence72match expr_arena.get(*left) {73// Argument order is a + b * c74// so we must swap operands75//76// input77// (a * b) + c78// swapped as79// c + (a * b)80BinaryExpr {81left: a,82op: Operator::Multiply,83right: b,84} => Ok(check_eligible(left, right, expr_arena, schema)?.then(|| {85let input = &[*right, *a, *b];86get_expr(input, FusedOperator::MultiplyAdd, expr_arena)87})),88_ => match expr_arena.get(*right) {89// input90// (a + (b * c)91// kept as input92BinaryExpr {93left: a,94op: Operator::Multiply,95right: b,96} => Ok(check_eligible(left, right, expr_arena, schema)?.then(|| {97let input = &[*left, *a, *b];98get_expr(input, FusedOperator::MultiplyAdd, expr_arena)99})),100_ => Ok(None),101},102}103},104105BinaryExpr {106left,107op: Operator::Minus,108right,109} => {110// FUSED SUB MULTIPLY111match expr_arena.get(*right) {112// input113// (a - (b * c)114// kept as input115BinaryExpr {116left: a,117op: Operator::Multiply,118right: b,119} => Ok(check_eligible(left, right, expr_arena, schema)?.then(|| {120let input = &[*left, *a, *b];121get_expr(input, FusedOperator::SubMultiply, expr_arena)122})),123_ => {124// FUSED MULTIPLY SUB125match expr_arena.get(*left) {126// input127// (a * b) - c128// kept as input129BinaryExpr {130left: a,131op: Operator::Multiply,132right: b,133} => Ok(check_eligible(left, right, expr_arena, schema)?.then(|| {134let input = &[*a, *b, *right];135get_expr(input, FusedOperator::MultiplySub, expr_arena)136})),137_ => Ok(None),138}139},140}141},142_ => Ok(None),143}144}145}146147148