Path: blob/main/crates/polars-arrow/src/bitmap/bitmask.rs
6939 views
#[cfg(feature = "simd")]1use std::simd::{LaneCount, Mask, MaskElement, SupportedLaneCount};23use polars_utils::slice::load_padded_le_u64;45use super::iterator::FastU56BitmapIter;6use super::utils::{BitmapIter, count_zeros, fmt};7use crate::bitmap::Bitmap;89/// Returns the nth set bit in w, if n+1 bits are set. The indexing is10/// zero-based, nth_set_bit_u32(w, 0) returns the least significant set bit in w.11#[inline]12pub fn nth_set_bit_u32(w: u32, n: u32) -> Option<u32> {13// If we have BMI2's PDEP available, we use it. It takes the lower order14// bits of the first argument and spreads it along its second argument15// where those bits are 1. So PDEP(abcdefgh, 11001001) becomes ef00g00h.16// We use this by setting the first argument to 1 << n, which means the17// first n-1 zero bits of it will spread to the first n-1 one bits of w,18// after which the one bit will exactly get copied to the nth one bit of w.19#[cfg(all(not(miri), target_feature = "bmi2"))]20{21if n >= 32 {22return None;23}2425let nth_set_bit = unsafe { core::arch::x86_64::_pdep_u32(1 << n, w) };26if nth_set_bit == 0 {27return None;28}2930Some(nth_set_bit.trailing_zeros())31}3233#[cfg(any(miri, not(target_feature = "bmi2")))]34{35// Each block of 2/4/8/16 bits contains how many set bits there are in that block.36let set_per_2 = w - ((w >> 1) & 0x55555555);37let set_per_4 = (set_per_2 & 0x33333333) + ((set_per_2 >> 2) & 0x33333333);38let set_per_8 = (set_per_4 + (set_per_4 >> 4)) & 0x0f0f0f0f;39let set_per_16 = (set_per_8 + (set_per_8 >> 8)) & 0x00ff00ff;40let set_per_32 = (set_per_16 + (set_per_16 >> 16)) & 0xff;41if n >= set_per_32 {42return None;43}4445let mut idx = 0;46let mut n = n;47let next16 = set_per_16 & 0xff;48if n >= next16 {49n -= next16;50idx += 16;51}52let next8 = (set_per_8 >> idx) & 0xff;53if n >= next8 {54n -= next8;55idx += 8;56}57let next4 = (set_per_4 >> idx) & 0b1111;58if n >= next4 {59n -= next4;60idx += 4;61}62let next2 = (set_per_2 >> idx) & 0b11;63if n >= next2 {64n -= next2;65idx += 2;66}67let next1 = (w >> idx) & 0b1;68if n >= next1 {69idx += 1;70}71Some(idx)72}73}7475#[derive(Default, Clone)]76pub struct BitMask<'a> {77bytes: &'a [u8],78offset: usize,79len: usize,80}8182impl std::fmt::Debug for BitMask<'_> {83fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {84let Self { bytes, offset, len } = self;85let offset_num_bytes = offset / 8;86let offset_in_byte = offset % 8;87fmt(&bytes[offset_num_bytes..], offset_in_byte, *len, f)88}89}9091impl<'a> BitMask<'a> {92pub fn from_bitmap(bitmap: &'a Bitmap) -> Self {93let (bytes, offset, len) = bitmap.as_slice();94Self::new(bytes, offset, len)95}9697pub fn inner(&self) -> (&[u8], usize, usize) {98(self.bytes, self.offset, self.len)99}100101pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self {102// Check length so we can use unsafe access in our get.103assert!(bytes.len() * 8 >= len + offset);104Self { bytes, offset, len }105}106107#[inline(always)]108pub fn len(&self) -> usize {109self.len110}111112#[inline]113pub fn advance_by(&mut self, idx: usize) {114assert!(idx <= self.len);115self.offset += idx;116self.len -= idx;117}118119#[inline]120pub fn split_at(&self, idx: usize) -> (Self, Self) {121assert!(idx <= self.len);122unsafe { self.split_at_unchecked(idx) }123}124125/// # Safety126/// The index must be in-bounds.127#[inline]128pub unsafe fn split_at_unchecked(&self, idx: usize) -> (Self, Self) {129debug_assert!(idx <= self.len);130let left = Self { len: idx, ..*self };131let right = Self {132len: self.len - idx,133offset: self.offset + idx,134..*self135};136(left, right)137}138139#[inline]140pub fn sliced(&self, offset: usize, length: usize) -> Self {141assert!(offset.checked_add(length).unwrap() <= self.len);142unsafe { self.sliced_unchecked(offset, length) }143}144145/// # Safety146/// The index must be in-bounds.147#[inline]148pub unsafe fn sliced_unchecked(&self, offset: usize, length: usize) -> Self {149if cfg!(debug_assertions) {150assert!(offset.checked_add(length).unwrap() <= self.len);151}152153Self {154bytes: self.bytes,155offset: self.offset + offset,156len: length,157}158}159160pub fn unset_bits(&self) -> usize {161count_zeros(self.bytes, self.offset, self.len)162}163164pub fn set_bits(&self) -> usize {165self.len - self.unset_bits()166}167168pub fn fast_iter_u56(&self) -> FastU56BitmapIter<'_> {169FastU56BitmapIter::new(self.bytes, self.offset, self.len)170}171172#[cfg(feature = "simd")]173#[inline]174pub fn get_simd<T, const N: usize>(&self, idx: usize) -> Mask<T, N>175where176T: MaskElement,177LaneCount<N>: SupportedLaneCount,178{179// We don't support 64-lane masks because then we couldn't load our180// bitwise mask as a u64 and then do the byteshift on it.181182let lanes = LaneCount::<N>::BITMASK_LEN;183assert!(lanes < 64);184185let start_byte_idx = (self.offset + idx) / 8;186let byte_shift = (self.offset + idx) % 8;187if idx + lanes <= self.len {188// SAFETY: fast path, we know this is completely in-bounds.189let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });190Mask::from_bitmask(mask >> byte_shift)191} else if idx < self.len {192// SAFETY: we know that at least the first byte is in-bounds.193// This is partially out of bounds, we have to do extra masking.194let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });195let num_out_of_bounds = idx + lanes - self.len;196let shifted = (mask << num_out_of_bounds) >> (num_out_of_bounds + byte_shift);197Mask::from_bitmask(shifted)198} else {199Mask::from_bitmask(0u64)200}201}202203#[inline]204pub fn get_u32(&self, idx: usize) -> u32 {205let start_byte_idx = (self.offset + idx) / 8;206let byte_shift = (self.offset + idx) % 8;207if idx + 32 <= self.len {208// SAFETY: fast path, we know this is completely in-bounds.209let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });210(mask >> byte_shift) as u32211} else if idx < self.len {212// SAFETY: we know that at least the first byte is in-bounds.213// This is partially out of bounds, we have to do extra masking.214let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });215let out_of_bounds_mask = (1u32 << (self.len - idx)) - 1;216((mask >> byte_shift) as u32) & out_of_bounds_mask217} else {2180219}220}221222/// Computes the index of the nth set bit after start.223///224/// Both are zero-indexed, so `nth_set_bit_idx(0, 0)` finds the index of the225/// first bit set (which can be 0 as well). The returned index is absolute,226/// not relative to start.227pub fn nth_set_bit_idx(&self, mut n: usize, mut start: usize) -> Option<usize> {228while start < self.len {229let next_u32_mask = self.get_u32(start);230if next_u32_mask == u32::MAX {231// Happy fast path for dense non-null section.232if n < 32 {233return Some(start + n);234}235n -= 32;236} else {237let ones = next_u32_mask.count_ones() as usize;238if n < ones {239let idx = unsafe {240// SAFETY: we know the nth bit is in the mask.241nth_set_bit_u32(next_u32_mask, n as u32).unwrap_unchecked() as usize242};243return Some(start + idx);244}245n -= ones;246}247248start += 32;249}250251None252}253254/// Computes the index of the nth set bit before end, counting backwards.255///256/// Both are zero-indexed, so nth_set_bit_idx_rev(0, len) finds the index of257/// the last bit set (which can be 0 as well). The returned index is258/// absolute (and starts at the beginning), not relative to end.259pub fn nth_set_bit_idx_rev(&self, mut n: usize, mut end: usize) -> Option<usize> {260while end > 0 {261// We want to find bits *before* end, so if end < 32 we must mask262// out the bits after the endth.263let (u32_mask_start, u32_mask_mask) = if end >= 32 {264(end - 32, u32::MAX)265} else {266(0, (1 << end) - 1)267};268let next_u32_mask = self.get_u32(u32_mask_start) & u32_mask_mask;269if next_u32_mask == u32::MAX {270// Happy fast path for dense non-null section.271if n < 32 {272return Some(end - 1 - n);273}274n -= 32;275} else {276let ones = next_u32_mask.count_ones() as usize;277if n < ones {278let rev_n = ones - 1 - n;279let idx = unsafe {280// SAFETY: we know the rev_nth bit is in the mask.281nth_set_bit_u32(next_u32_mask, rev_n as u32).unwrap_unchecked() as usize282};283return Some(u32_mask_start + idx);284}285n -= ones;286}287288end = u32_mask_start;289}290291None292}293294#[inline]295pub fn get(&self, idx: usize) -> bool {296let byte_idx = (self.offset + idx) / 8;297let byte_shift = (self.offset + idx) % 8;298299if idx < self.len {300// SAFETY: we know this is in-bounds.301let byte = unsafe { *self.bytes.get_unchecked(byte_idx) };302(byte >> byte_shift) & 1 == 1303} else {304false305}306}307308pub fn iter(&self) -> BitmapIter<'_> {309BitmapIter::new(self.bytes, self.offset, self.len)310}311}312313#[cfg(test)]314mod test {315use super::*;316317fn naive_nth_bit_set(mut w: u32, mut n: u32) -> Option<u32> {318for i in 0..32 {319if w & (1 << i) != 0 {320if n == 0 {321return Some(i);322}323n -= 1;324w ^= 1 << i;325}326}327None328}329330#[test]331fn test_nth_set_bit_u32() {332for n in 0..256 {333assert_eq!(nth_set_bit_u32(0, n), None);334}335336for i in 0..32 {337assert_eq!(nth_set_bit_u32(1 << i, 0), Some(i));338assert_eq!(nth_set_bit_u32(1 << i, 1), None);339}340341for i in 0..10000 {342let rnd = (0xbdbc9d8ec9d5c461u64.wrapping_mul(i as u64) >> 32) as u32;343for i in 0..=32 {344assert_eq!(nth_set_bit_u32(rnd, i), naive_nth_bit_set(rnd, i));345}346}347}348}349350351