Path: blob/main/crates/polars-ops/src/series/ops/interpolation/interpolate.rs
8459 views
use std::ops::{Add, Div, Mul, Sub};12use arrow::array::PrimitiveArray;3use arrow::bitmap::MutableBitmap;4use num_traits::{NumCast, Zero};5use polars_core::downcast_as_macro_arg_physical;6use polars_core::prelude::*;7#[cfg(feature = "serde")]8use serde::{Deserialize, Serialize};910use super::{linear_itp, nearest_itp};1112fn near_interp<T>(low: T, high: T, steps: IdxSize, steps_n: T, out: &mut Vec<T>)13where14T: Sub<Output = T>15+ Mul<Output = T>16+ Add<Output = T>17+ Div<Output = T>18+ NumCast19+ Copy20+ PartialOrd,21{22let diff = high - low;23for step_i in 1..steps {24let step_i: T = NumCast::from(step_i).unwrap();25let v = nearest_itp(low, step_i, diff, steps_n);26out.push(v)27}28}2930#[inline]31fn signed_interp<T>(low: T, high: T, steps: IdxSize, steps_n: T, out: &mut Vec<T>)32where33T: Sub<Output = T> + Mul<Output = T> + Add<Output = T> + Div<Output = T> + NumCast + Copy,34{35let slope = (high - low) / steps_n;36for step_i in 1..steps {37let step_i: T = NumCast::from(step_i).unwrap();38let v = linear_itp(low, step_i, slope);39out.push(v)40}41}4243fn interpolate_impl<T, I>(chunked_arr: &ChunkedArray<T>, interpolation_branch: I) -> ChunkedArray<T>44where45T: PolarsNumericType,46I: Fn(T::Native, T::Native, IdxSize, T::Native, &mut Vec<T::Native>),47{48// This implementation differs from pandas as that boundary None's are not removed.49// This prevents a lot of errors due to expressions leading to different lengths.50if !chunked_arr.has_nulls() || chunked_arr.null_count() == chunked_arr.len() {51return chunked_arr.clone();52}5354// We first find the first and last so that we can set the null buffer.55let first = chunked_arr.first_non_null().unwrap();56let last = chunked_arr.last_non_null().unwrap() + 1;5758// Fill out with `first` nulls.59let mut out = Vec::with_capacity(chunked_arr.len());60let mut iter = chunked_arr.iter().skip(first);61for _ in 0..first {62out.push(Zero::zero());63}6465// The next element of `iter` is definitely `Some(Some(v))`, because we skipped the first66// elements `first` and if all values were missing we'd have done an early return.67let mut low = iter.next().unwrap().unwrap();68out.push(low);69while let Some(next) = iter.next() {70if let Some(v) = next {71out.push(v);72low = v;73} else {74let mut steps = 1 as IdxSize;75for next in iter.by_ref() {76steps += 1;77if let Some(high) = next {78let steps_n: T::Native = NumCast::from(steps).unwrap();79interpolation_branch(low, high, steps, steps_n, &mut out);80out.push(high);81low = high;82break;83}84}85}86}87if first != 0 || last != chunked_arr.len() {88let mut validity = MutableBitmap::with_capacity(chunked_arr.len());89validity.extend_constant(chunked_arr.len(), true);9091for i in 0..first {92unsafe { validity.set_unchecked(i, false) };93}9495for i in last..chunked_arr.len() {96unsafe { validity.set_unchecked(i, false) };97out.push(Zero::zero())98}99100let array = PrimitiveArray::new(101T::get_static_dtype().to_arrow(CompatLevel::newest()),102out.into(),103Some(validity.into()),104);105ChunkedArray::with_chunk(chunked_arr.name().clone(), array)106} else {107ChunkedArray::from_vec(chunked_arr.name().clone(), out)108}109}110111fn interpolate_nearest(s: &Series) -> Series {112match s.dtype() {113#[cfg(feature = "dtype-categorical")]114DataType::Categorical(_, _) | DataType::Enum(_, _) => s.clone(),115DataType::Binary => s.clone(),116#[cfg(feature = "dtype-struct")]117DataType::Struct(_) => s.clone(),118DataType::List(_) => s.clone(),119_ => {120let logical = s.dtype();121let s = s.to_physical_repr();122123macro_rules! dispatch {124($ca:expr) => {{ interpolate_impl($ca, near_interp).into_series() }};125}126let out = downcast_as_macro_arg_physical!(s, dispatch);127match logical {128#[cfg(feature = "dtype-decimal")]129DataType::Decimal(_, _) => unsafe { out.from_physical_unchecked(logical).unwrap() },130_ => out.cast(logical).unwrap(),131}132},133}134}135136fn interpolate_linear(s: &Series) -> Series {137match s.dtype() {138#[cfg(feature = "dtype-categorical")]139DataType::Categorical(_, _) | DataType::Enum(_, _) => s.clone(),140DataType::Binary => s.clone(),141#[cfg(feature = "dtype-struct")]142DataType::Struct(_) => s.clone(),143DataType::List(_) => s.clone(),144_ => {145let logical = s.dtype();146147let s = s.to_physical_repr();148149#[cfg(feature = "dtype-decimal")]150{151if matches!(logical, DataType::Decimal(_, _)) {152let out = linear_interp_signed(s.i128().unwrap());153return unsafe { out.from_physical_unchecked(logical).unwrap() };154}155}156157let out = if matches!(158logical,159DataType::Date | DataType::Datetime(_, _) | DataType::Duration(_) | DataType::Time160) {161match s.dtype() {162// Datetime, Time, or Duration163DataType::Int64 => linear_interp_signed(s.i64().unwrap()),164// Date165DataType::Int32 => linear_interp_signed(s.i32().unwrap()),166_ => unreachable!(),167}168} else {169match s.dtype() {170#[cfg(feature = "dtype-f16")]171DataType::Float16 => linear_interp_signed(s.f16().unwrap()),172DataType::Float32 => linear_interp_signed(s.f32().unwrap()),173DataType::Float64 => linear_interp_signed(s.f64().unwrap()),174DataType::Int8175| DataType::Int16176| DataType::Int32177| DataType::Int64178| DataType::Int128179| DataType::UInt8180| DataType::UInt16181| DataType::UInt32182| DataType::UInt64183| DataType::UInt128 => {184linear_interp_signed(s.cast(&DataType::Float64).unwrap().f64().unwrap())185},186_ => s.as_ref().clone(),187}188};189match logical {190DataType::Date191| DataType::Datetime(_, _)192| DataType::Duration(_)193| DataType::Time => out.cast(logical).unwrap(),194_ => out,195}196},197}198}199200fn linear_interp_signed<T: PolarsNumericType>(ca: &ChunkedArray<T>) -> Series {201interpolate_impl(ca, signed_interp::<T::Native>).into_series()202}203204#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]205#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]206#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]207pub enum InterpolationMethod {208Linear,209Nearest,210}211212pub fn interpolate(s: &Series, method: InterpolationMethod) -> Series {213match method {214InterpolationMethod::Linear => interpolate_linear(s),215InterpolationMethod::Nearest => interpolate_nearest(s),216}217}218219#[cfg(test)]220mod test {221use super::*;222223#[test]224fn test_interpolate() {225let ca = UInt32Chunked::new("".into(), &[Some(1), None, None, Some(4), Some(5)]);226let out = interpolate(&ca.into_series(), InterpolationMethod::Linear);227let out = out.f64().unwrap();228assert_eq!(229Vec::from(out),230&[Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)]231);232233let ca = UInt32Chunked::new("".into(), &[None, Some(1), None, None, Some(4), Some(5)]);234let out = interpolate(&ca.into_series(), InterpolationMethod::Linear);235let out = out.f64().unwrap();236assert_eq!(237Vec::from(out),238&[None, Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)]239);240241let ca = UInt32Chunked::new(242"".into(),243&[None, Some(1), None, None, Some(4), Some(5), None],244);245let out = interpolate(&ca.into_series(), InterpolationMethod::Linear);246let out = out.f64().unwrap();247assert_eq!(248Vec::from(out),249&[250None,251Some(1.0),252Some(2.0),253Some(3.0),254Some(4.0),255Some(5.0),256None257]258);259let ca = UInt32Chunked::new(260"".into(),261&[None, Some(1), None, None, Some(4), Some(5), None],262);263let out = interpolate(&ca.into_series(), InterpolationMethod::Nearest);264let out = out.u32().unwrap();265assert_eq!(266Vec::from(out),267&[None, Some(1), Some(1), Some(4), Some(4), Some(5), None]268);269}270271#[test]272fn test_interpolate_decreasing_unsigned() {273let ca = UInt32Chunked::new("".into(), &[Some(4), None, None, Some(1)]);274let out = interpolate(&ca.into_series(), InterpolationMethod::Linear);275let out = out.f64().unwrap();276assert_eq!(277Vec::from(out),278&[Some(4.0), Some(3.0), Some(2.0), Some(1.0)]279)280}281282#[test]283fn test_interpolate2() {284let ca = Float32Chunked::new(285"".into(),286&[287Some(4653f32),288None,289None,290None,291Some(4657f32),292None,293None,294Some(4657f32),295None,296Some(4657f32),297None,298None,299Some(4660f32),300],301);302let out = interpolate(&ca.into_series(), InterpolationMethod::Linear);303let out = out.f32().unwrap();304305assert_eq!(306Vec::from(out),307&[308Some(4653.0),309Some(4654.0),310Some(4655.0),311Some(4656.0),312Some(4657.0),313Some(4657.0),314Some(4657.0),315Some(4657.0),316Some(4657.0),317Some(4657.0),318Some(4658.0),319Some(4659.0),320Some(4660.0)321]322);323}324}325326327