Path: blob/main/crates/polars-compute/src/hyperloglogplus.rs
6939 views
//! # HyperLogLogPlus1//!2//! `hyperloglogplus` module contains implementation of HyperLogLogPlus3//! algorithm for cardinality estimation so that [`crate::series::approx_n_unique`] function can4//! be efficiently implemented.5//!6//! This module borrows code from [arrow-datafusion](https://github.com/apache/arrow-datafusion/blob/93771052c5ac31f2cf22b8c25bf938656afe1047/datafusion/physical-expr/src/aggregate/hyperloglog.rs).7//!8//! # Examples9//!10//! ```11//! # use polars_compute::hyperloglogplus::*;12//! let mut hllp = HyperLogLog::new();13//! hllp.add(&12345);14//! hllp.add(&23456);15//!16//! assert_eq!(hllp.count(), 2);17//! ```1819use std::hash::{BuildHasher, Hash};20use std::marker::PhantomData;2122use polars_utils::aliases::PlFixedStateQuality;2324/// The greater is P, the smaller the error.25const HLL_P: usize = 14_usize;26/// The number of bits of the hash value used determining the number of leading zeros27const HLL_Q: usize = 64_usize - HLL_P;28const NUM_REGISTERS: usize = 1_usize << HLL_P;29/// Mask to obtain index into the registers30const HLL_P_MASK: u64 = (NUM_REGISTERS as u64) - 1;3132#[derive(Clone, Debug)]33pub struct HyperLogLog<T>34where35T: Hash + ?Sized,36{37registers: [u8; NUM_REGISTERS],38phantom: PhantomData<T>,39}4041impl<T> Default for HyperLogLog<T>42where43T: Hash + ?Sized,44{45fn default() -> Self {46Self::new()47}48}4950/// Fixed seed for the hashing so that values are consistent across runs51///52/// Note that when we later move on to have serialized HLL register binaries53/// shared across cluster, this SEED will have to be consistent across all54/// parties otherwise we might have corruption. So ideally for later this seed55/// shall be part of the serialized form (or stay unchanged across versions).56const SEED: PlFixedStateQuality = PlFixedStateQuality::with_seed(0);5758impl<T> HyperLogLog<T>59where60T: Hash + ?Sized,61{62/// Creates a new, empty HyperLogLog.63pub fn new() -> Self {64let registers = [0; NUM_REGISTERS];65Self::new_with_registers(registers)66}6768/// Creates a HyperLogLog from already populated registers69/// note that this method should not be invoked in untrusted environment70/// because the internal structure of registers are not examined.71pub(crate) fn new_with_registers(registers: [u8; NUM_REGISTERS]) -> Self {72Self {73registers,74phantom: PhantomData,75}76}7778#[inline]79fn hash_value(&self, obj: &T) -> u64 {80SEED.hash_one(obj)81}8283/// Adds an element to the HyperLogLog.84pub fn add(&mut self, obj: &T) {85let hash = self.hash_value(obj);86let index = (hash & HLL_P_MASK) as usize;87let p = ((hash >> HLL_P) | (1_u64 << HLL_Q)).trailing_zeros() + 1;88self.registers[index] = self.registers[index].max(p as u8);89}9091/// Get the register histogram (each value in register index into92/// the histogram; u32 is enough because we only have 2**14=16384 registers93#[inline]94fn get_histogram(&self) -> [u32; HLL_Q + 2] {95let mut histogram = [0; HLL_Q + 2];96// hopefully this can be unrolled97for r in self.registers {98histogram[r as usize] += 1;99}100histogram101}102103/// Merge the other [`HyperLogLog`] into this one104pub fn merge(&mut self, other: &HyperLogLog<T>) {105assert!(106self.registers.len() == other.registers.len(),107"unexpected got unequal register size, expect {}, got {}",108self.registers.len(),109other.registers.len()110);111for i in 0..self.registers.len() {112self.registers[i] = self.registers[i].max(other.registers[i]);113}114}115116/// Guess the number of unique elements seen by the HyperLogLog.117pub fn count(&self) -> usize {118let histogram = self.get_histogram();119let m = NUM_REGISTERS as f64;120let mut z = m * hll_tau((m - histogram[HLL_Q + 1] as f64) / m);121for i in histogram[1..=HLL_Q].iter().rev() {122z += *i as f64;123z *= 0.5;124}125z += m * hll_sigma(histogram[0] as f64 / m);126(0.5 / 2_f64.ln() * m * m / z).round() as usize127}128}129130/// Helper function sigma as defined in131/// "New cardinality estimation algorithms for HyperLogLog sketches"132/// Otmar Ertl, arXiv:1702.01284133#[inline]134fn hll_sigma(x: f64) -> f64 {135if x == 1. {136f64::INFINITY137} else {138let mut y = 1.0;139let mut z = x;140let mut x = x;141loop {142x *= x;143let z_prime = z;144z += x * y;145y += y;146if z_prime == z {147break;148}149}150z151}152}153154/// Helper function tau as defined in155/// "New cardinality estimation algorithms for HyperLogLog sketches"156/// Otmar Ertl, arXiv:1702.01284157#[inline]158fn hll_tau(x: f64) -> f64 {159if x == 0.0 || x == 1.0 {1600.0161} else {162let mut y = 1.0;163let mut z = 1.0 - x;164let mut x = x;165loop {166x = x.sqrt();167let z_prime = z;168y *= 0.5;169z -= (1.0 - x).powi(2) * y;170if z_prime == z {171break;172}173}174z / 3.0175}176}177178impl<T> AsRef<[u8]> for HyperLogLog<T>179where180T: Hash + ?Sized,181{182fn as_ref(&self) -> &[u8] {183&self.registers184}185}186187impl<T> Extend<T> for HyperLogLog<T>188where189T: Hash,190{191fn extend<S: IntoIterator<Item = T>>(&mut self, iter: S) {192for elem in iter {193self.add(&elem);194}195}196}197198impl<'a, T> Extend<&'a T> for HyperLogLog<T>199where200T: 'a + Hash + ?Sized,201{202fn extend<S: IntoIterator<Item = &'a T>>(&mut self, iter: S) {203for elem in iter {204self.add(elem);205}206}207}208209#[cfg(test)]210mod tests {211use super::{HyperLogLog, NUM_REGISTERS};212213fn compare_with_delta(got: usize, expected: usize) {214let expected = expected as f64;215let diff = (got as f64) - expected;216let diff = diff.abs() / expected;217// times 6 because we want the tests to be stable218// so we allow a rather large margin of error219// this is adopted from redis's unit test version as well220let margin = 1.04 / ((NUM_REGISTERS as f64).sqrt()) * 6.0;221assert!(222diff <= margin,223"{} is not near {} percent of {} which is ({}, {})",224got,225margin,226expected,227expected * (1.0 - margin),228expected * (1.0 + margin)229);230}231232macro_rules! sized_number_test {233($SIZE: expr, $T: tt) => {{234let mut hll = HyperLogLog::<$T>::new();235for i in 0..$SIZE {236hll.add(&i);237}238compare_with_delta(hll.count(), $SIZE);239}};240}241242macro_rules! typed_large_number_test {243($SIZE: expr) => {{244sized_number_test!($SIZE, u64);245sized_number_test!($SIZE, u128);246sized_number_test!($SIZE, i64);247sized_number_test!($SIZE, i128);248}};249}250251macro_rules! typed_number_test {252($SIZE: expr) => {{253sized_number_test!($SIZE, u16);254sized_number_test!($SIZE, u32);255sized_number_test!($SIZE, i16);256sized_number_test!($SIZE, i32);257typed_large_number_test!($SIZE);258}};259}260261#[test]262fn test_empty() {263let hll = HyperLogLog::<u64>::new();264assert_eq!(hll.count(), 0);265}266267#[test]268fn test_one() {269let mut hll = HyperLogLog::<u64>::new();270hll.add(&1);271assert_eq!(hll.count(), 1);272}273274#[test]275fn test_number_100() {276typed_number_test!(100);277}278279#[test]280fn test_number_1k() {281typed_number_test!(1_000);282}283284#[test]285fn test_number_10k() {286typed_number_test!(10_000);287}288289#[test]290fn test_number_100k() {291typed_large_number_test!(100_000);292}293294#[test]295fn test_number_1m() {296typed_large_number_test!(1_000_000);297}298299#[test]300fn test_u8() {301let mut hll = HyperLogLog::<[u8]>::new();302for i in 0..1000 {303let s = i.to_string();304let b = s.as_bytes();305hll.add(b);306}307compare_with_delta(hll.count(), 1000);308}309310#[test]311fn test_string() {312let mut hll = HyperLogLog::<String>::new();313hll.extend((0..1000).map(|i| i.to_string()));314compare_with_delta(hll.count(), 1000);315}316317#[test]318fn test_empty_merge() {319let mut hll = HyperLogLog::<u64>::new();320hll.merge(&HyperLogLog::<u64>::new());321assert_eq!(hll.count(), 0);322}323324#[test]325fn test_merge_overlapped() {326let mut hll = HyperLogLog::<String>::new();327hll.extend((0..1000).map(|i| i.to_string()));328329let mut other = HyperLogLog::<String>::new();330other.extend((0..1000).map(|i| i.to_string()));331332hll.merge(&other);333compare_with_delta(hll.count(), 1000);334}335336#[test]337fn test_repetition() {338let mut hll = HyperLogLog::<u32>::new();339for i in 0..1_000_000 {340hll.add(&(i % 1000));341}342compare_with_delta(hll.count(), 1000);343}344}345346347