Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/optimizer/cluster_with_columns.rs
8458 views
1
use std::sync::Arc;
2
3
use polars_core::prelude::{PlHashSet, PlIndexMap};
4
use polars_utils::aliases::InitHashMaps;
5
use polars_utils::arena::{Arena, Node};
6
7
use super::aexpr::AExpr;
8
use super::ir::IR;
9
use super::{PlSmallStr, aexpr_to_leaf_names_iter};
10
use crate::plans::ExprIR;
11
12
pub fn optimize(root: Node, lp_arena: &mut Arena<IR>, expr_arena: &Arena<AExpr>) {
13
let mut ir_stack = Vec::with_capacity(16);
14
ir_stack.push(root);
15
16
// key: output_name, value: (expr, is_original)
17
let mut input_name_to_expr_map: PlIndexMap<PlSmallStr, (ExprIR, bool)> = PlIndexMap::new();
18
let mut input_names_accessed_by_non_candidates: PlHashSet<PlSmallStr> = PlHashSet::new();
19
let mut push_candidate_idxs: Vec<usize> = vec![];
20
let mut new_current_exprs: Vec<ExprIR> = vec![];
21
let mut visited_caches = PlHashSet::new();
22
23
while let Some(current_node) = ir_stack.pop() {
24
let current_ir = lp_arena.get(current_node);
25
26
if let IR::Cache { id, .. } = current_ir {
27
if !visited_caches.insert(*id) {
28
continue;
29
}
30
}
31
32
current_ir.copy_inputs(&mut ir_stack);
33
34
let IR::HStack { input, .. } = current_ir else {
35
continue;
36
};
37
38
let input_node = *input;
39
40
let [current_ir, input_ir] = lp_arena.get_disjoint_mut([current_node, input_node]);
41
42
let IR::HStack {
43
input: _,
44
exprs: current_exprs,
45
schema: current_schema,
46
options: _,
47
} = current_ir
48
else {
49
unreachable!();
50
};
51
52
let IR::HStack {
53
input: _,
54
exprs: input_exprs,
55
schema: input_schema,
56
options: _,
57
} = input_ir
58
else {
59
continue;
60
};
61
62
input_name_to_expr_map.clear();
63
input_names_accessed_by_non_candidates.clear();
64
push_candidate_idxs.clear();
65
new_current_exprs.clear();
66
67
input_name_to_expr_map.extend(
68
input_exprs
69
.iter()
70
.map(|e| (e.output_name().clone(), (e.clone(), true))),
71
);
72
73
if input_name_to_expr_map.len() != input_exprs.len() {
74
if cfg!(debug_assertions) {
75
panic!()
76
};
77
78
continue;
79
}
80
81
for (i, e) in current_exprs.iter().enumerate() {
82
// Ignore col()
83
if let AExpr::Column(name) = expr_arena.get(e.node())
84
&& name == e.output_name()
85
{
86
continue;
87
}
88
89
if aexpr_to_leaf_names_iter(e.node(), expr_arena)
90
.all(|name| !input_name_to_expr_map.contains_key(name))
91
{
92
push_candidate_idxs.push(i);
93
}
94
}
95
96
let mut candidate_idx: usize = 0;
97
98
for (i, e) in current_exprs.iter().enumerate() {
99
if push_candidate_idxs.get(candidate_idx) == Some(&i) {
100
candidate_idx += 1;
101
continue;
102
}
103
104
for name in aexpr_to_leaf_names_iter(e.node(), expr_arena) {
105
input_names_accessed_by_non_candidates.insert(name.clone());
106
}
107
}
108
109
push_candidate_idxs.retain(|&i| {
110
let e = &current_exprs[i];
111
!input_names_accessed_by_non_candidates.contains(e.output_name())
112
});
113
114
let mut candidate_idx: usize = 0;
115
116
for (i, e) in current_exprs.iter().enumerate() {
117
// Prune col()
118
if let AExpr::Column(name) = expr_arena.get(e.node())
119
&& name == e.output_name()
120
{
121
continue;
122
}
123
124
if push_candidate_idxs.get(candidate_idx) == Some(&i) {
125
candidate_idx += 1;
126
input_name_to_expr_map.insert(e.output_name().clone(), (e.clone(), false));
127
continue;
128
}
129
130
new_current_exprs.push(e.clone());
131
}
132
133
if new_current_exprs.len() == current_exprs.len() {
134
continue;
135
}
136
137
input_exprs.clear();
138
139
for (output_name, (e, is_original)) in input_name_to_expr_map
140
.iter()
141
.map(|x| (x.0.clone(), x.1.clone()))
142
{
143
input_exprs.push(e);
144
145
if !is_original {
146
let dtype = current_schema.get(&output_name).unwrap().clone();
147
Arc::make_mut(input_schema).insert(output_name, dtype);
148
}
149
}
150
151
if new_current_exprs.is_empty() {
152
let input_ir = input_ir.clone();
153
lp_arena.replace(current_node, input_ir);
154
*ir_stack.last_mut().unwrap() = current_node;
155
continue;
156
}
157
158
let fix_output_order = current_exprs.iter().any(|e| {
159
input_schema
160
.index_of(e.output_name())
161
.is_some_and(|i| i != current_schema.index_of(e.output_name()).unwrap())
162
});
163
164
current_exprs.clear();
165
std::mem::swap(current_exprs, &mut new_current_exprs);
166
167
if fix_output_order {
168
let projection = current_schema.clone();
169
170
Arc::make_mut(current_schema)
171
.sort_by_key(|name, _| input_schema.index_of(name).unwrap_or(usize::MAX));
172
173
let current_ir = lp_arena.replace(current_node, IR::Invalid);
174
let moved_current_node = lp_arena.add(current_ir);
175
lp_arena.replace(
176
current_node,
177
IR::SimpleProjection {
178
input: moved_current_node,
179
columns: projection,
180
},
181
);
182
}
183
}
184
}
185
186