Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/visitor/hash.rs
6940 views
1
use std::hash::{Hash, Hasher};
2
use std::sync::Arc;
3
4
use polars_utils::arena::Arena;
5
6
use super::*;
7
#[cfg(feature = "python")]
8
use crate::plans::PythonOptions;
9
use crate::plans::{AExpr, IR};
10
use crate::prelude::ExprIR;
11
use crate::prelude::aexpr::traverse_and_hash_aexpr;
12
13
impl IRNode {
14
pub(crate) fn hashable_and_cmp<'a>(
15
&'a self,
16
lp_arena: &'a Arena<IR>,
17
expr_arena: &'a Arena<AExpr>,
18
) -> HashableEqLP<'a> {
19
HashableEqLP {
20
node: *self,
21
lp_arena,
22
expr_arena,
23
ignore_cache: false,
24
}
25
}
26
}
27
28
pub(crate) struct HashableEqLP<'a> {
29
node: IRNode,
30
lp_arena: &'a Arena<IR>,
31
expr_arena: &'a Arena<AExpr>,
32
ignore_cache: bool,
33
}
34
35
impl HashableEqLP<'_> {
36
/// When encountering a Cache node, ignore it and take the input.
37
#[cfg(feature = "cse")]
38
pub(crate) fn ignore_caches(mut self) -> Self {
39
self.ignore_cache = true;
40
self
41
}
42
}
43
44
fn hash_option_expr<H: Hasher>(expr: &Option<ExprIR>, expr_arena: &Arena<AExpr>, state: &mut H) {
45
if let Some(e) = expr {
46
e.traverse_and_hash(expr_arena, state)
47
}
48
}
49
50
fn hash_exprs<H: Hasher>(exprs: &[ExprIR], expr_arena: &Arena<AExpr>, state: &mut H) {
51
for e in exprs {
52
e.traverse_and_hash(expr_arena, state);
53
}
54
}
55
56
#[cfg(feature = "python")]
57
fn hash_python_predicate<H: Hasher>(
58
pred: &crate::prelude::PythonPredicate,
59
expr_arena: &Arena<AExpr>,
60
state: &mut H,
61
) {
62
use crate::prelude::PythonPredicate;
63
std::mem::discriminant(pred).hash(state);
64
match pred {
65
PythonPredicate::None => {},
66
PythonPredicate::PyArrow(s) => s.hash(state),
67
PythonPredicate::Polars(e) => e.traverse_and_hash(expr_arena, state),
68
}
69
}
70
71
#[cfg(feature = "python")]
72
fn pred_eq(
73
l: &crate::prelude::PythonPredicate,
74
r: &crate::prelude::PythonPredicate,
75
expr_arena: &Arena<AExpr>,
76
) -> bool {
77
use crate::prelude::PythonPredicate;
78
match (l, r) {
79
(PythonPredicate::None, PythonPredicate::None) => true,
80
(PythonPredicate::PyArrow(a), PythonPredicate::PyArrow(b)) => a == b,
81
(PythonPredicate::Polars(a), PythonPredicate::Polars(b)) => expr_ir_eq(a, b, expr_arena),
82
_ => false,
83
}
84
}
85
86
impl Hash for HashableEqLP<'_> {
87
// This hashes the variant, not the whole plan
88
fn hash<H: Hasher>(&self, state: &mut H) {
89
let alp = self.node.to_alp(self.lp_arena);
90
std::mem::discriminant(alp).hash(state);
91
match alp {
92
#[cfg(feature = "python")]
93
IR::PythonScan {
94
options:
95
PythonOptions {
96
scan_fn,
97
schema,
98
output_schema,
99
with_columns,
100
python_source,
101
n_rows,
102
predicate,
103
validate_schema,
104
is_pure,
105
},
106
} => {
107
// Hash the Python function object using the pointer to the object
108
// This should be the same as calling id() in python, but we don't need the GIL
109
if let Some(scan_fn) = scan_fn {
110
let ptr_addr = scan_fn.0.as_ptr() as usize;
111
ptr_addr.hash(state);
112
}
113
// Hash the stable fields
114
// We include the schema since it can be set by the user
115
schema.hash(state);
116
output_schema.hash(state);
117
with_columns.hash(state);
118
python_source.hash(state);
119
n_rows.hash(state);
120
hash_python_predicate(predicate, self.expr_arena, state);
121
validate_schema.hash(state);
122
is_pure.hash(state);
123
},
124
IR::Slice {
125
offset,
126
len,
127
input: _,
128
} => {
129
len.hash(state);
130
offset.hash(state);
131
},
132
IR::Filter {
133
input: _,
134
predicate,
135
} => {
136
predicate.traverse_and_hash(self.expr_arena, state);
137
},
138
IR::Scan {
139
sources,
140
file_info: _,
141
hive_parts: _,
142
predicate,
143
output_schema: _,
144
scan_type,
145
unified_scan_args,
146
} => {
147
// We don't have to traverse the schema, hive partitions etc. as they are derivative from the paths.
148
scan_type.hash(state);
149
sources.hash(state);
150
hash_option_expr(predicate, self.expr_arena, state);
151
unified_scan_args.hash(state);
152
},
153
IR::DataFrameScan {
154
df,
155
schema: _,
156
output_schema,
157
..
158
} => {
159
(Arc::as_ptr(df) as usize).hash(state);
160
output_schema.hash(state);
161
},
162
IR::SimpleProjection { columns, input: _ } => {
163
columns.hash(state);
164
},
165
IR::Select {
166
input: _,
167
expr,
168
schema: _,
169
options,
170
} => {
171
hash_exprs(expr, self.expr_arena, state);
172
options.hash(state);
173
},
174
IR::Sort {
175
input: _,
176
by_column,
177
slice,
178
sort_options,
179
} => {
180
hash_exprs(by_column, self.expr_arena, state);
181
slice.hash(state);
182
sort_options.hash(state);
183
},
184
IR::GroupBy {
185
input: _,
186
keys,
187
aggs,
188
schema: _,
189
apply,
190
maintain_order,
191
options,
192
} => {
193
hash_exprs(keys, self.expr_arena, state);
194
hash_exprs(aggs, self.expr_arena, state);
195
apply.is_none().hash(state);
196
maintain_order.hash(state);
197
options.hash(state);
198
},
199
IR::Join {
200
input_left: _,
201
input_right: _,
202
schema: _,
203
left_on,
204
right_on,
205
options,
206
} => {
207
hash_exprs(left_on, self.expr_arena, state);
208
hash_exprs(right_on, self.expr_arena, state);
209
options.hash(state);
210
},
211
IR::HStack {
212
input: _,
213
exprs,
214
schema: _,
215
options,
216
} => {
217
hash_exprs(exprs, self.expr_arena, state);
218
options.hash(state);
219
},
220
IR::Distinct { input: _, options } => {
221
options.hash(state);
222
},
223
IR::MapFunction { input: _, function } => {
224
function.hash(state);
225
},
226
IR::Union { inputs: _, options } => options.hash(state),
227
IR::HConcat {
228
inputs: _,
229
schema: _,
230
options,
231
} => {
232
options.hash(state);
233
},
234
IR::ExtContext {
235
input: _,
236
contexts,
237
schema: _,
238
} => {
239
for node in contexts {
240
traverse_and_hash_aexpr(*node, self.expr_arena, state);
241
}
242
},
243
IR::Sink { input: _, payload } => {
244
payload.traverse_and_hash(self.expr_arena, state);
245
},
246
IR::SinkMultiple { .. } => {},
247
IR::Cache { input: _, id } => {
248
id.hash(state);
249
},
250
#[cfg(feature = "merge_sorted")]
251
IR::MergeSorted {
252
input_left: _,
253
input_right: _,
254
key,
255
} => {
256
key.hash(state);
257
},
258
IR::Invalid => unreachable!(),
259
}
260
}
261
}
262
263
fn expr_irs_eq(l: &[ExprIR], r: &[ExprIR], expr_arena: &Arena<AExpr>) -> bool {
264
l.len() == r.len() && l.iter().zip(r).all(|(l, r)| expr_ir_eq(l, r, expr_arena))
265
}
266
267
fn expr_ir_eq(l: &ExprIR, r: &ExprIR, expr_arena: &Arena<AExpr>) -> bool {
268
l.get_alias() == r.get_alias() && {
269
let l = AexprNode::new(l.node());
270
let r = AexprNode::new(r.node());
271
l.hashable_and_cmp(expr_arena) == r.hashable_and_cmp(expr_arena)
272
}
273
}
274
275
fn opt_expr_ir_eq(l: &Option<ExprIR>, r: &Option<ExprIR>, expr_arena: &Arena<AExpr>) -> bool {
276
match (l, r) {
277
(None, None) => true,
278
(Some(l), Some(r)) => expr_ir_eq(l, r, expr_arena),
279
_ => false,
280
}
281
}
282
283
impl HashableEqLP<'_> {
284
fn is_equal(&self, other: &Self) -> bool {
285
let alp_l = self.node.to_alp(self.lp_arena);
286
let alp_r = other.node.to_alp(self.lp_arena);
287
if std::mem::discriminant(alp_l) != std::mem::discriminant(alp_r) {
288
return false;
289
}
290
match (alp_l, alp_r) {
291
#[cfg(feature = "python")]
292
(
293
IR::PythonScan {
294
options:
295
PythonOptions {
296
scan_fn: scan_fn_l,
297
schema: schema_l,
298
output_schema: output_schema_l,
299
with_columns: with_columns_l,
300
python_source: python_source_l,
301
n_rows: n_rows_l,
302
predicate: predicate_l,
303
validate_schema: validate_schema_l,
304
is_pure: is_pure_l,
305
},
306
},
307
IR::PythonScan {
308
options:
309
PythonOptions {
310
scan_fn: scan_fn_r,
311
schema: schema_r,
312
output_schema: output_schema_r,
313
with_columns: with_columns_r,
314
python_source: python_source_r,
315
n_rows: n_rows_r,
316
predicate: predicate_r,
317
validate_schema: validate_schema_r,
318
is_pure: is_pure_r,
319
},
320
},
321
) => {
322
// Require both to be pure to compare equal for CSE.
323
if !(*is_pure_l && *is_pure_r) {
324
return false;
325
}
326
327
let scan_fn_eq = match (scan_fn_l, scan_fn_r) {
328
(None, None) => true,
329
(Some(a), Some(b)) => a.0.as_ptr() == b.0.as_ptr(),
330
_ => false,
331
};
332
333
scan_fn_eq
334
&& schema_l == schema_r
335
&& output_schema_l == output_schema_r
336
&& with_columns_l == with_columns_r
337
&& python_source_l == python_source_r
338
&& n_rows_l == n_rows_r
339
&& validate_schema_l == validate_schema_r
340
&& pred_eq(predicate_l, predicate_r, self.expr_arena)
341
},
342
(
343
IR::Slice {
344
input: _,
345
offset: ol,
346
len: ll,
347
},
348
IR::Slice {
349
input: _,
350
offset: or,
351
len: lr,
352
},
353
) => ol == or && ll == lr,
354
(
355
IR::Filter {
356
input: _,
357
predicate: l,
358
},
359
IR::Filter {
360
input: _,
361
predicate: r,
362
},
363
) => expr_ir_eq(l, r, self.expr_arena),
364
(
365
IR::Scan {
366
sources: pl,
367
file_info: _,
368
hive_parts: _,
369
predicate: pred_l,
370
output_schema: _,
371
scan_type: stl,
372
unified_scan_args: ol,
373
},
374
IR::Scan {
375
sources: pr,
376
file_info: _,
377
hive_parts: _,
378
predicate: pred_r,
379
output_schema: _,
380
scan_type: str,
381
unified_scan_args: or,
382
},
383
) => {
384
pl == pr
385
&& stl == str
386
&& ol == or
387
&& opt_expr_ir_eq(pred_l, pred_r, self.expr_arena)
388
},
389
(
390
IR::DataFrameScan {
391
df: dfl,
392
schema: _,
393
output_schema: s_l,
394
},
395
IR::DataFrameScan {
396
df: dfr,
397
schema: _,
398
output_schema: s_r,
399
},
400
) => std::ptr::eq(Arc::as_ptr(dfl), Arc::as_ptr(dfr)) && s_l == s_r,
401
(
402
IR::SimpleProjection {
403
input: _,
404
columns: cl,
405
},
406
IR::SimpleProjection {
407
input: _,
408
columns: cr,
409
},
410
) => cl == cr,
411
(
412
IR::Select {
413
input: _,
414
expr: el,
415
options: ol,
416
schema: _,
417
},
418
IR::Select {
419
input: _,
420
expr: er,
421
options: or,
422
schema: _,
423
},
424
) => ol == or && expr_irs_eq(el, er, self.expr_arena),
425
(
426
IR::Sort {
427
input: _,
428
by_column: cl,
429
slice: l_slice,
430
sort_options: l_options,
431
},
432
IR::Sort {
433
input: _,
434
by_column: cr,
435
slice: r_slice,
436
sort_options: r_options,
437
},
438
) => {
439
(l_slice == r_slice && l_options == r_options)
440
&& expr_irs_eq(cl, cr, self.expr_arena)
441
},
442
(
443
IR::GroupBy {
444
input: _,
445
keys: keys_l,
446
aggs: aggs_l,
447
schema: _,
448
apply: apply_l,
449
maintain_order: maintain_l,
450
options: ol,
451
},
452
IR::GroupBy {
453
input: _,
454
keys: keys_r,
455
aggs: aggs_r,
456
schema: _,
457
apply: apply_r,
458
maintain_order: maintain_r,
459
options: or,
460
},
461
) => {
462
apply_l.is_none()
463
&& apply_r.is_none()
464
&& ol == or
465
&& maintain_l == maintain_r
466
&& expr_irs_eq(keys_l, keys_r, self.expr_arena)
467
&& expr_irs_eq(aggs_l, aggs_r, self.expr_arena)
468
},
469
(
470
IR::Join {
471
input_left: _,
472
input_right: _,
473
schema: _,
474
left_on: ll,
475
right_on: rl,
476
options: ol,
477
},
478
IR::Join {
479
input_left: _,
480
input_right: _,
481
schema: _,
482
left_on: lr,
483
right_on: rr,
484
options: or,
485
},
486
) => {
487
ol == or
488
&& expr_irs_eq(ll, lr, self.expr_arena)
489
&& expr_irs_eq(rl, rr, self.expr_arena)
490
},
491
(
492
IR::HStack {
493
input: _,
494
exprs: el,
495
schema: _,
496
options: ol,
497
},
498
IR::HStack {
499
input: _,
500
exprs: er,
501
schema: _,
502
options: or,
503
},
504
) => ol == or && expr_irs_eq(el, er, self.expr_arena),
505
(
506
IR::Distinct {
507
input: _,
508
options: ol,
509
},
510
IR::Distinct {
511
input: _,
512
options: or,
513
},
514
) => ol == or,
515
(
516
IR::MapFunction {
517
input: _,
518
function: l,
519
},
520
IR::MapFunction {
521
input: _,
522
function: r,
523
},
524
) => l == r,
525
(
526
IR::Union {
527
inputs: _,
528
options: l,
529
},
530
IR::Union {
531
inputs: _,
532
options: r,
533
},
534
) => l == r,
535
(
536
IR::HConcat {
537
inputs: _,
538
schema: _,
539
options: l,
540
},
541
IR::HConcat {
542
inputs: _,
543
schema: _,
544
options: r,
545
},
546
) => l == r,
547
(
548
IR::ExtContext {
549
input: _,
550
contexts: l,
551
schema: _,
552
},
553
IR::ExtContext {
554
input: _,
555
contexts: r,
556
schema: _,
557
},
558
) => {
559
l.len() == r.len()
560
&& l.iter().zip(r.iter()).all(|(l, r)| {
561
let l = AexprNode::new(*l).hashable_and_cmp(self.expr_arena);
562
let r = AexprNode::new(*r).hashable_and_cmp(self.expr_arena);
563
l == r
564
})
565
},
566
_ => false,
567
}
568
}
569
}
570
571
impl PartialEq for HashableEqLP<'_> {
572
fn eq(&self, other: &Self) -> bool {
573
let mut scratch_1 = vec![];
574
let mut scratch_2 = vec![];
575
576
scratch_1.push(self.node.node());
577
scratch_2.push(other.node.node());
578
579
loop {
580
match (scratch_1.pop(), scratch_2.pop()) {
581
(Some(l), Some(r)) => {
582
let l = IRNode::new(l);
583
let r = IRNode::new(r);
584
let l_alp = l.to_alp(self.lp_arena);
585
let r_alp = r.to_alp(self.lp_arena);
586
587
if self.ignore_cache {
588
match (l_alp, r_alp) {
589
(IR::Cache { input: l, .. }, IR::Cache { input: r, .. }) => {
590
scratch_1.push(*l);
591
scratch_2.push(*r);
592
continue;
593
},
594
(IR::Cache { input: l, .. }, _) => {
595
scratch_1.push(*l);
596
scratch_2.push(r.node());
597
continue;
598
},
599
(_, IR::Cache { input: r, .. }) => {
600
scratch_1.push(l.node());
601
scratch_2.push(*r);
602
continue;
603
},
604
_ => {},
605
}
606
}
607
608
if !l
609
.hashable_and_cmp(self.lp_arena, self.expr_arena)
610
.is_equal(&r.hashable_and_cmp(self.lp_arena, self.expr_arena))
611
{
612
return false;
613
}
614
615
l_alp.copy_inputs(&mut scratch_1);
616
r_alp.copy_inputs(&mut scratch_2);
617
},
618
(None, None) => return true,
619
_ => return false,
620
}
621
}
622
}
623
}
624
625