Path: blob/main/crates/polars-ops/src/frame/join/asof/mod.rs
6940 views
mod default;1mod groups;2use std::borrow::Cow;3use std::cmp::Ordering;45use default::*;6pub use groups::AsofJoinBy;7use polars_core::prelude::*;8use polars_utils::pl_str::PlSmallStr;9#[cfg(feature = "serde")]10use serde::{Deserialize, Serialize};1112use super::{_finish_join, build_tables};13use crate::frame::IntoDf;14use crate::series::SeriesMethods;1516#[inline]17fn ge_allow_eq<T: PartialOrd>(l: &T, r: &T, allow_eq: bool) -> bool {18match l.partial_cmp(r) {19Some(Ordering::Equal) => allow_eq,20Some(Ordering::Greater) => true,21_ => false,22}23}2425#[inline]26fn lt_allow_eq<T: PartialOrd>(l: &T, r: &T, allow_eq: bool) -> bool {27match l.partial_cmp(r) {28Some(Ordering::Equal) => allow_eq,29Some(Ordering::Less) => true,30_ => false,31}32}3334trait AsofJoinState<T> {35fn next<F: FnMut(IdxSize) -> Option<T>>(36&mut self,37left_val: &T,38right: F,39n_right: IdxSize,40) -> Option<IdxSize>;4142fn new(allow_eq: bool) -> Self;43}4445struct AsofJoinForwardState {46scan_offset: IdxSize,47allow_eq: bool,48}4950impl<T: PartialOrd> AsofJoinState<T> for AsofJoinForwardState {51fn new(allow_eq: bool) -> Self {52AsofJoinForwardState {53scan_offset: Default::default(),54allow_eq,55}56}57#[inline]58fn next<F: FnMut(IdxSize) -> Option<T>>(59&mut self,60left_val: &T,61mut right: F,62n_right: IdxSize,63) -> Option<IdxSize> {64while (self.scan_offset) < n_right {65if let Some(right_val) = right(self.scan_offset) {66if ge_allow_eq(&right_val, left_val, self.allow_eq) {67return Some(self.scan_offset);68}69}70self.scan_offset += 1;71}72None73}74}7576struct AsofJoinBackwardState {77// best_bound is the greatest right index <= left_val.78best_bound: Option<IdxSize>,79scan_offset: IdxSize,80allow_eq: bool,81}8283impl<T: PartialOrd> AsofJoinState<T> for AsofJoinBackwardState {84fn new(allow_eq: bool) -> Self {85AsofJoinBackwardState {86scan_offset: Default::default(),87best_bound: Default::default(),88allow_eq,89}90}91#[inline]92fn next<F: FnMut(IdxSize) -> Option<T>>(93&mut self,94left_val: &T,95mut right: F,96n_right: IdxSize,97) -> Option<IdxSize> {98while self.scan_offset < n_right {99if let Some(right_val) = right(self.scan_offset) {100if lt_allow_eq(&right_val, left_val, self.allow_eq) {101self.best_bound = Some(self.scan_offset);102} else {103break;104}105}106self.scan_offset += 1;107}108self.best_bound109}110}111112#[derive(Default)]113struct AsofJoinNearestState {114// best_bound is the nearest value to left_val, with ties broken towards the last element.115best_bound: Option<IdxSize>,116scan_offset: IdxSize,117allow_eq: bool,118}119120impl<T: NumericNative> AsofJoinState<T> for AsofJoinNearestState {121fn new(allow_eq: bool) -> Self {122AsofJoinNearestState {123scan_offset: Default::default(),124best_bound: Default::default(),125allow_eq,126}127}128#[inline]129fn next<F: FnMut(IdxSize) -> Option<T>>(130&mut self,131left_val: &T,132mut right: F,133n_right: IdxSize,134) -> Option<IdxSize> {135// Skipping ahead to the first value greater than left_val. This is136// cheaper than computing differences.137while self.scan_offset < n_right {138if let Some(scan_right_val) = right(self.scan_offset) {139if lt_allow_eq(&scan_right_val, left_val, self.allow_eq) {140self.best_bound = Some(self.scan_offset);141} else {142// Now we must compute a difference to see if scan_right_val143// is closer than our current best bound.144let scan_is_better = if let Some(best_idx) = self.best_bound {145let best_right_val = unsafe { right(best_idx).unwrap_unchecked() };146let best_diff = left_val.abs_diff(best_right_val);147let scan_diff = left_val.abs_diff(scan_right_val);148149lt_allow_eq(&scan_diff, &best_diff, self.allow_eq)150} else {151true152};153154if scan_is_better {155self.best_bound = Some(self.scan_offset);156self.scan_offset += 1;157158// It is possible there are later elements equal to our159// scan, so keep going on.160while self.scan_offset < n_right {161if let Some(next_right_val) = right(self.scan_offset) {162if next_right_val == scan_right_val && self.allow_eq {163self.best_bound = Some(self.scan_offset);164} else {165break;166}167}168169self.scan_offset += 1;170}171}172173break;174}175}176177self.scan_offset += 1;178}179180self.best_bound181}182}183184#[derive(Clone, Debug, PartialEq, Default, Hash)]185#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]186#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]187pub struct AsOfOptions {188pub strategy: AsofStrategy,189/// A tolerance in the same unit as the asof column190pub tolerance: Option<Scalar>,191/// A time duration specified as a string, for example:192/// - "5m"193/// - "2h15m"194/// - "1d6h"195pub tolerance_str: Option<PlSmallStr>,196pub left_by: Option<Vec<PlSmallStr>>,197pub right_by: Option<Vec<PlSmallStr>>,198/// Allow equal matches199pub allow_eq: bool,200pub check_sortedness: bool,201}202203fn check_asof_columns(204a: &Series,205b: &Series,206has_tolerance: bool,207check_sortedness: bool,208by_groups_present: bool,209) -> PolarsResult<()> {210let dtype_a = a.dtype();211let dtype_b = b.dtype();212if has_tolerance {213polars_ensure!(214dtype_a.to_physical().is_primitive_numeric() && dtype_b.to_physical().is_primitive_numeric(),215InvalidOperation:216"asof join with tolerance is only supported on numeric/temporal keys"217);218} else {219polars_ensure!(220dtype_a.to_physical().is_primitive() && dtype_b.to_physical().is_primitive(),221InvalidOperation:222"asof join is only supported on primitive key types"223);224}225polars_ensure!(226dtype_a == dtype_b,227ComputeError: "mismatching key dtypes in asof-join: `{}` and `{}`",228a.dtype(), b.dtype()229);230if check_sortedness {231if by_groups_present {232polars_warn!("Sortedness of columns cannot be checked when 'by' groups provided");233} else {234a.ensure_sorted_arg("asof_join")?;235b.ensure_sorted_arg("asof_join")?;236}237}238Ok(())239}240241#[derive(Clone, Copy, Debug, PartialEq, Eq, Default, Hash)]242#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]243#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]244pub enum AsofStrategy {245/// selects the last row in the right DataFrame whose ‘on’ key is less than or equal to the left’s key246#[default]247Backward,248/// selects the first row in the right DataFrame whose ‘on’ key is greater than or equal to the left’s key.249Forward,250/// selects the right in the right DataFrame whose 'on' key is nearest to the left's key.251Nearest,252}253254pub trait AsofJoin: IntoDf {255#[doc(hidden)]256#[allow(clippy::too_many_arguments)]257fn _join_asof(258&self,259other: &DataFrame,260left_key: &Series,261right_key: &Series,262strategy: AsofStrategy,263tolerance: Option<AnyValue<'static>>,264suffix: Option<PlSmallStr>,265slice: Option<(i64, usize)>,266coalesce: bool,267allow_eq: bool,268check_sortedness: bool,269) -> PolarsResult<DataFrame> {270let self_df = self.to_df();271272check_asof_columns(273left_key,274right_key,275tolerance.is_some(),276check_sortedness,277false,278)?;279let left_key = left_key.to_physical_repr();280let right_key = right_key.to_physical_repr();281282let mut take_idx = match left_key.dtype() {283DataType::Int64 => {284let ca = left_key.i64().unwrap();285join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)286},287DataType::Int32 => {288let ca = left_key.i32().unwrap();289join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)290},291#[cfg(feature = "dtype-i128")]292DataType::Int128 => {293let ca = left_key.i128().unwrap();294join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)295},296DataType::UInt64 => {297let ca = left_key.u64().unwrap();298join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)299},300DataType::UInt32 => {301let ca = left_key.u32().unwrap();302join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)303},304DataType::Float32 => {305let ca = left_key.f32().unwrap();306join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)307},308DataType::Float64 => {309let ca = left_key.f64().unwrap();310join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)311},312DataType::Boolean => {313let ca = left_key.bool().unwrap();314join_asof::<BooleanType>(ca, &right_key, strategy, allow_eq)315},316DataType::Binary => {317let ca = left_key.binary().unwrap();318join_asof::<BinaryType>(ca, &right_key, strategy, allow_eq)319},320DataType::String => {321let ca = left_key.str().unwrap();322let right_binary = right_key.cast(&DataType::Binary).unwrap();323join_asof::<BinaryType>(&ca.as_binary(), &right_binary, strategy, allow_eq)324},325DataType::Int8 | DataType::UInt8 | DataType::Int16 | DataType::UInt16 => {326let left_key = left_key.cast(&DataType::Int32).unwrap();327let right_key = right_key.cast(&DataType::Int32).unwrap();328let ca = left_key.i32().unwrap();329join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)330},331dt => polars_bail!(opq = asof_join, dt),332}?;333try_raise_keyboard_interrupt();334335// Drop right join column.336let other = if coalesce && left_key.name() == right_key.name() {337Cow::Owned(other.drop(right_key.name())?)338} else {339Cow::Borrowed(other)340};341342let mut left = self_df.clone();343if let Some((offset, len)) = slice {344left = left.slice(offset, len);345take_idx = take_idx.slice(offset, len);346}347348// SAFETY: join tuples are in bounds.349let right_df = unsafe { other.take_unchecked(&take_idx) };350351_finish_join(left, right_df, suffix)352}353}354355impl AsofJoin for DataFrame {}356357358