Path: blob/main/crates/polars-ops/src/chunked_array/scatter.rs
8412 views
#![allow(unsafe_op_in_unsafe_fn)]1use arrow::array::{Array, BinaryViewArrayGeneric, BooleanArray, PrimitiveArray, View, ViewType};2use polars_buffer::Buffer;3use polars_core::prelude::*;4use polars_core::series::IsSorted;5use polars_core::utils::arrow::bitmap::MutableBitmap;6use polars_core::utils::arrow::types::NativeType;7use polars_utils::index::check_bounds;89pub trait ChunkedSet<T: Copy> {10/// Invariant for implementations: if the scatter() fails, typically because11/// of bad indexes, then self should remain unmodified.12fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>13where14V: IntoIterator<Item = Option<T>>;15}1617trait PolarsOpsNumericType: PolarsNumericType {}1819impl PolarsOpsNumericType for UInt8Type {}20impl PolarsOpsNumericType for UInt16Type {}21impl PolarsOpsNumericType for UInt32Type {}22impl PolarsOpsNumericType for UInt64Type {}23#[cfg(feature = "dtype-u128")]24impl PolarsOpsNumericType for UInt128Type {}25impl PolarsOpsNumericType for Int8Type {}26impl PolarsOpsNumericType for Int16Type {}27impl PolarsOpsNumericType for Int32Type {}28impl PolarsOpsNumericType for Int64Type {}29#[cfg(feature = "dtype-i128")]30impl PolarsOpsNumericType for Int128Type {}31#[cfg(feature = "dtype-f16")]32impl PolarsOpsNumericType for Float16Type {}33impl PolarsOpsNumericType for Float32Type {}34impl PolarsOpsNumericType for Float64Type {}3536unsafe fn scatter_primitive_impl<V, T: NativeType>(37set_values: V,38arr: &mut PrimitiveArray<T>,39idx: &[IdxSize],40) where41V: IntoIterator<Item = Option<T>>,42{43let mut values_iter = set_values.into_iter();4445if let Some(validity) = arr.take_validity() {46let mut mut_validity = validity.make_mut();47arr.with_values_mut(|cur_values| {48for (idx, val) in idx.iter().zip(&mut values_iter) {49match val {50Some(value) => {51mut_validity.set_unchecked(*idx as usize, true);52*cur_values.get_unchecked_mut(*idx as usize) = value53},54None => mut_validity.set_unchecked(*idx as usize, false),55}56}57});58arr.set_validity(mut_validity.into())59} else {60let mut null_idx = vec![];61arr.with_values_mut(|cur_values| {62for (idx, val) in idx.iter().zip(values_iter) {63match val {64Some(value) => *cur_values.get_unchecked_mut(*idx as usize) = value,65None => {66null_idx.push(*idx);67},68}69}70});7172// Only make a validity bitmap when null values are set.73if !null_idx.is_empty() {74let mut validity = MutableBitmap::with_capacity(arr.len());75validity.extend_constant(arr.len(), true);76for idx in null_idx {77validity.set_unchecked(idx as usize, false)78}79arr.set_validity(Some(validity.into()))80}81}82}8384unsafe fn scatter_bool_impl<V>(set_values: V, arr: &mut BooleanArray, idx: &[IdxSize])85where86V: IntoIterator<Item = Option<bool>>,87{88let mut values_iter = set_values.into_iter();8990if let Some(validity) = arr.take_validity() {91let mut mut_validity = validity.make_mut();92arr.apply_values_mut(|cur_values| {93for (idx, val) in idx.iter().zip(&mut values_iter) {94match val {95Some(value) => {96mut_validity.set_unchecked(*idx as usize, true);97cur_values.set_unchecked(*idx as usize, value);98},99None => mut_validity.set_unchecked(*idx as usize, false),100}101}102});103arr.set_validity(mut_validity.into())104} else {105let mut null_idx = vec![];106arr.apply_values_mut(|cur_values| {107for (idx, val) in idx.iter().zip(values_iter) {108match val {109Some(value) => cur_values.set_unchecked(*idx as usize, value),110None => {111null_idx.push(*idx);112},113}114}115});116117// Only make a validity bitmap when null values are set.118if !null_idx.is_empty() {119let mut validity = MutableBitmap::with_capacity(arr.len());120validity.extend_constant(arr.len(), true);121for idx in null_idx {122validity.set_unchecked(idx as usize, false)123}124arr.set_validity(Some(validity.into()))125}126}127}128129unsafe fn scatter_binview_impl<'a, V, T: ViewType + ?Sized>(130set_values: V,131arr: &mut BinaryViewArrayGeneric<T>,132idx: &[IdxSize],133) where134V: IntoIterator<Item = Option<&'a T>>,135{136let mut values_iter = set_values.into_iter();137let buffer_offset = arr.data_buffers().len() as u32;138let mut new_buffers = Vec::new();139140if let Some(validity) = arr.take_validity() {141let mut mut_validity = validity.make_mut();142arr.with_views_mut(|views| {143for (idx, val) in idx.iter().zip(&mut values_iter) {144if let Some(v) = val {145let view =146View::new_with_buffers(v.to_bytes(), buffer_offset, &mut new_buffers);147*views.get_unchecked_mut(*idx as usize) = view;148mut_validity.set_unchecked(*idx as usize, true);149} else {150mut_validity.set_unchecked(*idx as usize, false);151}152}153});154arr.set_validity(mut_validity.into())155} else {156let mut null_idx = vec![];157arr.with_views_mut(|views| {158for (idx, val) in idx.iter().zip(values_iter) {159if let Some(v) = val {160let view =161View::new_with_buffers(v.to_bytes(), buffer_offset, &mut new_buffers);162*views.get_unchecked_mut(*idx as usize) = view;163} else {164null_idx.push(*idx);165}166}167});168169// Only make a validity bitmap when null values are set.170if !null_idx.is_empty() {171let mut validity = MutableBitmap::with_capacity(arr.len());172validity.extend_constant(arr.len(), true);173for idx in null_idx {174validity.set_unchecked(idx as usize, false)175}176arr.set_validity(Some(validity.into()))177}178}179180let mut buffers = Buffer::to_vec(core::mem::take(arr.data_buffers_mut()));181buffers.extend(new_buffers.into_iter().map(Buffer::from));182*arr.data_buffers_mut() = Buffer::from(buffers);183}184185impl<T: PolarsOpsNumericType> ChunkedSet<T::Native> for &mut ChunkedArray<T> {186fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>187where188V: IntoIterator<Item = Option<T::Native>>,189{190check_bounds(idx, self.len() as IdxSize)?;191let mut ca = std::mem::take(self);192193// SAFETY: we will not modify the length and we unset the sorted flag,194// making sure to update the null count as well.195unsafe {196ca.rechunk_mut();197let arr = ca.downcast_iter_mut().next().unwrap();198scatter_primitive_impl(values, arr, idx);199let null_count = arr.null_count();200ca.set_sorted_flag(IsSorted::Not);201ca.set_null_count(null_count);202}203204Ok(ca.into_series())205}206}207208impl<'a> ChunkedSet<&'a [u8]> for &mut BinaryChunked {209fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>210where211V: IntoIterator<Item = Option<&'a [u8]>>,212{213check_bounds(idx, self.len() as IdxSize)?;214let mut ca = std::mem::take(self);215216unsafe {217ca.rechunk_mut();218let arr = ca.downcast_iter_mut().next().unwrap();219scatter_binview_impl(values, arr, idx);220let null_count = arr.null_count();221ca.set_sorted_flag(IsSorted::Not);222ca.set_null_count(null_count);223}224225Ok(ca.into_series())226}227}228229impl<'a> ChunkedSet<&'a str> for &mut StringChunked {230fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>231where232V: IntoIterator<Item = Option<&'a str>>,233{234check_bounds(idx, self.len() as IdxSize)?;235let mut ca = std::mem::take(self);236237unsafe {238ca.rechunk_mut();239let arr = ca.downcast_iter_mut().next().unwrap();240scatter_binview_impl(values, arr, idx);241let null_count = arr.null_count();242ca.set_sorted_flag(IsSorted::Not);243ca.set_null_count(null_count);244}245246Ok(ca.into_series())247}248}249impl ChunkedSet<bool> for &mut BooleanChunked {250fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>251where252V: IntoIterator<Item = Option<bool>>,253{254check_bounds(idx, self.len() as IdxSize)?;255let mut ca = std::mem::take(self);256257unsafe {258ca.rechunk_mut();259let arr = ca.downcast_iter_mut().next().unwrap();260scatter_bool_impl(values, arr, idx);261let null_count = arr.null_count();262ca.set_sorted_flag(IsSorted::Not);263ca.set_null_count(null_count);264}265266Ok(ca.into_series())267}268}269270271