Path: blob/main/crates/polars-ops/src/frame/join/iejoin/mod.rs
6940 views
#![allow(unsafe_op_in_unsafe_fn)]1mod filtered_bit_array;2mod l1_l2;34use std::cmp::min;56use filtered_bit_array::FilteredBitArray;7use l1_l2::*;8use polars_core::chunked_array::ChunkedArray;9use polars_core::datatypes::{IdxCa, NumericNative, PolarsNumericType};10use polars_core::frame::DataFrame;11use polars_core::prelude::*;12use polars_core::series::IsSorted;13use polars_core::utils::{_set_partition_size, split};14use polars_core::{POOL, with_match_physical_numeric_polars_type};15use polars_error::{PolarsResult, polars_err};16use polars_utils::IdxSize;17use polars_utils::binary_search::ExponentialSearch;18use polars_utils::itertools::Itertools;19use polars_utils::total_ord::{TotalEq, TotalOrd};20use rayon::prelude::*;21#[cfg(feature = "serde")]22use serde::{Deserialize, Serialize};2324use crate::frame::_finish_join;2526#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]28pub enum InequalityOperator {29#[default]30Lt,31LtEq,32Gt,33GtEq,34}3536impl InequalityOperator {37fn is_strict(&self) -> bool {38matches!(self, InequalityOperator::Gt | InequalityOperator::Lt)39}40}41#[derive(Clone, Debug, PartialEq, Eq, Default, Hash)]42#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]43pub struct IEJoinOptions {44pub operator1: InequalityOperator,45pub operator2: Option<InequalityOperator>,46}4748#[allow(clippy::too_many_arguments)]49fn ie_join_impl_t<T: PolarsNumericType>(50slice: Option<(i64, usize)>,51l1_order: IdxCa,52l2_order: &[IdxSize],53op1: InequalityOperator,54op2: InequalityOperator,55x: Series,56y_ordered_by_x: Series,57left_height: usize,58) -> PolarsResult<(Vec<IdxSize>, Vec<IdxSize>)> {59// Create a bit array with order corresponding to L1,60// denoting which entries have been visited while traversing L2.61let mut bit_array = FilteredBitArray::from_len_zeroed(l1_order.len());6263let mut left_row_idx: Vec<IdxSize> = vec![];64let mut right_row_idx: Vec<IdxSize> = vec![];6566let slice_end = slice_end_index(slice);67let mut match_count = 0;6869let ca: &ChunkedArray<T> = x.as_ref().as_ref();70let l1_array = build_l1_array(ca, &l1_order, left_height as IdxSize)?;7172if op2.is_strict() {73// For strict inequalities, we rely on using a stable sort of l2 so that74// p values only increase as we traverse a run of equal y values.75// To handle inclusive comparisons in x and duplicate x values we also need the76// sort of l1 to be stable, so that the left hand side entries come before the right77// hand side entries (as we mark visited entries from the right hand side).78for &p in l2_order {79match_count += unsafe {80l1_array.process_entry(81p as usize,82&mut bit_array,83op1,84&mut left_row_idx,85&mut right_row_idx,86)87};8889if slice_end.is_some_and(|end| match_count >= end) {90break;91}92}93} else {94let l2_array = build_l2_array(&y_ordered_by_x, l2_order)?;9596// For non-strict inequalities in l2, we need to track runs of equal y values and only97// check for matches after we reach the end of the run and have marked all rhs entries98// in the run as visited.99let mut run_start = 0;100101for i in 0..l2_array.len() {102// Elide bound checks103unsafe {104let item = l2_array.get_unchecked(i);105let p = item.l1_index;106l1_array.mark_visited(p as usize, &mut bit_array);107108if item.run_end {109for l2_item in l2_array.get_unchecked(run_start..i + 1) {110let p = l2_item.l1_index;111match_count += l1_array.process_lhs_entry(112p as usize,113&bit_array,114op1,115&mut left_row_idx,116&mut right_row_idx,117);118}119120run_start = i + 1;121122if slice_end.is_some_and(|end| match_count >= end) {123break;124}125}126}127}128}129Ok((left_row_idx, right_row_idx))130}131132fn piecewise_merge_join_impl_t<T, P>(133slice: Option<(i64, usize)>,134left_order: Option<&[IdxSize]>,135right_order: Option<&[IdxSize]>,136left_ordered: Series,137right_ordered: Series,138mut pred: P,139) -> PolarsResult<(Vec<IdxSize>, Vec<IdxSize>)>140where141T: PolarsNumericType,142P: FnMut(&T::Native, &T::Native) -> bool,143{144let slice_end = slice_end_index(slice);145146let mut left_row_idx: Vec<IdxSize> = vec![];147let mut right_row_idx: Vec<IdxSize> = vec![];148149let left_ca: &ChunkedArray<T> = left_ordered.as_ref().as_ref();150let right_ca: &ChunkedArray<T> = right_ordered.as_ref().as_ref();151152debug_assert!(left_order.is_none_or(|order| order.len() == left_ca.len()));153debug_assert!(right_order.is_none_or(|order| order.len() == right_ca.len()));154155let mut left_idx = 0;156let mut right_idx = 0;157let mut match_count = 0;158159while left_idx < left_ca.len() {160debug_assert!(left_ca.get(left_idx).is_some());161let left_val = unsafe { left_ca.value_unchecked(left_idx) };162while right_idx < right_ca.len() {163debug_assert!(right_ca.get(right_idx).is_some());164let right_val = unsafe { right_ca.value_unchecked(right_idx) };165if pred(&left_val, &right_val) {166// If the predicate is true, then it will also be true for all167// remaining rows from the right side.168let left_row = match left_order {169None => left_idx as IdxSize,170Some(order) => order[left_idx],171};172let right_end_idx = match slice_end {173None => right_ca.len(),174Some(end) => min(right_ca.len(), (end as usize) - match_count + right_idx),175};176for included_right_row_idx in right_idx..right_end_idx {177let right_row = match right_order {178None => included_right_row_idx as IdxSize,179Some(order) => order[included_right_row_idx],180};181left_row_idx.push(left_row);182right_row_idx.push(right_row);183}184match_count += right_end_idx - right_idx;185break;186} else {187right_idx += 1;188}189}190if right_idx == right_ca.len() {191// We've reached the end of the right side192// so there can be no more matches for LHS rows193break;194}195if slice_end.is_some_and(|end| match_count >= end as usize) {196break;197}198left_idx += 1;199}200201Ok((left_row_idx, right_row_idx))202}203204pub(super) fn iejoin_par(205left: &DataFrame,206right: &DataFrame,207selected_left: Vec<Series>,208selected_right: Vec<Series>,209options: &IEJoinOptions,210suffix: Option<PlSmallStr>,211slice: Option<(i64, usize)>,212) -> PolarsResult<DataFrame> {213let l1_descending = matches!(214options.operator1,215InequalityOperator::Gt | InequalityOperator::GtEq216);217218let l1_sort_options = SortOptions::default()219.with_maintain_order(true)220.with_nulls_last(false)221.with_order_descending(l1_descending);222223let sl = &selected_left[0];224let l1_s_l = sl225.arg_sort(l1_sort_options)226.slice(sl.null_count() as i64, sl.len() - sl.null_count());227228let sr = &selected_right[0];229let l1_s_r = sr230.arg_sort(l1_sort_options)231.slice(sr.null_count() as i64, sr.len() - sr.null_count());232233// Because we do a cartesian product, the number of partitions is squared.234// We take the sqrt, but we don't expect every partition to produce results and work can be235// imbalanced, so we multiply the number of partitions by 2, which leads to 2^2= 4236let n_partitions = (_set_partition_size() as f32).sqrt() as usize * 2;237let splitted_a = split(&l1_s_l, n_partitions);238let splitted_b = split(&l1_s_r, n_partitions);239240let cartesian_prod = splitted_a241.iter()242.flat_map(|l| splitted_b.iter().map(move |r| (l, r)))243.collect::<Vec<_>>();244245let iter = cartesian_prod.par_iter().map(|(l_l1_idx, r_l1_idx)| {246if l_l1_idx.is_empty() || r_l1_idx.is_empty() {247return Ok(None);248}249fn get_extrema<'a>(250l1_idx: &'a IdxCa,251s: &'a Series,252) -> Option<(AnyValue<'a>, AnyValue<'a>)> {253let first = l1_idx.first()?;254let last = l1_idx.last()?;255256let start = s.get(first as usize).unwrap();257let end = s.get(last as usize).unwrap();258259Some(if start < end {260(start, end)261} else {262(end, start)263})264}265let Some((min_l, max_l)) = get_extrema(l_l1_idx, sl) else {266return Ok(None);267};268let Some((min_r, max_r)) = get_extrema(r_l1_idx, sr) else {269return Ok(None);270};271272let include_block = match options.operator1 {273InequalityOperator::Lt => min_l < max_r,274InequalityOperator::LtEq => min_l <= max_r,275InequalityOperator::Gt => max_l > min_r,276InequalityOperator::GtEq => max_l >= min_r,277};278279if include_block {280let (mut l, mut r) = unsafe {281(282selected_left283.iter()284.map(|s| s.take_unchecked(l_l1_idx))285.collect_vec(),286selected_right287.iter()288.map(|s| s.take_unchecked(r_l1_idx))289.collect_vec(),290)291};292let sorted_flag = if l1_descending {293IsSorted::Descending294} else {295IsSorted::Ascending296};297// We sorted using the first series298l[0].set_sorted_flag(sorted_flag);299r[0].set_sorted_flag(sorted_flag);300301// Compute the row indexes302let (idx_l, idx_r) = if options.operator2.is_some() {303iejoin_tuples(l, r, options, None)304} else {305piecewise_merge_join_tuples(l, r, options, None)306}?;307308if idx_l.is_empty() {309return Ok(None);310}311312// These are row indexes in the slices we have given, so we use those to gather in the313// original l1 offset arrays. This gives us indexes in the original tables.314unsafe {315Ok(Some((316l_l1_idx.take_unchecked(&idx_l),317r_l1_idx.take_unchecked(&idx_r),318)))319}320} else {321Ok(None)322}323});324325let row_indices = POOL.install(|| iter.collect::<PolarsResult<Vec<_>>>())?;326327let mut left_idx = IdxCa::default();328let mut right_idx = IdxCa::default();329for (l, r) in row_indices.into_iter().flatten() {330left_idx.append(&l)?;331right_idx.append(&r)?;332}333if let Some((offset, end)) = slice {334left_idx = left_idx.slice(offset, end);335right_idx = right_idx.slice(offset, end);336}337338unsafe { materialize_join(left, right, &left_idx, &right_idx, suffix) }339}340341pub(super) fn iejoin(342left: &DataFrame,343right: &DataFrame,344selected_left: Vec<Series>,345selected_right: Vec<Series>,346options: &IEJoinOptions,347suffix: Option<PlSmallStr>,348slice: Option<(i64, usize)>,349) -> PolarsResult<DataFrame> {350let (left_row_idx, right_row_idx) = if options.operator2.is_some() {351iejoin_tuples(selected_left, selected_right, options, slice)352} else {353piecewise_merge_join_tuples(selected_left, selected_right, options, slice)354}?;355unsafe { materialize_join(left, right, &left_row_idx, &right_row_idx, suffix) }356}357358unsafe fn materialize_join(359left: &DataFrame,360right: &DataFrame,361left_row_idx: &IdxCa,362right_row_idx: &IdxCa,363suffix: Option<PlSmallStr>,364) -> PolarsResult<DataFrame> {365try_raise_keyboard_interrupt();366let (join_left, join_right) = {367POOL.join(368|| left.take_unchecked(left_row_idx),369|| right.take_unchecked(right_row_idx),370)371};372373_finish_join(join_left, join_right, suffix)374}375376/// Inequality join. Matches rows between two DataFrames using two inequality operators377/// (one of [<, <=, >, >=]).378/// Based on Khayyat et al. 2015, "Lightning Fast and Space Efficient Inequality Joins"379/// and extended to work with duplicate values.380fn iejoin_tuples(381selected_left: Vec<Series>,382selected_right: Vec<Series>,383options: &IEJoinOptions,384slice: Option<(i64, usize)>,385) -> PolarsResult<(IdxCa, IdxCa)> {386if selected_left.len() != 2 {387return Err(388polars_err!(ComputeError: "IEJoin requires exactly two expressions from the left DataFrame"),389);390};391if selected_right.len() != 2 {392return Err(393polars_err!(ComputeError: "IEJoin requires exactly two expressions from the right DataFrame"),394);395};396397let op1 = options.operator1;398let op2 = match options.operator2 {399None => {400return Err(polars_err!(ComputeError: "IEJoin requires two inequality operators"));401},402Some(op2) => op2,403};404405// Determine the sort order based on the comparison operators used.406// We want to sort L1 so that "x[i] op1 x[j]" is true for j > i,407// and L2 so that "y[i] op2 y[j]" is true for j < i408// (except in the case of duplicates and strict inequalities).409// Note that the algorithms published in Khayyat et al. have incorrect logic for410// determining whether to sort descending.411let l1_descending = matches!(op1, InequalityOperator::Gt | InequalityOperator::GtEq);412let l2_descending = matches!(op2, InequalityOperator::Lt | InequalityOperator::LtEq);413414let mut x = selected_left[0].to_physical_repr().into_owned();415let left_height = x.len();416417x.extend(&selected_right[0].to_physical_repr())?;418// Rechunk because we will gather.419let x = x.rechunk();420421let mut y = selected_left[1].to_physical_repr().into_owned();422y.extend(&selected_right[1].to_physical_repr())?;423// Rechunk because we will gather.424let y = y.rechunk();425426let l1_sort_options = SortOptions::default()427.with_maintain_order(true)428.with_nulls_last(false)429.with_order_descending(l1_descending);430// Get ordering of x, skipping any null entries as these cannot be matches431let l1_order = x432.arg_sort(l1_sort_options)433.slice(x.null_count() as i64, x.len() - x.null_count());434435let y_ordered_by_x = unsafe { y.take_unchecked(&l1_order) };436let l2_sort_options = SortOptions::default()437.with_maintain_order(true)438.with_nulls_last(false)439.with_order_descending(l2_descending);440// Get the indexes into l1, ordered by y values.441// l2_order is the same as "p" from Khayyat et al.442let l2_order = y_ordered_by_x.arg_sort(l2_sort_options).slice(443y_ordered_by_x.null_count() as i64,444y_ordered_by_x.len() - y_ordered_by_x.null_count(),445);446let l2_order = l2_order.rechunk();447let l2_order = l2_order.downcast_as_array().values().as_slice();448449let (left_row_idx, right_row_idx) = with_match_physical_numeric_polars_type!(x.dtype(), |$T| {450ie_join_impl_t::<$T>(451slice,452l1_order,453l2_order,454op1,455op2,456x,457y_ordered_by_x,458left_height459)460})?;461462debug_assert_eq!(left_row_idx.len(), right_row_idx.len());463let left_row_idx = IdxCa::from_vec("".into(), left_row_idx);464let right_row_idx = IdxCa::from_vec("".into(), right_row_idx);465let (left_row_idx, right_row_idx) = match slice {466None => (left_row_idx, right_row_idx),467Some((offset, len)) => (468left_row_idx.slice(offset, len),469right_row_idx.slice(offset, len),470),471};472Ok((left_row_idx, right_row_idx))473}474475/// Piecewise merge join, for joins with only a single inequality.476fn piecewise_merge_join_tuples(477selected_left: Vec<Series>,478selected_right: Vec<Series>,479options: &IEJoinOptions,480slice: Option<(i64, usize)>,481) -> PolarsResult<(IdxCa, IdxCa)> {482if selected_left.len() != 1 {483return Err(484polars_err!(ComputeError: "Piecewise merge join requires exactly one expression from the left DataFrame"),485);486};487if selected_right.len() != 1 {488return Err(489polars_err!(ComputeError: "Piecewise merge join requires exactly one expression from the right DataFrame"),490);491};492if options.operator2.is_some() {493return Err(494polars_err!(ComputeError: "Piecewise merge join expects only one inequality operator"),495);496}497498let op = options.operator1;499// The left side is sorted such that if the condition is false, it will also500// be false for the same RHS row and all following LHS rows.501// The right side is sorted such that if the condition is true then it is also502// true for the same LHS row and all following RHS rows.503// The desired sort order should match the l1 order used in iejoin_par504// so we don't need to re-sort slices when doing a parallel join.505let descending = matches!(op, InequalityOperator::Gt | InequalityOperator::GtEq);506507let left = selected_left[0].to_physical_repr().into_owned();508let mut right = selected_right[0].to_physical_repr().into_owned();509let must_cast = right.dtype().matches_schema_type(left.dtype())?;510if must_cast {511right = right.cast(left.dtype())?;512}513514fn get_sorted(series: Series, descending: bool) -> (Series, Option<IdxCa>) {515let expected_flag = if descending {516IsSorted::Descending517} else {518IsSorted::Ascending519};520if (series.is_sorted_flag() == expected_flag || series.len() <= 1) && !series.has_nulls() {521// Fast path, no need to re-sort522(series, None)523} else {524let sort_options = SortOptions::default()525.with_nulls_last(false)526.with_order_descending(descending);527528// Get order and slice to ignore any null values, which cannot be match results529let mut order = series.arg_sort(sort_options).slice(530series.null_count() as i64,531series.len() - series.null_count(),532);533order.rechunk_mut();534let ordered = unsafe { series.take_unchecked(&order) };535(ordered, Some(order))536}537}538539let (left_ordered, left_order) = get_sorted(left, descending);540debug_assert!(541left_order542.as_ref()543.is_none_or(|order| order.chunks().len() == 1)544);545let left_order = left_order546.as_ref()547.map(|order| order.downcast_get(0).unwrap().values().as_slice());548549let (right_ordered, right_order) = get_sorted(right, descending);550debug_assert!(551right_order552.as_ref()553.is_none_or(|order| order.chunks().len() == 1)554);555let right_order = right_order556.as_ref()557.map(|order| order.downcast_get(0).unwrap().values().as_slice());558559let (left_row_idx, right_row_idx) = with_match_physical_numeric_polars_type!(left_ordered.dtype(), |$T| {560match op {561InequalityOperator::Lt => piecewise_merge_join_impl_t::<$T, _>(562slice,563left_order,564right_order,565left_ordered,566right_ordered,567|l, r| l.tot_lt(r),568),569InequalityOperator::LtEq => piecewise_merge_join_impl_t::<$T, _>(570slice,571left_order,572right_order,573left_ordered,574right_ordered,575|l, r| l.tot_le(r),576),577InequalityOperator::Gt => piecewise_merge_join_impl_t::<$T, _>(578slice,579left_order,580right_order,581left_ordered,582right_ordered,583|l, r| l.tot_gt(r),584),585InequalityOperator::GtEq => piecewise_merge_join_impl_t::<$T, _>(586slice,587left_order,588right_order,589left_ordered,590right_ordered,591|l, r| l.tot_ge(r),592),593}594})?;595596debug_assert_eq!(left_row_idx.len(), right_row_idx.len());597let left_row_idx = IdxCa::from_vec("".into(), left_row_idx);598let right_row_idx = IdxCa::from_vec("".into(), right_row_idx);599let (left_row_idx, right_row_idx) = match slice {600None => (left_row_idx, right_row_idx),601Some((offset, len)) => (602left_row_idx.slice(offset, len),603right_row_idx.slice(offset, len),604),605};606Ok((left_row_idx, right_row_idx))607}608609fn slice_end_index(slice: Option<(i64, usize)>) -> Option<i64> {610match slice {611Some((offset, len)) if offset >= 0 => Some(offset.saturating_add_unsigned(len as u64)),612_ => None,613}614}615616617