Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-arrow/src/bitmap/bitmask.rs
8424 views
1
#[cfg(feature = "simd")]
2
use std::simd::{LaneCount, Mask, MaskElement, SupportedLaneCount};
3
4
use polars_utils::slice::load_padded_le_u64;
5
6
use super::iterator::FastU56BitmapIter;
7
use super::utils::{self, BitChunk, BitChunks, BitmapIter, count_zeros, fmt};
8
use crate::bitmap::Bitmap;
9
10
/// Returns the nth set bit in w, if n+1 bits are set. The indexing is
11
/// zero-based, nth_set_bit_u32(w, 0) returns the least significant set bit in w.
12
#[inline]
13
pub fn nth_set_bit_u32(w: u32, n: u32) -> Option<u32> {
14
// If we have BMI2's PDEP available, we use it. It takes the lower order
15
// bits of the first argument and spreads it along its second argument
16
// where those bits are 1. So PDEP(abcdefgh, 11001001) becomes ef00g00h.
17
// We use this by setting the first argument to 1 << n, which means the
18
// first n-1 zero bits of it will spread to the first n-1 one bits of w,
19
// after which the one bit will exactly get copied to the nth one bit of w.
20
#[cfg(all(not(miri), target_feature = "bmi2"))]
21
{
22
if n >= 32 {
23
return None;
24
}
25
26
let nth_set_bit = unsafe { core::arch::x86_64::_pdep_u32(1 << n, w) };
27
if nth_set_bit == 0 {
28
return None;
29
}
30
31
Some(nth_set_bit.trailing_zeros())
32
}
33
34
#[cfg(any(miri, not(target_feature = "bmi2")))]
35
{
36
// Each block of 2/4/8/16 bits contains how many set bits there are in that block.
37
let set_per_2 = w - ((w >> 1) & 0x55555555);
38
let set_per_4 = (set_per_2 & 0x33333333) + ((set_per_2 >> 2) & 0x33333333);
39
let set_per_8 = (set_per_4 + (set_per_4 >> 4)) & 0x0f0f0f0f;
40
let set_per_16 = (set_per_8 + (set_per_8 >> 8)) & 0x00ff00ff;
41
let set_per_32 = (set_per_16 + (set_per_16 >> 16)) & 0xff;
42
43
if n >= set_per_32 {
44
return None;
45
}
46
47
let mut idx = 0;
48
let mut n = n;
49
50
let next16 = set_per_16 & 0xff;
51
if n >= next16 {
52
n -= next16;
53
idx += 16;
54
}
55
let next8 = (set_per_8 >> idx) & 0xff;
56
if n >= next8 {
57
n -= next8;
58
idx += 8;
59
}
60
let next4 = (set_per_4 >> idx) & 0b1111;
61
if n >= next4 {
62
n -= next4;
63
idx += 4;
64
}
65
let next2 = (set_per_2 >> idx) & 0b11;
66
if n >= next2 {
67
n -= next2;
68
idx += 2;
69
}
70
let next1 = (w >> idx) & 0b1;
71
if n >= next1 {
72
idx += 1;
73
}
74
Some(idx)
75
}
76
}
77
78
#[inline]
79
pub fn nth_set_bit_u64(w: u64, n: u64) -> Option<u64> {
80
#[cfg(all(not(miri), target_feature = "bmi2"))]
81
{
82
if n >= 64 {
83
return None;
84
}
85
86
let nth_set_bit = unsafe { core::arch::x86_64::_pdep_u64(1 << n, w) };
87
if nth_set_bit == 0 {
88
return None;
89
}
90
91
Some(nth_set_bit.trailing_zeros().into())
92
}
93
94
#[cfg(any(miri, not(target_feature = "bmi2")))]
95
{
96
// Each block of 2/4/8/16/32 bits contains how many set bits there are in that block.
97
let set_per_2 = w - ((w >> 1) & 0x5555555555555555);
98
let set_per_4 = (set_per_2 & 0x3333333333333333) + ((set_per_2 >> 2) & 0x3333333333333333);
99
let set_per_8 = (set_per_4 + (set_per_4 >> 4)) & 0x0f0f0f0f0f0f0f0f;
100
let set_per_16 = (set_per_8 + (set_per_8 >> 8)) & 0x00ff00ff00ff00ff;
101
let set_per_32 = (set_per_16 + (set_per_16 >> 16)) & 0x0000ffff0000ffff;
102
let set_per_64 = (set_per_32 + (set_per_32 >> 32)) & 0xffffffff;
103
104
if n >= set_per_64 {
105
return None;
106
}
107
108
let mut idx = 0;
109
let mut n = n;
110
111
let next32 = set_per_32 & 0xffff;
112
if n >= next32 {
113
n -= next32;
114
idx += 32;
115
}
116
let next16 = (set_per_16 >> idx) & 0xffff;
117
if n >= next16 {
118
n -= next16;
119
idx += 16;
120
}
121
let next8 = (set_per_8 >> idx) & 0xff;
122
if n >= next8 {
123
n -= next8;
124
idx += 8;
125
}
126
let next4 = (set_per_4 >> idx) & 0b1111;
127
if n >= next4 {
128
n -= next4;
129
idx += 4;
130
}
131
let next2 = (set_per_2 >> idx) & 0b11;
132
if n >= next2 {
133
n -= next2;
134
idx += 2;
135
}
136
let next1 = (w >> idx) & 0b1;
137
if n >= next1 {
138
idx += 1;
139
}
140
Some(idx)
141
}
142
}
143
144
#[derive(Default, Clone, Copy)]
145
pub struct BitMask<'a> {
146
bytes: &'a [u8],
147
offset: usize,
148
len: usize,
149
}
150
151
impl std::fmt::Debug for BitMask<'_> {
152
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153
let Self { bytes, offset, len } = self;
154
let offset_num_bytes = offset / 8;
155
let offset_in_byte = offset % 8;
156
fmt(&bytes[offset_num_bytes..], offset_in_byte, *len, f)
157
}
158
}
159
160
impl<'a> BitMask<'a> {
161
pub fn from_bitmap(bitmap: &'a Bitmap) -> Self {
162
let (bytes, offset, len) = bitmap.as_slice();
163
Self::new(bytes, offset, len)
164
}
165
166
pub fn inner(&self) -> (&[u8], usize, usize) {
167
(self.bytes, self.offset, self.len)
168
}
169
170
pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self {
171
// Check length so we can use unsafe access in our get.
172
assert!(bytes.len() * 8 >= len + offset);
173
Self { bytes, offset, len }
174
}
175
176
#[inline(always)]
177
pub fn len(&self) -> usize {
178
self.len
179
}
180
181
#[inline]
182
pub fn advance_by(&mut self, idx: usize) {
183
assert!(idx <= self.len);
184
self.offset += idx;
185
self.len -= idx;
186
}
187
188
#[inline]
189
pub fn split_at(&self, idx: usize) -> (Self, Self) {
190
assert!(idx <= self.len);
191
unsafe { self.split_at_unchecked(idx) }
192
}
193
194
/// # Safety
195
/// The index must be in-bounds.
196
#[inline]
197
pub unsafe fn split_at_unchecked(&self, idx: usize) -> (Self, Self) {
198
debug_assert!(idx <= self.len);
199
let left = Self { len: idx, ..*self };
200
let right = Self {
201
len: self.len - idx,
202
offset: self.offset + idx,
203
..*self
204
};
205
(left, right)
206
}
207
208
#[inline]
209
pub fn sliced(&self, offset: usize, length: usize) -> Self {
210
assert!(offset.checked_add(length).unwrap() <= self.len);
211
unsafe { self.sliced_unchecked(offset, length) }
212
}
213
214
/// # Safety
215
/// The index must be in-bounds.
216
#[inline]
217
pub unsafe fn sliced_unchecked(&self, offset: usize, length: usize) -> Self {
218
if cfg!(debug_assertions) {
219
assert!(offset.checked_add(length).unwrap() <= self.len);
220
}
221
222
Self {
223
bytes: self.bytes,
224
offset: self.offset + offset,
225
len: length,
226
}
227
}
228
229
pub fn unset_bits(&self) -> usize {
230
count_zeros(self.bytes, self.offset, self.len)
231
}
232
233
pub fn set_bits(&self) -> usize {
234
self.len - self.unset_bits()
235
}
236
237
pub fn fast_iter_u56(&self) -> FastU56BitmapIter<'_> {
238
FastU56BitmapIter::new(self.bytes, self.offset, self.len)
239
}
240
241
#[cfg(feature = "simd")]
242
#[inline]
243
pub fn get_simd<T, const N: usize>(&self, idx: usize) -> Mask<T, N>
244
where
245
T: MaskElement,
246
LaneCount<N>: SupportedLaneCount,
247
{
248
// We don't support 64-lane masks because then we couldn't load our
249
// bitwise mask as a u64 and then do the byteshift on it.
250
251
let lanes = LaneCount::<N>::BITMASK_LEN;
252
assert!(lanes < 64);
253
254
let start_byte_idx = (self.offset + idx) / 8;
255
let byte_shift = (self.offset + idx) % 8;
256
if idx + lanes <= self.len {
257
// SAFETY: fast path, we know this is completely in-bounds.
258
let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
259
Mask::from_bitmask(mask >> byte_shift)
260
} else if idx < self.len {
261
// SAFETY: we know that at least the first byte is in-bounds.
262
// This is partially out of bounds, we have to do extra masking.
263
let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
264
let num_out_of_bounds = idx + lanes - self.len;
265
let shifted = (mask << num_out_of_bounds) >> (num_out_of_bounds + byte_shift);
266
Mask::from_bitmask(shifted)
267
} else {
268
Mask::from_bitmask(0u64)
269
}
270
}
271
272
#[inline]
273
pub fn get_u32(&self, idx: usize) -> u32 {
274
let start_byte_idx = (self.offset + idx) / 8;
275
let byte_shift = (self.offset + idx) % 8;
276
if idx + 32 <= self.len {
277
// SAFETY: fast path, we know this is completely in-bounds.
278
let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
279
(mask >> byte_shift) as u32
280
} else if idx < self.len {
281
// SAFETY: we know that at least the first byte is in-bounds.
282
// This is partially out of bounds, we have to do extra masking.
283
let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
284
let out_of_bounds_mask = (1u32 << (self.len - idx)) - 1;
285
((mask >> byte_shift) as u32) & out_of_bounds_mask
286
} else {
287
0
288
}
289
}
290
291
/// Computes the index of the nth set bit after start.
292
///
293
/// Both are zero-indexed, so `nth_set_bit_idx(0, 0)` finds the index of the
294
/// first bit set (which can be 0 as well). The returned index is absolute,
295
/// not relative to start.
296
pub fn nth_set_bit_idx(&self, mut n: usize, mut start: usize) -> Option<usize> {
297
while start < self.len {
298
let next_u32_mask = self.get_u32(start);
299
if next_u32_mask == u32::MAX {
300
// Happy fast path for dense non-null section.
301
if n < 32 {
302
return Some(start + n);
303
}
304
n -= 32;
305
} else {
306
let ones = next_u32_mask.count_ones() as usize;
307
if n < ones {
308
let idx = unsafe {
309
// SAFETY: we know the nth bit is in the mask.
310
nth_set_bit_u32(next_u32_mask, n as u32).unwrap_unchecked() as usize
311
};
312
return Some(start + idx);
313
}
314
n -= ones;
315
}
316
317
start += 32;
318
}
319
320
None
321
}
322
323
/// Computes the index of the nth set bit before end, counting backwards.
324
///
325
/// Both are zero-indexed, so nth_set_bit_idx_rev(0, len) finds the index of
326
/// the last bit set (which can be 0 as well). The returned index is
327
/// absolute (and starts at the beginning), not relative to end.
328
pub fn nth_set_bit_idx_rev(&self, mut n: usize, mut end: usize) -> Option<usize> {
329
while end > 0 {
330
// We want to find bits *before* end, so if end < 32 we must mask
331
// out the bits after the endth.
332
let (u32_mask_start, u32_mask_mask) = if end >= 32 {
333
(end - 32, u32::MAX)
334
} else {
335
(0, (1 << end) - 1)
336
};
337
let next_u32_mask = self.get_u32(u32_mask_start) & u32_mask_mask;
338
if next_u32_mask == u32::MAX {
339
// Happy fast path for dense non-null section.
340
if n < 32 {
341
return Some(end - 1 - n);
342
}
343
n -= 32;
344
} else {
345
let ones = next_u32_mask.count_ones() as usize;
346
if n < ones {
347
let rev_n = ones - 1 - n;
348
let idx = unsafe {
349
// SAFETY: we know the rev_nth bit is in the mask.
350
nth_set_bit_u32(next_u32_mask, rev_n as u32).unwrap_unchecked() as usize
351
};
352
return Some(u32_mask_start + idx);
353
}
354
n -= ones;
355
}
356
357
end = u32_mask_start;
358
}
359
360
None
361
}
362
363
#[inline]
364
pub fn get(&self, idx: usize) -> bool {
365
if idx < self.len {
366
// SAFETY: we know this is in-bounds.
367
unsafe { self.get_bit_unchecked(idx) }
368
} else {
369
false
370
}
371
}
372
373
#[inline]
374
/// Get a bit at a certain idx.
375
///
376
/// # Safety
377
///
378
/// `idx` should be smaller than `len`
379
pub unsafe fn get_bit_unchecked(&self, idx: usize) -> bool {
380
let byte_idx = (self.offset + idx) / 8;
381
let byte_shift = (self.offset + idx) % 8;
382
383
// SAFETY: we know this is in-bounds.
384
let byte = unsafe { *self.bytes.get_unchecked(byte_idx) };
385
(byte >> byte_shift) & 1 == 1
386
}
387
388
pub fn iter(self) -> BitmapIter<'a> {
389
BitmapIter::new(self.bytes, self.offset, self.len)
390
}
391
392
/// Returns the number of zero bits from the start before a one bit is seen
393
pub fn leading_zeros(self) -> usize {
394
utils::leading_zeros(self.bytes, self.offset, self.len)
395
}
396
/// Returns the number of one bits from the start before a zero bit is seen
397
pub fn leading_ones(self) -> usize {
398
utils::leading_ones(self.bytes, self.offset, self.len)
399
}
400
/// Returns the number of zero bits from the back before a one bit is seen
401
pub fn trailing_zeros(self) -> usize {
402
utils::trailing_zeros(self.bytes, self.offset, self.len)
403
}
404
/// Returns the number of one bits from the back before a zero bit is seen
405
pub fn trailing_ones(self) -> usize {
406
utils::trailing_ones(self.bytes, self.offset, self.len)
407
}
408
409
/// Checks whether two [`Bitmap`]s have shared set bits.
410
///
411
/// This is an optimized version of `(self & other) != 0000..`.
412
pub fn intersects_with(self, other: Self) -> bool {
413
self.num_intersections_with(other) != 0
414
}
415
416
/// Calculates the number of shared set bits between two [`Bitmap`]s.
417
pub fn num_intersections_with(self, other: Self) -> usize {
418
super::num_intersections_with(self, other)
419
}
420
421
/// Returns an iterator over bits in bit chunks [`BitChunk`].
422
///
423
/// This iterator is useful to operate over multiple bits via e.g. bitwise.
424
pub fn chunks<T: BitChunk>(self) -> BitChunks<'a, T> {
425
BitChunks::new(self.bytes, self.offset, self.len)
426
}
427
}
428
429
#[cfg(test)]
430
mod test {
431
use super::*;
432
433
fn naive_nth_bit_set_u32(mut w: u32, mut n: u32) -> Option<u32> {
434
for i in 0..32 {
435
if w & (1 << i) != 0 {
436
if n == 0 {
437
return Some(i);
438
}
439
n -= 1;
440
w ^= 1 << i;
441
}
442
}
443
None
444
}
445
446
fn naive_nth_bit_set_u64(mut w: u64, mut n: u64) -> Option<u64> {
447
for i in 0..64 {
448
if w & (1 << i) != 0 {
449
if n == 0 {
450
return Some(i);
451
}
452
n -= 1;
453
w ^= 1 << i;
454
}
455
}
456
None
457
}
458
459
#[test]
460
fn test_nth_set_bit_u32() {
461
for n in 0..256 {
462
assert_eq!(nth_set_bit_u32(0, n), None);
463
}
464
465
for i in 0..32 {
466
assert_eq!(nth_set_bit_u32(1 << i, 0), Some(i));
467
assert_eq!(nth_set_bit_u32(1 << i, 1), None);
468
}
469
470
for i in 0..10000 {
471
let rnd = (0xbdbc9d8ec9d5c461u64.wrapping_mul(i as u64) >> 32) as u32;
472
for i in 0..=32 {
473
assert_eq!(nth_set_bit_u32(rnd, i), naive_nth_bit_set_u32(rnd, i));
474
}
475
}
476
}
477
478
#[test]
479
fn test_nth_set_bit_u64() {
480
for n in 0..256 {
481
assert_eq!(nth_set_bit_u64(0, n), None);
482
}
483
484
for i in 0..64 {
485
assert_eq!(nth_set_bit_u64(1 << i, 0), Some(i));
486
assert_eq!(nth_set_bit_u64(1 << i, 1), None);
487
}
488
489
for i in 0..10000 {
490
let rnd = 0xbdbc9d8ec9d5c461u64.wrapping_mul(i as u64) >> 32;
491
for i in 0..=64 {
492
assert_eq!(nth_set_bit_u64(rnd, i), naive_nth_bit_set_u64(rnd, i));
493
}
494
}
495
}
496
}
497
498