Path: blob/main/crates/polars-ops/src/chunked_array/scatter.rs
6939 views
#![allow(unsafe_op_in_unsafe_fn)]1use arrow::array::{Array, PrimitiveArray};2use polars_core::prelude::*;3use polars_core::series::IsSorted;4use polars_core::utils::arrow::bitmap::MutableBitmap;5use polars_core::utils::arrow::types::NativeType;6use polars_utils::index::check_bounds;78pub trait ChunkedSet<T: Copy> {9/// Invariant for implementations: if the scatter() fails, typically because10/// of bad indexes, then self should remain unmodified.11fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>12where13V: IntoIterator<Item = Option<T>>;14}15fn check_sorted(idx: &[IdxSize]) -> PolarsResult<()> {16if idx.is_empty() {17return Ok(());18}19let mut sorted = true;20let mut previous = idx[0];21for &i in &idx[1..] {22if i < previous {23// we will not break here as that prevents SIMD24sorted = false;25}26previous = i;27}28polars_ensure!(sorted, ComputeError: "set indices must be sorted");29Ok(())30}3132trait PolarsOpsNumericType: PolarsNumericType {}3334impl PolarsOpsNumericType for UInt8Type {}35impl PolarsOpsNumericType for UInt16Type {}36impl PolarsOpsNumericType for UInt32Type {}37impl PolarsOpsNumericType for UInt64Type {}38impl PolarsOpsNumericType for Int8Type {}39impl PolarsOpsNumericType for Int16Type {}40impl PolarsOpsNumericType for Int32Type {}41impl PolarsOpsNumericType for Int64Type {}42#[cfg(feature = "dtype-i128")]43impl PolarsOpsNumericType for Int128Type {}44impl PolarsOpsNumericType for Float32Type {}45impl PolarsOpsNumericType for Float64Type {}4647unsafe fn scatter_impl<V, T: NativeType>(48new_values_slice: &mut [T],49set_values: V,50arr: &mut PrimitiveArray<T>,51idx: &[IdxSize],52len: usize,53) where54V: IntoIterator<Item = Option<T>>,55{56let mut values_iter = set_values.into_iter();5758if arr.null_count() > 0 {59arr.apply_validity(|v| {60let mut mut_validity = v.make_mut();6162for (idx, val) in idx.iter().zip(&mut values_iter) {63match val {64Some(value) => {65mut_validity.set_unchecked(*idx as usize, true);66*new_values_slice.get_unchecked_mut(*idx as usize) = value67},68None => mut_validity.set_unchecked(*idx as usize, false),69}70}71mut_validity.into()72})73} else {74let mut null_idx = vec![];75for (idx, val) in idx.iter().zip(values_iter) {76match val {77Some(value) => *new_values_slice.get_unchecked_mut(*idx as usize) = value,78None => {79null_idx.push(*idx);80},81}82}83// only make a validity bitmap when null values are set84if !null_idx.is_empty() {85let mut validity = MutableBitmap::with_capacity(len);86validity.extend_constant(len, true);87for idx in null_idx {88validity.set_unchecked(idx as usize, false)89}90arr.set_validity(Some(validity.into()))91}92}93}9495impl<T: PolarsOpsNumericType> ChunkedSet<T::Native> for &mut ChunkedArray<T> {96fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>97where98V: IntoIterator<Item = Option<T::Native>>,99{100check_bounds(idx, self.len() as IdxSize)?;101let mut ca = std::mem::take(self);102ca.rechunk_mut();103104// SAFETY:105// we will not modify the length106// and we unset the sorted flag.107ca.set_sorted_flag(IsSorted::Not);108let arr = unsafe { ca.downcast_iter_mut() }.next().unwrap();109let len = arr.len();110111match arr.get_mut_values() {112Some(current_values) => {113let ptr = current_values.as_mut_ptr();114115// reborrow because the bck does not allow it116let current_values = unsafe { &mut *std::slice::from_raw_parts_mut(ptr, len) };117// SAFETY:118// we checked bounds119unsafe { scatter_impl(current_values, values, arr, idx, len) };120},121None => {122let mut new_values = arr.values().as_slice().to_vec();123// SAFETY:124// we checked bounds125unsafe { scatter_impl(&mut new_values, values, arr, idx, len) };126arr.set_values(new_values.into());127},128};129130// The null count may have changed - make sure to update the ChunkedArray131let new_null_count = arr.null_count();132unsafe { ca.set_null_count(new_null_count) };133134Ok(ca.into_series())135}136}137138impl<'a> ChunkedSet<&'a str> for &'a StringChunked {139fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>140where141V: IntoIterator<Item = Option<&'a str>>,142{143check_bounds(idx, self.len() as IdxSize)?;144check_sorted(idx)?;145let mut ca_iter = self.into_iter().enumerate();146let mut builder = StringChunkedBuilder::new(self.name().clone(), self.len());147148for (current_idx, current_value) in idx.iter().zip(values) {149for (cnt_idx, opt_val_self) in &mut ca_iter {150if cnt_idx == *current_idx as usize {151builder.append_option(current_value);152break;153} else {154builder.append_option(opt_val_self);155}156}157}158// the last idx is probably not the last value so we finish the iterator159for (_, opt_val_self) in ca_iter {160builder.append_option(opt_val_self);161}162163let ca = builder.finish();164Ok(ca.into_series())165}166}167impl ChunkedSet<bool> for &BooleanChunked {168fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>169where170V: IntoIterator<Item = Option<bool>>,171{172check_bounds(idx, self.len() as IdxSize)?;173check_sorted(idx)?;174let mut ca_iter = self.into_iter().enumerate();175let mut builder = BooleanChunkedBuilder::new(self.name().clone(), self.len());176177for (current_idx, current_value) in idx.iter().zip(values) {178for (cnt_idx, opt_val_self) in &mut ca_iter {179if cnt_idx == *current_idx as usize {180builder.append_option(current_value);181break;182} else {183builder.append_option(opt_val_self);184}185}186}187// the last idx is probably not the last value so we finish the iterator188for (_, opt_val_self) in ca_iter {189builder.append_option(opt_val_self);190}191192let ca = builder.finish();193Ok(ca.into_series())194}195}196197198