Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/ir/dot.rs
6940 views
1
use std::fmt;
2
use std::path::PathBuf;
3
4
use polars_core::prelude::{InitHashMaps, PlHashSet};
5
use polars_core::schema::Schema;
6
use polars_utils::pl_str::PlSmallStr;
7
use polars_utils::unique_id::UniqueId;
8
use recursive::recursive;
9
10
use super::format::ExprIRSliceDisplay;
11
use crate::prelude::ir::format::ColumnsDisplay;
12
use crate::prelude::*;
13
14
pub struct IRDotDisplay<'a> {
15
lp: IRPlanRef<'a>,
16
}
17
18
const INDENT: &str = " ";
19
20
#[derive(Clone, Copy)]
21
enum DotNode {
22
Plain(usize),
23
Cache(UniqueId),
24
}
25
26
impl fmt::Display for DotNode {
27
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28
match self {
29
DotNode::Plain(n) => write!(f, "p{n}"),
30
DotNode::Cache(n) => write!(f, "\"{n}\""),
31
}
32
}
33
}
34
35
#[inline(always)]
36
fn write_label<'a, 'b>(
37
f: &'a mut fmt::Formatter<'b>,
38
id: DotNode,
39
mut w: impl FnMut(&mut EscapeLabel<'a>) -> fmt::Result,
40
) -> fmt::Result {
41
write!(f, "{INDENT}{id}[label=\"")?;
42
43
let mut escaped = EscapeLabel(f);
44
w(&mut escaped)?;
45
let EscapeLabel(f) = escaped;
46
47
writeln!(f, "\"]")?;
48
49
Ok(())
50
}
51
52
impl<'a> IRDotDisplay<'a> {
53
pub fn new(lp: IRPlanRef<'a>) -> Self {
54
Self { lp }
55
}
56
57
fn with_root(&self, root: Node) -> Self {
58
Self {
59
lp: self.lp.with_root(root),
60
}
61
}
62
63
fn display_expr(&self, expr: &'a ExprIR) -> ExprIRDisplay<'a> {
64
expr.display(self.lp.expr_arena)
65
}
66
67
fn display_exprs(&self, exprs: &'a [ExprIR]) -> ExprIRSliceDisplay<'a, ExprIR> {
68
ExprIRSliceDisplay {
69
exprs,
70
expr_arena: self.lp.expr_arena,
71
}
72
}
73
74
#[recursive]
75
fn _format(
76
&self,
77
f: &mut fmt::Formatter<'_>,
78
parent: Option<DotNode>,
79
last: &mut usize,
80
visited_caches: &mut PlHashSet<UniqueId>,
81
) -> std::fmt::Result {
82
use fmt::Write;
83
84
let root = self.lp.root();
85
let id = if let IR::Cache { id, .. } = root {
86
DotNode::Cache(*id)
87
} else {
88
*last += 1;
89
DotNode::Plain(*last)
90
};
91
92
if let Some(parent) = parent {
93
writeln!(f, "{INDENT}{id} -> {parent}")?;
94
}
95
96
macro_rules! recurse {
97
($input:expr) => {
98
self.with_root($input)
99
._format(f, Some(id), last, visited_caches)?;
100
};
101
}
102
103
use IR::*;
104
match root {
105
Union { inputs, .. } => {
106
for input in inputs {
107
recurse!(*input);
108
}
109
110
write_label(f, id, |f| f.write_str("UNION"))?;
111
},
112
HConcat { inputs, .. } => {
113
for input in inputs {
114
recurse!(*input);
115
}
116
117
write_label(f, id, |f| f.write_str("HCONCAT"))?;
118
},
119
Cache {
120
input,
121
id: cache_id,
122
..
123
} => {
124
if !visited_caches.contains(cache_id) {
125
visited_caches.insert(*cache_id);
126
127
recurse!(*input);
128
129
write_label(f, id, |f| f.write_str("CACHE"))?;
130
}
131
},
132
Filter { predicate, input } => {
133
recurse!(*input);
134
135
let pred = self.display_expr(predicate);
136
write_label(f, id, |f| write!(f, "FILTER BY {pred}"))?;
137
},
138
#[cfg(feature = "python")]
139
PythonScan { options } => {
140
let predicate = match &options.predicate {
141
PythonPredicate::Polars(e) => format!("{}", self.display_expr(e)),
142
PythonPredicate::PyArrow(s) => s.clone(),
143
PythonPredicate::None => "none".to_string(),
144
};
145
let with_columns = NumColumns(options.with_columns.as_ref().map(|s| s.as_ref()));
146
let total_columns = options.schema.len();
147
148
write_label(f, id, |f| {
149
write!(
150
f,
151
"PYTHON SCAN\nπ {with_columns}/{total_columns};\nσ {predicate}"
152
)
153
})?
154
},
155
Select {
156
expr,
157
input,
158
schema,
159
..
160
} => {
161
recurse!(*input);
162
write_label(f, id, |f| write!(f, "π {}/{}", expr.len(), schema.len()))?;
163
},
164
Sort {
165
input, by_column, ..
166
} => {
167
let by_column = self.display_exprs(by_column);
168
recurse!(*input);
169
write_label(f, id, |f| write!(f, "SORT BY {by_column}"))?;
170
},
171
GroupBy {
172
input, keys, aggs, ..
173
} => {
174
let keys = self.display_exprs(keys);
175
let aggs = self.display_exprs(aggs);
176
recurse!(*input);
177
write_label(f, id, |f| write!(f, "AGG {aggs}\nBY\n{keys}"))?;
178
},
179
HStack { input, exprs, .. } => {
180
let exprs = self.display_exprs(exprs);
181
recurse!(*input);
182
write_label(f, id, |f| write!(f, "WITH COLUMNS {exprs}"))?;
183
},
184
Slice { input, offset, len } => {
185
recurse!(*input);
186
write_label(f, id, |f| write!(f, "SLICE offset: {offset}; len: {len}"))?;
187
},
188
Distinct { input, options, .. } => {
189
recurse!(*input);
190
write_label(f, id, |f| {
191
f.write_str("DISTINCT")?;
192
193
if let Some(subset) = &options.subset {
194
f.write_str(" BY ")?;
195
196
let mut subset = subset.iter();
197
198
if let Some(fst) = subset.next() {
199
f.write_str(fst)?;
200
for name in subset {
201
write!(f, ", \"{name}\"")?;
202
}
203
} else {
204
f.write_str("None")?;
205
}
206
}
207
208
Ok(())
209
})?;
210
},
211
DataFrameScan {
212
schema,
213
output_schema,
214
..
215
} => {
216
let num_columns = NumColumnsSchema(output_schema.as_ref().map(|p| p.as_ref()));
217
let total_columns = schema.len();
218
219
write_label(f, id, |f| {
220
write!(f, "TABLE\nπ {num_columns}/{total_columns}")
221
})?;
222
},
223
Scan {
224
sources,
225
file_info,
226
hive_parts: _,
227
predicate,
228
scan_type,
229
unified_scan_args,
230
output_schema: _,
231
} => {
232
let name: &str = (&**scan_type).into();
233
let path = ScanSourcesDisplay(sources);
234
let with_columns = unified_scan_args
235
.projection
236
.as_ref()
237
.map(|cols| cols.as_ref());
238
let with_columns = NumColumns(with_columns);
239
let total_columns =
240
file_info.schema.len() - usize::from(unified_scan_args.row_index.is_some());
241
242
write_label(f, id, |f| {
243
write!(f, "{name} SCAN {path}\nπ {with_columns}/{total_columns};",)?;
244
245
if let Some(predicate) = predicate.as_ref() {
246
write!(f, "\nσ {}", self.display_expr(predicate))?;
247
}
248
249
if let Some(row_index) = unified_scan_args.row_index.as_ref() {
250
write!(f, "\nrow index: {} (+{})", row_index.name, row_index.offset)?;
251
}
252
253
Ok(())
254
})?;
255
},
256
Join {
257
input_left,
258
input_right,
259
left_on,
260
right_on,
261
options,
262
..
263
} => {
264
recurse!(*input_left);
265
recurse!(*input_right);
266
267
write_label(f, id, |f| {
268
write!(f, "JOIN {}", options.args.how)?;
269
270
if !left_on.is_empty() {
271
let left_on = self.display_exprs(left_on);
272
let right_on = self.display_exprs(right_on);
273
write!(f, "\nleft: {left_on};\nright: {right_on}")?
274
}
275
Ok(())
276
})?;
277
},
278
MapFunction {
279
input, function, ..
280
} => {
281
recurse!(*input);
282
write_label(f, id, |f| write!(f, "{function}"))?;
283
},
284
ExtContext { input, .. } => {
285
recurse!(*input);
286
write_label(f, id, |f| f.write_str("EXTERNAL_CONTEXT"))?;
287
},
288
Sink { input, payload, .. } => {
289
recurse!(*input);
290
291
write_label(f, id, |f| {
292
f.write_str(match payload {
293
SinkTypeIR::Memory => "SINK (MEMORY)",
294
SinkTypeIR::File { .. } => "SINK (FILE)",
295
SinkTypeIR::Partition { .. } => "SINK (PARTITION)",
296
})
297
})?;
298
},
299
SinkMultiple { inputs } => {
300
for input in inputs {
301
recurse!(*input);
302
}
303
304
write_label(f, id, |f| f.write_str("SINK MULTIPLE"))?;
305
},
306
SimpleProjection { input, columns } => {
307
let num_columns = columns.as_ref().len();
308
let total_columns = self.lp.lp_arena.get(*input).schema(self.lp.lp_arena).len();
309
310
let columns = ColumnsDisplay(columns.as_ref());
311
recurse!(*input);
312
write_label(f, id, |f| {
313
write!(f, "simple π {num_columns}/{total_columns}\n[{columns}]")
314
})?;
315
},
316
#[cfg(feature = "merge_sorted")]
317
MergeSorted {
318
input_left,
319
input_right,
320
key,
321
} => {
322
recurse!(*input_left);
323
recurse!(*input_right);
324
325
write_label(f, id, |f| write!(f, "MERGE_SORTED ON '{key}'",))?;
326
},
327
Invalid => write_label(f, id, |f| f.write_str("INVALID"))?,
328
}
329
330
Ok(())
331
}
332
}
333
334
// A few utility structures for formatting
335
pub struct PathsDisplay<'a>(pub &'a [PathBuf]);
336
pub struct ScanSourcesDisplay<'a>(pub &'a ScanSources);
337
struct NumColumns<'a>(Option<&'a [PlSmallStr]>);
338
struct NumColumnsSchema<'a>(Option<&'a Schema>);
339
340
impl fmt::Display for ScanSourceRef<'_> {
341
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
342
match self {
343
ScanSourceRef::Path(addr) => addr.display().fmt(f),
344
ScanSourceRef::File(_) => f.write_str("open-file"),
345
ScanSourceRef::Buffer(buff) => write!(f, "{} in-mem bytes", buff.len()),
346
}
347
}
348
}
349
350
impl fmt::Display for ScanSourcesDisplay<'_> {
351
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352
match self.0.len() {
353
0 => write!(f, "[]"),
354
1 => write!(f, "[{}]", self.0.at(0)),
355
2 => write!(f, "[{}, {}]", self.0.at(0), self.0.at(1)),
356
_ => write!(
357
f,
358
"[{}, ... {} other sources]",
359
self.0.at(0),
360
self.0.len() - 1,
361
),
362
}
363
}
364
}
365
366
impl fmt::Display for PathsDisplay<'_> {
367
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
368
match self.0.len() {
369
0 => write!(f, "[]"),
370
1 => write!(f, "[{}]", self.0[0].display()),
371
2 => write!(f, "[{}, {}]", self.0[0].display(), self.0[1].display()),
372
_ => write!(
373
f,
374
"[{}, ... {} other files]",
375
self.0[0].display(),
376
self.0.len() - 1,
377
),
378
}
379
}
380
}
381
382
impl fmt::Display for NumColumns<'_> {
383
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
384
match self.0 {
385
None => f.write_str("*"),
386
Some(columns) => columns.len().fmt(f),
387
}
388
}
389
}
390
391
impl fmt::Display for NumColumnsSchema<'_> {
392
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
393
match self.0 {
394
None => f.write_str("*"),
395
Some(columns) => columns.len().fmt(f),
396
}
397
}
398
}
399
400
/// Utility structure to write to a [`fmt::Formatter`] whilst escaping the output as a label name
401
pub struct EscapeLabel<'a>(pub &'a mut dyn fmt::Write);
402
403
impl fmt::Write for EscapeLabel<'_> {
404
fn write_str(&mut self, mut s: &str) -> fmt::Result {
405
loop {
406
let mut char_indices = s.char_indices();
407
408
// This escapes quotes and new lines
409
// @NOTE: I am aware this does not work for \" and such. I am ignoring that fact as we
410
// are not really using such strings.
411
let f = char_indices.find_map(|(i, c)| match c {
412
'"' => Some((i, r#"\""#)),
413
'\n' => Some((i, r#"\n"#)),
414
_ => None,
415
});
416
417
let Some((at, to_write)) = f else {
418
break;
419
};
420
421
self.0.write_str(&s[..at])?;
422
self.0.write_str(to_write)?;
423
s = &s[at + 1..];
424
}
425
426
self.0.write_str(s)?;
427
428
Ok(())
429
}
430
}
431
432
impl fmt::Display for IRDotDisplay<'_> {
433
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
434
writeln!(f, "digraph polars_query {{")?;
435
writeln!(f, "{INDENT}rankdir=\"BT\"")?;
436
writeln!(f, "{INDENT}node [fontname=\"Monospace\", shape=\"box\"]")?;
437
438
let mut last = 0;
439
let mut visited_caches = PlHashSet::new();
440
self._format(f, None, &mut last, &mut visited_caches)?;
441
442
writeln!(f, "}}")?;
443
444
Ok(())
445
}
446
}
447
448