Path: blob/main/crates/polars-arrow/src/bitmap/utils/iterator.rs
8396 views
use polars_utils::slice::load_padded_le_u64;12use super::get_bit_unchecked;3use crate::bitmap::MutableBitmap;4use crate::trusted_len::TrustedLen;56/// An iterator over bits according to the [LSB](https://en.wikipedia.org/wiki/Bit_numbering#Least_significant_bit),7/// i.e. the bytes `[4u8, 128u8]` correspond to `[false, false, true, false, ..., true]`.8#[derive(Debug, Clone)]9pub struct BitmapIter<'a> {10bytes: &'a [u8],11word: u64,12word_len: usize,13rest_len: usize,14}1516impl<'a> BitmapIter<'a> {17/// Creates a new [`BitmapIter`].18pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self {19if len == 0 {20return Self {21bytes,22word: 0,23word_len: 0,24rest_len: 0,25};26}2728assert!(bytes.len() * 8 >= offset + len);29let first_byte_idx = offset / 8;30let bytes = &bytes[first_byte_idx..];31let offset = offset % 8;3233// Make sure during our hot loop all our loads are full 8-byte loads34// by loading the remainder now if it exists.35let word = load_padded_le_u64(bytes) >> offset;36let mod8 = bytes.len() % 8;37let first_word_bytes = if mod8 > 0 { mod8 } else { 8 };38let bytes = &bytes[first_word_bytes..];3940let word_len = (first_word_bytes * 8 - offset).min(len);41let rest_len = len - word_len;42Self {43bytes,44word,45word_len,46rest_len,47}48}4950/// Consume and returns the numbers of `1` / `true` values at the beginning of the iterator.51///52/// This performs the same operation as `(&mut iter).take_while(|b| b).count()`.53///54/// This is a lot more efficient than consecutively polling the iterator and should therefore55/// be preferred, if the use-case allows for it.56pub fn take_leading_ones(&mut self) -> usize {57let word_ones = usize::min(self.word_len, self.word.trailing_ones() as usize);58self.word_len -= word_ones;59self.word = self.word.wrapping_shr(word_ones as u32);6061if self.word_len != 0 {62return word_ones;63}6465let mut num_leading_ones = word_ones;6667while self.rest_len != 0 {68self.word_len = usize::min(self.rest_len, 64);69self.rest_len -= self.word_len;7071unsafe {72let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();73self.word = u64::from_le_bytes(chunk);74self.bytes = self.bytes.get_unchecked(8..);75}7677let word_ones = usize::min(self.word_len, self.word.trailing_ones() as usize);78self.word_len -= word_ones;79self.word = self.word.wrapping_shr(word_ones as u32);80num_leading_ones += word_ones;8182if self.word_len != 0 {83return num_leading_ones;84}85}8687num_leading_ones88}8990/// Consume and returns the numbers of `0` / `false` values that the start of the iterator.91///92/// This performs the same operation as `(&mut iter).take_while(|b| !b).count()`.93///94/// This is a lot more efficient than consecutively polling the iterator and should therefore95/// be preferred, if the use-case allows for it.96pub fn take_leading_zeros(&mut self) -> usize {97let word_zeros = usize::min(self.word_len, self.word.trailing_zeros() as usize);98self.word_len -= word_zeros;99self.word = self.word.wrapping_shr(word_zeros as u32);100101if self.word_len != 0 {102return word_zeros;103}104105let mut num_leading_zeros = word_zeros;106107while self.rest_len != 0 {108self.word_len = usize::min(self.rest_len, 64);109self.rest_len -= self.word_len;110unsafe {111let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();112self.word = u64::from_le_bytes(chunk);113self.bytes = self.bytes.get_unchecked(8..);114}115116let word_zeros = usize::min(self.word_len, self.word.trailing_zeros() as usize);117self.word_len -= word_zeros;118self.word = self.word.wrapping_shr(word_zeros as u32);119num_leading_zeros += word_zeros;120121if self.word_len != 0 {122return num_leading_zeros;123}124}125126num_leading_zeros127}128129/// Returns the number of remaining elements in the iterator130#[inline]131pub fn num_remaining(&self) -> usize {132self.word_len + self.rest_len133}134135/// Collect at most `n` elements from this iterator into `bitmap`136pub fn collect_n_into(&mut self, bitmap: &mut MutableBitmap, n: usize) {137fn collect_word(138word: &mut u64,139word_len: &mut usize,140bitmap: &mut MutableBitmap,141n: &mut usize,142) {143while *n > 0 && *word_len > 0 {144{145let trailing_ones = u32::min(word.trailing_ones(), *word_len as u32);146let shift = u32::min(usize::min(*n, u32::MAX as usize) as u32, trailing_ones);147*word = word.wrapping_shr(shift);148*word_len -= shift as usize;149*n -= shift as usize;150151bitmap.extend_constant(shift as usize, true);152}153154{155let trailing_zeros = u32::min(word.trailing_zeros(), *word_len as u32);156let shift = u32::min(usize::min(*n, u32::MAX as usize) as u32, trailing_zeros);157*word = word.wrapping_shr(shift);158*word_len -= shift as usize;159*n -= shift as usize;160161bitmap.extend_constant(shift as usize, false);162}163}164}165166let mut n = usize::min(n, self.num_remaining());167bitmap.reserve(n);168169collect_word(&mut self.word, &mut self.word_len, bitmap, &mut n);170171if n == 0 {172return;173}174175let num_words = n / 64;176177if num_words > 0 {178assert!(self.bytes.len() >= num_words * size_of::<u64>());179180bitmap.extend_from_slice(self.bytes, 0, num_words * u64::BITS as usize);181182self.bytes = unsafe { self.bytes.get_unchecked(num_words * 8..) };183self.rest_len -= num_words * u64::BITS as usize;184n -= num_words * u64::BITS as usize;185}186187if n == 0 {188return;189}190191assert!(self.bytes.len() >= size_of::<u64>());192193self.word_len = usize::min(self.rest_len, 64);194self.rest_len -= self.word_len;195unsafe {196let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();197self.word = u64::from_le_bytes(chunk);198self.bytes = self.bytes.get_unchecked(8..);199}200201collect_word(&mut self.word, &mut self.word_len, bitmap, &mut n);202203debug_assert!(self.num_remaining() == 0 || n == 0);204}205}206207impl Iterator for BitmapIter<'_> {208type Item = bool;209210#[inline]211fn next(&mut self) -> Option<Self::Item> {212if self.word_len == 0 {213if self.rest_len == 0 {214return None;215}216217self.word_len = self.rest_len.min(64);218self.rest_len -= self.word_len;219220unsafe {221let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();222self.word = u64::from_le_bytes(chunk);223self.bytes = self.bytes.get_unchecked(8..);224}225}226227let ret = self.word & 1 != 0;228self.word >>= 1;229self.word_len -= 1;230Some(ret)231}232233#[inline]234fn size_hint(&self) -> (usize, Option<usize>) {235let num_remaining = self.num_remaining();236(num_remaining, Some(num_remaining))237}238239#[inline]240fn nth(&mut self, mut n: usize) -> Option<Self::Item> {241if n >= self.word_len + self.rest_len {242self.word = 0;243self.word_len = 0;244self.rest_len = 0;245return None;246}247248// Advance words in buffer, skip words as needed249if n >= self.word_len {250n -= self.word_len;251252let word_offset = n / 64;253n -= word_offset * 64;254self.rest_len -= word_offset * 64;255256self.word_len = self.rest_len.min(64);257self.rest_len -= self.word_len;258259let byte_offset = 8 * word_offset;260261// Safety: bytes is large enough at construction time.262debug_assert!(byte_offset + 8 <= self.bytes.len());263unsafe {264let chunk = self265.bytes266.get_unchecked(byte_offset..byte_offset + 8)267.try_into()268.unwrap();269self.word = u64::from_le_bytes(chunk);270self.bytes = self.bytes.get_unchecked(byte_offset + 8..);271}272}273274// At this point, n < self.word_len275debug_assert!(self.word_len > n);276277// Advance index by n and take value at final index278self.word >>= n;279self.word_len -= n;280281let ret = self.word & 1 != 0;282self.word >>= 1;283self.word_len -= 1;284Some(ret)285}286}287288impl DoubleEndedIterator for BitmapIter<'_> {289#[inline]290fn next_back(&mut self) -> Option<bool> {291if self.rest_len > 0 {292self.rest_len -= 1;293Some(unsafe { get_bit_unchecked(self.bytes, self.rest_len) })294} else if self.word_len > 0 {295self.word_len -= 1;296Some(self.word & (1 << self.word_len) != 0)297} else {298None299}300}301}302303unsafe impl TrustedLen for BitmapIter<'_> {}304impl ExactSizeIterator for BitmapIter<'_> {}305306#[cfg(test)]307mod tests {308use super::*;309310#[test]311fn test_collect_into_17579() {312let mut bitmap = MutableBitmap::with_capacity(64);313BitmapIter::new(&[0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0], 0, 128)314.collect_n_into(&mut bitmap, 129);315316let bitmap = bitmap.freeze();317318assert_eq!(bitmap.set_bits(), 4);319}320321#[test]322#[ignore = "Fuzz test. Too slow"]323fn test_fuzz_collect_into() {324for _ in 0..10_000 {325let mut set_bits = 0;326let mut unset_bits = 0;327328let mut length = 0;329let mut pattern = Vec::new();330for _ in 0..rand::random::<u64>() % 1024 {331let bs = rand::random::<u8>() % 4;332333let word = match bs {3340 => u64::MIN,3351 => u64::MAX,3362 | 3 => rand::random(),337_ => unreachable!(),338};339340pattern.extend_from_slice(&word.to_le_bytes());341set_bits += word.count_ones();342unset_bits += word.count_zeros();343length += 64;344}345346for _ in 0..rand::random::<u64>() % 7 {347let b = rand::random::<u8>();348pattern.push(b);349set_bits += b.count_ones();350unset_bits += b.count_zeros();351length += 8;352}353354let last_length = rand::random::<u64>() % 8;355if last_length != 0 {356let b = rand::random::<u8>();357pattern.push(b);358let ones = (b & ((1 << last_length) - 1)).count_ones();359set_bits += ones;360unset_bits += last_length as u32 - ones;361length += last_length;362}363364let mut iter = BitmapIter::new(&pattern, 0, length as usize);365let mut bitmap = MutableBitmap::with_capacity(length as usize);366367while iter.num_remaining() > 0 {368let len_before = bitmap.len();369let n = rand::random::<u64>() as usize % iter.num_remaining();370iter.collect_n_into(&mut bitmap, n);371372// Ensure we are booking the progress we expect373assert_eq!(bitmap.len(), len_before + n);374}375376let bitmap = bitmap.freeze();377378assert_eq!(bitmap.set_bits(), set_bits as usize);379assert_eq!(bitmap.unset_bits(), unset_bits as usize);380}381}382383#[test]384#[ignore = "Fuzz test. Too slow"]385fn test_fuzz_leading_ops() {386for _ in 0..10_000 {387let mut length = 0;388let mut pattern = Vec::new();389for _ in 0..rand::random::<u64>() % 1024 {390let bs = rand::random::<u8>() % 4;391392let word = match bs {3930 => u64::MIN,3941 => u64::MAX,3952 | 3 => rand::random(),396_ => unreachable!(),397};398399pattern.extend_from_slice(&word.to_le_bytes());400length += 64;401}402403for _ in 0..rand::random::<u64>() % 7 {404pattern.push(rand::random::<u8>());405length += 8;406}407408let last_length = rand::random::<u64>() % 8;409if last_length != 0 {410pattern.push(rand::random::<u8>());411length += last_length;412}413414let mut iter = BitmapIter::new(&pattern, 0, length as usize);415416let mut prev_remaining = iter.num_remaining();417while iter.num_remaining() != 0 {418let num_ones = iter.clone().take_leading_ones();419assert_eq!(num_ones, (&mut iter).take_while(|&b| b).count());420421let num_zeros = iter.clone().take_leading_zeros();422assert_eq!(num_zeros, (&mut iter).take_while(|&b| !b).count());423424// Ensure that we are making progress425assert!(iter.num_remaining() < prev_remaining);426prev_remaining = iter.num_remaining();427}428429assert_eq!(iter.take_leading_zeros(), 0);430assert_eq!(iter.take_leading_ones(), 0);431}432}433434#[test]435#[allow(clippy::iter_nth_zero)]436fn test_bitmap_iter_nth() {437// Calling nth repeatedly advances through the bitmap438{439let mut iter = BitmapIter::new(&[0b10110001], 0, 8);440assert_eq!(iter.nth(0), Some(true));441assert_eq!(iter.nth(0), Some(false));442assert_eq!(iter.nth(2), Some(true));443assert_eq!(iter.nth(3), None);444445assert_eq!(iter.next(), None);446}447448// Test parity with next()-based implementation on of singular call to nth()449for len in [0, 1, 2, 63, 64, 65, 127, 128, 129] {450for offset in [0, 1, 2] {451// binary '01010101' == 85452let iter = BitmapIter::new(453&[4540, 1, 2, 4, 8, 16, 32, 64, 85, 170, 85, 170, 85, 170, 85, 170, 255, 0,455],456offset,457len,458);459460for i in 0..=len {461let mut iter_expected = iter.clone();462let mut iter_test = iter.clone();463464let prev_rest_len = iter_test.rest_len;465let prev_word_len = iter_test.word_len;466467assert_eq!(len, prev_rest_len + prev_word_len);468469// Iterate.470let out = iter_test.nth(i);471for _ in 0..i {472iter_expected.next();473}474let expected = iter_expected.next();475476// Check value.477assert_eq!(out, expected);478479// Check internal sate.480let final_rest_len = iter_test.rest_len;481let final_word_len = iter_test.word_len;482match out {483Some(_) => assert_eq!(484prev_rest_len + prev_word_len,485i + 1 + final_rest_len + final_word_len486),487None => {488assert!(i >= prev_rest_len + prev_word_len);489assert_eq!(final_rest_len + final_word_len, 0)490},491};492}493}494}495496// Check internal state on repeat calls to nth().497{498for len in [0, 63, 64, 65, 126, 128, 129] {499let mut iter =500BitmapIter::new(&[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0], 0, len);501for step in [0, 1, 2, 3] {502for i in (0..len + step + 1).step_by(step + 1) {503let prev_rest_len = iter.rest_len;504let prev_word_len = iter.word_len;505506let out = iter.nth(step);507508let final_rest_len = iter.rest_len;509let final_word_len = iter.word_len;510match out {511Some(_) => assert_eq!(512prev_rest_len + prev_word_len,513step + 1 + final_rest_len + final_word_len514),515None => {516assert!(i >= prev_rest_len + prev_word_len);517assert_eq!(final_rest_len + final_word_len, 0)518},519};520}521}522}523}524525// Edge cases526let mut iter = BitmapIter::new(&[], 0, 0);527assert_eq!(iter.nth(0), None);528}529}530531532