Path: blob/main/crates/polars-ops/src/series/ops/interpolation/interpolate_by.rs
6939 views
use std::ops::{Add, Div, Mul, Sub};12use arrow::array::PrimitiveArray;3use arrow::bitmap::MutableBitmap;4use bytemuck::allocation::zeroed_vec;5use num_traits::{NumCast, Zero};6use polars_core::prelude::*;7use polars_utils::slice::SliceAble;89use super::linear_itp;1011/// # Safety12/// - `x` must be non-empty.13#[inline]14unsafe fn signed_interp_by_sorted<T, F>(y_start: T, y_end: T, x: &[F], out: &mut Vec<T>)15where16T: Sub<Output = T>17+ Mul<Output = T>18+ Add<Output = T>19+ Div<Output = T>20+ NumCast21+ Copy22+ Zero,23F: Sub<Output = F> + NumCast + Copy,24{25let range_y = y_end - y_start;26let x_start;27let range_x;28let iter;29unsafe {30x_start = x.get_unchecked(0);31range_x = NumCast::from(*x.get_unchecked(x.len() - 1) - *x_start).unwrap();32iter = x.slice_unchecked(1..x.len() - 1).iter();33}34let slope = range_y / range_x;35for x_i in iter {36let x_delta = NumCast::from(*x_i - *x_start).unwrap();37let v = linear_itp(y_start, x_delta, slope);38out.push(v)39}40}4142/// # Safety43/// - `x` must be non-empty.44/// - `sorting_indices` must be the same size as `x`45#[inline]46unsafe fn signed_interp_by<T, F>(47y_start: T,48y_end: T,49x: &[F],50out: &mut [T],51sorting_indices: &[IdxSize],52) where53T: Sub<Output = T>54+ Mul<Output = T>55+ Add<Output = T>56+ Div<Output = T>57+ NumCast58+ Copy59+ Zero,60F: Sub<Output = F> + NumCast + Copy,61{62let range_y = y_end - y_start;63let x_start;64let range_x;65let iter;66unsafe {67x_start = x.get_unchecked(0);68range_x = NumCast::from(*x.get_unchecked(x.len() - 1) - *x_start).unwrap();69iter = x.slice_unchecked(1..x.len() - 1).iter();70}71let slope = range_y / range_x;72for (idx, x_i) in iter.enumerate() {73let x_delta = NumCast::from(*x_i - *x_start).unwrap();74let v = linear_itp(y_start, x_delta, slope);75unsafe {76let out_idx = sorting_indices.get_unchecked(idx + 1);77*out.get_unchecked_mut(*out_idx as usize) = v;78}79}80}8182fn interpolate_impl_by_sorted<T, F, I>(83chunked_arr: &ChunkedArray<T>,84by: &ChunkedArray<F>,85interpolation_branch: I,86) -> PolarsResult<ChunkedArray<T>>87where88T: PolarsNumericType,89F: PolarsNumericType,90I: Fn(T::Native, T::Native, &[F::Native], &mut Vec<T::Native>),91{92// This implementation differs from pandas as that boundary None's are not removed.93// This prevents a lot of errors due to expressions leading to different lengths.94if !chunked_arr.has_nulls() || chunked_arr.null_count() == chunked_arr.len() {95return Ok(chunked_arr.clone());96}9798polars_ensure!(by.null_count() == 0, InvalidOperation: "null values in `by` column are not yet supported in 'interpolate_by' expression");99let by = by.rechunk();100let by_values = by.cont_slice().unwrap();101102// We first find the first and last so that we can set the null buffer.103let first = chunked_arr.first_non_null().unwrap();104let last = chunked_arr.last_non_null().unwrap() + 1;105106// Fill out with `first` nulls.107let mut out = Vec::with_capacity(chunked_arr.len());108let mut iter = chunked_arr.iter().enumerate().skip(first);109for _ in 0..first {110out.push(Zero::zero());111}112113// The next element of `iter` is definitely `Some(idx, Some(v))`, because we skipped the first114// `first` elements and if all values were missing we'd have done an early return.115let (mut low_idx, opt_low) = iter.next().unwrap();116let mut low = opt_low.unwrap();117out.push(low);118while let Some((idx, next)) = iter.next() {119if let Some(v) = next {120out.push(v);121low = v;122low_idx = idx;123} else {124for (high_idx, next) in iter.by_ref() {125if let Some(high) = next {126// SAFETY: we are in bounds, and `x` is non-empty.127unsafe {128let x = &by_values.slice_unchecked(low_idx..high_idx + 1);129interpolation_branch(low, high, x, &mut out);130}131out.push(high);132low = high;133low_idx = high_idx;134break;135}136}137}138}139if first != 0 || last != chunked_arr.len() {140let mut validity = MutableBitmap::with_capacity(chunked_arr.len());141validity.extend_constant(chunked_arr.len(), true);142143for i in 0..first {144unsafe { validity.set_unchecked(i, false) };145}146147for i in last..chunked_arr.len() {148unsafe { validity.set_unchecked(i, false) }149out.push(Zero::zero());150}151152let array = PrimitiveArray::new(153T::get_static_dtype().to_arrow(CompatLevel::newest()),154out.into(),155Some(validity.into()),156);157Ok(ChunkedArray::with_chunk(chunked_arr.name().clone(), array))158} else {159Ok(ChunkedArray::from_vec(chunked_arr.name().clone(), out))160}161}162163// Sort on behalf of user164fn interpolate_impl_by<T, F, I>(165ca: &ChunkedArray<T>,166by: &ChunkedArray<F>,167interpolation_branch: I,168) -> PolarsResult<ChunkedArray<T>>169where170T: PolarsNumericType,171F: PolarsNumericType,172I: Fn(T::Native, T::Native, &[F::Native], &mut [T::Native], &[IdxSize]),173{174// This implementation differs from pandas as that boundary None's are not removed.175// This prevents a lot of errors due to expressions leading to different lengths.176if !ca.has_nulls() || ca.null_count() == ca.len() {177return Ok(ca.clone());178}179180polars_ensure!(by.null_count() == 0, InvalidOperation: "null values in `by` column are not yet supported in 'interpolate_by' expression");181let sorting_indices = by.arg_sort(Default::default());182let sorting_indices = sorting_indices183.cont_slice()184.expect("arg sort produces single chunk");185let by_sorted = unsafe { by.take_unchecked(sorting_indices) };186let ca_sorted = unsafe { ca.take_unchecked(sorting_indices) };187let by_sorted_values = by_sorted188.cont_slice()189.expect("We already checked for nulls, and `take_unchecked` produces single chunk");190191// We first find the first and last so that we can set the null buffer.192let first = ca_sorted.first_non_null().unwrap();193let last = ca_sorted.last_non_null().unwrap() + 1;194195let mut out = zeroed_vec(ca_sorted.len());196let mut iter = ca_sorted.iter().enumerate().skip(first);197198// The next element of `iter` is definitely `Some(idx, Some(v))`, because we skipped the first199// `first` elements and if all values were missing we'd have done an early return.200let (mut low_idx, opt_low) = iter.next().unwrap();201let mut low = opt_low.unwrap();202unsafe {203let out_idx = sorting_indices.get_unchecked(low_idx);204*out.get_unchecked_mut(*out_idx as usize) = low;205}206while let Some((idx, next)) = iter.next() {207if let Some(v) = next {208unsafe {209let out_idx = sorting_indices.get_unchecked(idx);210*out.get_unchecked_mut(*out_idx as usize) = v;211}212low = v;213low_idx = idx;214} else {215for (high_idx, next) in iter.by_ref() {216if let Some(high) = next {217// SAFETY: we are in bounds, and the slices are the same length (and non-empty).218unsafe {219interpolation_branch(220low,221high,222by_sorted_values.slice_unchecked(low_idx..high_idx + 1),223&mut out,224sorting_indices.slice_unchecked(low_idx..high_idx + 1),225);226let out_idx = sorting_indices.get_unchecked(high_idx);227*out.get_unchecked_mut(*out_idx as usize) = high;228}229low = high;230low_idx = high_idx;231break;232}233}234}235}236if first != 0 || last != ca_sorted.len() {237let mut validity = MutableBitmap::with_capacity(ca_sorted.len());238validity.extend_constant(ca_sorted.len(), true);239240for i in 0..first {241unsafe {242let out_idx = sorting_indices.get_unchecked(i);243validity.set_unchecked(*out_idx as usize, false);244}245}246247for i in last..ca_sorted.len() {248unsafe {249let out_idx = sorting_indices.get_unchecked(i);250validity.set_unchecked(*out_idx as usize, false);251}252}253254let array = PrimitiveArray::new(255T::get_static_dtype().to_arrow(CompatLevel::newest()),256out.into(),257Some(validity.into()),258);259Ok(ChunkedArray::with_chunk(ca_sorted.name().clone(), array))260} else {261Ok(ChunkedArray::from_vec(ca_sorted.name().clone(), out))262}263}264265pub fn interpolate_by(s: &Column, by: &Column, by_is_sorted: bool) -> PolarsResult<Column> {266polars_ensure!(s.len() == by.len(), InvalidOperation: "`by` column must be the same length as Series ({}), got {}", s.len(), by.len());267268fn func<T, F>(269ca: &ChunkedArray<T>,270by: &ChunkedArray<F>,271is_sorted: bool,272) -> PolarsResult<Column>273where274T: PolarsNumericType,275F: PolarsNumericType,276ChunkedArray<T>: IntoColumn,277{278if is_sorted {279interpolate_impl_by_sorted(ca, by, |y_start, y_end, x, out| unsafe {280signed_interp_by_sorted(y_start, y_end, x, out)281})282.map(|x| x.into_column())283} else {284interpolate_impl_by(ca, by, |y_start, y_end, x, out, sorting_indices| unsafe {285signed_interp_by(y_start, y_end, x, out, sorting_indices)286})287.map(|x| x.into_column())288}289}290291match (s.dtype(), by.dtype()) {292(DataType::Float64, DataType::Float64) => {293func(s.f64().unwrap(), by.f64().unwrap(), by_is_sorted)294},295(DataType::Float64, DataType::Float32) => {296func(s.f64().unwrap(), by.f32().unwrap(), by_is_sorted)297},298(DataType::Float32, DataType::Float64) => {299func(s.f32().unwrap(), by.f64().unwrap(), by_is_sorted)300},301(DataType::Float32, DataType::Float32) => {302func(s.f32().unwrap(), by.f32().unwrap(), by_is_sorted)303},304(DataType::Float64, DataType::Int64) => {305func(s.f64().unwrap(), by.i64().unwrap(), by_is_sorted)306},307(DataType::Float64, DataType::Int32) => {308func(s.f64().unwrap(), by.i32().unwrap(), by_is_sorted)309},310(DataType::Float64, DataType::UInt64) => {311func(s.f64().unwrap(), by.u64().unwrap(), by_is_sorted)312},313(DataType::Float64, DataType::UInt32) => {314func(s.f64().unwrap(), by.u32().unwrap(), by_is_sorted)315},316(DataType::Float32, DataType::Int64) => {317func(s.f32().unwrap(), by.i64().unwrap(), by_is_sorted)318},319(DataType::Float32, DataType::Int32) => {320func(s.f32().unwrap(), by.i32().unwrap(), by_is_sorted)321},322(DataType::Float32, DataType::UInt64) => {323func(s.f32().unwrap(), by.u64().unwrap(), by_is_sorted)324},325(DataType::Float32, DataType::UInt32) => {326func(s.f32().unwrap(), by.u32().unwrap(), by_is_sorted)327},328#[cfg(feature = "dtype-date")]329(_, DataType::Date) => interpolate_by(s, &by.cast(&DataType::Int32).unwrap(), by_is_sorted),330#[cfg(feature = "dtype-datetime")]331(_, DataType::Datetime(_, _)) => {332interpolate_by(s, &by.cast(&DataType::Int64).unwrap(), by_is_sorted)333},334(DataType::UInt64 | DataType::UInt32 | DataType::Int64 | DataType::Int32, _) => {335interpolate_by(&s.cast(&DataType::Float64).unwrap(), by, by_is_sorted)336},337_ => {338polars_bail!(InvalidOperation: "expected series to be Float64, Float32, \339Int64, Int32, UInt64, UInt32, and `by` to be Date, Datetime, Int64, Int32, \340UInt64, UInt32, Float32 or Float64")341},342}343}344345346