Path: blob/main/crates/polars-arrow/src/legacy/kernels/set.rs
6939 views
use std::ops::BitOr;12use polars_error::polars_err;3use polars_utils::IdxSize;45use crate::array::*;6use crate::datatypes::ArrowDataType;7use crate::legacy::array::default_arrays::FromData;8use crate::legacy::error::PolarsResult;9use crate::legacy::kernels::BinaryMaskedSliceIterator;10use crate::legacy::trusted_len::TrustedLenPush;11use crate::types::NativeType;1213/// Set values in a primitive array where the primitive array has null values.14/// this is faster because we don't have to invert and combine bitmaps15pub fn set_at_nulls<T>(array: &PrimitiveArray<T>, value: T) -> PrimitiveArray<T>16where17T: NativeType,18{19let values = array.values();20if array.null_count() == 0 {21return array.clone();22}2324let validity = array.validity().unwrap();25let validity = BooleanArray::from_data_default(validity.clone(), None);2627let mut av = Vec::with_capacity(array.len());28BinaryMaskedSliceIterator::new(&validity).for_each(|(lower, upper, truthy)| {29if truthy {30av.extend_from_slice(&values[lower..upper])31} else {32av.extend_trusted_len(std::iter::repeat_n(value, upper - lower))33}34});3536PrimitiveArray::new(array.dtype().clone(), av.into(), None)37}3839/// Set values in a primitive array based on a mask array. This is fast when large chunks of bits are set or unset.40pub fn set_with_mask<T: NativeType>(41array: &PrimitiveArray<T>,42mask: &BooleanArray,43value: T,44dtype: ArrowDataType,45) -> PrimitiveArray<T> {46let values = array.values();4748let mut buf = Vec::with_capacity(array.len());49BinaryMaskedSliceIterator::new(mask).for_each(|(lower, upper, truthy)| {50if truthy {51buf.extend_trusted_len(std::iter::repeat_n(value, upper - lower))52} else {53buf.extend_from_slice(&values[lower..upper])54}55});56// make sure that where the mask is set to true, the validity buffer is also set to valid57// after we have applied the or operation we have new buffer with no offsets58let validity = array.validity().as_ref().map(|valid| {59let mask_bitmap = mask.values();60valid.bitor(mask_bitmap)61});6263PrimitiveArray::new(dtype, buf.into(), validity)64}6566/// Efficiently sets value at the indices from the iterator to `set_value`.67/// The new array is initialized with a `memcpy` from the old values.68pub fn scatter_single_non_null<T, I>(69array: &PrimitiveArray<T>,70idx: I,71set_value: T,72dtype: ArrowDataType,73) -> PolarsResult<PrimitiveArray<T>>74where75T: NativeType,76I: IntoIterator<Item = IdxSize>,77{78let mut buf = Vec::with_capacity(array.len());79buf.extend_from_slice(array.values().as_slice());80let mut_slice = buf.as_mut_slice();8182idx.into_iter().try_for_each::<_, PolarsResult<_>>(|idx| {83let val = mut_slice84.get_mut(idx as usize)85.ok_or_else(|| polars_err!(ComputeError: "index is out of bounds"))?;86*val = set_value;87Ok(())88})?;8990Ok(PrimitiveArray::new(91dtype,92buf.into(),93array.validity().cloned(),94))95}9697#[cfg(test)]98mod test {99use super::*;100101#[test]102fn test_set_mask() {103let mask = BooleanArray::from_iter((0..86).map(|v| v > 68 && v != 85).map(Some));104let val = UInt32Array::from_iter((0..86).map(Some));105let a = set_with_mask(&val, &mask, 100, ArrowDataType::UInt32);106let slice = a.values();107108assert_eq!(slice[a.len() - 1], 85);109assert_eq!(slice[a.len() - 2], 100);110assert_eq!(slice[67], 67);111assert_eq!(slice[68], 68);112assert_eq!(slice[1], 1);113assert_eq!(slice[0], 0);114115let mask = BooleanArray::from_slice([116false, true, false, true, false, true, false, true, false, false,117]);118let val = UInt32Array::from_slice([0; 10]);119let out = set_with_mask(&val, &mask, 1, ArrowDataType::UInt32);120assert_eq!(out.values().as_slice(), &[0, 1, 0, 1, 0, 1, 0, 1, 0, 0]);121122let val = UInt32Array::from(&[None, None, None]);123let mask = BooleanArray::from(&[Some(true), Some(true), None]);124let out = set_with_mask(&val, &mask, 1, ArrowDataType::UInt32);125let out: Vec<_> = out.iter().map(|v| v.copied()).collect();126assert_eq!(out, &[Some(1), Some(1), None])127}128129#[test]130fn test_scatter_single_non_null() {131let val = UInt32Array::from_slice([1, 2, 3]);132let out =133scatter_single_non_null(&val, std::iter::once(1), 100, ArrowDataType::UInt32).unwrap();134assert_eq!(out.values().as_slice(), &[1, 100, 3]);135let out = scatter_single_non_null(&val, std::iter::once(100), 100, ArrowDataType::UInt32);136assert!(out.is_err())137}138}139140141