Path: blob/main/crates/polars-python/src/series/scatter.rs
8393 views
use arrow::array::Array;1use polars::prelude::*;2use polars_core::with_match_physical_numeric_polars_type;3use pyo3::prelude::*;45use super::PySeries;6use crate::utils::EnterPolarsExt;78#[pymethods]9impl PySeries {10fn scatter(&self, py: Python<'_>, idx: PySeries, values: PySeries) -> PyResult<()> {11py.enter_polars(|| {12// We take the value because we want a ref count of 1 so that we can13// have mutable access cheaply via _get_inner_mut().14let mut lock = self.series.write();15let s = std::mem::take(&mut *lock);16let result = scatter(s, &idx.series.into_inner(), &values.series.into_inner());17match result {18Ok(out) => {19*lock = out;20Ok(())21},22Err((s, e)) => {23*lock = s; // Restore original series.24Err(e)25},26}27})28}29}3031fn scatter(s: Series, idx: &Series, values: &Series) -> Result<Series, (Series, PolarsError)> {32let logical_dtype = s.dtype().clone();33let converted_values;34let values = if logical_dtype.is_categorical() || logical_dtype.is_enum() {35if matches!(36values.dtype(),37DataType::Categorical(_, _) | DataType::Enum(_, _) | DataType::String | DataType::Null38) {39converted_values = values.strict_cast(&logical_dtype);40match converted_values {41Ok(ref values) => values,42Err(err) => return Err((s, err)),43}44} else {45return Err((46s,47polars_err!(InvalidOperation: "invalid values dtype '{}' for scattering into dtype '{}'", values.dtype(), logical_dtype),48));49}50} else if logical_dtype.is_decimal() {51if values.dtype().is_numeric() {52converted_values = values.strict_cast(&logical_dtype);53match converted_values {54Ok(ref values) => values,55Err(err) => return Err((s, err)),56}57} else {58return Err((59s,60polars_err!(InvalidOperation: "invalid values dtype '{}' for scattering into dtype '{}'", values.dtype(), logical_dtype),61));62}63} else {64values65};6667let null_on_oob = false;68let idx = match polars_ops::prelude::convert_and_bound_index(idx, s.len(), null_on_oob) {69Ok(idx) => idx,70Err(err) => return Err((s, err)),71};72let idx = idx.rechunk();73let idx = idx.downcast_as_array();74if idx.has_nulls() {75return Err((76s,77PolarsError::ComputeError("index values should not be null".into()),78));79}80let idx = idx.values().as_slice();8182let mut values = match values.to_physical_repr().cast(&s.dtype().to_physical()) {83Ok(values) => values,84Err(err) => return Err((s, err)),85};8687// Broadcast values input.88if values.len() == 1 && idx.len() > 1 {89values = values.new_from_index(0, idx.len());90}9192let mut phys = s.to_physical_repr().into_owned();93drop(s); // Reduce refcount to make use of in-place mutation of possible.94let ret = scatter_impl(&mut phys, &logical_dtype, idx, &values);95match ret {96Ok(s) => Ok(unsafe { s.from_physical_unchecked(&logical_dtype).unwrap() }),97Err(e) => Err((98unsafe { phys.from_physical_unchecked(&logical_dtype).unwrap() },99e,100)),101}102}103104fn scatter_impl(105s: &mut Series,106logical_dtype: &DataType,107idx: &[IdxSize],108values: &Series,109) -> PolarsResult<Series> {110let mutable_s = s._get_inner_mut();111112match mutable_s.dtype() {113dt if dt.is_primitive_numeric() => {114with_match_physical_numeric_polars_type!(dt, |$T| {115let ca: &mut ChunkedArray<$T> = mutable_s.as_mut();116let values: &ChunkedArray<$T> = values.as_ref().as_ref();117ca.scatter(idx, values)118})119},120DataType::Boolean => {121let ca: &mut ChunkedArray<BooleanType> = mutable_s.as_mut();122let values = values.bool()?;123ca.scatter(idx, values)124},125DataType::Binary => {126let ca: &mut ChunkedArray<BinaryType> = mutable_s.as_mut();127let values = values.binary()?;128ca.scatter(idx, values)129},130DataType::String => {131let ca: &mut ChunkedArray<StringType> = mutable_s.as_mut();132let values = values.str()?;133ca.scatter(idx, values)134},135_ => Err(PolarsError::ComputeError(136format!("not yet implemented for dtype: {logical_dtype}").into(),137)),138}139}140141142