Path: blob/main/crates/polars-ops/src/chunked_array/gather_skip_nulls.rs
6939 views
use arrow::array::Array;1use arrow::bitmap::bitmask::BitMask;2use arrow::compute::concatenate::concatenate_validities;3use bytemuck::allocation::zeroed_vec;4use polars_core::prelude::gather::check_bounds_ca;5use polars_core::prelude::*;6use polars_utils::index::check_bounds;78/// # Safety9/// For each index pair, pair.0 < len && pair.1 < ca.null_count() must hold.10unsafe fn gather_skip_nulls_idx_pairs_unchecked<'a, T: PolarsDataType>(11ca: &'a ChunkedArray<T>,12mut index_pairs: Vec<(IdxSize, IdxSize)>,13len: usize,14) -> Vec<T::ZeroablePhysical<'a>> {15if index_pairs.is_empty() {16return zeroed_vec(len);17}1819// We sort by gather index so we can do the null scan in one pass.20index_pairs.sort_unstable_by_key(|t| t.1);21let mut pair_iter = index_pairs.iter().copied();22let (mut out_idx, mut nonnull_idx);23(out_idx, nonnull_idx) = pair_iter.next().unwrap();2425let mut out: Vec<T::ZeroablePhysical<'a>> = zeroed_vec(len);26let mut nonnull_prev_arrays = 0;27'outer: for arr in ca.downcast_iter() {28let arr_nonnull_len = arr.len() - arr.null_count();29let mut arr_scan_offset = 0;30let mut nonnull_before_offset = 0;31let mask = arr.validity().map(BitMask::from_bitmap).unwrap_or_default();3233// Is our next nonnull_idx in this array?34while nonnull_idx as usize - nonnull_prev_arrays < arr_nonnull_len {35let nonnull_idx_in_arr = nonnull_idx as usize - nonnull_prev_arrays;3637let phys_idx_in_arr = if arr.null_count() == 0 {38// Happy fast path for full non-null array.39nonnull_idx_in_arr40} else {41mask.nth_set_bit_idx(nonnull_idx_in_arr - nonnull_before_offset, arr_scan_offset)42.unwrap()43};4445unsafe {46let val = arr.value_unchecked(phys_idx_in_arr);47*out.get_unchecked_mut(out_idx as usize) = val.into();48}4950arr_scan_offset = phys_idx_in_arr;51nonnull_before_offset = nonnull_idx_in_arr;5253let Some(next_pair) = pair_iter.next() else {54break 'outer;55};56(out_idx, nonnull_idx) = next_pair;57}5859nonnull_prev_arrays += arr_nonnull_len;60}6162out63}6465pub trait ChunkGatherSkipNulls<I: ?Sized>: Sized {66fn gather_skip_nulls(&self, indices: &I) -> PolarsResult<Self>;67}6869impl<T: PolarsDataType> ChunkGatherSkipNulls<[IdxSize]> for ChunkedArray<T>70where71ChunkedArray<T>: ChunkFilter<T> + ChunkTake<[IdxSize]>,72{73fn gather_skip_nulls(&self, indices: &[IdxSize]) -> PolarsResult<Self> {74if self.null_count() == 0 {75return self.take(indices);76}7778// If we want many indices it's probably better to do a normal gather on79// a dense array.80if indices.len() >= self.len() / 4 {81return ChunkFilter::filter(self, &self.is_not_null())82.unwrap()83.take(indices);84}8586let bound = self.len() - self.null_count();87check_bounds(indices, bound as IdxSize)?;8889let index_pairs: Vec<_> = indices90.iter()91.enumerate()92.map(|(out_idx, nonnull_idx)| (out_idx as IdxSize, *nonnull_idx))93.collect();94let gathered =95unsafe { gather_skip_nulls_idx_pairs_unchecked(self, index_pairs, indices.len()) };96let arr =97T::Array::from_zeroable_vec(gathered, self.dtype().to_arrow(CompatLevel::newest()));98Ok(ChunkedArray::from_chunk_iter_like(self, [arr]))99}100}101102impl<T: PolarsDataType> ChunkGatherSkipNulls<IdxCa> for ChunkedArray<T>103where104ChunkedArray<T>: ChunkFilter<T> + ChunkTake<IdxCa>,105{106fn gather_skip_nulls(&self, indices: &IdxCa) -> PolarsResult<Self> {107if self.null_count() == 0 {108return self.take(indices);109}110111// If we want many indices it's probably better to do a normal gather on112// a dense array.113if indices.len() >= self.len() / 4 {114return ChunkFilter::filter(self, &self.is_not_null())115.unwrap()116.take(indices);117}118119let bound = self.len() - self.null_count();120check_bounds_ca(indices, bound as IdxSize)?;121122let index_pairs: Vec<_> = if indices.null_count() == 0 {123indices124.downcast_iter()125.flat_map(|arr| arr.values_iter())126.enumerate()127.map(|(out_idx, nonnull_idx)| (out_idx as IdxSize, *nonnull_idx))128.collect()129} else {130// Filter *after* the enumerate so we place the non-null gather131// requests at the right places.132indices133.downcast_iter()134.flat_map(|arr| arr.iter())135.enumerate()136.filter_map(|(out_idx, nonnull_idx)| Some((out_idx as IdxSize, *nonnull_idx?)))137.collect()138};139let gathered = unsafe {140gather_skip_nulls_idx_pairs_unchecked(self, index_pairs, indices.as_ref().len())141};142143let mut arr =144T::Array::from_zeroable_vec(gathered, self.dtype().to_arrow(CompatLevel::newest()));145if indices.null_count() > 0 {146arr = arr.with_validity_typed(concatenate_validities(indices.chunks()));147}148Ok(ChunkedArray::from_chunk_iter_like(self, [arr]))149}150}151152#[cfg(test)]153mod test {154use std::ops::Range;155156use rand::distr::uniform::SampleUniform;157use rand::prelude::*;158159use super::*;160161fn random_vec<T: SampleUniform + PartialOrd + Clone, R: Rng>(162rng: &mut R,163val: Range<T>,164len_range: Range<usize>,165) -> Vec<T> {166let n = rng.random_range(len_range);167(0..n).map(|_| rng.random_range(val.clone())).collect()168}169170fn random_filter<T: Clone, R: Rng>(rng: &mut R, v: &[T], pr: Range<f64>) -> Vec<Option<T>> {171let p = rng.random_range(pr);172let rand_filter = |x| Some(x).filter(|_| rng.random::<f64>() < p);173v.iter().cloned().map(rand_filter).collect()174}175176fn ref_gather_nulls(v: Vec<Option<u32>>, idx: Vec<Option<usize>>) -> Option<Vec<Option<u32>>> {177let v: Vec<u32> = v.into_iter().flatten().collect();178if idx.iter().any(|oi| oi.map(|i| i >= v.len()) == Some(true)) {179return None;180}181Some(idx.into_iter().map(|i| Some(v[i?])).collect())182}183184fn test_equal_ref(ca: &UInt32Chunked, idx_ca: &IdxCa) {185let ref_ca: Vec<Option<u32>> = ca.iter().collect();186let ref_idx_ca: Vec<Option<usize>> = idx_ca.iter().map(|i| Some(i? as usize)).collect();187let gather = ca.gather_skip_nulls(idx_ca).ok();188let ref_gather = ref_gather_nulls(ref_ca, ref_idx_ca);189assert_eq!(gather.map(|ca| ca.iter().collect()), ref_gather);190}191192fn gather_skip_nulls_check(ca: &UInt32Chunked, idx_ca: &IdxCa) {193test_equal_ref(ca, idx_ca);194test_equal_ref(&ca.rechunk(), idx_ca);195test_equal_ref(ca, &idx_ca.rechunk());196test_equal_ref(&ca.rechunk(), &idx_ca.rechunk());197}198199#[rustfmt::skip]200#[test]201fn test_gather_skip_nulls() {202let mut rng = SmallRng::seed_from_u64(0xdeadbeef);203204for _test in 0..20 {205let num_elem_chunks = rng.random_range(1..10);206let elem_chunks: Vec<_> = (0..num_elem_chunks).map(|_| random_vec(&mut rng, 0..u32::MAX, 0..100)).collect();207let null_elem_chunks: Vec<_> = elem_chunks.iter().map(|c| random_filter(&mut rng, c, 0.7..1.0)).collect();208let num_nonnull_elems: usize = null_elem_chunks.iter().map(|c| c.iter().filter(|x| x.is_some()).count()).sum();209210let num_idx_chunks = rng.random_range(1..10);211let idx_chunks: Vec<_> = (0..num_idx_chunks).map(|_| random_vec(&mut rng, 0..num_nonnull_elems as IdxSize, 0..200)).collect();212let null_idx_chunks: Vec<_> = idx_chunks.iter().map(|c| random_filter(&mut rng, c, 0.7..1.0)).collect();213214let nonnull_ca = UInt32Chunked::from_chunk_iter("".into(), elem_chunks.iter().cloned().map(|v| v.into_iter().collect_arr()));215let ca = UInt32Chunked::from_chunk_iter("".into(), null_elem_chunks.iter().cloned().map(|v| v.into_iter().collect_arr()));216let nonnull_idx_ca = IdxCa::from_chunk_iter("".into(), idx_chunks.iter().cloned().map(|v| v.into_iter().collect_arr()));217let idx_ca = IdxCa::from_chunk_iter("".into(), null_idx_chunks.iter().cloned().map(|v| v.into_iter().collect_arr()));218219gather_skip_nulls_check(&ca, &idx_ca);220gather_skip_nulls_check(&ca, &nonnull_idx_ca);221gather_skip_nulls_check(&nonnull_ca, &idx_ca);222gather_skip_nulls_check(&nonnull_ca, &nonnull_idx_ca);223}224}225}226227228