Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/chunked_array/gather_skip_nulls.rs
6939 views
1
use arrow::array::Array;
2
use arrow::bitmap::bitmask::BitMask;
3
use arrow::compute::concatenate::concatenate_validities;
4
use bytemuck::allocation::zeroed_vec;
5
use polars_core::prelude::gather::check_bounds_ca;
6
use polars_core::prelude::*;
7
use polars_utils::index::check_bounds;
8
9
/// # Safety
10
/// For each index pair, pair.0 < len && pair.1 < ca.null_count() must hold.
11
unsafe fn gather_skip_nulls_idx_pairs_unchecked<'a, T: PolarsDataType>(
12
ca: &'a ChunkedArray<T>,
13
mut index_pairs: Vec<(IdxSize, IdxSize)>,
14
len: usize,
15
) -> Vec<T::ZeroablePhysical<'a>> {
16
if index_pairs.is_empty() {
17
return zeroed_vec(len);
18
}
19
20
// We sort by gather index so we can do the null scan in one pass.
21
index_pairs.sort_unstable_by_key(|t| t.1);
22
let mut pair_iter = index_pairs.iter().copied();
23
let (mut out_idx, mut nonnull_idx);
24
(out_idx, nonnull_idx) = pair_iter.next().unwrap();
25
26
let mut out: Vec<T::ZeroablePhysical<'a>> = zeroed_vec(len);
27
let mut nonnull_prev_arrays = 0;
28
'outer: for arr in ca.downcast_iter() {
29
let arr_nonnull_len = arr.len() - arr.null_count();
30
let mut arr_scan_offset = 0;
31
let mut nonnull_before_offset = 0;
32
let mask = arr.validity().map(BitMask::from_bitmap).unwrap_or_default();
33
34
// Is our next nonnull_idx in this array?
35
while nonnull_idx as usize - nonnull_prev_arrays < arr_nonnull_len {
36
let nonnull_idx_in_arr = nonnull_idx as usize - nonnull_prev_arrays;
37
38
let phys_idx_in_arr = if arr.null_count() == 0 {
39
// Happy fast path for full non-null array.
40
nonnull_idx_in_arr
41
} else {
42
mask.nth_set_bit_idx(nonnull_idx_in_arr - nonnull_before_offset, arr_scan_offset)
43
.unwrap()
44
};
45
46
unsafe {
47
let val = arr.value_unchecked(phys_idx_in_arr);
48
*out.get_unchecked_mut(out_idx as usize) = val.into();
49
}
50
51
arr_scan_offset = phys_idx_in_arr;
52
nonnull_before_offset = nonnull_idx_in_arr;
53
54
let Some(next_pair) = pair_iter.next() else {
55
break 'outer;
56
};
57
(out_idx, nonnull_idx) = next_pair;
58
}
59
60
nonnull_prev_arrays += arr_nonnull_len;
61
}
62
63
out
64
}
65
66
pub trait ChunkGatherSkipNulls<I: ?Sized>: Sized {
67
fn gather_skip_nulls(&self, indices: &I) -> PolarsResult<Self>;
68
}
69
70
impl<T: PolarsDataType> ChunkGatherSkipNulls<[IdxSize]> for ChunkedArray<T>
71
where
72
ChunkedArray<T>: ChunkFilter<T> + ChunkTake<[IdxSize]>,
73
{
74
fn gather_skip_nulls(&self, indices: &[IdxSize]) -> PolarsResult<Self> {
75
if self.null_count() == 0 {
76
return self.take(indices);
77
}
78
79
// If we want many indices it's probably better to do a normal gather on
80
// a dense array.
81
if indices.len() >= self.len() / 4 {
82
return ChunkFilter::filter(self, &self.is_not_null())
83
.unwrap()
84
.take(indices);
85
}
86
87
let bound = self.len() - self.null_count();
88
check_bounds(indices, bound as IdxSize)?;
89
90
let index_pairs: Vec<_> = indices
91
.iter()
92
.enumerate()
93
.map(|(out_idx, nonnull_idx)| (out_idx as IdxSize, *nonnull_idx))
94
.collect();
95
let gathered =
96
unsafe { gather_skip_nulls_idx_pairs_unchecked(self, index_pairs, indices.len()) };
97
let arr =
98
T::Array::from_zeroable_vec(gathered, self.dtype().to_arrow(CompatLevel::newest()));
99
Ok(ChunkedArray::from_chunk_iter_like(self, [arr]))
100
}
101
}
102
103
impl<T: PolarsDataType> ChunkGatherSkipNulls<IdxCa> for ChunkedArray<T>
104
where
105
ChunkedArray<T>: ChunkFilter<T> + ChunkTake<IdxCa>,
106
{
107
fn gather_skip_nulls(&self, indices: &IdxCa) -> PolarsResult<Self> {
108
if self.null_count() == 0 {
109
return self.take(indices);
110
}
111
112
// If we want many indices it's probably better to do a normal gather on
113
// a dense array.
114
if indices.len() >= self.len() / 4 {
115
return ChunkFilter::filter(self, &self.is_not_null())
116
.unwrap()
117
.take(indices);
118
}
119
120
let bound = self.len() - self.null_count();
121
check_bounds_ca(indices, bound as IdxSize)?;
122
123
let index_pairs: Vec<_> = if indices.null_count() == 0 {
124
indices
125
.downcast_iter()
126
.flat_map(|arr| arr.values_iter())
127
.enumerate()
128
.map(|(out_idx, nonnull_idx)| (out_idx as IdxSize, *nonnull_idx))
129
.collect()
130
} else {
131
// Filter *after* the enumerate so we place the non-null gather
132
// requests at the right places.
133
indices
134
.downcast_iter()
135
.flat_map(|arr| arr.iter())
136
.enumerate()
137
.filter_map(|(out_idx, nonnull_idx)| Some((out_idx as IdxSize, *nonnull_idx?)))
138
.collect()
139
};
140
let gathered = unsafe {
141
gather_skip_nulls_idx_pairs_unchecked(self, index_pairs, indices.as_ref().len())
142
};
143
144
let mut arr =
145
T::Array::from_zeroable_vec(gathered, self.dtype().to_arrow(CompatLevel::newest()));
146
if indices.null_count() > 0 {
147
arr = arr.with_validity_typed(concatenate_validities(indices.chunks()));
148
}
149
Ok(ChunkedArray::from_chunk_iter_like(self, [arr]))
150
}
151
}
152
153
#[cfg(test)]
154
mod test {
155
use std::ops::Range;
156
157
use rand::distr::uniform::SampleUniform;
158
use rand::prelude::*;
159
160
use super::*;
161
162
fn random_vec<T: SampleUniform + PartialOrd + Clone, R: Rng>(
163
rng: &mut R,
164
val: Range<T>,
165
len_range: Range<usize>,
166
) -> Vec<T> {
167
let n = rng.random_range(len_range);
168
(0..n).map(|_| rng.random_range(val.clone())).collect()
169
}
170
171
fn random_filter<T: Clone, R: Rng>(rng: &mut R, v: &[T], pr: Range<f64>) -> Vec<Option<T>> {
172
let p = rng.random_range(pr);
173
let rand_filter = |x| Some(x).filter(|_| rng.random::<f64>() < p);
174
v.iter().cloned().map(rand_filter).collect()
175
}
176
177
fn ref_gather_nulls(v: Vec<Option<u32>>, idx: Vec<Option<usize>>) -> Option<Vec<Option<u32>>> {
178
let v: Vec<u32> = v.into_iter().flatten().collect();
179
if idx.iter().any(|oi| oi.map(|i| i >= v.len()) == Some(true)) {
180
return None;
181
}
182
Some(idx.into_iter().map(|i| Some(v[i?])).collect())
183
}
184
185
fn test_equal_ref(ca: &UInt32Chunked, idx_ca: &IdxCa) {
186
let ref_ca: Vec<Option<u32>> = ca.iter().collect();
187
let ref_idx_ca: Vec<Option<usize>> = idx_ca.iter().map(|i| Some(i? as usize)).collect();
188
let gather = ca.gather_skip_nulls(idx_ca).ok();
189
let ref_gather = ref_gather_nulls(ref_ca, ref_idx_ca);
190
assert_eq!(gather.map(|ca| ca.iter().collect()), ref_gather);
191
}
192
193
fn gather_skip_nulls_check(ca: &UInt32Chunked, idx_ca: &IdxCa) {
194
test_equal_ref(ca, idx_ca);
195
test_equal_ref(&ca.rechunk(), idx_ca);
196
test_equal_ref(ca, &idx_ca.rechunk());
197
test_equal_ref(&ca.rechunk(), &idx_ca.rechunk());
198
}
199
200
#[rustfmt::skip]
201
#[test]
202
fn test_gather_skip_nulls() {
203
let mut rng = SmallRng::seed_from_u64(0xdeadbeef);
204
205
for _test in 0..20 {
206
let num_elem_chunks = rng.random_range(1..10);
207
let elem_chunks: Vec<_> = (0..num_elem_chunks).map(|_| random_vec(&mut rng, 0..u32::MAX, 0..100)).collect();
208
let null_elem_chunks: Vec<_> = elem_chunks.iter().map(|c| random_filter(&mut rng, c, 0.7..1.0)).collect();
209
let num_nonnull_elems: usize = null_elem_chunks.iter().map(|c| c.iter().filter(|x| x.is_some()).count()).sum();
210
211
let num_idx_chunks = rng.random_range(1..10);
212
let idx_chunks: Vec<_> = (0..num_idx_chunks).map(|_| random_vec(&mut rng, 0..num_nonnull_elems as IdxSize, 0..200)).collect();
213
let null_idx_chunks: Vec<_> = idx_chunks.iter().map(|c| random_filter(&mut rng, c, 0.7..1.0)).collect();
214
215
let nonnull_ca = UInt32Chunked::from_chunk_iter("".into(), elem_chunks.iter().cloned().map(|v| v.into_iter().collect_arr()));
216
let ca = UInt32Chunked::from_chunk_iter("".into(), null_elem_chunks.iter().cloned().map(|v| v.into_iter().collect_arr()));
217
let nonnull_idx_ca = IdxCa::from_chunk_iter("".into(), idx_chunks.iter().cloned().map(|v| v.into_iter().collect_arr()));
218
let idx_ca = IdxCa::from_chunk_iter("".into(), null_idx_chunks.iter().cloned().map(|v| v.into_iter().collect_arr()));
219
220
gather_skip_nulls_check(&ca, &idx_ca);
221
gather_skip_nulls_check(&ca, &nonnull_idx_ca);
222
gather_skip_nulls_check(&nonnull_ca, &idx_ca);
223
gather_skip_nulls_check(&nonnull_ca, &nonnull_idx_ca);
224
}
225
}
226
}
227
228