Path: blob/main/crates/polars-plan/src/plans/optimizer/fused.rs
8431 views
use super::stack_opt::OptimizeExprContext;1use super::*;23pub struct FusedArithmetic {}45fn get_expr(input: &[Node], op: FusedOperator, expr_arena: &Arena<AExpr>) -> AExpr {6let input = input7.iter()8.copied()9.map(|n| ExprIR::from_node(n, expr_arena))10.collect();11let mut options =12FunctionOptions::elementwise().with_casting_rules(CastingRules::cast_to_supertypes());13// order of operations change because of FMA14// so we must toggle this check off15// it is still safe as it is a trusted operation16unsafe { options.no_check_lengths() }17AExpr::Function {18input,19function: IRFunctionExpr::Fused(op),20options,21}22}2324fn check_eligible(25left: &Node,26right: &Node,27expr_arena: &Arena<AExpr>,28schema: &Schema,29) -> PolarsResult<bool> {30let field_left = expr_arena31.get(*left)32.to_field(&ToFieldContext::new(expr_arena, schema))?;33let type_right = expr_arena34.get(*right)35.to_dtype(&ToFieldContext::new(expr_arena, schema))?;36let type_left = &field_left.dtype;37// Exclude literals for now as these will not benefit from fused operations downstream #985738// This optimization would also interfere with the `col -> lit` type-coercion rules39// And it might also interfere with constant folding which is a more suitable optimizations here40if type_left.is_primitive_numeric()41&& type_right.is_primitive_numeric()42&& !has_aexpr_literal(*left, expr_arena)43&& !has_aexpr_literal(*right, expr_arena)44{45Ok(true)46} else {47Ok(false)48}49}5051impl OptimizationRule for FusedArithmetic {52#[allow(clippy::float_cmp)]53fn optimize_expr(54&mut self,55expr_arena: &mut Arena<AExpr>,56expr_node: Node,57schema: &Schema,58ctx: OptimizeExprContext,59) -> PolarsResult<Option<AExpr>> {60// We don't want to fuse arithmetic that we send to pyarrow.61if ctx.in_pyarrow_scan || ctx.in_io_plugin {62return Ok(None);63}6465let expr = expr_arena.get(expr_node);6667use AExpr::*;68match expr {69BinaryExpr {70left,71op: Operator::Plus,72right,73} => {74// FUSED MULTIPLY ADD75// For fma the plus is always the out as the multiply takes prevalence76match expr_arena.get(*left) {77// Argument order is a + b * c78// so we must swap operands79//80// input81// (a * b) + c82// swapped as83// c + (a * b)84BinaryExpr {85left: a,86op: Operator::Multiply,87right: b,88} => Ok(check_eligible(left, right, expr_arena, schema)?.then(|| {89let input = &[*right, *a, *b];90get_expr(input, FusedOperator::MultiplyAdd, expr_arena)91})),92_ => match expr_arena.get(*right) {93// input94// (a + (b * c)95// kept as input96BinaryExpr {97left: a,98op: Operator::Multiply,99right: b,100} => Ok(check_eligible(left, right, expr_arena, schema)?.then(|| {101let input = &[*left, *a, *b];102get_expr(input, FusedOperator::MultiplyAdd, expr_arena)103})),104_ => Ok(None),105},106}107},108109BinaryExpr {110left,111op: Operator::Minus,112right,113} => {114// FUSED SUB MULTIPLY115match expr_arena.get(*right) {116// input117// (a - (b * c)118// kept as input119BinaryExpr {120left: a,121op: Operator::Multiply,122right: b,123} => Ok(check_eligible(left, right, expr_arena, schema)?.then(|| {124let input = &[*left, *a, *b];125get_expr(input, FusedOperator::SubMultiply, expr_arena)126})),127_ => {128// FUSED MULTIPLY SUB129match expr_arena.get(*left) {130// input131// (a * b) - c132// kept as input133BinaryExpr {134left: a,135op: Operator::Multiply,136right: b,137} => Ok(check_eligible(left, right, expr_arena, schema)?.then(|| {138let input = &[*a, *b, *right];139get_expr(input, FusedOperator::MultiplySub, expr_arena)140})),141_ => Ok(None),142}143},144}145},146_ => Ok(None),147}148}149}150151152