Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/optimizer/fused.rs
8431 views
1
use super::stack_opt::OptimizeExprContext;
2
use super::*;
3
4
pub struct FusedArithmetic {}
5
6
fn get_expr(input: &[Node], op: FusedOperator, expr_arena: &Arena<AExpr>) -> AExpr {
7
let input = input
8
.iter()
9
.copied()
10
.map(|n| ExprIR::from_node(n, expr_arena))
11
.collect();
12
let mut options =
13
FunctionOptions::elementwise().with_casting_rules(CastingRules::cast_to_supertypes());
14
// order of operations change because of FMA
15
// so we must toggle this check off
16
// it is still safe as it is a trusted operation
17
unsafe { options.no_check_lengths() }
18
AExpr::Function {
19
input,
20
function: IRFunctionExpr::Fused(op),
21
options,
22
}
23
}
24
25
fn check_eligible(
26
left: &Node,
27
right: &Node,
28
expr_arena: &Arena<AExpr>,
29
schema: &Schema,
30
) -> PolarsResult<bool> {
31
let field_left = expr_arena
32
.get(*left)
33
.to_field(&ToFieldContext::new(expr_arena, schema))?;
34
let type_right = expr_arena
35
.get(*right)
36
.to_dtype(&ToFieldContext::new(expr_arena, schema))?;
37
let type_left = &field_left.dtype;
38
// Exclude literals for now as these will not benefit from fused operations downstream #9857
39
// This optimization would also interfere with the `col -> lit` type-coercion rules
40
// And it might also interfere with constant folding which is a more suitable optimizations here
41
if type_left.is_primitive_numeric()
42
&& type_right.is_primitive_numeric()
43
&& !has_aexpr_literal(*left, expr_arena)
44
&& !has_aexpr_literal(*right, expr_arena)
45
{
46
Ok(true)
47
} else {
48
Ok(false)
49
}
50
}
51
52
impl OptimizationRule for FusedArithmetic {
53
#[allow(clippy::float_cmp)]
54
fn optimize_expr(
55
&mut self,
56
expr_arena: &mut Arena<AExpr>,
57
expr_node: Node,
58
schema: &Schema,
59
ctx: OptimizeExprContext,
60
) -> PolarsResult<Option<AExpr>> {
61
// We don't want to fuse arithmetic that we send to pyarrow.
62
if ctx.in_pyarrow_scan || ctx.in_io_plugin {
63
return Ok(None);
64
}
65
66
let expr = expr_arena.get(expr_node);
67
68
use AExpr::*;
69
match expr {
70
BinaryExpr {
71
left,
72
op: Operator::Plus,
73
right,
74
} => {
75
// FUSED MULTIPLY ADD
76
// For fma the plus is always the out as the multiply takes prevalence
77
match expr_arena.get(*left) {
78
// Argument order is a + b * c
79
// so we must swap operands
80
//
81
// input
82
// (a * b) + c
83
// swapped as
84
// c + (a * b)
85
BinaryExpr {
86
left: a,
87
op: Operator::Multiply,
88
right: b,
89
} => Ok(check_eligible(left, right, expr_arena, schema)?.then(|| {
90
let input = &[*right, *a, *b];
91
get_expr(input, FusedOperator::MultiplyAdd, expr_arena)
92
})),
93
_ => match expr_arena.get(*right) {
94
// input
95
// (a + (b * c)
96
// kept as input
97
BinaryExpr {
98
left: a,
99
op: Operator::Multiply,
100
right: b,
101
} => Ok(check_eligible(left, right, expr_arena, schema)?.then(|| {
102
let input = &[*left, *a, *b];
103
get_expr(input, FusedOperator::MultiplyAdd, expr_arena)
104
})),
105
_ => Ok(None),
106
},
107
}
108
},
109
110
BinaryExpr {
111
left,
112
op: Operator::Minus,
113
right,
114
} => {
115
// FUSED SUB MULTIPLY
116
match expr_arena.get(*right) {
117
// input
118
// (a - (b * c)
119
// kept as input
120
BinaryExpr {
121
left: a,
122
op: Operator::Multiply,
123
right: b,
124
} => Ok(check_eligible(left, right, expr_arena, schema)?.then(|| {
125
let input = &[*left, *a, *b];
126
get_expr(input, FusedOperator::SubMultiply, expr_arena)
127
})),
128
_ => {
129
// FUSED MULTIPLY SUB
130
match expr_arena.get(*left) {
131
// input
132
// (a * b) - c
133
// kept as input
134
BinaryExpr {
135
left: a,
136
op: Operator::Multiply,
137
right: b,
138
} => Ok(check_eligible(left, right, expr_arena, schema)?.then(|| {
139
let input = &[*a, *b, *right];
140
get_expr(input, FusedOperator::MultiplySub, expr_arena)
141
})),
142
_ => Ok(None),
143
}
144
},
145
}
146
},
147
_ => Ok(None),
148
}
149
}
150
}
151
152