Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-arrow/src/bitmap/bitmask.rs
6939 views
1
#[cfg(feature = "simd")]
2
use std::simd::{LaneCount, Mask, MaskElement, SupportedLaneCount};
3
4
use polars_utils::slice::load_padded_le_u64;
5
6
use super::iterator::FastU56BitmapIter;
7
use super::utils::{BitmapIter, count_zeros, fmt};
8
use crate::bitmap::Bitmap;
9
10
/// Returns the nth set bit in w, if n+1 bits are set. The indexing is
11
/// zero-based, nth_set_bit_u32(w, 0) returns the least significant set bit in w.
12
#[inline]
13
pub fn nth_set_bit_u32(w: u32, n: u32) -> Option<u32> {
14
// If we have BMI2's PDEP available, we use it. It takes the lower order
15
// bits of the first argument and spreads it along its second argument
16
// where those bits are 1. So PDEP(abcdefgh, 11001001) becomes ef00g00h.
17
// We use this by setting the first argument to 1 << n, which means the
18
// first n-1 zero bits of it will spread to the first n-1 one bits of w,
19
// after which the one bit will exactly get copied to the nth one bit of w.
20
#[cfg(all(not(miri), target_feature = "bmi2"))]
21
{
22
if n >= 32 {
23
return None;
24
}
25
26
let nth_set_bit = unsafe { core::arch::x86_64::_pdep_u32(1 << n, w) };
27
if nth_set_bit == 0 {
28
return None;
29
}
30
31
Some(nth_set_bit.trailing_zeros())
32
}
33
34
#[cfg(any(miri, not(target_feature = "bmi2")))]
35
{
36
// Each block of 2/4/8/16 bits contains how many set bits there are in that block.
37
let set_per_2 = w - ((w >> 1) & 0x55555555);
38
let set_per_4 = (set_per_2 & 0x33333333) + ((set_per_2 >> 2) & 0x33333333);
39
let set_per_8 = (set_per_4 + (set_per_4 >> 4)) & 0x0f0f0f0f;
40
let set_per_16 = (set_per_8 + (set_per_8 >> 8)) & 0x00ff00ff;
41
let set_per_32 = (set_per_16 + (set_per_16 >> 16)) & 0xff;
42
if n >= set_per_32 {
43
return None;
44
}
45
46
let mut idx = 0;
47
let mut n = n;
48
let next16 = set_per_16 & 0xff;
49
if n >= next16 {
50
n -= next16;
51
idx += 16;
52
}
53
let next8 = (set_per_8 >> idx) & 0xff;
54
if n >= next8 {
55
n -= next8;
56
idx += 8;
57
}
58
let next4 = (set_per_4 >> idx) & 0b1111;
59
if n >= next4 {
60
n -= next4;
61
idx += 4;
62
}
63
let next2 = (set_per_2 >> idx) & 0b11;
64
if n >= next2 {
65
n -= next2;
66
idx += 2;
67
}
68
let next1 = (w >> idx) & 0b1;
69
if n >= next1 {
70
idx += 1;
71
}
72
Some(idx)
73
}
74
}
75
76
#[derive(Default, Clone)]
77
pub struct BitMask<'a> {
78
bytes: &'a [u8],
79
offset: usize,
80
len: usize,
81
}
82
83
impl std::fmt::Debug for BitMask<'_> {
84
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85
let Self { bytes, offset, len } = self;
86
let offset_num_bytes = offset / 8;
87
let offset_in_byte = offset % 8;
88
fmt(&bytes[offset_num_bytes..], offset_in_byte, *len, f)
89
}
90
}
91
92
impl<'a> BitMask<'a> {
93
pub fn from_bitmap(bitmap: &'a Bitmap) -> Self {
94
let (bytes, offset, len) = bitmap.as_slice();
95
Self::new(bytes, offset, len)
96
}
97
98
pub fn inner(&self) -> (&[u8], usize, usize) {
99
(self.bytes, self.offset, self.len)
100
}
101
102
pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self {
103
// Check length so we can use unsafe access in our get.
104
assert!(bytes.len() * 8 >= len + offset);
105
Self { bytes, offset, len }
106
}
107
108
#[inline(always)]
109
pub fn len(&self) -> usize {
110
self.len
111
}
112
113
#[inline]
114
pub fn advance_by(&mut self, idx: usize) {
115
assert!(idx <= self.len);
116
self.offset += idx;
117
self.len -= idx;
118
}
119
120
#[inline]
121
pub fn split_at(&self, idx: usize) -> (Self, Self) {
122
assert!(idx <= self.len);
123
unsafe { self.split_at_unchecked(idx) }
124
}
125
126
/// # Safety
127
/// The index must be in-bounds.
128
#[inline]
129
pub unsafe fn split_at_unchecked(&self, idx: usize) -> (Self, Self) {
130
debug_assert!(idx <= self.len);
131
let left = Self { len: idx, ..*self };
132
let right = Self {
133
len: self.len - idx,
134
offset: self.offset + idx,
135
..*self
136
};
137
(left, right)
138
}
139
140
#[inline]
141
pub fn sliced(&self, offset: usize, length: usize) -> Self {
142
assert!(offset.checked_add(length).unwrap() <= self.len);
143
unsafe { self.sliced_unchecked(offset, length) }
144
}
145
146
/// # Safety
147
/// The index must be in-bounds.
148
#[inline]
149
pub unsafe fn sliced_unchecked(&self, offset: usize, length: usize) -> Self {
150
if cfg!(debug_assertions) {
151
assert!(offset.checked_add(length).unwrap() <= self.len);
152
}
153
154
Self {
155
bytes: self.bytes,
156
offset: self.offset + offset,
157
len: length,
158
}
159
}
160
161
pub fn unset_bits(&self) -> usize {
162
count_zeros(self.bytes, self.offset, self.len)
163
}
164
165
pub fn set_bits(&self) -> usize {
166
self.len - self.unset_bits()
167
}
168
169
pub fn fast_iter_u56(&self) -> FastU56BitmapIter<'_> {
170
FastU56BitmapIter::new(self.bytes, self.offset, self.len)
171
}
172
173
#[cfg(feature = "simd")]
174
#[inline]
175
pub fn get_simd<T, const N: usize>(&self, idx: usize) -> Mask<T, N>
176
where
177
T: MaskElement,
178
LaneCount<N>: SupportedLaneCount,
179
{
180
// We don't support 64-lane masks because then we couldn't load our
181
// bitwise mask as a u64 and then do the byteshift on it.
182
183
let lanes = LaneCount::<N>::BITMASK_LEN;
184
assert!(lanes < 64);
185
186
let start_byte_idx = (self.offset + idx) / 8;
187
let byte_shift = (self.offset + idx) % 8;
188
if idx + lanes <= self.len {
189
// SAFETY: fast path, we know this is completely in-bounds.
190
let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
191
Mask::from_bitmask(mask >> byte_shift)
192
} else if idx < self.len {
193
// SAFETY: we know that at least the first byte is in-bounds.
194
// This is partially out of bounds, we have to do extra masking.
195
let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
196
let num_out_of_bounds = idx + lanes - self.len;
197
let shifted = (mask << num_out_of_bounds) >> (num_out_of_bounds + byte_shift);
198
Mask::from_bitmask(shifted)
199
} else {
200
Mask::from_bitmask(0u64)
201
}
202
}
203
204
#[inline]
205
pub fn get_u32(&self, idx: usize) -> u32 {
206
let start_byte_idx = (self.offset + idx) / 8;
207
let byte_shift = (self.offset + idx) % 8;
208
if idx + 32 <= self.len {
209
// SAFETY: fast path, we know this is completely in-bounds.
210
let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
211
(mask >> byte_shift) as u32
212
} else if idx < self.len {
213
// SAFETY: we know that at least the first byte is in-bounds.
214
// This is partially out of bounds, we have to do extra masking.
215
let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
216
let out_of_bounds_mask = (1u32 << (self.len - idx)) - 1;
217
((mask >> byte_shift) as u32) & out_of_bounds_mask
218
} else {
219
0
220
}
221
}
222
223
/// Computes the index of the nth set bit after start.
224
///
225
/// Both are zero-indexed, so `nth_set_bit_idx(0, 0)` finds the index of the
226
/// first bit set (which can be 0 as well). The returned index is absolute,
227
/// not relative to start.
228
pub fn nth_set_bit_idx(&self, mut n: usize, mut start: usize) -> Option<usize> {
229
while start < self.len {
230
let next_u32_mask = self.get_u32(start);
231
if next_u32_mask == u32::MAX {
232
// Happy fast path for dense non-null section.
233
if n < 32 {
234
return Some(start + n);
235
}
236
n -= 32;
237
} else {
238
let ones = next_u32_mask.count_ones() as usize;
239
if n < ones {
240
let idx = unsafe {
241
// SAFETY: we know the nth bit is in the mask.
242
nth_set_bit_u32(next_u32_mask, n as u32).unwrap_unchecked() as usize
243
};
244
return Some(start + idx);
245
}
246
n -= ones;
247
}
248
249
start += 32;
250
}
251
252
None
253
}
254
255
/// Computes the index of the nth set bit before end, counting backwards.
256
///
257
/// Both are zero-indexed, so nth_set_bit_idx_rev(0, len) finds the index of
258
/// the last bit set (which can be 0 as well). The returned index is
259
/// absolute (and starts at the beginning), not relative to end.
260
pub fn nth_set_bit_idx_rev(&self, mut n: usize, mut end: usize) -> Option<usize> {
261
while end > 0 {
262
// We want to find bits *before* end, so if end < 32 we must mask
263
// out the bits after the endth.
264
let (u32_mask_start, u32_mask_mask) = if end >= 32 {
265
(end - 32, u32::MAX)
266
} else {
267
(0, (1 << end) - 1)
268
};
269
let next_u32_mask = self.get_u32(u32_mask_start) & u32_mask_mask;
270
if next_u32_mask == u32::MAX {
271
// Happy fast path for dense non-null section.
272
if n < 32 {
273
return Some(end - 1 - n);
274
}
275
n -= 32;
276
} else {
277
let ones = next_u32_mask.count_ones() as usize;
278
if n < ones {
279
let rev_n = ones - 1 - n;
280
let idx = unsafe {
281
// SAFETY: we know the rev_nth bit is in the mask.
282
nth_set_bit_u32(next_u32_mask, rev_n as u32).unwrap_unchecked() as usize
283
};
284
return Some(u32_mask_start + idx);
285
}
286
n -= ones;
287
}
288
289
end = u32_mask_start;
290
}
291
292
None
293
}
294
295
#[inline]
296
pub fn get(&self, idx: usize) -> bool {
297
let byte_idx = (self.offset + idx) / 8;
298
let byte_shift = (self.offset + idx) % 8;
299
300
if idx < self.len {
301
// SAFETY: we know this is in-bounds.
302
let byte = unsafe { *self.bytes.get_unchecked(byte_idx) };
303
(byte >> byte_shift) & 1 == 1
304
} else {
305
false
306
}
307
}
308
309
pub fn iter(&self) -> BitmapIter<'_> {
310
BitmapIter::new(self.bytes, self.offset, self.len)
311
}
312
}
313
314
#[cfg(test)]
315
mod test {
316
use super::*;
317
318
fn naive_nth_bit_set(mut w: u32, mut n: u32) -> Option<u32> {
319
for i in 0..32 {
320
if w & (1 << i) != 0 {
321
if n == 0 {
322
return Some(i);
323
}
324
n -= 1;
325
w ^= 1 << i;
326
}
327
}
328
None
329
}
330
331
#[test]
332
fn test_nth_set_bit_u32() {
333
for n in 0..256 {
334
assert_eq!(nth_set_bit_u32(0, n), None);
335
}
336
337
for i in 0..32 {
338
assert_eq!(nth_set_bit_u32(1 << i, 0), Some(i));
339
assert_eq!(nth_set_bit_u32(1 << i, 1), None);
340
}
341
342
for i in 0..10000 {
343
let rnd = (0xbdbc9d8ec9d5c461u64.wrapping_mul(i as u64) >> 32) as u32;
344
for i in 0..=32 {
345
assert_eq!(nth_set_bit_u32(rnd, i), naive_nth_bit_set(rnd, i));
346
}
347
}
348
}
349
}
350
351