Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/optimizer/count_star.rs
6940 views
1
use polars_io::cloud::CloudOptions;
2
use polars_utils::mmap::MemSlice;
3
use polars_utils::plpath::PlPath;
4
5
use super::*;
6
7
pub(super) struct CountStar;
8
9
impl CountStar {
10
pub(super) fn new() -> Self {
11
Self
12
}
13
}
14
15
impl CountStar {
16
// Replace select count(*) from datasource with specialized map function.
17
pub(super) fn optimize_plan(
18
&mut self,
19
lp_arena: &mut Arena<IR>,
20
expr_arena: &mut Arena<AExpr>,
21
mut node: Node,
22
) -> PolarsResult<Option<IR>> {
23
// New-streaming always puts a sink on top.
24
if let IR::Sink { input, .. } = lp_arena.get(node) {
25
node = *input;
26
}
27
28
// Note: This will be a useful flag later for testing parallel CountLines on CSV.
29
let use_fast_file_count = match std::env::var("POLARS_FAST_FILE_COUNT_DISPATCH").as_deref()
30
{
31
Ok("1") => Some(true),
32
Ok("0") => Some(false),
33
Ok(v) => panic!("POLARS_FAST_FILE_COUNT_DISPATCH must be one of ('0', '1'), got: {v}"),
34
Err(_) => None,
35
};
36
37
Ok(visit_logical_plan_for_scan_paths(
38
node,
39
lp_arena,
40
expr_arena,
41
false,
42
use_fast_file_count,
43
)
44
.map(|count_star_expr| {
45
// MapFunction needs a leaf node, hence we create a dummy placeholder node
46
let placeholder = IR::DataFrameScan {
47
df: Arc::new(Default::default()),
48
schema: Arc::new(Default::default()),
49
output_schema: None,
50
};
51
let placeholder_node = lp_arena.add(placeholder);
52
53
let alp = IR::MapFunction {
54
input: placeholder_node,
55
function: FunctionIR::FastCount {
56
sources: count_star_expr.sources,
57
scan_type: count_star_expr.scan_type,
58
cloud_options: count_star_expr.cloud_options,
59
alias: count_star_expr.alias,
60
},
61
};
62
63
lp_arena.replace(count_star_expr.node, alp.clone());
64
alp
65
}))
66
}
67
}
68
69
struct CountStarExpr {
70
// Top node of the projection to replace
71
node: Node,
72
// Paths to the input files
73
sources: ScanSources,
74
cloud_options: Option<CloudOptions>,
75
// File Type
76
scan_type: Box<FileScanIR>,
77
// Column Alias
78
alias: Option<PlSmallStr>,
79
}
80
81
// Visit the logical plan and return CountStarExpr with the expr information gathered
82
// Return None if query is not a simple COUNT(*) FROM SOURCE
83
fn visit_logical_plan_for_scan_paths(
84
node: Node,
85
lp_arena: &Arena<IR>,
86
expr_arena: &Arena<AExpr>,
87
inside_union: bool, // Inside union's we do not check for COUNT(*) expression
88
use_fast_file_count: Option<bool>, // Overrides if Some
89
) -> Option<CountStarExpr> {
90
match lp_arena.get(node) {
91
IR::Union { inputs, .. } => {
92
enum MutableSources {
93
Addresses(Vec<PlPath>),
94
Buffers(Vec<MemSlice>),
95
}
96
97
let mut scan_type: Option<Box<FileScanIR>> = None;
98
let mut cloud_options = None;
99
let mut sources = None;
100
101
for input in inputs {
102
match visit_logical_plan_for_scan_paths(
103
*input,
104
lp_arena,
105
expr_arena,
106
true,
107
use_fast_file_count,
108
) {
109
Some(expr) => {
110
match (expr.sources, &mut sources) {
111
(
112
ScanSources::Paths(addrs),
113
Some(MutableSources::Addresses(mutable_addrs)),
114
) => mutable_addrs.extend_from_slice(&addrs[..]),
115
(ScanSources::Paths(addrs), None) => {
116
sources = Some(MutableSources::Addresses(addrs.to_vec()))
117
},
118
(
119
ScanSources::Buffers(buffers),
120
Some(MutableSources::Buffers(mutable_buffers)),
121
) => mutable_buffers.extend_from_slice(&buffers[..]),
122
(ScanSources::Buffers(buffers), None) => {
123
sources = Some(MutableSources::Buffers(buffers.to_vec()))
124
},
125
_ => return None,
126
}
127
128
// Take the first Some(_) cloud option
129
// TODO: Should check the cloud types are the same.
130
cloud_options = cloud_options.or(expr.cloud_options);
131
132
match &scan_type {
133
None => scan_type = Some(expr.scan_type),
134
Some(scan_type) => {
135
// All scans must be of the same type (e.g. csv / parquet)
136
if std::mem::discriminant(&**scan_type)
137
!= std::mem::discriminant(&*expr.scan_type)
138
{
139
return None;
140
}
141
},
142
};
143
},
144
None => return None,
145
}
146
}
147
Some(CountStarExpr {
148
sources: match sources {
149
Some(MutableSources::Addresses(addrs)) => ScanSources::Paths(addrs.into()),
150
Some(MutableSources::Buffers(buffers)) => ScanSources::Buffers(buffers.into()),
151
None => ScanSources::default(),
152
},
153
scan_type: scan_type.unwrap(),
154
cloud_options,
155
node,
156
alias: None,
157
})
158
},
159
IR::Scan {
160
scan_type,
161
sources,
162
unified_scan_args,
163
..
164
} => {
165
// New-streaming is generally on par for all except CSV (see https://github.com/pola-rs/polars/pull/22363).
166
// In the future we can potentially remove the dedicated count codepaths.
167
168
let use_fast_file_count = use_fast_file_count.unwrap_or(match scan_type.as_ref() {
169
#[cfg(feature = "csv")]
170
FileScanIR::Csv { .. } => true,
171
_ => false,
172
});
173
174
if use_fast_file_count {
175
Some(CountStarExpr {
176
sources: sources.clone(),
177
scan_type: scan_type.clone(),
178
cloud_options: unified_scan_args.cloud_options.clone(),
179
node,
180
alias: None,
181
})
182
} else {
183
None
184
}
185
},
186
// A union can insert a simple projection to ensure all projections align.
187
// We can ignore that if we are inside a count star.
188
IR::SimpleProjection { input, .. } if inside_union => visit_logical_plan_for_scan_paths(
189
*input,
190
lp_arena,
191
expr_arena,
192
false,
193
use_fast_file_count,
194
),
195
IR::Select { input, expr, .. } => {
196
if expr.len() == 1 {
197
let (valid, alias) = is_valid_count_expr(&expr[0], expr_arena);
198
if valid || inside_union {
199
return visit_logical_plan_for_scan_paths(
200
*input,
201
lp_arena,
202
expr_arena,
203
false,
204
use_fast_file_count,
205
)
206
.map(|mut expr| {
207
expr.alias = alias;
208
expr.node = node;
209
expr
210
});
211
}
212
}
213
None
214
},
215
_ => None,
216
}
217
}
218
219
fn is_valid_count_expr(e: &ExprIR, expr_arena: &Arena<AExpr>) -> (bool, Option<PlSmallStr>) {
220
match expr_arena.get(e.node()) {
221
AExpr::Len => (true, e.get_alias().cloned()),
222
_ => (false, None),
223
}
224
}
225
226