Path: blob/main/crates/polars-ops/src/frame/join/asof/default.rs
8446 views
use arrow::array::Array;1use arrow::bitmap::Bitmap;2use num_traits::Zero;3use polars_core::prelude::*;4use polars_utils::abs_diff::AbsDiff;5use polars_utils::total_ord::TotalOrd;67use super::{8AsofJoinBackwardState, AsofJoinForwardState, AsofJoinNearestState, AsofJoinState, AsofStrategy,9};1011fn join_asof_impl<'a, T, S, F>(12left: &'a T::Array,13right: &'a T::Array,14mut filter: F,15allow_eq: bool,16) -> IdxCa17where18T: PolarsDataType,19S: AsofJoinState<T::Physical<'a>>,20F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,21{22if left.len() == left.null_count() || right.len() == right.null_count() {23return IdxCa::full_null(PlSmallStr::EMPTY, left.len());24}2526let mut out = vec![0; left.len()];27let mut mask = vec![0; left.len().div_ceil(8)];28let mut state = S::new(allow_eq);2930if left.null_count() == 0 && right.null_count() == 0 {31for (i, val_l) in left.values_iter().enumerate() {32if let Some(r_idx) = state.next(33&val_l,34// SAFETY: next() only calls with indices < right.len().35|j| Some(unsafe { right.value_unchecked(j as usize) }),36right.len() as IdxSize,37) {38// SAFETY: r_idx is non-null and valid.39unsafe {40let val_r = right.value_unchecked(r_idx as usize);41*out.get_unchecked_mut(i) = r_idx;42*mask.get_unchecked_mut(i / 8) |= (filter(val_l, val_r) as u8) << (i % 8);43}44}45}46} else {47for (i, opt_val_l) in left.iter().enumerate() {48if let Some(val_l) = opt_val_l {49if let Some(r_idx) = state.next(50&val_l,51// SAFETY: next() only calls with indices < right.len().52|j| unsafe { right.get_unchecked(j as usize) },53right.len() as IdxSize,54) {55// SAFETY: r_idx is non-null and valid.56unsafe {57let val_r = right.value_unchecked(r_idx as usize);58*out.get_unchecked_mut(i) = r_idx;59*mask.get_unchecked_mut(i / 8) |= (filter(val_l, val_r) as u8) << (i % 8);60}61}62}63}64}6566let bitmap = Bitmap::try_new(mask, out.len()).unwrap();67IdxCa::from_vec_validity(PlSmallStr::EMPTY, out, Some(bitmap))68}6970fn join_asof_forward<'a, T, F>(71left: &'a T::Array,72right: &'a T::Array,73filter: F,74allow_eq: bool,75) -> IdxCa76where77T: PolarsDataType,78T::Physical<'a>: TotalOrd,79F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,80{81join_asof_impl::<'a, T, AsofJoinForwardState, _>(left, right, filter, allow_eq)82}8384fn join_asof_backward<'a, T, F>(85left: &'a T::Array,86right: &'a T::Array,87filter: F,88allow_eq: bool,89) -> IdxCa90where91T: PolarsDataType,92T::Physical<'a>: TotalOrd,93F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,94{95join_asof_impl::<'a, T, AsofJoinBackwardState, _>(left, right, filter, allow_eq)96}9798fn join_asof_nearest<'a, T, F>(99left: &'a T::Array,100right: &'a T::Array,101filter: F,102allow_eq: bool,103) -> IdxCa104where105T: PolarsDataType,106T::Physical<'a>: NumericNative,107F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,108{109join_asof_impl::<'a, T, AsofJoinNearestState, _>(left, right, filter, allow_eq)110}111112pub(crate) fn join_asof_numeric<T: PolarsNumericType>(113input_ca: &ChunkedArray<T>,114other: &Series,115strategy: AsofStrategy,116tolerance: Option<AnyValue<'static>>,117allow_eq: bool,118) -> PolarsResult<IdxCa> {119let other = input_ca.unpack_series_matching_type(other)?;120121let ca = input_ca.rechunk();122let other = other.rechunk();123let left = ca.downcast_as_array();124let right = other.downcast_as_array();125126let out = if let Some(t) = tolerance {127let native_tolerance = t.try_extract::<T::Native>()?;128let abs_tolerance = native_tolerance.abs_diff(T::Native::zero());129let filter = |l: T::Native, r: T::Native| l.abs_diff(r) <= abs_tolerance;130match strategy {131AsofStrategy::Forward => join_asof_forward::<T, _>(left, right, filter, allow_eq),132AsofStrategy::Backward => join_asof_backward::<T, _>(left, right, filter, allow_eq),133AsofStrategy::Nearest => join_asof_nearest::<T, _>(left, right, filter, allow_eq),134}135} else {136let filter = |_l: T::Native, _r: T::Native| true;137match strategy {138AsofStrategy::Forward => join_asof_forward::<T, _>(left, right, filter, allow_eq),139AsofStrategy::Backward => join_asof_backward::<T, _>(left, right, filter, allow_eq),140AsofStrategy::Nearest => join_asof_nearest::<T, _>(left, right, filter, allow_eq),141}142};143Ok(out)144}145146pub(crate) fn join_asof<T>(147input_ca: &ChunkedArray<T>,148other: &Series,149strategy: AsofStrategy,150allow_eq: bool,151) -> PolarsResult<IdxCa>152where153T: PolarsDataType,154for<'a> T::Physical<'a>: TotalOrd,155{156let other = input_ca.unpack_series_matching_type(other)?;157158let ca = input_ca.rechunk();159let other = other.rechunk();160let left = ca.downcast_iter().next().unwrap();161let right = other.downcast_iter().next().unwrap();162163let filter = |_l: T::Physical<'_>, _r: T::Physical<'_>| true;164Ok(match strategy {165AsofStrategy::Forward => {166join_asof_impl::<T, AsofJoinForwardState, _>(left, right, filter, allow_eq)167},168AsofStrategy::Backward => {169join_asof_impl::<T, AsofJoinBackwardState, _>(left, right, filter, allow_eq)170},171AsofStrategy::Nearest => polars_bail!(InvalidOperation:172"AsOf strategy \"nearest\" is not supported for {} data type",173T::get_static_dtype()174),175})176}177178#[cfg(test)]179mod test {180use arrow::array::PrimitiveArray;181182use super::*;183184#[test]185fn test_asof_backward() {186let a = PrimitiveArray::from_slice([-1, 2, 3, 3, 3, 4]);187let b = PrimitiveArray::from_slice([1, 2, 3, 3]);188189let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |_, _| true, true);190assert_eq!(tuples.len(), a.len());191assert_eq!(192tuples.to_vec(),193&[None, Some(1), Some(3), Some(3), Some(3), Some(3)]194);195196let b = PrimitiveArray::from_slice([1, 2, 4, 5]);197let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |_, _| true, true);198assert_eq!(199tuples.to_vec(),200&[None, Some(1), Some(1), Some(1), Some(1), Some(2)]201);202203let a = PrimitiveArray::from_slice([2, 4, 4, 4]);204let b = PrimitiveArray::from_slice([1, 2, 3, 3]);205let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |_, _| true, true);206assert_eq!(tuples.to_vec(), &[Some(1), Some(3), Some(3), Some(3)]);207}208209#[test]210fn test_asof_backward_tolerance() {211let a = PrimitiveArray::from_slice([-1, 20, 25, 30, 30, 40]);212let b = PrimitiveArray::from_slice([10, 20, 30, 30]);213let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |l, r| l.abs_diff(r) <= 4u32, true);214assert_eq!(215tuples.to_vec(),216&[None, Some(1), None, Some(3), Some(3), None]217);218}219220#[test]221fn test_asof_forward_tolerance() {222let a = PrimitiveArray::from_slice([-1, 20, 25, 30, 30, 40, 52]);223let b = PrimitiveArray::from_slice([10, 20, 33, 55]);224let tuples = join_asof_forward::<Int32Type, _>(&a, &b, |l, r| l.abs_diff(r) <= 4u32, true);225assert_eq!(226tuples.to_vec(),227&[None, Some(1), None, Some(2), Some(2), None, Some(3)]228);229}230231#[test]232fn test_asof_forward() {233let a = PrimitiveArray::from_slice([-1, 1, 2, 4, 6]);234let b = PrimitiveArray::from_slice([1, 2, 4, 5]);235236let tuples = join_asof_forward::<Int32Type, _>(&a, &b, |_, _| true, true);237assert_eq!(tuples.len(), a.len());238assert_eq!(tuples.to_vec(), &[Some(0), Some(0), Some(1), Some(2), None]);239}240}241242243