Path: blob/main/crates/polars-ops/src/series/ops/clip.rs
6939 views
use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise, unary_elementwise};1use polars_core::prelude::*;2use polars_core::with_match_physical_numeric_polars_type;34/// Set values outside the given boundaries to the boundary value.5pub fn clip(s: &Series, min: &Series, max: &Series) -> PolarsResult<Series> {6polars_ensure!(7s.dtype().to_physical().is_primitive_numeric(),8InvalidOperation: "`clip` only supports physical numeric types"9);10let n = [s.len(), min.len(), max.len()]11.into_iter()12.find(|l| *l != 1)13.unwrap_or(1);1415for (i, (name, length)) in [("self", s.len()), ("min", min.len()), ("max", max.len())]16.into_iter()17.enumerate()18{19polars_ensure!(20length == n || length == 1,21length_mismatch = "clip",22length,23n,24argument = name,25argument_idx = i26);27}2829let original_type = s.dtype();30let (min, max) = (min.strict_cast(s.dtype())?, max.strict_cast(s.dtype())?);3132let (s, min, max) = (33s.to_physical_repr(),34min.to_physical_repr(),35max.to_physical_repr(),36);3738with_match_physical_numeric_polars_type!(s.dtype(), |$T| {39let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();40let min: &ChunkedArray<$T> = min.as_ref().as_ref().as_ref();41let max: &ChunkedArray<$T> = max.as_ref().as_ref().as_ref();42let out = clip_helper_both_bounds(ca, min, max).into_series();43match original_type {44#[cfg(feature = "dtype-decimal")]45DataType::Decimal(precision, scale) => {46let phys = out.i128()?.as_ref().clone();47Ok(phys.into_decimal_unchecked(*precision, scale.unwrap()).into_series())48},49dt if dt.is_logical() => out.cast(original_type),50_ => Ok(out)51}52})53}5455/// Set values above the given maximum to the maximum value.56pub fn clip_max(s: &Series, max: &Series) -> PolarsResult<Series> {57polars_ensure!(58s.dtype().to_physical().is_primitive_numeric(),59InvalidOperation: "`clip` only supports physical numeric types"60);61polars_ensure!(62s.len() == max.len() || s.len() == 1 || max.len() == 1,63length_mismatch = "clip(max)",64s.len(),65max.len()66);6768let original_type = s.dtype();69let max = max.strict_cast(s.dtype())?;7071let (s, max) = (s.to_physical_repr(), max.to_physical_repr());7273with_match_physical_numeric_polars_type!(s.dtype(), |$T| {74let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();75let max: &ChunkedArray<$T> = max.as_ref().as_ref().as_ref();76let out = clip_helper_single_bound(ca, max, num_traits::clamp_max).into_series();77match original_type {78#[cfg(feature = "dtype-decimal")]79DataType::Decimal(precision, scale) => {80let phys = out.i128()?.as_ref().clone();81Ok(phys.into_decimal_unchecked(*precision, scale.unwrap()).into_series())82},83dt if dt.is_logical() => out.cast(original_type),84_ => Ok(out)85}86})87}8889/// Set values below the given minimum to the minimum value.90pub fn clip_min(s: &Series, min: &Series) -> PolarsResult<Series> {91polars_ensure!(92s.dtype().to_physical().is_primitive_numeric(),93InvalidOperation: "`clip` only supports physical numeric types"94);95polars_ensure!(96s.len() == min.len() || s.len() == 1 || min.len() == 1,97length_mismatch = "clip(min)",98s.len(),99min.len()100);101102let original_type = s.dtype();103let min = min.strict_cast(s.dtype())?;104105let (s, min) = (s.to_physical_repr(), min.to_physical_repr());106107with_match_physical_numeric_polars_type!(s.dtype(), |$T| {108let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();109let min: &ChunkedArray<$T> = min.as_ref().as_ref().as_ref();110let out = clip_helper_single_bound(ca, min, num_traits::clamp_min).into_series();111match original_type {112#[cfg(feature = "dtype-decimal")]113DataType::Decimal(precision, scale) => {114let phys = out.i128()?.as_ref().clone();115Ok(phys.into_decimal_unchecked(*precision, scale.unwrap()).into_series())116},117dt if dt.is_logical() => out.cast(original_type),118_ => Ok(out)119}120})121}122123fn clip_helper_both_bounds<T>(124ca: &ChunkedArray<T>,125min: &ChunkedArray<T>,126max: &ChunkedArray<T>,127) -> ChunkedArray<T>128where129T: PolarsNumericType,130T::Native: PartialOrd,131{132match (min.len(), max.len()) {133(1, 1) => match (min.get(0), max.get(0)) {134(Some(min), Some(max)) => clip_unary(ca, |v| num_traits::clamp(v, min, max)),135(Some(min), None) => clip_unary(ca, |v| num_traits::clamp_min(v, min)),136(None, Some(max)) => clip_unary(ca, |v| num_traits::clamp_max(v, max)),137(None, None) => ca.clone(),138},139(1, _) => match min.get(0) {140Some(min) => clip_binary(ca, max, |v, b| num_traits::clamp(v, min, b)),141None => clip_binary(ca, max, num_traits::clamp_max),142},143(_, 1) => match max.get(0) {144Some(max) => clip_binary(ca, min, |v, b| num_traits::clamp(v, b, max)),145None => clip_binary(ca, min, num_traits::clamp_min),146},147_ => clip_ternary(ca, min, max),148}149}150151fn clip_helper_single_bound<T, F>(152ca: &ChunkedArray<T>,153bound: &ChunkedArray<T>,154op: F,155) -> ChunkedArray<T>156where157T: PolarsNumericType,158T::Native: PartialOrd,159F: Fn(T::Native, T::Native) -> T::Native,160{161match bound.len() {1621 => match bound.get(0) {163Some(bound) => clip_unary(ca, |v| op(v, bound)),164None => ca.clone(),165},166_ => clip_binary(ca, bound, op),167}168}169170fn clip_unary<T, F>(ca: &ChunkedArray<T>, op: F) -> ChunkedArray<T>171where172T: PolarsNumericType,173F: Fn(T::Native) -> T::Native + Copy,174{175unary_elementwise(ca, |v| v.map(op))176}177178fn clip_binary<T, F>(ca: &ChunkedArray<T>, bound: &ChunkedArray<T>, op: F) -> ChunkedArray<T>179where180T: PolarsNumericType,181T::Native: PartialOrd,182F: Fn(T::Native, T::Native) -> T::Native,183{184binary_elementwise(ca, bound, |opt_s, opt_bound| match (opt_s, opt_bound) {185(Some(s), Some(bound)) => Some(op(s, bound)),186(Some(s), None) => Some(s),187(None, _) => None,188})189}190191fn clip_ternary<T>(192ca: &ChunkedArray<T>,193min: &ChunkedArray<T>,194max: &ChunkedArray<T>,195) -> ChunkedArray<T>196where197T: PolarsNumericType,198T::Native: PartialOrd,199{200ternary_elementwise(ca, min, max, |opt_v, opt_min, opt_max| {201match (opt_v, opt_min, opt_max) {202(Some(v), Some(min), Some(max)) => Some(num_traits::clamp(v, min, max)),203(Some(v), Some(min), None) => Some(num_traits::clamp_min(v, min)),204(Some(v), None, Some(max)) => Some(num_traits::clamp_max(v, max)),205(Some(v), None, None) => Some(v),206(None, _, _) => None,207}208})209}210211212