Path: blob/main/crates/polars-arrow/src/bitmap/bitmask.rs
8424 views
#[cfg(feature = "simd")]1use std::simd::{LaneCount, Mask, MaskElement, SupportedLaneCount};23use polars_utils::slice::load_padded_le_u64;45use super::iterator::FastU56BitmapIter;6use super::utils::{self, BitChunk, BitChunks, BitmapIter, count_zeros, fmt};7use crate::bitmap::Bitmap;89/// Returns the nth set bit in w, if n+1 bits are set. The indexing is10/// zero-based, nth_set_bit_u32(w, 0) returns the least significant set bit in w.11#[inline]12pub fn nth_set_bit_u32(w: u32, n: u32) -> Option<u32> {13// If we have BMI2's PDEP available, we use it. It takes the lower order14// bits of the first argument and spreads it along its second argument15// where those bits are 1. So PDEP(abcdefgh, 11001001) becomes ef00g00h.16// We use this by setting the first argument to 1 << n, which means the17// first n-1 zero bits of it will spread to the first n-1 one bits of w,18// after which the one bit will exactly get copied to the nth one bit of w.19#[cfg(all(not(miri), target_feature = "bmi2"))]20{21if n >= 32 {22return None;23}2425let nth_set_bit = unsafe { core::arch::x86_64::_pdep_u32(1 << n, w) };26if nth_set_bit == 0 {27return None;28}2930Some(nth_set_bit.trailing_zeros())31}3233#[cfg(any(miri, not(target_feature = "bmi2")))]34{35// Each block of 2/4/8/16 bits contains how many set bits there are in that block.36let set_per_2 = w - ((w >> 1) & 0x55555555);37let set_per_4 = (set_per_2 & 0x33333333) + ((set_per_2 >> 2) & 0x33333333);38let set_per_8 = (set_per_4 + (set_per_4 >> 4)) & 0x0f0f0f0f;39let set_per_16 = (set_per_8 + (set_per_8 >> 8)) & 0x00ff00ff;40let set_per_32 = (set_per_16 + (set_per_16 >> 16)) & 0xff;4142if n >= set_per_32 {43return None;44}4546let mut idx = 0;47let mut n = n;4849let next16 = set_per_16 & 0xff;50if n >= next16 {51n -= next16;52idx += 16;53}54let next8 = (set_per_8 >> idx) & 0xff;55if n >= next8 {56n -= next8;57idx += 8;58}59let next4 = (set_per_4 >> idx) & 0b1111;60if n >= next4 {61n -= next4;62idx += 4;63}64let next2 = (set_per_2 >> idx) & 0b11;65if n >= next2 {66n -= next2;67idx += 2;68}69let next1 = (w >> idx) & 0b1;70if n >= next1 {71idx += 1;72}73Some(idx)74}75}7677#[inline]78pub fn nth_set_bit_u64(w: u64, n: u64) -> Option<u64> {79#[cfg(all(not(miri), target_feature = "bmi2"))]80{81if n >= 64 {82return None;83}8485let nth_set_bit = unsafe { core::arch::x86_64::_pdep_u64(1 << n, w) };86if nth_set_bit == 0 {87return None;88}8990Some(nth_set_bit.trailing_zeros().into())91}9293#[cfg(any(miri, not(target_feature = "bmi2")))]94{95// Each block of 2/4/8/16/32 bits contains how many set bits there are in that block.96let set_per_2 = w - ((w >> 1) & 0x5555555555555555);97let set_per_4 = (set_per_2 & 0x3333333333333333) + ((set_per_2 >> 2) & 0x3333333333333333);98let set_per_8 = (set_per_4 + (set_per_4 >> 4)) & 0x0f0f0f0f0f0f0f0f;99let set_per_16 = (set_per_8 + (set_per_8 >> 8)) & 0x00ff00ff00ff00ff;100let set_per_32 = (set_per_16 + (set_per_16 >> 16)) & 0x0000ffff0000ffff;101let set_per_64 = (set_per_32 + (set_per_32 >> 32)) & 0xffffffff;102103if n >= set_per_64 {104return None;105}106107let mut idx = 0;108let mut n = n;109110let next32 = set_per_32 & 0xffff;111if n >= next32 {112n -= next32;113idx += 32;114}115let next16 = (set_per_16 >> idx) & 0xffff;116if n >= next16 {117n -= next16;118idx += 16;119}120let next8 = (set_per_8 >> idx) & 0xff;121if n >= next8 {122n -= next8;123idx += 8;124}125let next4 = (set_per_4 >> idx) & 0b1111;126if n >= next4 {127n -= next4;128idx += 4;129}130let next2 = (set_per_2 >> idx) & 0b11;131if n >= next2 {132n -= next2;133idx += 2;134}135let next1 = (w >> idx) & 0b1;136if n >= next1 {137idx += 1;138}139Some(idx)140}141}142143#[derive(Default, Clone, Copy)]144pub struct BitMask<'a> {145bytes: &'a [u8],146offset: usize,147len: usize,148}149150impl std::fmt::Debug for BitMask<'_> {151fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {152let Self { bytes, offset, len } = self;153let offset_num_bytes = offset / 8;154let offset_in_byte = offset % 8;155fmt(&bytes[offset_num_bytes..], offset_in_byte, *len, f)156}157}158159impl<'a> BitMask<'a> {160pub fn from_bitmap(bitmap: &'a Bitmap) -> Self {161let (bytes, offset, len) = bitmap.as_slice();162Self::new(bytes, offset, len)163}164165pub fn inner(&self) -> (&[u8], usize, usize) {166(self.bytes, self.offset, self.len)167}168169pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self {170// Check length so we can use unsafe access in our get.171assert!(bytes.len() * 8 >= len + offset);172Self { bytes, offset, len }173}174175#[inline(always)]176pub fn len(&self) -> usize {177self.len178}179180#[inline]181pub fn advance_by(&mut self, idx: usize) {182assert!(idx <= self.len);183self.offset += idx;184self.len -= idx;185}186187#[inline]188pub fn split_at(&self, idx: usize) -> (Self, Self) {189assert!(idx <= self.len);190unsafe { self.split_at_unchecked(idx) }191}192193/// # Safety194/// The index must be in-bounds.195#[inline]196pub unsafe fn split_at_unchecked(&self, idx: usize) -> (Self, Self) {197debug_assert!(idx <= self.len);198let left = Self { len: idx, ..*self };199let right = Self {200len: self.len - idx,201offset: self.offset + idx,202..*self203};204(left, right)205}206207#[inline]208pub fn sliced(&self, offset: usize, length: usize) -> Self {209assert!(offset.checked_add(length).unwrap() <= self.len);210unsafe { self.sliced_unchecked(offset, length) }211}212213/// # Safety214/// The index must be in-bounds.215#[inline]216pub unsafe fn sliced_unchecked(&self, offset: usize, length: usize) -> Self {217if cfg!(debug_assertions) {218assert!(offset.checked_add(length).unwrap() <= self.len);219}220221Self {222bytes: self.bytes,223offset: self.offset + offset,224len: length,225}226}227228pub fn unset_bits(&self) -> usize {229count_zeros(self.bytes, self.offset, self.len)230}231232pub fn set_bits(&self) -> usize {233self.len - self.unset_bits()234}235236pub fn fast_iter_u56(&self) -> FastU56BitmapIter<'_> {237FastU56BitmapIter::new(self.bytes, self.offset, self.len)238}239240#[cfg(feature = "simd")]241#[inline]242pub fn get_simd<T, const N: usize>(&self, idx: usize) -> Mask<T, N>243where244T: MaskElement,245LaneCount<N>: SupportedLaneCount,246{247// We don't support 64-lane masks because then we couldn't load our248// bitwise mask as a u64 and then do the byteshift on it.249250let lanes = LaneCount::<N>::BITMASK_LEN;251assert!(lanes < 64);252253let start_byte_idx = (self.offset + idx) / 8;254let byte_shift = (self.offset + idx) % 8;255if idx + lanes <= self.len {256// SAFETY: fast path, we know this is completely in-bounds.257let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });258Mask::from_bitmask(mask >> byte_shift)259} else if idx < self.len {260// SAFETY: we know that at least the first byte is in-bounds.261// This is partially out of bounds, we have to do extra masking.262let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });263let num_out_of_bounds = idx + lanes - self.len;264let shifted = (mask << num_out_of_bounds) >> (num_out_of_bounds + byte_shift);265Mask::from_bitmask(shifted)266} else {267Mask::from_bitmask(0u64)268}269}270271#[inline]272pub fn get_u32(&self, idx: usize) -> u32 {273let start_byte_idx = (self.offset + idx) / 8;274let byte_shift = (self.offset + idx) % 8;275if idx + 32 <= self.len {276// SAFETY: fast path, we know this is completely in-bounds.277let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });278(mask >> byte_shift) as u32279} else if idx < self.len {280// SAFETY: we know that at least the first byte is in-bounds.281// This is partially out of bounds, we have to do extra masking.282let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });283let out_of_bounds_mask = (1u32 << (self.len - idx)) - 1;284((mask >> byte_shift) as u32) & out_of_bounds_mask285} else {2860287}288}289290/// Computes the index of the nth set bit after start.291///292/// Both are zero-indexed, so `nth_set_bit_idx(0, 0)` finds the index of the293/// first bit set (which can be 0 as well). The returned index is absolute,294/// not relative to start.295pub fn nth_set_bit_idx(&self, mut n: usize, mut start: usize) -> Option<usize> {296while start < self.len {297let next_u32_mask = self.get_u32(start);298if next_u32_mask == u32::MAX {299// Happy fast path for dense non-null section.300if n < 32 {301return Some(start + n);302}303n -= 32;304} else {305let ones = next_u32_mask.count_ones() as usize;306if n < ones {307let idx = unsafe {308// SAFETY: we know the nth bit is in the mask.309nth_set_bit_u32(next_u32_mask, n as u32).unwrap_unchecked() as usize310};311return Some(start + idx);312}313n -= ones;314}315316start += 32;317}318319None320}321322/// Computes the index of the nth set bit before end, counting backwards.323///324/// Both are zero-indexed, so nth_set_bit_idx_rev(0, len) finds the index of325/// the last bit set (which can be 0 as well). The returned index is326/// absolute (and starts at the beginning), not relative to end.327pub fn nth_set_bit_idx_rev(&self, mut n: usize, mut end: usize) -> Option<usize> {328while end > 0 {329// We want to find bits *before* end, so if end < 32 we must mask330// out the bits after the endth.331let (u32_mask_start, u32_mask_mask) = if end >= 32 {332(end - 32, u32::MAX)333} else {334(0, (1 << end) - 1)335};336let next_u32_mask = self.get_u32(u32_mask_start) & u32_mask_mask;337if next_u32_mask == u32::MAX {338// Happy fast path for dense non-null section.339if n < 32 {340return Some(end - 1 - n);341}342n -= 32;343} else {344let ones = next_u32_mask.count_ones() as usize;345if n < ones {346let rev_n = ones - 1 - n;347let idx = unsafe {348// SAFETY: we know the rev_nth bit is in the mask.349nth_set_bit_u32(next_u32_mask, rev_n as u32).unwrap_unchecked() as usize350};351return Some(u32_mask_start + idx);352}353n -= ones;354}355356end = u32_mask_start;357}358359None360}361362#[inline]363pub fn get(&self, idx: usize) -> bool {364if idx < self.len {365// SAFETY: we know this is in-bounds.366unsafe { self.get_bit_unchecked(idx) }367} else {368false369}370}371372#[inline]373/// Get a bit at a certain idx.374///375/// # Safety376///377/// `idx` should be smaller than `len`378pub unsafe fn get_bit_unchecked(&self, idx: usize) -> bool {379let byte_idx = (self.offset + idx) / 8;380let byte_shift = (self.offset + idx) % 8;381382// SAFETY: we know this is in-bounds.383let byte = unsafe { *self.bytes.get_unchecked(byte_idx) };384(byte >> byte_shift) & 1 == 1385}386387pub fn iter(self) -> BitmapIter<'a> {388BitmapIter::new(self.bytes, self.offset, self.len)389}390391/// Returns the number of zero bits from the start before a one bit is seen392pub fn leading_zeros(self) -> usize {393utils::leading_zeros(self.bytes, self.offset, self.len)394}395/// Returns the number of one bits from the start before a zero bit is seen396pub fn leading_ones(self) -> usize {397utils::leading_ones(self.bytes, self.offset, self.len)398}399/// Returns the number of zero bits from the back before a one bit is seen400pub fn trailing_zeros(self) -> usize {401utils::trailing_zeros(self.bytes, self.offset, self.len)402}403/// Returns the number of one bits from the back before a zero bit is seen404pub fn trailing_ones(self) -> usize {405utils::trailing_ones(self.bytes, self.offset, self.len)406}407408/// Checks whether two [`Bitmap`]s have shared set bits.409///410/// This is an optimized version of `(self & other) != 0000..`.411pub fn intersects_with(self, other: Self) -> bool {412self.num_intersections_with(other) != 0413}414415/// Calculates the number of shared set bits between two [`Bitmap`]s.416pub fn num_intersections_with(self, other: Self) -> usize {417super::num_intersections_with(self, other)418}419420/// Returns an iterator over bits in bit chunks [`BitChunk`].421///422/// This iterator is useful to operate over multiple bits via e.g. bitwise.423pub fn chunks<T: BitChunk>(self) -> BitChunks<'a, T> {424BitChunks::new(self.bytes, self.offset, self.len)425}426}427428#[cfg(test)]429mod test {430use super::*;431432fn naive_nth_bit_set_u32(mut w: u32, mut n: u32) -> Option<u32> {433for i in 0..32 {434if w & (1 << i) != 0 {435if n == 0 {436return Some(i);437}438n -= 1;439w ^= 1 << i;440}441}442None443}444445fn naive_nth_bit_set_u64(mut w: u64, mut n: u64) -> Option<u64> {446for i in 0..64 {447if w & (1 << i) != 0 {448if n == 0 {449return Some(i);450}451n -= 1;452w ^= 1 << i;453}454}455None456}457458#[test]459fn test_nth_set_bit_u32() {460for n in 0..256 {461assert_eq!(nth_set_bit_u32(0, n), None);462}463464for i in 0..32 {465assert_eq!(nth_set_bit_u32(1 << i, 0), Some(i));466assert_eq!(nth_set_bit_u32(1 << i, 1), None);467}468469for i in 0..10000 {470let rnd = (0xbdbc9d8ec9d5c461u64.wrapping_mul(i as u64) >> 32) as u32;471for i in 0..=32 {472assert_eq!(nth_set_bit_u32(rnd, i), naive_nth_bit_set_u32(rnd, i));473}474}475}476477#[test]478fn test_nth_set_bit_u64() {479for n in 0..256 {480assert_eq!(nth_set_bit_u64(0, n), None);481}482483for i in 0..64 {484assert_eq!(nth_set_bit_u64(1 << i, 0), Some(i));485assert_eq!(nth_set_bit_u64(1 << i, 1), None);486}487488for i in 0..10000 {489let rnd = 0xbdbc9d8ec9d5c461u64.wrapping_mul(i as u64) >> 32;490for i in 0..=64 {491assert_eq!(nth_set_bit_u64(rnd, i), naive_nth_bit_set_u64(rnd, i));492}493}494}495}496497498