Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-arrow/src/bitmap/utils/iterator.rs
6939 views
1
use polars_utils::slice::load_padded_le_u64;
2
3
use super::get_bit_unchecked;
4
use crate::bitmap::MutableBitmap;
5
use crate::trusted_len::TrustedLen;
6
7
/// An iterator over bits according to the [LSB](https://en.wikipedia.org/wiki/Bit_numbering#Least_significant_bit),
8
/// i.e. the bytes `[4u8, 128u8]` correspond to `[false, false, true, false, ..., true]`.
9
#[derive(Debug, Clone)]
10
pub struct BitmapIter<'a> {
11
bytes: &'a [u8],
12
word: u64,
13
word_len: usize,
14
rest_len: usize,
15
}
16
17
impl<'a> BitmapIter<'a> {
18
/// Creates a new [`BitmapIter`].
19
pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self {
20
if len == 0 {
21
return Self {
22
bytes,
23
word: 0,
24
word_len: 0,
25
rest_len: 0,
26
};
27
}
28
29
assert!(bytes.len() * 8 >= offset + len);
30
let first_byte_idx = offset / 8;
31
let bytes = &bytes[first_byte_idx..];
32
let offset = offset % 8;
33
34
// Make sure during our hot loop all our loads are full 8-byte loads
35
// by loading the remainder now if it exists.
36
let word = load_padded_le_u64(bytes) >> offset;
37
let mod8 = bytes.len() % 8;
38
let first_word_bytes = if mod8 > 0 { mod8 } else { 8 };
39
let bytes = &bytes[first_word_bytes..];
40
41
let word_len = (first_word_bytes * 8 - offset).min(len);
42
let rest_len = len - word_len;
43
Self {
44
bytes,
45
word,
46
word_len,
47
rest_len,
48
}
49
}
50
51
/// Consume and returns the numbers of `1` / `true` values at the beginning of the iterator.
52
///
53
/// This performs the same operation as `(&mut iter).take_while(|b| b).count()`.
54
///
55
/// This is a lot more efficient than consecutively polling the iterator and should therefore
56
/// be preferred, if the use-case allows for it.
57
pub fn take_leading_ones(&mut self) -> usize {
58
let word_ones = usize::min(self.word_len, self.word.trailing_ones() as usize);
59
self.word_len -= word_ones;
60
self.word = self.word.wrapping_shr(word_ones as u32);
61
62
if self.word_len != 0 {
63
return word_ones;
64
}
65
66
let mut num_leading_ones = word_ones;
67
68
while self.rest_len != 0 {
69
self.word_len = usize::min(self.rest_len, 64);
70
self.rest_len -= self.word_len;
71
72
unsafe {
73
let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
74
self.word = u64::from_le_bytes(chunk);
75
self.bytes = self.bytes.get_unchecked(8..);
76
}
77
78
let word_ones = usize::min(self.word_len, self.word.trailing_ones() as usize);
79
self.word_len -= word_ones;
80
self.word = self.word.wrapping_shr(word_ones as u32);
81
num_leading_ones += word_ones;
82
83
if self.word_len != 0 {
84
return num_leading_ones;
85
}
86
}
87
88
num_leading_ones
89
}
90
91
/// Consume and returns the numbers of `0` / `false` values that the start of the iterator.
92
///
93
/// This performs the same operation as `(&mut iter).take_while(|b| !b).count()`.
94
///
95
/// This is a lot more efficient than consecutively polling the iterator and should therefore
96
/// be preferred, if the use-case allows for it.
97
pub fn take_leading_zeros(&mut self) -> usize {
98
let word_zeros = usize::min(self.word_len, self.word.trailing_zeros() as usize);
99
self.word_len -= word_zeros;
100
self.word = self.word.wrapping_shr(word_zeros as u32);
101
102
if self.word_len != 0 {
103
return word_zeros;
104
}
105
106
let mut num_leading_zeros = word_zeros;
107
108
while self.rest_len != 0 {
109
self.word_len = usize::min(self.rest_len, 64);
110
self.rest_len -= self.word_len;
111
unsafe {
112
let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
113
self.word = u64::from_le_bytes(chunk);
114
self.bytes = self.bytes.get_unchecked(8..);
115
}
116
117
let word_zeros = usize::min(self.word_len, self.word.trailing_zeros() as usize);
118
self.word_len -= word_zeros;
119
self.word = self.word.wrapping_shr(word_zeros as u32);
120
num_leading_zeros += word_zeros;
121
122
if self.word_len != 0 {
123
return num_leading_zeros;
124
}
125
}
126
127
num_leading_zeros
128
}
129
130
/// Returns the number of remaining elements in the iterator
131
#[inline]
132
pub fn num_remaining(&self) -> usize {
133
self.word_len + self.rest_len
134
}
135
136
/// Collect at most `n` elements from this iterator into `bitmap`
137
pub fn collect_n_into(&mut self, bitmap: &mut MutableBitmap, n: usize) {
138
fn collect_word(
139
word: &mut u64,
140
word_len: &mut usize,
141
bitmap: &mut MutableBitmap,
142
n: &mut usize,
143
) {
144
while *n > 0 && *word_len > 0 {
145
{
146
let trailing_ones = u32::min(word.trailing_ones(), *word_len as u32);
147
let shift = u32::min(usize::min(*n, u32::MAX as usize) as u32, trailing_ones);
148
*word = word.wrapping_shr(shift);
149
*word_len -= shift as usize;
150
*n -= shift as usize;
151
152
bitmap.extend_constant(shift as usize, true);
153
}
154
155
{
156
let trailing_zeros = u32::min(word.trailing_zeros(), *word_len as u32);
157
let shift = u32::min(usize::min(*n, u32::MAX as usize) as u32, trailing_zeros);
158
*word = word.wrapping_shr(shift);
159
*word_len -= shift as usize;
160
*n -= shift as usize;
161
162
bitmap.extend_constant(shift as usize, false);
163
}
164
}
165
}
166
167
let mut n = usize::min(n, self.num_remaining());
168
bitmap.reserve(n);
169
170
collect_word(&mut self.word, &mut self.word_len, bitmap, &mut n);
171
172
if n == 0 {
173
return;
174
}
175
176
let num_words = n / 64;
177
178
if num_words > 0 {
179
assert!(self.bytes.len() >= num_words * size_of::<u64>());
180
181
bitmap.extend_from_slice(self.bytes, 0, num_words * u64::BITS as usize);
182
183
self.bytes = unsafe { self.bytes.get_unchecked(num_words * 8..) };
184
self.rest_len -= num_words * u64::BITS as usize;
185
n -= num_words * u64::BITS as usize;
186
}
187
188
if n == 0 {
189
return;
190
}
191
192
assert!(self.bytes.len() >= size_of::<u64>());
193
194
self.word_len = usize::min(self.rest_len, 64);
195
self.rest_len -= self.word_len;
196
unsafe {
197
let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
198
self.word = u64::from_le_bytes(chunk);
199
self.bytes = self.bytes.get_unchecked(8..);
200
}
201
202
collect_word(&mut self.word, &mut self.word_len, bitmap, &mut n);
203
204
debug_assert!(self.num_remaining() == 0 || n == 0);
205
}
206
}
207
208
impl Iterator for BitmapIter<'_> {
209
type Item = bool;
210
211
#[inline]
212
fn next(&mut self) -> Option<Self::Item> {
213
if self.word_len == 0 {
214
if self.rest_len == 0 {
215
return None;
216
}
217
218
self.word_len = self.rest_len.min(64);
219
self.rest_len -= self.word_len;
220
221
unsafe {
222
let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
223
self.word = u64::from_le_bytes(chunk);
224
self.bytes = self.bytes.get_unchecked(8..);
225
}
226
}
227
228
let ret = self.word & 1 != 0;
229
self.word >>= 1;
230
self.word_len -= 1;
231
Some(ret)
232
}
233
234
#[inline]
235
fn size_hint(&self) -> (usize, Option<usize>) {
236
let num_remaining = self.num_remaining();
237
(num_remaining, Some(num_remaining))
238
}
239
}
240
241
impl DoubleEndedIterator for BitmapIter<'_> {
242
#[inline]
243
fn next_back(&mut self) -> Option<bool> {
244
if self.rest_len > 0 {
245
self.rest_len -= 1;
246
Some(unsafe { get_bit_unchecked(self.bytes, self.rest_len) })
247
} else if self.word_len > 0 {
248
self.word_len -= 1;
249
Some(self.word & (1 << self.word_len) != 0)
250
} else {
251
None
252
}
253
}
254
}
255
256
unsafe impl TrustedLen for BitmapIter<'_> {}
257
impl ExactSizeIterator for BitmapIter<'_> {}
258
259
#[cfg(test)]
260
mod tests {
261
use super::*;
262
263
#[test]
264
fn test_collect_into_17579() {
265
let mut bitmap = MutableBitmap::with_capacity(64);
266
BitmapIter::new(&[0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0], 0, 128)
267
.collect_n_into(&mut bitmap, 129);
268
269
let bitmap = bitmap.freeze();
270
271
assert_eq!(bitmap.set_bits(), 4);
272
}
273
274
#[test]
275
#[ignore = "Fuzz test. Too slow"]
276
fn test_fuzz_collect_into() {
277
for _ in 0..10_000 {
278
let mut set_bits = 0;
279
let mut unset_bits = 0;
280
281
let mut length = 0;
282
let mut pattern = Vec::new();
283
for _ in 0..rand::random::<u64>() % 1024 {
284
let bs = rand::random::<u8>() % 4;
285
286
let word = match bs {
287
0 => u64::MIN,
288
1 => u64::MAX,
289
2 | 3 => rand::random(),
290
_ => unreachable!(),
291
};
292
293
pattern.extend_from_slice(&word.to_le_bytes());
294
set_bits += word.count_ones();
295
unset_bits += word.count_zeros();
296
length += 64;
297
}
298
299
for _ in 0..rand::random::<u64>() % 7 {
300
let b = rand::random::<u8>();
301
pattern.push(b);
302
set_bits += b.count_ones();
303
unset_bits += b.count_zeros();
304
length += 8;
305
}
306
307
let last_length = rand::random::<u64>() % 8;
308
if last_length != 0 {
309
let b = rand::random::<u8>();
310
pattern.push(b);
311
let ones = (b & ((1 << last_length) - 1)).count_ones();
312
set_bits += ones;
313
unset_bits += last_length as u32 - ones;
314
length += last_length;
315
}
316
317
let mut iter = BitmapIter::new(&pattern, 0, length as usize);
318
let mut bitmap = MutableBitmap::with_capacity(length as usize);
319
320
while iter.num_remaining() > 0 {
321
let len_before = bitmap.len();
322
let n = rand::random::<u64>() as usize % iter.num_remaining();
323
iter.collect_n_into(&mut bitmap, n);
324
325
// Ensure we are booking the progress we expect
326
assert_eq!(bitmap.len(), len_before + n);
327
}
328
329
let bitmap = bitmap.freeze();
330
331
assert_eq!(bitmap.set_bits(), set_bits as usize);
332
assert_eq!(bitmap.unset_bits(), unset_bits as usize);
333
}
334
}
335
336
#[test]
337
#[ignore = "Fuzz test. Too slow"]
338
fn test_fuzz_leading_ops() {
339
for _ in 0..10_000 {
340
let mut length = 0;
341
let mut pattern = Vec::new();
342
for _ in 0..rand::random::<u64>() % 1024 {
343
let bs = rand::random::<u8>() % 4;
344
345
let word = match bs {
346
0 => u64::MIN,
347
1 => u64::MAX,
348
2 | 3 => rand::random(),
349
_ => unreachable!(),
350
};
351
352
pattern.extend_from_slice(&word.to_le_bytes());
353
length += 64;
354
}
355
356
for _ in 0..rand::random::<u64>() % 7 {
357
pattern.push(rand::random::<u8>());
358
length += 8;
359
}
360
361
let last_length = rand::random::<u64>() % 8;
362
if last_length != 0 {
363
pattern.push(rand::random::<u8>());
364
length += last_length;
365
}
366
367
let mut iter = BitmapIter::new(&pattern, 0, length as usize);
368
369
let mut prev_remaining = iter.num_remaining();
370
while iter.num_remaining() != 0 {
371
let num_ones = iter.clone().take_leading_ones();
372
assert_eq!(num_ones, (&mut iter).take_while(|&b| b).count());
373
374
let num_zeros = iter.clone().take_leading_zeros();
375
assert_eq!(num_zeros, (&mut iter).take_while(|&b| !b).count());
376
377
// Ensure that we are making progress
378
assert!(iter.num_remaining() < prev_remaining);
379
prev_remaining = iter.num_remaining();
380
}
381
382
assert_eq!(iter.take_leading_zeros(), 0);
383
assert_eq!(iter.take_leading_ones(), 0);
384
}
385
}
386
}
387
388