Path: blob/main/crates/polars-arrow/src/bitmap/utils/iterator.rs
6939 views
use polars_utils::slice::load_padded_le_u64;12use super::get_bit_unchecked;3use crate::bitmap::MutableBitmap;4use crate::trusted_len::TrustedLen;56/// An iterator over bits according to the [LSB](https://en.wikipedia.org/wiki/Bit_numbering#Least_significant_bit),7/// i.e. the bytes `[4u8, 128u8]` correspond to `[false, false, true, false, ..., true]`.8#[derive(Debug, Clone)]9pub struct BitmapIter<'a> {10bytes: &'a [u8],11word: u64,12word_len: usize,13rest_len: usize,14}1516impl<'a> BitmapIter<'a> {17/// Creates a new [`BitmapIter`].18pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self {19if len == 0 {20return Self {21bytes,22word: 0,23word_len: 0,24rest_len: 0,25};26}2728assert!(bytes.len() * 8 >= offset + len);29let first_byte_idx = offset / 8;30let bytes = &bytes[first_byte_idx..];31let offset = offset % 8;3233// Make sure during our hot loop all our loads are full 8-byte loads34// by loading the remainder now if it exists.35let word = load_padded_le_u64(bytes) >> offset;36let mod8 = bytes.len() % 8;37let first_word_bytes = if mod8 > 0 { mod8 } else { 8 };38let bytes = &bytes[first_word_bytes..];3940let word_len = (first_word_bytes * 8 - offset).min(len);41let rest_len = len - word_len;42Self {43bytes,44word,45word_len,46rest_len,47}48}4950/// Consume and returns the numbers of `1` / `true` values at the beginning of the iterator.51///52/// This performs the same operation as `(&mut iter).take_while(|b| b).count()`.53///54/// This is a lot more efficient than consecutively polling the iterator and should therefore55/// be preferred, if the use-case allows for it.56pub fn take_leading_ones(&mut self) -> usize {57let word_ones = usize::min(self.word_len, self.word.trailing_ones() as usize);58self.word_len -= word_ones;59self.word = self.word.wrapping_shr(word_ones as u32);6061if self.word_len != 0 {62return word_ones;63}6465let mut num_leading_ones = word_ones;6667while self.rest_len != 0 {68self.word_len = usize::min(self.rest_len, 64);69self.rest_len -= self.word_len;7071unsafe {72let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();73self.word = u64::from_le_bytes(chunk);74self.bytes = self.bytes.get_unchecked(8..);75}7677let word_ones = usize::min(self.word_len, self.word.trailing_ones() as usize);78self.word_len -= word_ones;79self.word = self.word.wrapping_shr(word_ones as u32);80num_leading_ones += word_ones;8182if self.word_len != 0 {83return num_leading_ones;84}85}8687num_leading_ones88}8990/// Consume and returns the numbers of `0` / `false` values that the start of the iterator.91///92/// This performs the same operation as `(&mut iter).take_while(|b| !b).count()`.93///94/// This is a lot more efficient than consecutively polling the iterator and should therefore95/// be preferred, if the use-case allows for it.96pub fn take_leading_zeros(&mut self) -> usize {97let word_zeros = usize::min(self.word_len, self.word.trailing_zeros() as usize);98self.word_len -= word_zeros;99self.word = self.word.wrapping_shr(word_zeros as u32);100101if self.word_len != 0 {102return word_zeros;103}104105let mut num_leading_zeros = word_zeros;106107while self.rest_len != 0 {108self.word_len = usize::min(self.rest_len, 64);109self.rest_len -= self.word_len;110unsafe {111let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();112self.word = u64::from_le_bytes(chunk);113self.bytes = self.bytes.get_unchecked(8..);114}115116let word_zeros = usize::min(self.word_len, self.word.trailing_zeros() as usize);117self.word_len -= word_zeros;118self.word = self.word.wrapping_shr(word_zeros as u32);119num_leading_zeros += word_zeros;120121if self.word_len != 0 {122return num_leading_zeros;123}124}125126num_leading_zeros127}128129/// Returns the number of remaining elements in the iterator130#[inline]131pub fn num_remaining(&self) -> usize {132self.word_len + self.rest_len133}134135/// Collect at most `n` elements from this iterator into `bitmap`136pub fn collect_n_into(&mut self, bitmap: &mut MutableBitmap, n: usize) {137fn collect_word(138word: &mut u64,139word_len: &mut usize,140bitmap: &mut MutableBitmap,141n: &mut usize,142) {143while *n > 0 && *word_len > 0 {144{145let trailing_ones = u32::min(word.trailing_ones(), *word_len as u32);146let shift = u32::min(usize::min(*n, u32::MAX as usize) as u32, trailing_ones);147*word = word.wrapping_shr(shift);148*word_len -= shift as usize;149*n -= shift as usize;150151bitmap.extend_constant(shift as usize, true);152}153154{155let trailing_zeros = u32::min(word.trailing_zeros(), *word_len as u32);156let shift = u32::min(usize::min(*n, u32::MAX as usize) as u32, trailing_zeros);157*word = word.wrapping_shr(shift);158*word_len -= shift as usize;159*n -= shift as usize;160161bitmap.extend_constant(shift as usize, false);162}163}164}165166let mut n = usize::min(n, self.num_remaining());167bitmap.reserve(n);168169collect_word(&mut self.word, &mut self.word_len, bitmap, &mut n);170171if n == 0 {172return;173}174175let num_words = n / 64;176177if num_words > 0 {178assert!(self.bytes.len() >= num_words * size_of::<u64>());179180bitmap.extend_from_slice(self.bytes, 0, num_words * u64::BITS as usize);181182self.bytes = unsafe { self.bytes.get_unchecked(num_words * 8..) };183self.rest_len -= num_words * u64::BITS as usize;184n -= num_words * u64::BITS as usize;185}186187if n == 0 {188return;189}190191assert!(self.bytes.len() >= size_of::<u64>());192193self.word_len = usize::min(self.rest_len, 64);194self.rest_len -= self.word_len;195unsafe {196let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();197self.word = u64::from_le_bytes(chunk);198self.bytes = self.bytes.get_unchecked(8..);199}200201collect_word(&mut self.word, &mut self.word_len, bitmap, &mut n);202203debug_assert!(self.num_remaining() == 0 || n == 0);204}205}206207impl Iterator for BitmapIter<'_> {208type Item = bool;209210#[inline]211fn next(&mut self) -> Option<Self::Item> {212if self.word_len == 0 {213if self.rest_len == 0 {214return None;215}216217self.word_len = self.rest_len.min(64);218self.rest_len -= self.word_len;219220unsafe {221let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();222self.word = u64::from_le_bytes(chunk);223self.bytes = self.bytes.get_unchecked(8..);224}225}226227let ret = self.word & 1 != 0;228self.word >>= 1;229self.word_len -= 1;230Some(ret)231}232233#[inline]234fn size_hint(&self) -> (usize, Option<usize>) {235let num_remaining = self.num_remaining();236(num_remaining, Some(num_remaining))237}238}239240impl DoubleEndedIterator for BitmapIter<'_> {241#[inline]242fn next_back(&mut self) -> Option<bool> {243if self.rest_len > 0 {244self.rest_len -= 1;245Some(unsafe { get_bit_unchecked(self.bytes, self.rest_len) })246} else if self.word_len > 0 {247self.word_len -= 1;248Some(self.word & (1 << self.word_len) != 0)249} else {250None251}252}253}254255unsafe impl TrustedLen for BitmapIter<'_> {}256impl ExactSizeIterator for BitmapIter<'_> {}257258#[cfg(test)]259mod tests {260use super::*;261262#[test]263fn test_collect_into_17579() {264let mut bitmap = MutableBitmap::with_capacity(64);265BitmapIter::new(&[0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0], 0, 128)266.collect_n_into(&mut bitmap, 129);267268let bitmap = bitmap.freeze();269270assert_eq!(bitmap.set_bits(), 4);271}272273#[test]274#[ignore = "Fuzz test. Too slow"]275fn test_fuzz_collect_into() {276for _ in 0..10_000 {277let mut set_bits = 0;278let mut unset_bits = 0;279280let mut length = 0;281let mut pattern = Vec::new();282for _ in 0..rand::random::<u64>() % 1024 {283let bs = rand::random::<u8>() % 4;284285let word = match bs {2860 => u64::MIN,2871 => u64::MAX,2882 | 3 => rand::random(),289_ => unreachable!(),290};291292pattern.extend_from_slice(&word.to_le_bytes());293set_bits += word.count_ones();294unset_bits += word.count_zeros();295length += 64;296}297298for _ in 0..rand::random::<u64>() % 7 {299let b = rand::random::<u8>();300pattern.push(b);301set_bits += b.count_ones();302unset_bits += b.count_zeros();303length += 8;304}305306let last_length = rand::random::<u64>() % 8;307if last_length != 0 {308let b = rand::random::<u8>();309pattern.push(b);310let ones = (b & ((1 << last_length) - 1)).count_ones();311set_bits += ones;312unset_bits += last_length as u32 - ones;313length += last_length;314}315316let mut iter = BitmapIter::new(&pattern, 0, length as usize);317let mut bitmap = MutableBitmap::with_capacity(length as usize);318319while iter.num_remaining() > 0 {320let len_before = bitmap.len();321let n = rand::random::<u64>() as usize % iter.num_remaining();322iter.collect_n_into(&mut bitmap, n);323324// Ensure we are booking the progress we expect325assert_eq!(bitmap.len(), len_before + n);326}327328let bitmap = bitmap.freeze();329330assert_eq!(bitmap.set_bits(), set_bits as usize);331assert_eq!(bitmap.unset_bits(), unset_bits as usize);332}333}334335#[test]336#[ignore = "Fuzz test. Too slow"]337fn test_fuzz_leading_ops() {338for _ in 0..10_000 {339let mut length = 0;340let mut pattern = Vec::new();341for _ in 0..rand::random::<u64>() % 1024 {342let bs = rand::random::<u8>() % 4;343344let word = match bs {3450 => u64::MIN,3461 => u64::MAX,3472 | 3 => rand::random(),348_ => unreachable!(),349};350351pattern.extend_from_slice(&word.to_le_bytes());352length += 64;353}354355for _ in 0..rand::random::<u64>() % 7 {356pattern.push(rand::random::<u8>());357length += 8;358}359360let last_length = rand::random::<u64>() % 8;361if last_length != 0 {362pattern.push(rand::random::<u8>());363length += last_length;364}365366let mut iter = BitmapIter::new(&pattern, 0, length as usize);367368let mut prev_remaining = iter.num_remaining();369while iter.num_remaining() != 0 {370let num_ones = iter.clone().take_leading_ones();371assert_eq!(num_ones, (&mut iter).take_while(|&b| b).count());372373let num_zeros = iter.clone().take_leading_zeros();374assert_eq!(num_zeros, (&mut iter).take_while(|&b| !b).count());375376// Ensure that we are making progress377assert!(iter.num_remaining() < prev_remaining);378prev_remaining = iter.num_remaining();379}380381assert_eq!(iter.take_leading_zeros(), 0);382assert_eq!(iter.take_leading_ones(), 0);383}384}385}386387388