Path: blob/main/crates/polars-python/src/series/scatter.rs
7889 views
use arrow::array::Array;1use polars::prelude::*;2use pyo3::prelude::*;34use super::PySeries;5use crate::utils::EnterPolarsExt;67#[pymethods]8impl PySeries {9fn scatter(&self, py: Python<'_>, idx: PySeries, values: PySeries) -> PyResult<()> {10py.enter_polars(|| {11// we take the value because we want a ref count of 1 so that we can12// have mutable access cheaply via _get_inner_mut().13let mut lock = self.series.write();14let s = std::mem::take(&mut *lock);15let result = scatter(s, &idx.series.into_inner(), &values.series.into_inner());16match result {17Ok(out) => {18*lock = out;19Ok(())20},21Err((s, e)) => {22// Restore original series:23*lock = s;24Err(e)25},26}27})28}29}3031fn scatter(mut s: Series, idx: &Series, values: &Series) -> Result<Series, (Series, PolarsError)> {32let logical_dtype = s.dtype().clone();3334let values = if logical_dtype.is_categorical() || logical_dtype.is_enum() {35if matches!(36values.dtype(),37DataType::Categorical(_, _) | DataType::Enum(_, _) | DataType::String | DataType::Null38) {39match values.strict_cast(&logical_dtype) {40Ok(values) => values,41Err(err) => return Err((s, err)),42}43} else {44return Err((45s,46polars_err!(InvalidOperation: "invalid values dtype '{}' for scattering into dtype '{}'", values.dtype(), logical_dtype),47));48}49} else {50values.clone()51};5253let idx = match polars_ops::prelude::convert_to_unsigned_index(idx, s.len()) {54Ok(idx) => idx,55Err(err) => return Err((s, err)),56};57let idx = idx.rechunk();58let idx = idx.downcast_as_array();5960if idx.null_count() > 0 {61return Err((62s,63PolarsError::ComputeError("index values should not be null".into()),64));65}6667let idx = idx.values().as_slice();6869let mut values = match values.to_physical_repr().cast(&s.dtype().to_physical()) {70Ok(values) => values,71Err(err) => return Err((s, err)),72};7374// Broadcast values input75if values.len() == 1 && idx.len() > 1 {76values = values.new_from_index(0, idx.len());77}7879// do not shadow, otherwise s is not dropped immediately80// and we want to have mutable access81s = s.to_physical_repr().into_owned();82let s_mut_ref = &mut s;83scatter_impl(s_mut_ref, logical_dtype, idx, &values).map_err(|err| (s, err))84}8586fn scatter_impl(87s: &mut Series,88logical_dtype: DataType,89idx: &[IdxSize],90values: &Series,91) -> PolarsResult<Series> {92let mutable_s = s._get_inner_mut();9394let s = match logical_dtype.to_physical() {95DataType::Int8 => {96let ca: &mut ChunkedArray<Int8Type> = mutable_s.as_mut();97let values = values.i8()?;98ca.scatter(idx, values)99},100DataType::Int16 => {101let ca: &mut ChunkedArray<Int16Type> = mutable_s.as_mut();102let values = values.i16()?;103ca.scatter(idx, values)104},105DataType::Int32 => {106let ca: &mut ChunkedArray<Int32Type> = mutable_s.as_mut();107let values = values.i32()?;108ca.scatter(idx, values)109},110DataType::Int64 => {111let ca: &mut ChunkedArray<Int64Type> = mutable_s.as_mut();112let values = values.i64()?;113ca.scatter(idx, values)114},115DataType::UInt8 => {116let ca: &mut ChunkedArray<UInt8Type> = mutable_s.as_mut();117let values = values.u8()?;118ca.scatter(idx, values)119},120DataType::UInt16 => {121let ca: &mut ChunkedArray<UInt16Type> = mutable_s.as_mut();122let values = values.u16()?;123ca.scatter(idx, values)124},125DataType::UInt32 => {126let ca: &mut ChunkedArray<UInt32Type> = mutable_s.as_mut();127let values = values.u32()?;128ca.scatter(idx, values)129},130DataType::UInt64 => {131let ca: &mut ChunkedArray<UInt64Type> = mutable_s.as_mut();132let values = values.u64()?;133ca.scatter(idx, values)134},135DataType::Float32 => {136let ca: &mut ChunkedArray<Float32Type> = mutable_s.as_mut();137let values = values.f32()?;138ca.scatter(idx, values)139},140DataType::Float64 => {141let ca: &mut ChunkedArray<Float64Type> = mutable_s.as_mut();142let values = values.f64()?;143ca.scatter(idx, values)144},145DataType::Boolean => {146let ca = s.bool()?;147let values = values.bool()?;148ca.scatter(idx, values)149},150DataType::String => {151let ca = s.str()?;152let values = values.str()?;153ca.scatter(idx, values)154},155_ => {156return Err(PolarsError::ComputeError(157format!("not yet implemented for dtype: {logical_dtype}").into(),158));159},160};161162s.and_then(|s| s.cast(&logical_dtype))163}164165166