Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-arrow/src/bitmap/utils/iterator.rs
8396 views
1
use polars_utils::slice::load_padded_le_u64;
2
3
use super::get_bit_unchecked;
4
use crate::bitmap::MutableBitmap;
5
use crate::trusted_len::TrustedLen;
6
7
/// An iterator over bits according to the [LSB](https://en.wikipedia.org/wiki/Bit_numbering#Least_significant_bit),
8
/// i.e. the bytes `[4u8, 128u8]` correspond to `[false, false, true, false, ..., true]`.
9
#[derive(Debug, Clone)]
10
pub struct BitmapIter<'a> {
11
bytes: &'a [u8],
12
word: u64,
13
word_len: usize,
14
rest_len: usize,
15
}
16
17
impl<'a> BitmapIter<'a> {
18
/// Creates a new [`BitmapIter`].
19
pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self {
20
if len == 0 {
21
return Self {
22
bytes,
23
word: 0,
24
word_len: 0,
25
rest_len: 0,
26
};
27
}
28
29
assert!(bytes.len() * 8 >= offset + len);
30
let first_byte_idx = offset / 8;
31
let bytes = &bytes[first_byte_idx..];
32
let offset = offset % 8;
33
34
// Make sure during our hot loop all our loads are full 8-byte loads
35
// by loading the remainder now if it exists.
36
let word = load_padded_le_u64(bytes) >> offset;
37
let mod8 = bytes.len() % 8;
38
let first_word_bytes = if mod8 > 0 { mod8 } else { 8 };
39
let bytes = &bytes[first_word_bytes..];
40
41
let word_len = (first_word_bytes * 8 - offset).min(len);
42
let rest_len = len - word_len;
43
Self {
44
bytes,
45
word,
46
word_len,
47
rest_len,
48
}
49
}
50
51
/// Consume and returns the numbers of `1` / `true` values at the beginning of the iterator.
52
///
53
/// This performs the same operation as `(&mut iter).take_while(|b| b).count()`.
54
///
55
/// This is a lot more efficient than consecutively polling the iterator and should therefore
56
/// be preferred, if the use-case allows for it.
57
pub fn take_leading_ones(&mut self) -> usize {
58
let word_ones = usize::min(self.word_len, self.word.trailing_ones() as usize);
59
self.word_len -= word_ones;
60
self.word = self.word.wrapping_shr(word_ones as u32);
61
62
if self.word_len != 0 {
63
return word_ones;
64
}
65
66
let mut num_leading_ones = word_ones;
67
68
while self.rest_len != 0 {
69
self.word_len = usize::min(self.rest_len, 64);
70
self.rest_len -= self.word_len;
71
72
unsafe {
73
let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
74
self.word = u64::from_le_bytes(chunk);
75
self.bytes = self.bytes.get_unchecked(8..);
76
}
77
78
let word_ones = usize::min(self.word_len, self.word.trailing_ones() as usize);
79
self.word_len -= word_ones;
80
self.word = self.word.wrapping_shr(word_ones as u32);
81
num_leading_ones += word_ones;
82
83
if self.word_len != 0 {
84
return num_leading_ones;
85
}
86
}
87
88
num_leading_ones
89
}
90
91
/// Consume and returns the numbers of `0` / `false` values that the start of the iterator.
92
///
93
/// This performs the same operation as `(&mut iter).take_while(|b| !b).count()`.
94
///
95
/// This is a lot more efficient than consecutively polling the iterator and should therefore
96
/// be preferred, if the use-case allows for it.
97
pub fn take_leading_zeros(&mut self) -> usize {
98
let word_zeros = usize::min(self.word_len, self.word.trailing_zeros() as usize);
99
self.word_len -= word_zeros;
100
self.word = self.word.wrapping_shr(word_zeros as u32);
101
102
if self.word_len != 0 {
103
return word_zeros;
104
}
105
106
let mut num_leading_zeros = word_zeros;
107
108
while self.rest_len != 0 {
109
self.word_len = usize::min(self.rest_len, 64);
110
self.rest_len -= self.word_len;
111
unsafe {
112
let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
113
self.word = u64::from_le_bytes(chunk);
114
self.bytes = self.bytes.get_unchecked(8..);
115
}
116
117
let word_zeros = usize::min(self.word_len, self.word.trailing_zeros() as usize);
118
self.word_len -= word_zeros;
119
self.word = self.word.wrapping_shr(word_zeros as u32);
120
num_leading_zeros += word_zeros;
121
122
if self.word_len != 0 {
123
return num_leading_zeros;
124
}
125
}
126
127
num_leading_zeros
128
}
129
130
/// Returns the number of remaining elements in the iterator
131
#[inline]
132
pub fn num_remaining(&self) -> usize {
133
self.word_len + self.rest_len
134
}
135
136
/// Collect at most `n` elements from this iterator into `bitmap`
137
pub fn collect_n_into(&mut self, bitmap: &mut MutableBitmap, n: usize) {
138
fn collect_word(
139
word: &mut u64,
140
word_len: &mut usize,
141
bitmap: &mut MutableBitmap,
142
n: &mut usize,
143
) {
144
while *n > 0 && *word_len > 0 {
145
{
146
let trailing_ones = u32::min(word.trailing_ones(), *word_len as u32);
147
let shift = u32::min(usize::min(*n, u32::MAX as usize) as u32, trailing_ones);
148
*word = word.wrapping_shr(shift);
149
*word_len -= shift as usize;
150
*n -= shift as usize;
151
152
bitmap.extend_constant(shift as usize, true);
153
}
154
155
{
156
let trailing_zeros = u32::min(word.trailing_zeros(), *word_len as u32);
157
let shift = u32::min(usize::min(*n, u32::MAX as usize) as u32, trailing_zeros);
158
*word = word.wrapping_shr(shift);
159
*word_len -= shift as usize;
160
*n -= shift as usize;
161
162
bitmap.extend_constant(shift as usize, false);
163
}
164
}
165
}
166
167
let mut n = usize::min(n, self.num_remaining());
168
bitmap.reserve(n);
169
170
collect_word(&mut self.word, &mut self.word_len, bitmap, &mut n);
171
172
if n == 0 {
173
return;
174
}
175
176
let num_words = n / 64;
177
178
if num_words > 0 {
179
assert!(self.bytes.len() >= num_words * size_of::<u64>());
180
181
bitmap.extend_from_slice(self.bytes, 0, num_words * u64::BITS as usize);
182
183
self.bytes = unsafe { self.bytes.get_unchecked(num_words * 8..) };
184
self.rest_len -= num_words * u64::BITS as usize;
185
n -= num_words * u64::BITS as usize;
186
}
187
188
if n == 0 {
189
return;
190
}
191
192
assert!(self.bytes.len() >= size_of::<u64>());
193
194
self.word_len = usize::min(self.rest_len, 64);
195
self.rest_len -= self.word_len;
196
unsafe {
197
let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
198
self.word = u64::from_le_bytes(chunk);
199
self.bytes = self.bytes.get_unchecked(8..);
200
}
201
202
collect_word(&mut self.word, &mut self.word_len, bitmap, &mut n);
203
204
debug_assert!(self.num_remaining() == 0 || n == 0);
205
}
206
}
207
208
impl Iterator for BitmapIter<'_> {
209
type Item = bool;
210
211
#[inline]
212
fn next(&mut self) -> Option<Self::Item> {
213
if self.word_len == 0 {
214
if self.rest_len == 0 {
215
return None;
216
}
217
218
self.word_len = self.rest_len.min(64);
219
self.rest_len -= self.word_len;
220
221
unsafe {
222
let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
223
self.word = u64::from_le_bytes(chunk);
224
self.bytes = self.bytes.get_unchecked(8..);
225
}
226
}
227
228
let ret = self.word & 1 != 0;
229
self.word >>= 1;
230
self.word_len -= 1;
231
Some(ret)
232
}
233
234
#[inline]
235
fn size_hint(&self) -> (usize, Option<usize>) {
236
let num_remaining = self.num_remaining();
237
(num_remaining, Some(num_remaining))
238
}
239
240
#[inline]
241
fn nth(&mut self, mut n: usize) -> Option<Self::Item> {
242
if n >= self.word_len + self.rest_len {
243
self.word = 0;
244
self.word_len = 0;
245
self.rest_len = 0;
246
return None;
247
}
248
249
// Advance words in buffer, skip words as needed
250
if n >= self.word_len {
251
n -= self.word_len;
252
253
let word_offset = n / 64;
254
n -= word_offset * 64;
255
self.rest_len -= word_offset * 64;
256
257
self.word_len = self.rest_len.min(64);
258
self.rest_len -= self.word_len;
259
260
let byte_offset = 8 * word_offset;
261
262
// Safety: bytes is large enough at construction time.
263
debug_assert!(byte_offset + 8 <= self.bytes.len());
264
unsafe {
265
let chunk = self
266
.bytes
267
.get_unchecked(byte_offset..byte_offset + 8)
268
.try_into()
269
.unwrap();
270
self.word = u64::from_le_bytes(chunk);
271
self.bytes = self.bytes.get_unchecked(byte_offset + 8..);
272
}
273
}
274
275
// At this point, n < self.word_len
276
debug_assert!(self.word_len > n);
277
278
// Advance index by n and take value at final index
279
self.word >>= n;
280
self.word_len -= n;
281
282
let ret = self.word & 1 != 0;
283
self.word >>= 1;
284
self.word_len -= 1;
285
Some(ret)
286
}
287
}
288
289
impl DoubleEndedIterator for BitmapIter<'_> {
290
#[inline]
291
fn next_back(&mut self) -> Option<bool> {
292
if self.rest_len > 0 {
293
self.rest_len -= 1;
294
Some(unsafe { get_bit_unchecked(self.bytes, self.rest_len) })
295
} else if self.word_len > 0 {
296
self.word_len -= 1;
297
Some(self.word & (1 << self.word_len) != 0)
298
} else {
299
None
300
}
301
}
302
}
303
304
unsafe impl TrustedLen for BitmapIter<'_> {}
305
impl ExactSizeIterator for BitmapIter<'_> {}
306
307
#[cfg(test)]
308
mod tests {
309
use super::*;
310
311
#[test]
312
fn test_collect_into_17579() {
313
let mut bitmap = MutableBitmap::with_capacity(64);
314
BitmapIter::new(&[0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0], 0, 128)
315
.collect_n_into(&mut bitmap, 129);
316
317
let bitmap = bitmap.freeze();
318
319
assert_eq!(bitmap.set_bits(), 4);
320
}
321
322
#[test]
323
#[ignore = "Fuzz test. Too slow"]
324
fn test_fuzz_collect_into() {
325
for _ in 0..10_000 {
326
let mut set_bits = 0;
327
let mut unset_bits = 0;
328
329
let mut length = 0;
330
let mut pattern = Vec::new();
331
for _ in 0..rand::random::<u64>() % 1024 {
332
let bs = rand::random::<u8>() % 4;
333
334
let word = match bs {
335
0 => u64::MIN,
336
1 => u64::MAX,
337
2 | 3 => rand::random(),
338
_ => unreachable!(),
339
};
340
341
pattern.extend_from_slice(&word.to_le_bytes());
342
set_bits += word.count_ones();
343
unset_bits += word.count_zeros();
344
length += 64;
345
}
346
347
for _ in 0..rand::random::<u64>() % 7 {
348
let b = rand::random::<u8>();
349
pattern.push(b);
350
set_bits += b.count_ones();
351
unset_bits += b.count_zeros();
352
length += 8;
353
}
354
355
let last_length = rand::random::<u64>() % 8;
356
if last_length != 0 {
357
let b = rand::random::<u8>();
358
pattern.push(b);
359
let ones = (b & ((1 << last_length) - 1)).count_ones();
360
set_bits += ones;
361
unset_bits += last_length as u32 - ones;
362
length += last_length;
363
}
364
365
let mut iter = BitmapIter::new(&pattern, 0, length as usize);
366
let mut bitmap = MutableBitmap::with_capacity(length as usize);
367
368
while iter.num_remaining() > 0 {
369
let len_before = bitmap.len();
370
let n = rand::random::<u64>() as usize % iter.num_remaining();
371
iter.collect_n_into(&mut bitmap, n);
372
373
// Ensure we are booking the progress we expect
374
assert_eq!(bitmap.len(), len_before + n);
375
}
376
377
let bitmap = bitmap.freeze();
378
379
assert_eq!(bitmap.set_bits(), set_bits as usize);
380
assert_eq!(bitmap.unset_bits(), unset_bits as usize);
381
}
382
}
383
384
#[test]
385
#[ignore = "Fuzz test. Too slow"]
386
fn test_fuzz_leading_ops() {
387
for _ in 0..10_000 {
388
let mut length = 0;
389
let mut pattern = Vec::new();
390
for _ in 0..rand::random::<u64>() % 1024 {
391
let bs = rand::random::<u8>() % 4;
392
393
let word = match bs {
394
0 => u64::MIN,
395
1 => u64::MAX,
396
2 | 3 => rand::random(),
397
_ => unreachable!(),
398
};
399
400
pattern.extend_from_slice(&word.to_le_bytes());
401
length += 64;
402
}
403
404
for _ in 0..rand::random::<u64>() % 7 {
405
pattern.push(rand::random::<u8>());
406
length += 8;
407
}
408
409
let last_length = rand::random::<u64>() % 8;
410
if last_length != 0 {
411
pattern.push(rand::random::<u8>());
412
length += last_length;
413
}
414
415
let mut iter = BitmapIter::new(&pattern, 0, length as usize);
416
417
let mut prev_remaining = iter.num_remaining();
418
while iter.num_remaining() != 0 {
419
let num_ones = iter.clone().take_leading_ones();
420
assert_eq!(num_ones, (&mut iter).take_while(|&b| b).count());
421
422
let num_zeros = iter.clone().take_leading_zeros();
423
assert_eq!(num_zeros, (&mut iter).take_while(|&b| !b).count());
424
425
// Ensure that we are making progress
426
assert!(iter.num_remaining() < prev_remaining);
427
prev_remaining = iter.num_remaining();
428
}
429
430
assert_eq!(iter.take_leading_zeros(), 0);
431
assert_eq!(iter.take_leading_ones(), 0);
432
}
433
}
434
435
#[test]
436
#[allow(clippy::iter_nth_zero)]
437
fn test_bitmap_iter_nth() {
438
// Calling nth repeatedly advances through the bitmap
439
{
440
let mut iter = BitmapIter::new(&[0b10110001], 0, 8);
441
assert_eq!(iter.nth(0), Some(true));
442
assert_eq!(iter.nth(0), Some(false));
443
assert_eq!(iter.nth(2), Some(true));
444
assert_eq!(iter.nth(3), None);
445
446
assert_eq!(iter.next(), None);
447
}
448
449
// Test parity with next()-based implementation on of singular call to nth()
450
for len in [0, 1, 2, 63, 64, 65, 127, 128, 129] {
451
for offset in [0, 1, 2] {
452
// binary '01010101' == 85
453
let iter = BitmapIter::new(
454
&[
455
0, 1, 2, 4, 8, 16, 32, 64, 85, 170, 85, 170, 85, 170, 85, 170, 255, 0,
456
],
457
offset,
458
len,
459
);
460
461
for i in 0..=len {
462
let mut iter_expected = iter.clone();
463
let mut iter_test = iter.clone();
464
465
let prev_rest_len = iter_test.rest_len;
466
let prev_word_len = iter_test.word_len;
467
468
assert_eq!(len, prev_rest_len + prev_word_len);
469
470
// Iterate.
471
let out = iter_test.nth(i);
472
for _ in 0..i {
473
iter_expected.next();
474
}
475
let expected = iter_expected.next();
476
477
// Check value.
478
assert_eq!(out, expected);
479
480
// Check internal sate.
481
let final_rest_len = iter_test.rest_len;
482
let final_word_len = iter_test.word_len;
483
match out {
484
Some(_) => assert_eq!(
485
prev_rest_len + prev_word_len,
486
i + 1 + final_rest_len + final_word_len
487
),
488
None => {
489
assert!(i >= prev_rest_len + prev_word_len);
490
assert_eq!(final_rest_len + final_word_len, 0)
491
},
492
};
493
}
494
}
495
}
496
497
// Check internal state on repeat calls to nth().
498
{
499
for len in [0, 63, 64, 65, 126, 128, 129] {
500
let mut iter =
501
BitmapIter::new(&[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0], 0, len);
502
for step in [0, 1, 2, 3] {
503
for i in (0..len + step + 1).step_by(step + 1) {
504
let prev_rest_len = iter.rest_len;
505
let prev_word_len = iter.word_len;
506
507
let out = iter.nth(step);
508
509
let final_rest_len = iter.rest_len;
510
let final_word_len = iter.word_len;
511
match out {
512
Some(_) => assert_eq!(
513
prev_rest_len + prev_word_len,
514
step + 1 + final_rest_len + final_word_len
515
),
516
None => {
517
assert!(i >= prev_rest_len + prev_word_len);
518
assert_eq!(final_rest_len + final_word_len, 0)
519
},
520
};
521
}
522
}
523
}
524
}
525
526
// Edge cases
527
let mut iter = BitmapIter::new(&[], 0, 0);
528
assert_eq!(iter.nth(0), None);
529
}
530
}
531
532