Path: blob/main/crates/polars-utils/src/order_statistic_tree.rs
7884 views
//! This module implements an order statistic multiset, which is implemented1//! as a weight-balanced tree (WBT).2//! It is based on the weight-balanced tree based on the following papers:3//!4//! * <https://doi.org/10.1017/S0956796811000104>5//! * <https://doi.org/10.1137/1.9781611976007.13>6//!7//! Each of the nodes in the tree contains a UnitVec of values to store8//! multiple values with the same key.910use std::cmp::Ordering;11use std::fmt::Debug;12use std::ops::RangeInclusive;1314use slotmap::{Key as SlotMapKey, SlotMap, new_key_type};1516use crate::UnitVec;1718const DELTA: usize = 3;19const GAMMA: usize = 2;2021type CompareFn<T> = fn(&T, &T) -> Ordering;2223new_key_type! {24struct Key;25}2627#[derive(Debug)]28struct Node<T> {29values: UnitVec<T>,30left: Key,31right: Key,32weight: u32,33num_elems: u32,34}3536#[derive(Debug)]37pub struct OrderStatisticTree<T> {38nodes: SlotMap<Key, Node<T>>,39root: Key,40compare: CompareFn<T>,41}4243impl<T> OrderStatisticTree<T> {44#[inline]45pub fn new(compare: CompareFn<T>) -> Self {46OrderStatisticTree {47nodes: SlotMap::with_key(),48root: Key::null(),49compare,50}51}5253#[inline]54pub fn with_capacity(capacity: usize, compare: CompareFn<T>) -> Self {55OrderStatisticTree {56nodes: SlotMap::with_capacity_and_key(capacity),57root: Key::null(),58compare,59}60}6162#[inline]63pub fn is_empty(&self) -> bool {64self.len() == 065}6667#[inline]68pub fn len(&self) -> usize {69self.num_elems(self.root)70}7172#[inline]73pub fn unique_len(&self) -> usize {74self.tree_weight(self.root)75}7677#[inline]78pub fn clear(&mut self) {79self.nodes.clear();80self.root = Key::null();81}8283/// Returns the total number of elements in the tree rooted at `tree`.84fn num_elems(&self, tree: Key) -> usize {85if tree.is_null() {86return 0;87}88unsafe { self.nodes.get_unchecked(tree) }.num_elems as usize89}9091/// Returns the number of tree nodes, which is equal to the number of unique92/// elements, in the tree rooted at `tree`.93fn tree_weight(&self, tree: Key) -> usize {94if tree.is_null() {95return 0;96}97unsafe { self.nodes.get_unchecked(tree) }.weight as usize98}99100#[must_use]101fn new_tree_node(&mut self, left: Key, values: UnitVec<T>, right: Key) -> Key {102let weight = self.tree_weight(left) + self.tree_weight(right) + 1;103let num_elems = self.num_elems(left) + self.num_elems(right) + values.len();104let n = Node {105values,106left,107right,108weight: weight as u32,109num_elems: num_elems as u32,110};111self.nodes.insert(n)112}113114#[must_use]115fn new_leaf(&mut self, value: T) -> Key {116let mut uv = UnitVec::new();117uv.push(value);118self.new_tree_node(Key::null(), uv, Key::null())119}120121#[must_use]122unsafe fn drop_tree_node(&mut self, tree: Key) -> Node<T> {123unsafe { self.nodes.remove(tree).unwrap_unchecked() }124}125126#[inline]127pub fn get(&self, idx: usize) -> Option<&T> {128self._get(idx, self.root)129}130131fn _get(&self, idx: usize, tree: Key) -> Option<&T> {132if tree.is_null() {133return None;134}135136let n = unsafe { self.nodes.get_unchecked(tree) };137let own_elems = self.num_elems(tree);138let left_elems = self.num_elems(n.left);139let right_elems = self.num_elems(n.right);140141if idx < left_elems {142self._get(idx, n.left)143} else if idx >= own_elems - right_elems {144self._get(idx - (own_elems - right_elems), n.right)145} else {146n.values.get(idx - left_elems)147}148}149150#[inline]151pub fn insert(&mut self, value: T) {152(self.root, _) = self._insert(value, self.root);153}154155#[must_use]156fn _insert(&mut self, value: T, tree: Key) -> (Key, bool) {157if tree.is_null() {158return (self.new_leaf(value), true);159}160161let n = unsafe { self.nodes.get_unchecked(tree) };162match (self.compare)(&value, &n.values[0]) {163Ordering::Less => {164let (left, node_added) = self._insert(value, n.left);165let n = unsafe { self.nodes.get_unchecked_mut(tree) };166n.left = left;167n.weight += node_added as u32;168n.num_elems += 1;169(self.balance_r(tree), node_added)170},171Ordering::Equal => {172let n = unsafe { self.nodes.get_unchecked_mut(tree) };173n.values.push(value);174n.num_elems += 1;175(tree, false)176},177Ordering::Greater => {178let (right, node_added) = self._insert(value, n.right);179let n = unsafe { self.nodes.get_unchecked_mut(tree) };180n.right = right;181n.weight += node_added as u32;182n.num_elems += 1;183(self.balance_l(tree), node_added)184},185}186}187188#[inline]189pub fn remove(&mut self, value: &T) -> Option<T> {190let deleted;191(deleted, self.root, _) = self._remove(value, self.root);192deleted193}194195#[must_use]196fn _remove(&mut self, value: &T, tree: Key) -> (Option<T>, Key, bool) {197if tree.is_null() {198return (None, tree, false);199}200201let n = unsafe { self.nodes.get_unchecked(tree) };202match (self.compare)(value, &n.values[0]) {203Ordering::Less => {204let (deleted, left, node_removed) = self._remove(value, n.left);205let n = unsafe { self.nodes.get_unchecked_mut(tree) };206n.left = left;207n.weight -= node_removed as u32;208n.num_elems -= deleted.is_some() as u32;209(deleted, self.balance_l(tree), node_removed)210},211Ordering::Greater => {212let (deleted, right, node_removed) = self._remove(value, n.right);213let n = unsafe { self.nodes.get_unchecked_mut(tree) };214n.right = right;215n.weight -= node_removed as u32;216n.num_elems -= deleted.is_some() as u32;217(deleted, self.balance_r(tree), node_removed)218},219Ordering::Equal if n.values.len() > 1 => {220let n = unsafe { self.nodes.get_unchecked_mut(tree) };221let popped_value = unsafe { n.values.pop().unwrap_unchecked() };222n.num_elems -= 1;223(Some(popped_value), tree, false)224},225Ordering::Equal => {226let mut n = unsafe { self.drop_tree_node(tree) };227(228Some(unsafe { n.values.pop().unwrap_unchecked() }),229self.glue(n.left, n.right),230true,231)232},233}234}235236#[must_use]237fn glue(&mut self, left: Key, right: Key) -> Key {238if left.is_null() {239right240} else if right.is_null() {241left242} else if self.tree_weight(left) > self.tree_weight(right) {243let (deleted, left) = self.remove_max(left);244let tree = self.new_tree_node(left, deleted, right);245self.balance_r(tree)246} else {247let (deleted, right) = self.remove_min(right);248let tree = self.new_tree_node(left, deleted, right);249self.balance_l(tree)250}251}252253#[must_use]254fn remove_min(&mut self, tree: Key) -> (UnitVec<T>, Key) {255debug_assert!(!tree.is_null());256let n = unsafe { self.nodes.get_unchecked(tree) };257if n.left.is_null() {258let n = unsafe { self.drop_tree_node(tree) };259return (n.values, n.right);260}261let (deleted, left) = self.remove_min(n.left);262let n = unsafe { self.nodes.get_unchecked_mut(tree) };263n.left = left;264n.weight -= 1;265n.num_elems -= deleted.len() as u32;266(deleted, self.balance_l(tree))267}268269#[must_use]270fn remove_max(&mut self, tree: Key) -> (UnitVec<T>, Key) {271debug_assert!(!tree.is_null());272let n = unsafe { self.nodes.get_unchecked(tree) };273if n.right.is_null() {274let n = unsafe { self.drop_tree_node(tree) };275return (n.values, n.left);276}277let (deleted, right) = self.remove_max(n.right);278let n = unsafe { self.nodes.get_unchecked_mut(tree) };279n.right = right;280n.weight -= 1;281n.num_elems -= deleted.len() as u32;282(deleted, self.balance_r(tree))283}284285#[inline]286pub fn contains(&self, value: &T) -> bool {287self._contains(value, self.root)288}289290fn _contains(&self, value: &T, tree: Key) -> bool {291if tree.is_null() {292return false;293}294let n = unsafe { self.nodes.get_unchecked(tree) };295match (self.compare)(value, &n.values[0]) {296Ordering::Less => self._contains(value, n.left),297Ordering::Equal => true,298Ordering::Greater => self._contains(value, n.right),299}300}301302#[must_use]303fn balance_l(&mut self, tree: Key) -> Key {304let n = unsafe { self.nodes.get_unchecked(tree) };305if self.pair_is_balanced(n.left, n.right) {306return tree;307}308self.rotate_l(tree)309}310311#[must_use]312fn rotate_l(&mut self, tree: Key) -> Key {313let n = unsafe { self.nodes.get_unchecked(tree) };314let r = unsafe { self.nodes.get_unchecked(n.right) };315if self.is_single(r.left, r.right) {316self.single_l(tree)317} else {318self.double_l(tree)319}320}321322#[must_use]323fn single_l(&mut self, tree: Key) -> Key {324let n = unsafe { self.drop_tree_node(tree) };325let r = unsafe { self.drop_tree_node(n.right) };326let new_left = self.new_tree_node(n.left, n.values, r.left);327self.new_tree_node(new_left, r.values, r.right)328}329330#[must_use]331fn double_l(&mut self, tree: Key) -> Key {332let n = unsafe { self.drop_tree_node(tree) };333let r = unsafe { self.drop_tree_node(n.right) };334let rl = unsafe { self.drop_tree_node(r.left) };335let new_left = self.new_tree_node(n.left, n.values, rl.left);336let new_right = self.new_tree_node(rl.right, r.values, r.right);337self.new_tree_node(new_left, rl.values, new_right)338}339340#[must_use]341fn balance_r(&mut self, tree: Key) -> Key {342let n = unsafe { self.nodes.get_unchecked(tree) };343if self.pair_is_balanced(n.right, n.left) {344return tree;345}346self.rotate_r(tree)347}348349#[must_use]350fn rotate_r(&mut self, tree: Key) -> Key {351let n = unsafe { self.nodes.get_unchecked(tree) };352let l = unsafe { self.nodes.get_unchecked(n.left) };353if self.is_single(l.right, l.left) {354self.single_r(tree)355} else {356self.double_r(tree)357}358}359360#[must_use]361fn single_r(&mut self, tree: Key) -> Key {362let n = unsafe { self.drop_tree_node(tree) };363let l = unsafe { self.drop_tree_node(n.left) };364let new_right = self.new_tree_node(l.right, n.values, n.right);365self.new_tree_node(l.left, l.values, new_right)366}367368#[must_use]369fn double_r(&mut self, tree: Key) -> Key {370let n = unsafe { self.drop_tree_node(tree) };371let l = unsafe { self.drop_tree_node(n.left) };372let lr = unsafe { self.drop_tree_node(l.right) };373let new_right = self.new_tree_node(lr.right, n.values, n.right);374let new_left = self.new_tree_node(l.left, l.values, lr.left);375self.new_tree_node(new_left, lr.values, new_right)376}377378#[doc(hidden)]379pub fn is_balanced(&self) -> bool {380self.tree_is_balanced(self.root)381}382383fn tree_is_balanced(&self, tree: Key) -> bool {384if tree.is_null() {385return true;386}387let n = unsafe { self.nodes.get_unchecked(tree) };388self.pair_is_balanced(n.left, n.right)389&& self.pair_is_balanced(n.right, n.left)390&& self.tree_is_balanced(n.left)391&& self.tree_is_balanced(n.right)392}393394fn pair_is_balanced(&self, left: Key, right: Key) -> bool {395let a = self.tree_weight(left);396let b = self.tree_weight(right);397DELTA * (a + 1) >= (b + 1) && DELTA * (b + 1) >= (a + 1)398}399400fn is_single(&self, left: Key, right: Key) -> bool {401let a = self.tree_weight(left);402let b = self.tree_weight(right);403a + 1 < GAMMA * (b + 1)404}405406#[inline]407pub fn rank_range(&self, bound: &T) -> Result<RangeInclusive<usize>, usize> {408self._rank_range(bound, self.root)409}410411fn _rank_range(&self, value: &T, tree: Key) -> Result<RangeInclusive<usize>, usize> {412if tree.is_null() {413return Err(0);414}415let n = unsafe { self.nodes.get_unchecked(tree) };416match (self.compare)(value, &n.values[0]) {417Ordering::Less => self._rank_range(value, n.left),418Ordering::Equal => {419let lo = self.num_elems(n.left);420let hi = lo + n.values.len() - 1;421Ok(lo..=hi)422},423Ordering::Greater => {424let update_rank = |r| self.num_elems(tree) - self.num_elems(n.right) + r;425self._rank_range(value, n.right)426.map(|rank| update_rank(*rank.start())..=update_rank(*rank.end()))427.map_err(update_rank)428},429}430}431432#[inline]433pub fn rank_unique(&self, value: &T) -> Result<usize, usize> {434self._rank_unique(value, self.root)435}436437fn _rank_unique(&self, value: &T, tree: Key) -> Result<usize, usize> {438if tree.is_null() {439return Err(0);440}441let n = unsafe { self.nodes.get_unchecked(tree) };442match (self.compare)(value, &n.values[0]) {443Ordering::Less => self._rank_unique(value, n.left),444Ordering::Equal => Ok(self.tree_weight(n.left)),445Ordering::Greater => self446._rank_unique(value, n.right)447.map(|rank| self.tree_weight(tree) - self.tree_weight(n.right) + rank)448.map_err(|rank| self.tree_weight(tree) - self.tree_weight(n.right) + rank),449}450}451452#[inline]453pub fn count(&self, value: &T) -> usize {454self._count(value, self.root)455}456457fn _count(&self, value: &T, tree: Key) -> usize {458if tree.is_null() {459return 0;460}461let n = unsafe { self.nodes.get_unchecked(tree) };462match (self.compare)(value, &n.values[0]) {463Ordering::Less => self._count(value, n.left),464Ordering::Equal => n.values.len(),465Ordering::Greater => self._count(value, n.right),466}467}468}469470impl<T> Extend<T> for OrderStatisticTree<T> {471fn extend<I: IntoIterator<Item = T>>(&mut self, iterable: I) {472let iterator = iterable.into_iter();473for element in iterator {474self.insert(element);475}476}477}478479#[cfg(test)]480mod test {481482use proptest::collection::vec;483use proptest::prelude::*;484use proptest::test_runner::TestRunner;485486use super::*;487488#[test]489fn test_insert() {490let mut runner = TestRunner::default();491runner492.run(&vec((0i32..100, 0i32..100), 0..100), test_insert_inner)493.unwrap()494}495496fn test_insert_inner(items: Vec<(i32, i32)>) -> Result<(), TestCaseError> {497let cmp = |a: &(i32, i32), b: &(i32, i32)| i32::cmp(&a.0, &b.0);498let mut ost = OrderStatisticTree::new(cmp);499for item in &items {500ost.insert(*item);501assert!(ost.is_balanced());502}503assert_eq!(ost.len(), items.len());504let mut sorted_items = items.clone();505sorted_items.sort();506let mut collected_items = Vec::new();507let mut i = 0;508while let Some(v) = ost.get(i) {509collected_items.push(*v);510i += 1;511}512collected_items.sort();513assert_eq!(ost.len(), items.len());514assert_eq!(&collected_items, &sorted_items);515Ok(())516}517518#[test]519fn test_remove() {520let mut runner = TestRunner::default();521runner522.run(523&(vec(0i32..100, 0..100), vec(0i32..100, 0..100)),524test_remove_inner,525)526.unwrap();527}528529fn test_remove_inner(input: (Vec<i32>, Vec<i32>)) -> Result<(), TestCaseError> {530let (mut items, to_remove) = input;531let mut ost = OrderStatisticTree::new(i32::cmp);532for item in &items {533ost.insert(*item);534assert!(ost.is_balanced());535}536items.sort();537for item in &to_remove {538let v = ost.remove(item);539assert!(ost.is_balanced());540let idx = items.binary_search(item);541assert_eq!(v.is_some(), idx.is_ok());542if let Ok(idx) = idx {543items.remove(idx);544}545assert_eq!(ost.len(), items.len());546}547assert_eq!(ost.len(), items.len());548for item in 0..100 {549assert_eq!(ost.contains(&item), items.contains(&item));550}551Ok(())552}553554#[test]555fn test_rank() {556let mut runner = TestRunner::default();557runner558.run(&vec(0i32..100, 0..100), test_rank_inner)559.unwrap();560}561562fn test_rank_inner(mut items: Vec<i32>) -> Result<(), TestCaseError> {563let mut ost = OrderStatisticTree::new(i32::cmp);564for item in &items {565ost.insert(*item);566}567items.sort();568for item in 0..100 {569let rank = ost.rank_range(&item);570571let expected_rank = if items.contains(&item) {572let expected_rank_lower = items.iter().filter(|&x| *x < item).count();573let expected_rank_upper = items.iter().filter(|&x| *x <= item).count() - 1;574Ok(expected_rank_lower..=expected_rank_upper)575} else {576Err(items.iter().filter(|&x| *x < item).count())577};578579assert_eq!(rank, expected_rank);580}581Ok(())582}583584#[test]585fn test_unique_rank() {586let mut runner = TestRunner::default();587runner588.run(&vec(0i32..50, 0..100), test_unique_rank_inner)589.unwrap();590}591592fn test_unique_rank_inner(mut items: Vec<i32>) -> Result<(), TestCaseError> {593let mut ost = OrderStatisticTree::new(i32::cmp);594for item in &items {595ost.insert(*item);596}597assert_eq!(ost.len(), items.len());598items.sort();599items.dedup();600assert_eq!(ost.unique_len(), items.len());601for item in 0..50 {602let unique_rank = ost.rank_unique(&item);603let expected_unique_rank = if items.contains(&item) {604Ok(items.iter().filter(|&x| *x < item).count())605} else {606Err(items.iter().filter(|&x| *x < item).count())607};608assert_eq!(unique_rank, expected_unique_rank);609}610Ok(())611}612613#[test]614fn test_empty() {615let ost = OrderStatisticTree::<i32>::new(i32::cmp);616assert!(ost.is_empty());617assert_eq!(ost.len(), 0);618assert_eq!(ost.unique_len(), 0);619assert!(ost.is_balanced());620assert!(!ost.contains(&1));621assert_eq!(ost.rank_range(&1), Err(0));622assert_eq!(ost.rank_unique(&1), Err(0));623}624625#[test]626fn test_clear() {627let mut ost = OrderStatisticTree::new(i32::cmp);628for item in 0..10 {629ost.insert(item);630}631assert_eq!(ost.len(), 10);632assert_eq!(ost.unique_len(), 10);633ost.clear();634assert!(ost.is_empty());635}636637#[test]638fn test_extend() {639let mut ost = OrderStatisticTree::new(i32::cmp);640ost.extend(0..10);641assert_eq!(ost.len(), 10);642assert_eq!(ost.unique_len(), 10);643for item in 0..10 {644assert!(ost.contains(&item));645}646}647648#[test]649fn test_count() {650let mut ost = OrderStatisticTree::new(i32::cmp);651for item in &[1, 2, 2, 3, 3, 3] {652ost.insert(*item);653}654assert_eq!(ost.count(&1), 1);655assert_eq!(ost.count(&2), 2);656assert_eq!(ost.count(&3), 3);657assert_eq!(ost.count(&4), 0);658}659660#[test]661fn test_get() {662let mut ost = OrderStatisticTree::new(i32::cmp);663let mut items = [3, 1, 4, 1, 5, 9, 2, 6, 5];664for item in items {665ost.insert(item);666}667items.sort();668for (i, item) in items.iter().enumerate() {669assert_eq!(ost.get(i), Some(item));670}671assert_eq!(ost.get(items.len()), None);672}673}674675676