Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-utils/src/order_statistic_tree.rs
7884 views
1
//! This module implements an order statistic multiset, which is implemented
2
//! as a weight-balanced tree (WBT).
3
//! It is based on the weight-balanced tree based on the following papers:
4
//!
5
//! * <https://doi.org/10.1017/S0956796811000104>
6
//! * <https://doi.org/10.1137/1.9781611976007.13>
7
//!
8
//! Each of the nodes in the tree contains a UnitVec of values to store
9
//! multiple values with the same key.
10
11
use std::cmp::Ordering;
12
use std::fmt::Debug;
13
use std::ops::RangeInclusive;
14
15
use slotmap::{Key as SlotMapKey, SlotMap, new_key_type};
16
17
use crate::UnitVec;
18
19
const DELTA: usize = 3;
20
const GAMMA: usize = 2;
21
22
type CompareFn<T> = fn(&T, &T) -> Ordering;
23
24
new_key_type! {
25
struct Key;
26
}
27
28
#[derive(Debug)]
29
struct Node<T> {
30
values: UnitVec<T>,
31
left: Key,
32
right: Key,
33
weight: u32,
34
num_elems: u32,
35
}
36
37
#[derive(Debug)]
38
pub struct OrderStatisticTree<T> {
39
nodes: SlotMap<Key, Node<T>>,
40
root: Key,
41
compare: CompareFn<T>,
42
}
43
44
impl<T> OrderStatisticTree<T> {
45
#[inline]
46
pub fn new(compare: CompareFn<T>) -> Self {
47
OrderStatisticTree {
48
nodes: SlotMap::with_key(),
49
root: Key::null(),
50
compare,
51
}
52
}
53
54
#[inline]
55
pub fn with_capacity(capacity: usize, compare: CompareFn<T>) -> Self {
56
OrderStatisticTree {
57
nodes: SlotMap::with_capacity_and_key(capacity),
58
root: Key::null(),
59
compare,
60
}
61
}
62
63
#[inline]
64
pub fn is_empty(&self) -> bool {
65
self.len() == 0
66
}
67
68
#[inline]
69
pub fn len(&self) -> usize {
70
self.num_elems(self.root)
71
}
72
73
#[inline]
74
pub fn unique_len(&self) -> usize {
75
self.tree_weight(self.root)
76
}
77
78
#[inline]
79
pub fn clear(&mut self) {
80
self.nodes.clear();
81
self.root = Key::null();
82
}
83
84
/// Returns the total number of elements in the tree rooted at `tree`.
85
fn num_elems(&self, tree: Key) -> usize {
86
if tree.is_null() {
87
return 0;
88
}
89
unsafe { self.nodes.get_unchecked(tree) }.num_elems as usize
90
}
91
92
/// Returns the number of tree nodes, which is equal to the number of unique
93
/// elements, in the tree rooted at `tree`.
94
fn tree_weight(&self, tree: Key) -> usize {
95
if tree.is_null() {
96
return 0;
97
}
98
unsafe { self.nodes.get_unchecked(tree) }.weight as usize
99
}
100
101
#[must_use]
102
fn new_tree_node(&mut self, left: Key, values: UnitVec<T>, right: Key) -> Key {
103
let weight = self.tree_weight(left) + self.tree_weight(right) + 1;
104
let num_elems = self.num_elems(left) + self.num_elems(right) + values.len();
105
let n = Node {
106
values,
107
left,
108
right,
109
weight: weight as u32,
110
num_elems: num_elems as u32,
111
};
112
self.nodes.insert(n)
113
}
114
115
#[must_use]
116
fn new_leaf(&mut self, value: T) -> Key {
117
let mut uv = UnitVec::new();
118
uv.push(value);
119
self.new_tree_node(Key::null(), uv, Key::null())
120
}
121
122
#[must_use]
123
unsafe fn drop_tree_node(&mut self, tree: Key) -> Node<T> {
124
unsafe { self.nodes.remove(tree).unwrap_unchecked() }
125
}
126
127
#[inline]
128
pub fn get(&self, idx: usize) -> Option<&T> {
129
self._get(idx, self.root)
130
}
131
132
fn _get(&self, idx: usize, tree: Key) -> Option<&T> {
133
if tree.is_null() {
134
return None;
135
}
136
137
let n = unsafe { self.nodes.get_unchecked(tree) };
138
let own_elems = self.num_elems(tree);
139
let left_elems = self.num_elems(n.left);
140
let right_elems = self.num_elems(n.right);
141
142
if idx < left_elems {
143
self._get(idx, n.left)
144
} else if idx >= own_elems - right_elems {
145
self._get(idx - (own_elems - right_elems), n.right)
146
} else {
147
n.values.get(idx - left_elems)
148
}
149
}
150
151
#[inline]
152
pub fn insert(&mut self, value: T) {
153
(self.root, _) = self._insert(value, self.root);
154
}
155
156
#[must_use]
157
fn _insert(&mut self, value: T, tree: Key) -> (Key, bool) {
158
if tree.is_null() {
159
return (self.new_leaf(value), true);
160
}
161
162
let n = unsafe { self.nodes.get_unchecked(tree) };
163
match (self.compare)(&value, &n.values[0]) {
164
Ordering::Less => {
165
let (left, node_added) = self._insert(value, n.left);
166
let n = unsafe { self.nodes.get_unchecked_mut(tree) };
167
n.left = left;
168
n.weight += node_added as u32;
169
n.num_elems += 1;
170
(self.balance_r(tree), node_added)
171
},
172
Ordering::Equal => {
173
let n = unsafe { self.nodes.get_unchecked_mut(tree) };
174
n.values.push(value);
175
n.num_elems += 1;
176
(tree, false)
177
},
178
Ordering::Greater => {
179
let (right, node_added) = self._insert(value, n.right);
180
let n = unsafe { self.nodes.get_unchecked_mut(tree) };
181
n.right = right;
182
n.weight += node_added as u32;
183
n.num_elems += 1;
184
(self.balance_l(tree), node_added)
185
},
186
}
187
}
188
189
#[inline]
190
pub fn remove(&mut self, value: &T) -> Option<T> {
191
let deleted;
192
(deleted, self.root, _) = self._remove(value, self.root);
193
deleted
194
}
195
196
#[must_use]
197
fn _remove(&mut self, value: &T, tree: Key) -> (Option<T>, Key, bool) {
198
if tree.is_null() {
199
return (None, tree, false);
200
}
201
202
let n = unsafe { self.nodes.get_unchecked(tree) };
203
match (self.compare)(value, &n.values[0]) {
204
Ordering::Less => {
205
let (deleted, left, node_removed) = self._remove(value, n.left);
206
let n = unsafe { self.nodes.get_unchecked_mut(tree) };
207
n.left = left;
208
n.weight -= node_removed as u32;
209
n.num_elems -= deleted.is_some() as u32;
210
(deleted, self.balance_l(tree), node_removed)
211
},
212
Ordering::Greater => {
213
let (deleted, right, node_removed) = self._remove(value, n.right);
214
let n = unsafe { self.nodes.get_unchecked_mut(tree) };
215
n.right = right;
216
n.weight -= node_removed as u32;
217
n.num_elems -= deleted.is_some() as u32;
218
(deleted, self.balance_r(tree), node_removed)
219
},
220
Ordering::Equal if n.values.len() > 1 => {
221
let n = unsafe { self.nodes.get_unchecked_mut(tree) };
222
let popped_value = unsafe { n.values.pop().unwrap_unchecked() };
223
n.num_elems -= 1;
224
(Some(popped_value), tree, false)
225
},
226
Ordering::Equal => {
227
let mut n = unsafe { self.drop_tree_node(tree) };
228
(
229
Some(unsafe { n.values.pop().unwrap_unchecked() }),
230
self.glue(n.left, n.right),
231
true,
232
)
233
},
234
}
235
}
236
237
#[must_use]
238
fn glue(&mut self, left: Key, right: Key) -> Key {
239
if left.is_null() {
240
right
241
} else if right.is_null() {
242
left
243
} else if self.tree_weight(left) > self.tree_weight(right) {
244
let (deleted, left) = self.remove_max(left);
245
let tree = self.new_tree_node(left, deleted, right);
246
self.balance_r(tree)
247
} else {
248
let (deleted, right) = self.remove_min(right);
249
let tree = self.new_tree_node(left, deleted, right);
250
self.balance_l(tree)
251
}
252
}
253
254
#[must_use]
255
fn remove_min(&mut self, tree: Key) -> (UnitVec<T>, Key) {
256
debug_assert!(!tree.is_null());
257
let n = unsafe { self.nodes.get_unchecked(tree) };
258
if n.left.is_null() {
259
let n = unsafe { self.drop_tree_node(tree) };
260
return (n.values, n.right);
261
}
262
let (deleted, left) = self.remove_min(n.left);
263
let n = unsafe { self.nodes.get_unchecked_mut(tree) };
264
n.left = left;
265
n.weight -= 1;
266
n.num_elems -= deleted.len() as u32;
267
(deleted, self.balance_l(tree))
268
}
269
270
#[must_use]
271
fn remove_max(&mut self, tree: Key) -> (UnitVec<T>, Key) {
272
debug_assert!(!tree.is_null());
273
let n = unsafe { self.nodes.get_unchecked(tree) };
274
if n.right.is_null() {
275
let n = unsafe { self.drop_tree_node(tree) };
276
return (n.values, n.left);
277
}
278
let (deleted, right) = self.remove_max(n.right);
279
let n = unsafe { self.nodes.get_unchecked_mut(tree) };
280
n.right = right;
281
n.weight -= 1;
282
n.num_elems -= deleted.len() as u32;
283
(deleted, self.balance_r(tree))
284
}
285
286
#[inline]
287
pub fn contains(&self, value: &T) -> bool {
288
self._contains(value, self.root)
289
}
290
291
fn _contains(&self, value: &T, tree: Key) -> bool {
292
if tree.is_null() {
293
return false;
294
}
295
let n = unsafe { self.nodes.get_unchecked(tree) };
296
match (self.compare)(value, &n.values[0]) {
297
Ordering::Less => self._contains(value, n.left),
298
Ordering::Equal => true,
299
Ordering::Greater => self._contains(value, n.right),
300
}
301
}
302
303
#[must_use]
304
fn balance_l(&mut self, tree: Key) -> Key {
305
let n = unsafe { self.nodes.get_unchecked(tree) };
306
if self.pair_is_balanced(n.left, n.right) {
307
return tree;
308
}
309
self.rotate_l(tree)
310
}
311
312
#[must_use]
313
fn rotate_l(&mut self, tree: Key) -> Key {
314
let n = unsafe { self.nodes.get_unchecked(tree) };
315
let r = unsafe { self.nodes.get_unchecked(n.right) };
316
if self.is_single(r.left, r.right) {
317
self.single_l(tree)
318
} else {
319
self.double_l(tree)
320
}
321
}
322
323
#[must_use]
324
fn single_l(&mut self, tree: Key) -> Key {
325
let n = unsafe { self.drop_tree_node(tree) };
326
let r = unsafe { self.drop_tree_node(n.right) };
327
let new_left = self.new_tree_node(n.left, n.values, r.left);
328
self.new_tree_node(new_left, r.values, r.right)
329
}
330
331
#[must_use]
332
fn double_l(&mut self, tree: Key) -> Key {
333
let n = unsafe { self.drop_tree_node(tree) };
334
let r = unsafe { self.drop_tree_node(n.right) };
335
let rl = unsafe { self.drop_tree_node(r.left) };
336
let new_left = self.new_tree_node(n.left, n.values, rl.left);
337
let new_right = self.new_tree_node(rl.right, r.values, r.right);
338
self.new_tree_node(new_left, rl.values, new_right)
339
}
340
341
#[must_use]
342
fn balance_r(&mut self, tree: Key) -> Key {
343
let n = unsafe { self.nodes.get_unchecked(tree) };
344
if self.pair_is_balanced(n.right, n.left) {
345
return tree;
346
}
347
self.rotate_r(tree)
348
}
349
350
#[must_use]
351
fn rotate_r(&mut self, tree: Key) -> Key {
352
let n = unsafe { self.nodes.get_unchecked(tree) };
353
let l = unsafe { self.nodes.get_unchecked(n.left) };
354
if self.is_single(l.right, l.left) {
355
self.single_r(tree)
356
} else {
357
self.double_r(tree)
358
}
359
}
360
361
#[must_use]
362
fn single_r(&mut self, tree: Key) -> Key {
363
let n = unsafe { self.drop_tree_node(tree) };
364
let l = unsafe { self.drop_tree_node(n.left) };
365
let new_right = self.new_tree_node(l.right, n.values, n.right);
366
self.new_tree_node(l.left, l.values, new_right)
367
}
368
369
#[must_use]
370
fn double_r(&mut self, tree: Key) -> Key {
371
let n = unsafe { self.drop_tree_node(tree) };
372
let l = unsafe { self.drop_tree_node(n.left) };
373
let lr = unsafe { self.drop_tree_node(l.right) };
374
let new_right = self.new_tree_node(lr.right, n.values, n.right);
375
let new_left = self.new_tree_node(l.left, l.values, lr.left);
376
self.new_tree_node(new_left, lr.values, new_right)
377
}
378
379
#[doc(hidden)]
380
pub fn is_balanced(&self) -> bool {
381
self.tree_is_balanced(self.root)
382
}
383
384
fn tree_is_balanced(&self, tree: Key) -> bool {
385
if tree.is_null() {
386
return true;
387
}
388
let n = unsafe { self.nodes.get_unchecked(tree) };
389
self.pair_is_balanced(n.left, n.right)
390
&& self.pair_is_balanced(n.right, n.left)
391
&& self.tree_is_balanced(n.left)
392
&& self.tree_is_balanced(n.right)
393
}
394
395
fn pair_is_balanced(&self, left: Key, right: Key) -> bool {
396
let a = self.tree_weight(left);
397
let b = self.tree_weight(right);
398
DELTA * (a + 1) >= (b + 1) && DELTA * (b + 1) >= (a + 1)
399
}
400
401
fn is_single(&self, left: Key, right: Key) -> bool {
402
let a = self.tree_weight(left);
403
let b = self.tree_weight(right);
404
a + 1 < GAMMA * (b + 1)
405
}
406
407
#[inline]
408
pub fn rank_range(&self, bound: &T) -> Result<RangeInclusive<usize>, usize> {
409
self._rank_range(bound, self.root)
410
}
411
412
fn _rank_range(&self, value: &T, tree: Key) -> Result<RangeInclusive<usize>, usize> {
413
if tree.is_null() {
414
return Err(0);
415
}
416
let n = unsafe { self.nodes.get_unchecked(tree) };
417
match (self.compare)(value, &n.values[0]) {
418
Ordering::Less => self._rank_range(value, n.left),
419
Ordering::Equal => {
420
let lo = self.num_elems(n.left);
421
let hi = lo + n.values.len() - 1;
422
Ok(lo..=hi)
423
},
424
Ordering::Greater => {
425
let update_rank = |r| self.num_elems(tree) - self.num_elems(n.right) + r;
426
self._rank_range(value, n.right)
427
.map(|rank| update_rank(*rank.start())..=update_rank(*rank.end()))
428
.map_err(update_rank)
429
},
430
}
431
}
432
433
#[inline]
434
pub fn rank_unique(&self, value: &T) -> Result<usize, usize> {
435
self._rank_unique(value, self.root)
436
}
437
438
fn _rank_unique(&self, value: &T, tree: Key) -> Result<usize, usize> {
439
if tree.is_null() {
440
return Err(0);
441
}
442
let n = unsafe { self.nodes.get_unchecked(tree) };
443
match (self.compare)(value, &n.values[0]) {
444
Ordering::Less => self._rank_unique(value, n.left),
445
Ordering::Equal => Ok(self.tree_weight(n.left)),
446
Ordering::Greater => self
447
._rank_unique(value, n.right)
448
.map(|rank| self.tree_weight(tree) - self.tree_weight(n.right) + rank)
449
.map_err(|rank| self.tree_weight(tree) - self.tree_weight(n.right) + rank),
450
}
451
}
452
453
#[inline]
454
pub fn count(&self, value: &T) -> usize {
455
self._count(value, self.root)
456
}
457
458
fn _count(&self, value: &T, tree: Key) -> usize {
459
if tree.is_null() {
460
return 0;
461
}
462
let n = unsafe { self.nodes.get_unchecked(tree) };
463
match (self.compare)(value, &n.values[0]) {
464
Ordering::Less => self._count(value, n.left),
465
Ordering::Equal => n.values.len(),
466
Ordering::Greater => self._count(value, n.right),
467
}
468
}
469
}
470
471
impl<T> Extend<T> for OrderStatisticTree<T> {
472
fn extend<I: IntoIterator<Item = T>>(&mut self, iterable: I) {
473
let iterator = iterable.into_iter();
474
for element in iterator {
475
self.insert(element);
476
}
477
}
478
}
479
480
#[cfg(test)]
481
mod test {
482
483
use proptest::collection::vec;
484
use proptest::prelude::*;
485
use proptest::test_runner::TestRunner;
486
487
use super::*;
488
489
#[test]
490
fn test_insert() {
491
let mut runner = TestRunner::default();
492
runner
493
.run(&vec((0i32..100, 0i32..100), 0..100), test_insert_inner)
494
.unwrap()
495
}
496
497
fn test_insert_inner(items: Vec<(i32, i32)>) -> Result<(), TestCaseError> {
498
let cmp = |a: &(i32, i32), b: &(i32, i32)| i32::cmp(&a.0, &b.0);
499
let mut ost = OrderStatisticTree::new(cmp);
500
for item in &items {
501
ost.insert(*item);
502
assert!(ost.is_balanced());
503
}
504
assert_eq!(ost.len(), items.len());
505
let mut sorted_items = items.clone();
506
sorted_items.sort();
507
let mut collected_items = Vec::new();
508
let mut i = 0;
509
while let Some(v) = ost.get(i) {
510
collected_items.push(*v);
511
i += 1;
512
}
513
collected_items.sort();
514
assert_eq!(ost.len(), items.len());
515
assert_eq!(&collected_items, &sorted_items);
516
Ok(())
517
}
518
519
#[test]
520
fn test_remove() {
521
let mut runner = TestRunner::default();
522
runner
523
.run(
524
&(vec(0i32..100, 0..100), vec(0i32..100, 0..100)),
525
test_remove_inner,
526
)
527
.unwrap();
528
}
529
530
fn test_remove_inner(input: (Vec<i32>, Vec<i32>)) -> Result<(), TestCaseError> {
531
let (mut items, to_remove) = input;
532
let mut ost = OrderStatisticTree::new(i32::cmp);
533
for item in &items {
534
ost.insert(*item);
535
assert!(ost.is_balanced());
536
}
537
items.sort();
538
for item in &to_remove {
539
let v = ost.remove(item);
540
assert!(ost.is_balanced());
541
let idx = items.binary_search(item);
542
assert_eq!(v.is_some(), idx.is_ok());
543
if let Ok(idx) = idx {
544
items.remove(idx);
545
}
546
assert_eq!(ost.len(), items.len());
547
}
548
assert_eq!(ost.len(), items.len());
549
for item in 0..100 {
550
assert_eq!(ost.contains(&item), items.contains(&item));
551
}
552
Ok(())
553
}
554
555
#[test]
556
fn test_rank() {
557
let mut runner = TestRunner::default();
558
runner
559
.run(&vec(0i32..100, 0..100), test_rank_inner)
560
.unwrap();
561
}
562
563
fn test_rank_inner(mut items: Vec<i32>) -> Result<(), TestCaseError> {
564
let mut ost = OrderStatisticTree::new(i32::cmp);
565
for item in &items {
566
ost.insert(*item);
567
}
568
items.sort();
569
for item in 0..100 {
570
let rank = ost.rank_range(&item);
571
572
let expected_rank = if items.contains(&item) {
573
let expected_rank_lower = items.iter().filter(|&x| *x < item).count();
574
let expected_rank_upper = items.iter().filter(|&x| *x <= item).count() - 1;
575
Ok(expected_rank_lower..=expected_rank_upper)
576
} else {
577
Err(items.iter().filter(|&x| *x < item).count())
578
};
579
580
assert_eq!(rank, expected_rank);
581
}
582
Ok(())
583
}
584
585
#[test]
586
fn test_unique_rank() {
587
let mut runner = TestRunner::default();
588
runner
589
.run(&vec(0i32..50, 0..100), test_unique_rank_inner)
590
.unwrap();
591
}
592
593
fn test_unique_rank_inner(mut items: Vec<i32>) -> Result<(), TestCaseError> {
594
let mut ost = OrderStatisticTree::new(i32::cmp);
595
for item in &items {
596
ost.insert(*item);
597
}
598
assert_eq!(ost.len(), items.len());
599
items.sort();
600
items.dedup();
601
assert_eq!(ost.unique_len(), items.len());
602
for item in 0..50 {
603
let unique_rank = ost.rank_unique(&item);
604
let expected_unique_rank = if items.contains(&item) {
605
Ok(items.iter().filter(|&x| *x < item).count())
606
} else {
607
Err(items.iter().filter(|&x| *x < item).count())
608
};
609
assert_eq!(unique_rank, expected_unique_rank);
610
}
611
Ok(())
612
}
613
614
#[test]
615
fn test_empty() {
616
let ost = OrderStatisticTree::<i32>::new(i32::cmp);
617
assert!(ost.is_empty());
618
assert_eq!(ost.len(), 0);
619
assert_eq!(ost.unique_len(), 0);
620
assert!(ost.is_balanced());
621
assert!(!ost.contains(&1));
622
assert_eq!(ost.rank_range(&1), Err(0));
623
assert_eq!(ost.rank_unique(&1), Err(0));
624
}
625
626
#[test]
627
fn test_clear() {
628
let mut ost = OrderStatisticTree::new(i32::cmp);
629
for item in 0..10 {
630
ost.insert(item);
631
}
632
assert_eq!(ost.len(), 10);
633
assert_eq!(ost.unique_len(), 10);
634
ost.clear();
635
assert!(ost.is_empty());
636
}
637
638
#[test]
639
fn test_extend() {
640
let mut ost = OrderStatisticTree::new(i32::cmp);
641
ost.extend(0..10);
642
assert_eq!(ost.len(), 10);
643
assert_eq!(ost.unique_len(), 10);
644
for item in 0..10 {
645
assert!(ost.contains(&item));
646
}
647
}
648
649
#[test]
650
fn test_count() {
651
let mut ost = OrderStatisticTree::new(i32::cmp);
652
for item in &[1, 2, 2, 3, 3, 3] {
653
ost.insert(*item);
654
}
655
assert_eq!(ost.count(&1), 1);
656
assert_eq!(ost.count(&2), 2);
657
assert_eq!(ost.count(&3), 3);
658
assert_eq!(ost.count(&4), 0);
659
}
660
661
#[test]
662
fn test_get() {
663
let mut ost = OrderStatisticTree::new(i32::cmp);
664
let mut items = [3, 1, 4, 1, 5, 9, 2, 6, 5];
665
for item in items {
666
ost.insert(item);
667
}
668
items.sort();
669
for (i, item) in items.iter().enumerate() {
670
assert_eq!(ost.get(i), Some(item));
671
}
672
assert_eq!(ost.get(items.len()), None);
673
}
674
}
675
676