Path: blob/main/crates/polars-ops/src/frame/join/asof/default.rs
6940 views
use arrow::array::Array;1use arrow::bitmap::Bitmap;2use num_traits::Zero;3use polars_core::prelude::*;4use polars_utils::abs_diff::AbsDiff;56use super::{7AsofJoinBackwardState, AsofJoinForwardState, AsofJoinNearestState, AsofJoinState, AsofStrategy,8};910fn join_asof_impl<'a, T, S, F>(11left: &'a T::Array,12right: &'a T::Array,13mut filter: F,14allow_eq: bool,15) -> IdxCa16where17T: PolarsDataType,18S: AsofJoinState<T::Physical<'a>>,19F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,20{21if left.len() == left.null_count() || right.len() == right.null_count() {22return IdxCa::full_null(PlSmallStr::EMPTY, left.len());23}2425let mut out = vec![0; left.len()];26let mut mask = vec![0; left.len().div_ceil(8)];27let mut state = S::new(allow_eq);2829if left.null_count() == 0 && right.null_count() == 0 {30for (i, val_l) in left.values_iter().enumerate() {31if let Some(r_idx) = state.next(32&val_l,33// SAFETY: next() only calls with indices < right.len().34|j| Some(unsafe { right.value_unchecked(j as usize) }),35right.len() as IdxSize,36) {37// SAFETY: r_idx is non-null and valid.38unsafe {39let val_r = right.value_unchecked(r_idx as usize);40*out.get_unchecked_mut(i) = r_idx;41*mask.get_unchecked_mut(i / 8) |= (filter(val_l, val_r) as u8) << (i % 8);42}43}44}45} else {46for (i, opt_val_l) in left.iter().enumerate() {47if let Some(val_l) = opt_val_l {48if let Some(r_idx) = state.next(49&val_l,50// SAFETY: next() only calls with indices < right.len().51|j| unsafe { right.get_unchecked(j as usize) },52right.len() as IdxSize,53) {54// SAFETY: r_idx is non-null and valid.55unsafe {56let val_r = right.value_unchecked(r_idx as usize);57*out.get_unchecked_mut(i) = r_idx;58*mask.get_unchecked_mut(i / 8) |= (filter(val_l, val_r) as u8) << (i % 8);59}60}61}62}63}6465let bitmap = Bitmap::try_new(mask, out.len()).unwrap();66IdxCa::from_vec_validity(PlSmallStr::EMPTY, out, Some(bitmap))67}6869fn join_asof_forward<'a, T, F>(70left: &'a T::Array,71right: &'a T::Array,72filter: F,73allow_eq: bool,74) -> IdxCa75where76T: PolarsDataType,77T::Physical<'a>: PartialOrd,78F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,79{80join_asof_impl::<'a, T, AsofJoinForwardState, _>(left, right, filter, allow_eq)81}8283fn join_asof_backward<'a, T, F>(84left: &'a T::Array,85right: &'a T::Array,86filter: F,87allow_eq: bool,88) -> IdxCa89where90T: PolarsDataType,91T::Physical<'a>: PartialOrd,92F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,93{94join_asof_impl::<'a, T, AsofJoinBackwardState, _>(left, right, filter, allow_eq)95}9697fn join_asof_nearest<'a, T, F>(98left: &'a T::Array,99right: &'a T::Array,100filter: F,101allow_eq: bool,102) -> IdxCa103where104T: PolarsDataType,105T::Physical<'a>: NumericNative,106F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,107{108join_asof_impl::<'a, T, AsofJoinNearestState, _>(left, right, filter, allow_eq)109}110111pub(crate) fn join_asof_numeric<T: PolarsNumericType>(112input_ca: &ChunkedArray<T>,113other: &Series,114strategy: AsofStrategy,115tolerance: Option<AnyValue<'static>>,116allow_eq: bool,117) -> PolarsResult<IdxCa> {118let other = input_ca.unpack_series_matching_type(other)?;119120let ca = input_ca.rechunk();121let other = other.rechunk();122let left = ca.downcast_as_array();123let right = other.downcast_as_array();124125let out = if let Some(t) = tolerance {126let native_tolerance = t.try_extract::<T::Native>()?;127let abs_tolerance = native_tolerance.abs_diff(T::Native::zero());128let filter = |l: T::Native, r: T::Native| l.abs_diff(r) <= abs_tolerance;129match strategy {130AsofStrategy::Forward => join_asof_forward::<T, _>(left, right, filter, allow_eq),131AsofStrategy::Backward => join_asof_backward::<T, _>(left, right, filter, allow_eq),132AsofStrategy::Nearest => join_asof_nearest::<T, _>(left, right, filter, allow_eq),133}134} else {135let filter = |_l: T::Native, _r: T::Native| true;136match strategy {137AsofStrategy::Forward => join_asof_forward::<T, _>(left, right, filter, allow_eq),138AsofStrategy::Backward => join_asof_backward::<T, _>(left, right, filter, allow_eq),139AsofStrategy::Nearest => join_asof_nearest::<T, _>(left, right, filter, allow_eq),140}141};142Ok(out)143}144145pub(crate) fn join_asof<T>(146input_ca: &ChunkedArray<T>,147other: &Series,148strategy: AsofStrategy,149allow_eq: bool,150) -> PolarsResult<IdxCa>151where152T: PolarsDataType,153for<'a> T::Physical<'a>: PartialOrd,154{155let other = input_ca.unpack_series_matching_type(other)?;156157let ca = input_ca.rechunk();158let other = other.rechunk();159let left = ca.downcast_iter().next().unwrap();160let right = other.downcast_iter().next().unwrap();161162let filter = |_l: T::Physical<'_>, _r: T::Physical<'_>| true;163Ok(match strategy {164AsofStrategy::Forward => {165join_asof_impl::<T, AsofJoinForwardState, _>(left, right, filter, allow_eq)166},167AsofStrategy::Backward => {168join_asof_impl::<T, AsofJoinBackwardState, _>(left, right, filter, allow_eq)169},170AsofStrategy::Nearest => unimplemented!(),171})172}173174#[cfg(test)]175mod test {176use arrow::array::PrimitiveArray;177178use super::*;179180#[test]181fn test_asof_backward() {182let a = PrimitiveArray::from_slice([-1, 2, 3, 3, 3, 4]);183let b = PrimitiveArray::from_slice([1, 2, 3, 3]);184185let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |_, _| true, true);186assert_eq!(tuples.len(), a.len());187assert_eq!(188tuples.to_vec(),189&[None, Some(1), Some(3), Some(3), Some(3), Some(3)]190);191192let b = PrimitiveArray::from_slice([1, 2, 4, 5]);193let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |_, _| true, true);194assert_eq!(195tuples.to_vec(),196&[None, Some(1), Some(1), Some(1), Some(1), Some(2)]197);198199let a = PrimitiveArray::from_slice([2, 4, 4, 4]);200let b = PrimitiveArray::from_slice([1, 2, 3, 3]);201let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |_, _| true, true);202assert_eq!(tuples.to_vec(), &[Some(1), Some(3), Some(3), Some(3)]);203}204205#[test]206fn test_asof_backward_tolerance() {207let a = PrimitiveArray::from_slice([-1, 20, 25, 30, 30, 40]);208let b = PrimitiveArray::from_slice([10, 20, 30, 30]);209let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |l, r| l.abs_diff(r) <= 4u32, true);210assert_eq!(211tuples.to_vec(),212&[None, Some(1), None, Some(3), Some(3), None]213);214}215216#[test]217fn test_asof_forward_tolerance() {218let a = PrimitiveArray::from_slice([-1, 20, 25, 30, 30, 40, 52]);219let b = PrimitiveArray::from_slice([10, 20, 33, 55]);220let tuples = join_asof_forward::<Int32Type, _>(&a, &b, |l, r| l.abs_diff(r) <= 4u32, true);221assert_eq!(222tuples.to_vec(),223&[None, Some(1), None, Some(2), Some(2), None, Some(3)]224);225}226227#[test]228fn test_asof_forward() {229let a = PrimitiveArray::from_slice([-1, 1, 2, 4, 6]);230let b = PrimitiveArray::from_slice([1, 2, 4, 5]);231232let tuples = join_asof_forward::<Int32Type, _>(&a, &b, |_, _| true, true);233assert_eq!(tuples.len(), a.len());234assert_eq!(tuples.to_vec(), &[Some(0), Some(0), Some(1), Some(2), None]);235}236}237238239