Path: blob/main/crates/polars-ops/src/frame/join/asof/mod.rs
8458 views
mod default;1mod groups;2use std::borrow::Cow;3use std::cmp::Ordering;45use default::*;6pub use groups::AsofJoinBy;7use polars_core::prelude::*;8use polars_utils::pl_str::PlSmallStr;9use polars_utils::total_ord::TotalOrd;10#[cfg(feature = "serde")]11use serde::{Deserialize, Serialize};1213use super::{_finish_join, build_tables};14use crate::frame::IntoDf;15use crate::series::SeriesMethods;1617#[inline]18fn ge_allow_eq<T: TotalOrd>(l: &T, r: &T, allow_eq: bool) -> bool {19match l.tot_cmp(r) {20Ordering::Equal => allow_eq,21Ordering::Greater => true,22Ordering::Less => false,23}24}2526#[inline]27fn lt_allow_eq<T: TotalOrd>(l: &T, r: &T, allow_eq: bool) -> bool {28match l.tot_cmp(r) {29Ordering::Equal => allow_eq,30Ordering::Less => true,31Ordering::Greater => false,32}33}3435trait AsofJoinState<T> {36fn next<F: FnMut(IdxSize) -> Option<T>>(37&mut self,38left_val: &T,39right: F,40n_right: IdxSize,41) -> Option<IdxSize>;4243fn new(allow_eq: bool) -> Self;44}4546struct AsofJoinForwardState {47scan_offset: IdxSize,48allow_eq: bool,49}5051impl<T: TotalOrd> AsofJoinState<T> for AsofJoinForwardState {52fn new(allow_eq: bool) -> Self {53AsofJoinForwardState {54scan_offset: Default::default(),55allow_eq,56}57}58#[inline]59fn next<F: FnMut(IdxSize) -> Option<T>>(60&mut self,61left_val: &T,62mut right: F,63n_right: IdxSize,64) -> Option<IdxSize> {65while (self.scan_offset) < n_right {66if let Some(right_val) = right(self.scan_offset) {67if ge_allow_eq(&right_val, left_val, self.allow_eq) {68return Some(self.scan_offset);69}70}71self.scan_offset += 1;72}73None74}75}7677struct AsofJoinBackwardState {78// best_bound is the greatest right index <= left_val.79best_bound: Option<IdxSize>,80scan_offset: IdxSize,81allow_eq: bool,82}8384impl<T: TotalOrd> AsofJoinState<T> for AsofJoinBackwardState {85fn new(allow_eq: bool) -> Self {86AsofJoinBackwardState {87scan_offset: Default::default(),88best_bound: Default::default(),89allow_eq,90}91}92#[inline]93fn next<F: FnMut(IdxSize) -> Option<T>>(94&mut self,95left_val: &T,96mut right: F,97n_right: IdxSize,98) -> Option<IdxSize> {99while self.scan_offset < n_right {100if let Some(right_val) = right(self.scan_offset) {101if lt_allow_eq(&right_val, left_val, self.allow_eq) {102self.best_bound = Some(self.scan_offset);103} else {104break;105}106}107self.scan_offset += 1;108}109self.best_bound110}111}112113#[derive(Default)]114struct AsofJoinNearestState {115/// The last value that is strictly smaller than the current116/// left value.117strictly_smaller: Option<IdxSize>,118/// If `allow_eq == false`: the first value strictly greater than the119/// current left value.120/// If `allow_eq == true`: the last value of the first chunk of equal121/// values that are strictly greater than the current left value.122upper_candidate: IdxSize,123allow_eq: bool,124}125126impl<T: NumericNative> AsofJoinState<T> for AsofJoinNearestState {127fn new(allow_eq: bool) -> Self {128AsofJoinNearestState {129allow_eq,130..Default::default()131}132}133#[inline]134fn next<F: FnMut(IdxSize) -> Option<T>>(135&mut self,136left_val: &T,137mut right: F,138n_right: IdxSize,139) -> Option<IdxSize> {140// Skipping ahead to the first value greater than left_val. This is141// cheaper than computing differences.142while self.upper_candidate < n_right {143let Some(scan_right_val) = right(self.upper_candidate) else {144self.upper_candidate += 1;145continue;146};147if scan_right_val > *left_val {148break;149}150self.upper_candidate += 1;151}152153if self.allow_eq154&& self.upper_candidate > 0155&& right(self.upper_candidate - 1) == Some(*left_val)156{157return Some(self.upper_candidate - 1);158}159160// It is possible there are later elements equal to our161// scan, so keep going on.162while self.upper_candidate + 1 < n_right163&& right(self.upper_candidate + 1) == right(self.upper_candidate)164{165self.upper_candidate += 1;166}167168let mut cursor = self.strictly_smaller.unwrap_or(0);169while cursor < self.upper_candidate {170let Some(scan_right_val) = right(cursor) else {171cursor += 1;172continue;173};174if scan_right_val >= *left_val {175break;176}177self.strictly_smaller = Some(cursor);178cursor += 1;179}180181let mut right_get = |idx: IdxSize| (idx < n_right).then(|| right(idx)).flatten();182let lower = self.strictly_smaller.and_then(&mut right_get);183let upper = right_get(self.upper_candidate);184match (lower, upper) {185(None, None) => None,186(Some(_), None) => self.strictly_smaller,187(None, Some(_)) => Some(self.upper_candidate),188(Some(lo), Some(hi)) => {189let lo_diff = left_val.abs_diff(lo);190let hi_diff = left_val.abs_diff(hi);191if hi_diff <= lo_diff {192Some(self.upper_candidate)193} else {194self.strictly_smaller195}196},197}198}199}200201#[derive(Clone, Debug, PartialEq, Default, Hash)]202#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]203#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]204pub struct AsOfOptions {205pub strategy: AsofStrategy,206/// A tolerance in the same unit as the asof column207pub tolerance: Option<Scalar>,208/// A time duration specified as a string, for example:209/// - "5m"210/// - "2h15m"211/// - "1d6h"212pub tolerance_str: Option<PlSmallStr>,213pub left_by: Option<Vec<PlSmallStr>>,214pub right_by: Option<Vec<PlSmallStr>>,215/// Allow equal matches216pub allow_eq: bool,217pub check_sortedness: bool,218}219220fn check_asof_columns(221a: &Series,222b: &Series,223has_tolerance: bool,224check_sortedness: bool,225by_groups_present: bool,226) -> PolarsResult<()> {227let dtype_a = a.dtype();228let dtype_b = b.dtype();229if has_tolerance {230polars_ensure!(231dtype_a.to_physical().is_primitive_numeric() && dtype_b.to_physical().is_primitive_numeric(),232InvalidOperation:233"asof join with tolerance is only supported on numeric/temporal keys"234);235} else {236polars_ensure!(237dtype_a.to_physical().is_primitive() && dtype_b.to_physical().is_primitive(),238InvalidOperation:239"asof join is only supported on primitive key types"240);241}242polars_ensure!(243dtype_a == dtype_b,244ComputeError: "mismatching key dtypes in asof-join: `{}` and `{}`",245a.dtype(), b.dtype()246);247if check_sortedness {248if by_groups_present {249polars_warn!("Sortedness of columns cannot be checked when 'by' groups provided");250} else {251a.ensure_sorted_arg("asof_join")?;252b.ensure_sorted_arg("asof_join")?;253}254}255Ok(())256}257258#[derive(Clone, Copy, Debug, PartialEq, Eq, Default, Hash)]259#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]260#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]261pub enum AsofStrategy {262/// selects the last row in the right DataFrame whose ‘on’ key is less than or equal to the left’s key263#[default]264Backward,265/// selects the first row in the right DataFrame whose ‘on’ key is greater than or equal to the left’s key.266Forward,267/// selects the right in the right DataFrame whose 'on' key is nearest to the left's key.268Nearest,269}270271pub trait AsofJoin: IntoDf {272#[doc(hidden)]273#[allow(clippy::too_many_arguments)]274fn _join_asof(275&self,276other: &DataFrame,277left_key: &Series,278right_key: &Series,279strategy: AsofStrategy,280tolerance: Option<AnyValue<'static>>,281suffix: Option<PlSmallStr>,282slice: Option<(i64, usize)>,283coalesce: bool,284allow_eq: bool,285check_sortedness: bool,286) -> PolarsResult<DataFrame> {287let self_df = self.to_df();288289check_asof_columns(290left_key,291right_key,292tolerance.is_some(),293check_sortedness,294false,295)?;296let left_key = left_key.to_physical_repr();297let right_key = right_key.to_physical_repr();298299let mut take_idx = match left_key.dtype() {300#[cfg(feature = "dtype-i128")]301DataType::Int128 => {302let ca = left_key.i128().unwrap();303join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)304},305DataType::Int64 => {306let ca = left_key.i64().unwrap();307join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)308},309DataType::Int32 => {310let ca = left_key.i32().unwrap();311join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)312},313#[cfg(feature = "dtype-u128")]314DataType::UInt128 => {315let ca = left_key.u128().unwrap();316join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)317},318DataType::UInt64 => {319let ca = left_key.u64().unwrap();320join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)321},322DataType::UInt32 => {323let ca = left_key.u32().unwrap();324join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)325},326#[cfg(feature = "dtype-f16")]327DataType::Float16 => {328let ca = left_key.f16().unwrap();329join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)330},331DataType::Float32 => {332let ca = left_key.f32().unwrap();333join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)334},335DataType::Float64 => {336let ca = left_key.f64().unwrap();337join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)338},339DataType::Boolean => {340let ca = left_key.bool().unwrap();341join_asof::<BooleanType>(ca, &right_key, strategy, allow_eq)342},343DataType::Binary => {344let ca = left_key.binary().unwrap();345join_asof::<BinaryType>(ca, &right_key, strategy, allow_eq)346},347DataType::String => {348let ca = left_key.str().unwrap();349let right_binary = right_key.cast(&DataType::Binary).unwrap();350join_asof::<BinaryType>(&ca.as_binary(), &right_binary, strategy, allow_eq)351},352DataType::Int8 | DataType::UInt8 | DataType::Int16 | DataType::UInt16 => {353let left_key = left_key.cast(&DataType::Int32).unwrap();354let right_key = right_key.cast(&DataType::Int32).unwrap();355let ca = left_key.i32().unwrap();356join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)357},358dt => polars_bail!(opq = asof_join, dt),359}?;360try_raise_keyboard_interrupt();361362// Drop right join column.363let other = if coalesce && left_key.name() == right_key.name() {364Cow::Owned(other.drop(right_key.name())?)365} else {366Cow::Borrowed(other)367};368369let mut left = self_df.clone();370if let Some((offset, len)) = slice {371left = left.slice(offset, len);372take_idx = take_idx.slice(offset, len);373}374375// SAFETY: join tuples are in bounds.376let right_df = unsafe { other.take_unchecked(&take_idx) };377378_finish_join(left, right_df, suffix)379}380}381382impl AsofJoin for DataFrame {}383384385