Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/optimizer/fused.rs
6940 views
1
use super::stack_opt::OptimizeExprContext;
2
use super::*;
3
4
pub struct FusedArithmetic {}
5
6
fn get_expr(input: &[Node], op: FusedOperator, expr_arena: &Arena<AExpr>) -> AExpr {
7
let input = input
8
.iter()
9
.copied()
10
.map(|n| ExprIR::from_node(n, expr_arena))
11
.collect();
12
let mut options =
13
FunctionOptions::elementwise().with_casting_rules(CastingRules::cast_to_supertypes());
14
// order of operations change because of FMA
15
// so we must toggle this check off
16
// it is still safe as it is a trusted operation
17
unsafe { options.no_check_lengths() }
18
AExpr::Function {
19
input,
20
function: IRFunctionExpr::Fused(op),
21
options,
22
}
23
}
24
25
fn check_eligible(
26
left: &Node,
27
right: &Node,
28
expr_arena: &Arena<AExpr>,
29
schema: &Schema,
30
) -> PolarsResult<bool> {
31
let field_left = expr_arena.get(*left).to_field(schema, expr_arena)?;
32
let type_right = expr_arena.get(*right).get_dtype(schema, expr_arena)?;
33
let type_left = &field_left.dtype;
34
// Exclude literals for now as these will not benefit from fused operations downstream #9857
35
// This optimization would also interfere with the `col -> lit` type-coercion rules
36
// And it might also interfere with constant folding which is a more suitable optimizations here
37
if type_left.is_primitive_numeric()
38
&& type_right.is_primitive_numeric()
39
&& !has_aexpr_literal(*left, expr_arena)
40
&& !has_aexpr_literal(*right, expr_arena)
41
{
42
Ok(true)
43
} else {
44
Ok(false)
45
}
46
}
47
48
impl OptimizationRule for FusedArithmetic {
49
#[allow(clippy::float_cmp)]
50
fn optimize_expr(
51
&mut self,
52
expr_arena: &mut Arena<AExpr>,
53
expr_node: Node,
54
schema: &Schema,
55
ctx: OptimizeExprContext,
56
) -> PolarsResult<Option<AExpr>> {
57
// We don't want to fuse arithmetic that we send to pyarrow.
58
if ctx.in_pyarrow_scan || ctx.in_io_plugin {
59
return Ok(None);
60
}
61
62
let expr = expr_arena.get(expr_node);
63
64
use AExpr::*;
65
match expr {
66
BinaryExpr {
67
left,
68
op: Operator::Plus,
69
right,
70
} => {
71
// FUSED MULTIPLY ADD
72
// For fma the plus is always the out as the multiply takes prevalence
73
match expr_arena.get(*left) {
74
// Argument order is a + b * c
75
// so we must swap operands
76
//
77
// input
78
// (a * b) + c
79
// swapped as
80
// c + (a * b)
81
BinaryExpr {
82
left: a,
83
op: Operator::Multiply,
84
right: b,
85
} => Ok(check_eligible(left, right, expr_arena, schema)?.then(|| {
86
let input = &[*right, *a, *b];
87
get_expr(input, FusedOperator::MultiplyAdd, expr_arena)
88
})),
89
_ => match expr_arena.get(*right) {
90
// input
91
// (a + (b * c)
92
// kept as input
93
BinaryExpr {
94
left: a,
95
op: Operator::Multiply,
96
right: b,
97
} => Ok(check_eligible(left, right, expr_arena, schema)?.then(|| {
98
let input = &[*left, *a, *b];
99
get_expr(input, FusedOperator::MultiplyAdd, expr_arena)
100
})),
101
_ => Ok(None),
102
},
103
}
104
},
105
106
BinaryExpr {
107
left,
108
op: Operator::Minus,
109
right,
110
} => {
111
// FUSED SUB MULTIPLY
112
match expr_arena.get(*right) {
113
// input
114
// (a - (b * c)
115
// kept as input
116
BinaryExpr {
117
left: a,
118
op: Operator::Multiply,
119
right: b,
120
} => Ok(check_eligible(left, right, expr_arena, schema)?.then(|| {
121
let input = &[*left, *a, *b];
122
get_expr(input, FusedOperator::SubMultiply, expr_arena)
123
})),
124
_ => {
125
// FUSED MULTIPLY SUB
126
match expr_arena.get(*left) {
127
// input
128
// (a * b) - c
129
// kept as input
130
BinaryExpr {
131
left: a,
132
op: Operator::Multiply,
133
right: b,
134
} => Ok(check_eligible(left, right, expr_arena, schema)?.then(|| {
135
let input = &[*a, *b, *right];
136
get_expr(input, FusedOperator::MultiplySub, expr_arena)
137
})),
138
_ => Ok(None),
139
}
140
},
141
}
142
},
143
_ => Ok(None),
144
}
145
}
146
}
147
148