Path: blob/main/crates/polars-ops/src/series/ops/interpolation/interpolate.rs
6939 views
use std::ops::{Add, Div, Mul, Sub};12use arrow::array::PrimitiveArray;3use arrow::bitmap::MutableBitmap;4use num_traits::{NumCast, Zero};5use polars_core::downcast_as_macro_arg_physical;6use polars_core::prelude::*;7#[cfg(feature = "serde")]8use serde::{Deserialize, Serialize};910use super::{linear_itp, nearest_itp};1112fn near_interp<T>(low: T, high: T, steps: IdxSize, steps_n: T, out: &mut Vec<T>)13where14T: Sub<Output = T>15+ Mul<Output = T>16+ Add<Output = T>17+ Div<Output = T>18+ NumCast19+ Copy20+ PartialOrd,21{22let diff = high - low;23for step_i in 1..steps {24let step_i: T = NumCast::from(step_i).unwrap();25let v = nearest_itp(low, step_i, diff, steps_n);26out.push(v)27}28}2930#[inline]31fn signed_interp<T>(low: T, high: T, steps: IdxSize, steps_n: T, out: &mut Vec<T>)32where33T: Sub<Output = T> + Mul<Output = T> + Add<Output = T> + Div<Output = T> + NumCast + Copy,34{35let slope = (high - low) / steps_n;36for step_i in 1..steps {37let step_i: T = NumCast::from(step_i).unwrap();38let v = linear_itp(low, step_i, slope);39out.push(v)40}41}4243fn interpolate_impl<T, I>(chunked_arr: &ChunkedArray<T>, interpolation_branch: I) -> ChunkedArray<T>44where45T: PolarsNumericType,46I: Fn(T::Native, T::Native, IdxSize, T::Native, &mut Vec<T::Native>),47{48// This implementation differs from pandas as that boundary None's are not removed.49// This prevents a lot of errors due to expressions leading to different lengths.50if !chunked_arr.has_nulls() || chunked_arr.null_count() == chunked_arr.len() {51return chunked_arr.clone();52}5354// We first find the first and last so that we can set the null buffer.55let first = chunked_arr.first_non_null().unwrap();56let last = chunked_arr.last_non_null().unwrap() + 1;5758// Fill out with `first` nulls.59let mut out = Vec::with_capacity(chunked_arr.len());60let mut iter = chunked_arr.iter().skip(first);61for _ in 0..first {62out.push(Zero::zero());63}6465// The next element of `iter` is definitely `Some(Some(v))`, because we skipped the first66// elements `first` and if all values were missing we'd have done an early return.67let mut low = iter.next().unwrap().unwrap();68out.push(low);69while let Some(next) = iter.next() {70if let Some(v) = next {71out.push(v);72low = v;73} else {74let mut steps = 1 as IdxSize;75for next in iter.by_ref() {76steps += 1;77if let Some(high) = next {78let steps_n: T::Native = NumCast::from(steps).unwrap();79interpolation_branch(low, high, steps, steps_n, &mut out);80out.push(high);81low = high;82break;83}84}85}86}87if first != 0 || last != chunked_arr.len() {88let mut validity = MutableBitmap::with_capacity(chunked_arr.len());89validity.extend_constant(chunked_arr.len(), true);9091for i in 0..first {92unsafe { validity.set_unchecked(i, false) };93}9495for i in last..chunked_arr.len() {96unsafe { validity.set_unchecked(i, false) };97out.push(Zero::zero())98}99100let array = PrimitiveArray::new(101T::get_static_dtype().to_arrow(CompatLevel::newest()),102out.into(),103Some(validity.into()),104);105ChunkedArray::with_chunk(chunked_arr.name().clone(), array)106} else {107ChunkedArray::from_vec(chunked_arr.name().clone(), out)108}109}110111fn interpolate_nearest(s: &Series) -> Series {112match s.dtype() {113#[cfg(feature = "dtype-categorical")]114DataType::Categorical(_, _) | DataType::Enum(_, _) => s.clone(),115DataType::Binary => s.clone(),116#[cfg(feature = "dtype-struct")]117DataType::Struct(_) => s.clone(),118DataType::List(_) => s.clone(),119_ => {120let logical = s.dtype();121let s = s.to_physical_repr();122123macro_rules! dispatch {124($ca:expr) => {{ interpolate_impl($ca, near_interp).into_series() }};125}126let out = downcast_as_macro_arg_physical!(s, dispatch);127match logical {128#[cfg(feature = "dtype-decimal")]129DataType::Decimal(_, _) => unsafe { out.from_physical_unchecked(logical).unwrap() },130_ => out.cast(logical).unwrap(),131}132},133}134}135136fn interpolate_linear(s: &Series) -> Series {137match s.dtype() {138#[cfg(feature = "dtype-categorical")]139DataType::Categorical(_, _) | DataType::Enum(_, _) => s.clone(),140DataType::Binary => s.clone(),141#[cfg(feature = "dtype-struct")]142DataType::Struct(_) => s.clone(),143DataType::List(_) => s.clone(),144_ => {145let logical = s.dtype();146147let s = s.to_physical_repr();148149#[cfg(feature = "dtype-decimal")]150{151if matches!(logical, DataType::Decimal(_, _)) {152let out = linear_interp_signed(s.i128().unwrap());153return unsafe { out.from_physical_unchecked(logical).unwrap() };154}155}156157let out = if matches!(158logical,159DataType::Date | DataType::Datetime(_, _) | DataType::Duration(_) | DataType::Time160) {161match s.dtype() {162// Datetime, Time, or Duration163DataType::Int64 => linear_interp_signed(s.i64().unwrap()),164// Date165DataType::Int32 => linear_interp_signed(s.i32().unwrap()),166_ => unreachable!(),167}168} else {169match s.dtype() {170DataType::Float32 => linear_interp_signed(s.f32().unwrap()),171DataType::Float64 => linear_interp_signed(s.f64().unwrap()),172DataType::Int8173| DataType::Int16174| DataType::Int32175| DataType::Int64176| DataType::Int128177| DataType::UInt8178| DataType::UInt16179| DataType::UInt32180| DataType::UInt64 => {181linear_interp_signed(s.cast(&DataType::Float64).unwrap().f64().unwrap())182},183_ => s.as_ref().clone(),184}185};186match logical {187DataType::Date188| DataType::Datetime(_, _)189| DataType::Duration(_)190| DataType::Time => out.cast(logical).unwrap(),191_ => out,192}193},194}195}196197fn linear_interp_signed<T: PolarsNumericType>(ca: &ChunkedArray<T>) -> Series {198interpolate_impl(ca, signed_interp::<T::Native>).into_series()199}200201#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]202#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]203#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]204pub enum InterpolationMethod {205Linear,206Nearest,207}208209pub fn interpolate(s: &Series, method: InterpolationMethod) -> Series {210match method {211InterpolationMethod::Linear => interpolate_linear(s),212InterpolationMethod::Nearest => interpolate_nearest(s),213}214}215216#[cfg(test)]217mod test {218use super::*;219220#[test]221fn test_interpolate() {222let ca = UInt32Chunked::new("".into(), &[Some(1), None, None, Some(4), Some(5)]);223let out = interpolate(&ca.into_series(), InterpolationMethod::Linear);224let out = out.f64().unwrap();225assert_eq!(226Vec::from(out),227&[Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)]228);229230let ca = UInt32Chunked::new("".into(), &[None, Some(1), None, None, Some(4), Some(5)]);231let out = interpolate(&ca.into_series(), InterpolationMethod::Linear);232let out = out.f64().unwrap();233assert_eq!(234Vec::from(out),235&[None, Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)]236);237238let ca = UInt32Chunked::new(239"".into(),240&[None, Some(1), None, None, Some(4), Some(5), None],241);242let out = interpolate(&ca.into_series(), InterpolationMethod::Linear);243let out = out.f64().unwrap();244assert_eq!(245Vec::from(out),246&[247None,248Some(1.0),249Some(2.0),250Some(3.0),251Some(4.0),252Some(5.0),253None254]255);256let ca = UInt32Chunked::new(257"".into(),258&[None, Some(1), None, None, Some(4), Some(5), None],259);260let out = interpolate(&ca.into_series(), InterpolationMethod::Nearest);261let out = out.u32().unwrap();262assert_eq!(263Vec::from(out),264&[None, Some(1), Some(1), Some(4), Some(4), Some(5), None]265);266}267268#[test]269fn test_interpolate_decreasing_unsigned() {270let ca = UInt32Chunked::new("".into(), &[Some(4), None, None, Some(1)]);271let out = interpolate(&ca.into_series(), InterpolationMethod::Linear);272let out = out.f64().unwrap();273assert_eq!(274Vec::from(out),275&[Some(4.0), Some(3.0), Some(2.0), Some(1.0)]276)277}278279#[test]280fn test_interpolate2() {281let ca = Float32Chunked::new(282"".into(),283&[284Some(4653f32),285None,286None,287None,288Some(4657f32),289None,290None,291Some(4657f32),292None,293Some(4657f32),294None,295None,296Some(4660f32),297],298);299let out = interpolate(&ca.into_series(), InterpolationMethod::Linear);300let out = out.f32().unwrap();301302assert_eq!(303Vec::from(out),304&[305Some(4653.0),306Some(4654.0),307Some(4655.0),308Some(4656.0),309Some(4657.0),310Some(4657.0),311Some(4657.0),312Some(4657.0),313Some(4657.0),314Some(4657.0),315Some(4658.0),316Some(4659.0),317Some(4660.0)318]319);320}321}322323324