Path: blob/main/crates/polars-ops/src/frame/join/iejoin/mod.rs
8458 views
#![allow(unsafe_op_in_unsafe_fn)]1mod filtered_bit_array;2mod l1_l2;34use std::cmp::min;56use filtered_bit_array::FilteredBitArray;7use l1_l2::*;8use polars_core::chunked_array::ChunkedArray;9use polars_core::datatypes::{IdxCa, NumericNative, PolarsNumericType};10use polars_core::frame::DataFrame;11use polars_core::prelude::*;12use polars_core::series::IsSorted;13use polars_core::utils::{_set_partition_size, split};14use polars_core::{POOL, with_match_physical_numeric_polars_type};15use polars_error::{PolarsResult, polars_err};16use polars_utils::IdxSize;17use polars_utils::binary_search::ExponentialSearch;18use polars_utils::itertools::Itertools;19use polars_utils::total_ord::{TotalEq, TotalOrd};20use rayon::prelude::*;21#[cfg(feature = "serde")]22use serde::{Deserialize, Serialize};2324use crate::frame::_finish_join;2526#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]28#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]29pub enum InequalityOperator {30#[default]31Lt,32LtEq,33Gt,34GtEq,35}3637impl InequalityOperator {38fn is_strict(&self) -> bool {39matches!(self, InequalityOperator::Gt | InequalityOperator::Lt)40}41}42#[derive(Clone, Debug, PartialEq, Eq, Default, Hash)]43#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]44pub struct IEJoinOptions {45pub operator1: InequalityOperator,46pub operator2: Option<InequalityOperator>,47}4849#[allow(clippy::too_many_arguments)]50fn ie_join_impl_t<T: PolarsNumericType>(51slice: Option<(i64, usize)>,52l1_order: IdxCa,53l2_order: &[IdxSize],54op1: InequalityOperator,55op2: InequalityOperator,56x: Series,57y_ordered_by_x: Series,58left_height: usize,59) -> PolarsResult<(Vec<IdxSize>, Vec<IdxSize>)> {60// Create a bit array with order corresponding to L1,61// denoting which entries have been visited while traversing L2.62let mut bit_array = FilteredBitArray::from_len_zeroed(l1_order.len());6364let mut left_row_idx: Vec<IdxSize> = vec![];65let mut right_row_idx: Vec<IdxSize> = vec![];6667let slice_end = slice_end_index(slice);68let mut match_count = 0;6970let ca: &ChunkedArray<T> = x.as_ref().as_ref();71let l1_array = build_l1_array(ca, &l1_order, left_height as IdxSize)?;7273if op2.is_strict() {74// For strict inequalities, we rely on using a stable sort of l2 so that75// p values only increase as we traverse a run of equal y values.76// To handle inclusive comparisons in x and duplicate x values we also need the77// sort of l1 to be stable, so that the left hand side entries come before the right78// hand side entries (as we mark visited entries from the right hand side).79for &p in l2_order {80match_count += unsafe {81l1_array.process_entry(82p as usize,83&mut bit_array,84op1,85&mut left_row_idx,86&mut right_row_idx,87)88};8990if slice_end.is_some_and(|end| match_count >= end) {91break;92}93}94} else {95let l2_array = build_l2_array(&y_ordered_by_x, l2_order)?;9697// For non-strict inequalities in l2, we need to track runs of equal y values and only98// check for matches after we reach the end of the run and have marked all rhs entries99// in the run as visited.100let mut run_start = 0;101102for i in 0..l2_array.len() {103// Elide bound checks104unsafe {105let item = l2_array.get_unchecked(i);106let p = item.l1_index;107l1_array.mark_visited(p as usize, &mut bit_array);108109if item.run_end {110for l2_item in l2_array.get_unchecked(run_start..i + 1) {111let p = l2_item.l1_index;112match_count += l1_array.process_lhs_entry(113p as usize,114&bit_array,115op1,116&mut left_row_idx,117&mut right_row_idx,118);119}120121run_start = i + 1;122123if slice_end.is_some_and(|end| match_count >= end) {124break;125}126}127}128}129}130Ok((left_row_idx, right_row_idx))131}132133fn piecewise_merge_join_impl_t<T, P>(134slice: Option<(i64, usize)>,135left_order: Option<&[IdxSize]>,136right_order: Option<&[IdxSize]>,137left_ordered: Series,138right_ordered: Series,139mut pred: P,140) -> PolarsResult<(Vec<IdxSize>, Vec<IdxSize>)>141where142T: PolarsNumericType,143P: FnMut(&T::Native, &T::Native) -> bool,144{145let slice_end = slice_end_index(slice);146147let mut left_row_idx: Vec<IdxSize> = vec![];148let mut right_row_idx: Vec<IdxSize> = vec![];149150let left_ca: &ChunkedArray<T> = left_ordered.as_ref().as_ref();151let right_ca: &ChunkedArray<T> = right_ordered.as_ref().as_ref();152153debug_assert!(left_order.is_none_or(|order| order.len() == left_ca.len()));154debug_assert!(right_order.is_none_or(|order| order.len() == right_ca.len()));155156let mut left_idx = 0;157let mut right_idx = 0;158let mut match_count = 0;159160while left_idx < left_ca.len() {161debug_assert!(left_ca.get(left_idx).is_some());162let left_val = unsafe { left_ca.value_unchecked(left_idx) };163while right_idx < right_ca.len() {164debug_assert!(right_ca.get(right_idx).is_some());165let right_val = unsafe { right_ca.value_unchecked(right_idx) };166if pred(&left_val, &right_val) {167// If the predicate is true, then it will also be true for all168// remaining rows from the right side.169let left_row = match left_order {170None => left_idx as IdxSize,171Some(order) => order[left_idx],172};173let right_end_idx = match slice_end {174None => right_ca.len(),175Some(end) => min(right_ca.len(), (end as usize) - match_count + right_idx),176};177for included_right_row_idx in right_idx..right_end_idx {178let right_row = match right_order {179None => included_right_row_idx as IdxSize,180Some(order) => order[included_right_row_idx],181};182left_row_idx.push(left_row);183right_row_idx.push(right_row);184}185match_count += right_end_idx - right_idx;186break;187} else {188right_idx += 1;189}190}191if right_idx == right_ca.len() {192// We've reached the end of the right side193// so there can be no more matches for LHS rows194break;195}196if slice_end.is_some_and(|end| match_count >= end as usize) {197break;198}199left_idx += 1;200}201202Ok((left_row_idx, right_row_idx))203}204205pub(super) fn iejoin_par(206left: &DataFrame,207right: &DataFrame,208selected_left: Vec<Series>,209selected_right: Vec<Series>,210options: &IEJoinOptions,211suffix: Option<PlSmallStr>,212slice: Option<(i64, usize)>,213) -> PolarsResult<DataFrame> {214let l1_descending = matches!(215options.operator1,216InequalityOperator::Gt | InequalityOperator::GtEq217);218219let l1_sort_options = SortOptions::default()220.with_maintain_order(true)221.with_nulls_last(false)222.with_order_descending(l1_descending);223224let sl = &selected_left[0];225let l1_s_l = sl226.arg_sort(l1_sort_options)227.slice(sl.null_count() as i64, sl.len() - sl.null_count());228229let sr = &selected_right[0];230let l1_s_r = sr231.arg_sort(l1_sort_options)232.slice(sr.null_count() as i64, sr.len() - sr.null_count());233234// Because we do a cartesian product, the number of partitions is squared.235// We take the sqrt, but we don't expect every partition to produce results and work can be236// imbalanced, so we multiply the number of partitions by 2, which leads to 2^2= 4237let n_partitions = (_set_partition_size() as f32).sqrt() as usize * 2;238let splitted_a = split(&l1_s_l, n_partitions);239let splitted_b = split(&l1_s_r, n_partitions);240241let cartesian_prod = splitted_a242.iter()243.flat_map(|l| splitted_b.iter().map(move |r| (l, r)))244.collect::<Vec<_>>();245246let iter = cartesian_prod.par_iter().map(|(l_l1_idx, r_l1_idx)| {247if l_l1_idx.is_empty() || r_l1_idx.is_empty() {248return Ok(None);249}250fn get_extrema<'a>(251l1_idx: &'a IdxCa,252s: &'a Series,253) -> Option<(AnyValue<'a>, AnyValue<'a>)> {254let first = l1_idx.first()?;255let last = l1_idx.last()?;256257let start = s.get(first as usize).unwrap();258let end = s.get(last as usize).unwrap();259260Some(if start < end {261(start, end)262} else {263(end, start)264})265}266let Some((min_l, max_l)) = get_extrema(l_l1_idx, sl) else {267return Ok(None);268};269let Some((min_r, max_r)) = get_extrema(r_l1_idx, sr) else {270return Ok(None);271};272273let include_block = match options.operator1 {274InequalityOperator::Lt => min_l < max_r,275InequalityOperator::LtEq => min_l <= max_r,276InequalityOperator::Gt => max_l > min_r,277InequalityOperator::GtEq => max_l >= min_r,278};279280if include_block {281let (mut l, mut r) = unsafe {282(283selected_left284.iter()285.map(|s| s.take_unchecked(l_l1_idx))286.collect_vec(),287selected_right288.iter()289.map(|s| s.take_unchecked(r_l1_idx))290.collect_vec(),291)292};293let sorted_flag = if l1_descending {294IsSorted::Descending295} else {296IsSorted::Ascending297};298// We sorted using the first series299l[0].set_sorted_flag(sorted_flag);300r[0].set_sorted_flag(sorted_flag);301302// Compute the row indexes303let (idx_l, idx_r) = if options.operator2.is_some() {304iejoin_tuples(l, r, options, None)305} else {306piecewise_merge_join_tuples(l, r, options, None)307}?;308309if idx_l.is_empty() {310return Ok(None);311}312313// These are row indexes in the slices we have given, so we use those to gather in the314// original l1 offset arrays. This gives us indexes in the original tables.315unsafe {316Ok(Some((317l_l1_idx.take_unchecked(&idx_l),318r_l1_idx.take_unchecked(&idx_r),319)))320}321} else {322Ok(None)323}324});325326let row_indices = POOL.install(|| iter.collect::<PolarsResult<Vec<_>>>())?;327328let mut left_idx = IdxCa::default();329let mut right_idx = IdxCa::default();330for (l, r) in row_indices.into_iter().flatten() {331left_idx.append(&l)?;332right_idx.append(&r)?;333}334if let Some((offset, end)) = slice {335left_idx = left_idx.slice(offset, end);336right_idx = right_idx.slice(offset, end);337}338339unsafe { materialize_join(left, right, &left_idx, &right_idx, suffix) }340}341342pub(super) fn iejoin(343left: &DataFrame,344right: &DataFrame,345selected_left: Vec<Series>,346selected_right: Vec<Series>,347options: &IEJoinOptions,348suffix: Option<PlSmallStr>,349slice: Option<(i64, usize)>,350) -> PolarsResult<DataFrame> {351let (left_row_idx, right_row_idx) = if options.operator2.is_some() {352iejoin_tuples(selected_left, selected_right, options, slice)353} else {354piecewise_merge_join_tuples(selected_left, selected_right, options, slice)355}?;356unsafe { materialize_join(left, right, &left_row_idx, &right_row_idx, suffix) }357}358359unsafe fn materialize_join(360left: &DataFrame,361right: &DataFrame,362left_row_idx: &IdxCa,363right_row_idx: &IdxCa,364suffix: Option<PlSmallStr>,365) -> PolarsResult<DataFrame> {366try_raise_keyboard_interrupt();367let (join_left, join_right) = {368POOL.join(369|| left.take_unchecked(left_row_idx),370|| right.take_unchecked(right_row_idx),371)372};373374_finish_join(join_left, join_right, suffix)375}376377/// Inequality join. Matches rows between two DataFrames using two inequality operators378/// (one of [<, <=, >, >=]).379/// Based on Khayyat et al. 2015, "Lightning Fast and Space Efficient Inequality Joins"380/// and extended to work with duplicate values.381fn iejoin_tuples(382selected_left: Vec<Series>,383selected_right: Vec<Series>,384options: &IEJoinOptions,385slice: Option<(i64, usize)>,386) -> PolarsResult<(IdxCa, IdxCa)> {387if selected_left.len() != 2 {388return Err(389polars_err!(ComputeError: "IEJoin requires exactly two expressions from the left DataFrame"),390);391};392if selected_right.len() != 2 {393return Err(394polars_err!(ComputeError: "IEJoin requires exactly two expressions from the right DataFrame"),395);396};397398let op1 = options.operator1;399let op2 = match options.operator2 {400None => {401return Err(polars_err!(ComputeError: "IEJoin requires two inequality operators"));402},403Some(op2) => op2,404};405406// Determine the sort order based on the comparison operators used.407// We want to sort L1 so that "x[i] op1 x[j]" is true for j > i,408// and L2 so that "y[i] op2 y[j]" is true for j < i409// (except in the case of duplicates and strict inequalities).410// Note that the algorithms published in Khayyat et al. have incorrect logic for411// determining whether to sort descending.412let l1_descending = matches!(op1, InequalityOperator::Gt | InequalityOperator::GtEq);413let l2_descending = matches!(op2, InequalityOperator::Lt | InequalityOperator::LtEq);414415let mut x = selected_left[0].to_physical_repr().into_owned();416let left_height = x.len();417418x.extend(&selected_right[0].to_physical_repr())?;419// Rechunk because we will gather.420let x = x.rechunk();421422let mut y = selected_left[1].to_physical_repr().into_owned();423y.extend(&selected_right[1].to_physical_repr())?;424// Rechunk because we will gather.425let y = y.rechunk();426427let l1_sort_options = SortOptions::default()428.with_maintain_order(true)429.with_nulls_last(false)430.with_order_descending(l1_descending);431// Get ordering of x, skipping any null entries as these cannot be matches432let l1_order = x433.arg_sort(l1_sort_options)434.slice(x.null_count() as i64, x.len() - x.null_count());435436let y_ordered_by_x = unsafe { y.take_unchecked(&l1_order) };437let l2_sort_options = SortOptions::default()438.with_maintain_order(true)439.with_nulls_last(false)440.with_order_descending(l2_descending);441// Get the indexes into l1, ordered by y values.442// l2_order is the same as "p" from Khayyat et al.443let l2_order = y_ordered_by_x.arg_sort(l2_sort_options).slice(444y_ordered_by_x.null_count() as i64,445y_ordered_by_x.len() - y_ordered_by_x.null_count(),446);447let l2_order = l2_order.rechunk();448let l2_order = l2_order.downcast_as_array().values().as_slice();449450let (left_row_idx, right_row_idx) = with_match_physical_numeric_polars_type!(x.dtype(), |$T| {451ie_join_impl_t::<$T>(452slice,453l1_order,454l2_order,455op1,456op2,457x,458y_ordered_by_x,459left_height460)461})?;462463debug_assert_eq!(left_row_idx.len(), right_row_idx.len());464let left_row_idx = IdxCa::from_vec("".into(), left_row_idx);465let right_row_idx = IdxCa::from_vec("".into(), right_row_idx);466let (left_row_idx, right_row_idx) = match slice {467None => (left_row_idx, right_row_idx),468Some((offset, len)) => (469left_row_idx.slice(offset, len),470right_row_idx.slice(offset, len),471),472};473Ok((left_row_idx, right_row_idx))474}475476/// Piecewise merge join, for joins with only a single inequality.477fn piecewise_merge_join_tuples(478selected_left: Vec<Series>,479selected_right: Vec<Series>,480options: &IEJoinOptions,481slice: Option<(i64, usize)>,482) -> PolarsResult<(IdxCa, IdxCa)> {483if selected_left.len() != 1 {484return Err(485polars_err!(ComputeError: "Piecewise merge join requires exactly one expression from the left DataFrame"),486);487};488if selected_right.len() != 1 {489return Err(490polars_err!(ComputeError: "Piecewise merge join requires exactly one expression from the right DataFrame"),491);492};493if options.operator2.is_some() {494return Err(495polars_err!(ComputeError: "Piecewise merge join expects only one inequality operator"),496);497}498499let op = options.operator1;500// The left side is sorted such that if the condition is false, it will also501// be false for the same RHS row and all following LHS rows.502// The right side is sorted such that if the condition is true then it is also503// true for the same LHS row and all following RHS rows.504// The desired sort order should match the l1 order used in iejoin_par505// so we don't need to re-sort slices when doing a parallel join.506let descending = matches!(op, InequalityOperator::Gt | InequalityOperator::GtEq);507508let left = selected_left[0].to_physical_repr().into_owned();509let mut right = selected_right[0].to_physical_repr().into_owned();510let must_cast = right.dtype().matches_schema_type(left.dtype())?;511if must_cast {512right = right.cast(left.dtype())?;513}514515fn get_sorted(series: Series, descending: bool) -> (Series, Option<IdxCa>) {516let expected_flag = if descending {517IsSorted::Descending518} else {519IsSorted::Ascending520};521if (series.is_sorted_flag() == expected_flag || series.len() <= 1) && !series.has_nulls() {522// Fast path, no need to re-sort523(series, None)524} else {525let sort_options = SortOptions::default()526.with_nulls_last(false)527.with_order_descending(descending);528529// Get order and slice to ignore any null values, which cannot be match results530let mut order = series.arg_sort(sort_options).slice(531series.null_count() as i64,532series.len() - series.null_count(),533);534order.rechunk_mut();535let ordered = unsafe { series.take_unchecked(&order) };536(ordered, Some(order))537}538}539540let (left_ordered, left_order) = get_sorted(left, descending);541debug_assert!(542left_order543.as_ref()544.is_none_or(|order| order.chunks().len() == 1)545);546let left_order = left_order547.as_ref()548.map(|order| order.downcast_get(0).unwrap().values().as_slice());549550let (right_ordered, right_order) = get_sorted(right, descending);551debug_assert!(552right_order553.as_ref()554.is_none_or(|order| order.chunks().len() == 1)555);556let right_order = right_order557.as_ref()558.map(|order| order.downcast_get(0).unwrap().values().as_slice());559560let (left_row_idx, right_row_idx) = with_match_physical_numeric_polars_type!(left_ordered.dtype(), |$T| {561match op {562InequalityOperator::Lt => piecewise_merge_join_impl_t::<$T, _>(563slice,564left_order,565right_order,566left_ordered,567right_ordered,568|l, r| l.tot_lt(r),569),570InequalityOperator::LtEq => piecewise_merge_join_impl_t::<$T, _>(571slice,572left_order,573right_order,574left_ordered,575right_ordered,576|l, r| l.tot_le(r),577),578InequalityOperator::Gt => piecewise_merge_join_impl_t::<$T, _>(579slice,580left_order,581right_order,582left_ordered,583right_ordered,584|l, r| l.tot_gt(r),585),586InequalityOperator::GtEq => piecewise_merge_join_impl_t::<$T, _>(587slice,588left_order,589right_order,590left_ordered,591right_ordered,592|l, r| l.tot_ge(r),593),594}595})?;596597debug_assert_eq!(left_row_idx.len(), right_row_idx.len());598let left_row_idx = IdxCa::from_vec("".into(), left_row_idx);599let right_row_idx = IdxCa::from_vec("".into(), right_row_idx);600let (left_row_idx, right_row_idx) = match slice {601None => (left_row_idx, right_row_idx),602Some((offset, len)) => (603left_row_idx.slice(offset, len),604right_row_idx.slice(offset, len),605),606};607Ok((left_row_idx, right_row_idx))608}609610fn slice_end_index(slice: Option<(i64, usize)>) -> Option<i64> {611match slice {612Some((offset, len)) if offset >= 0 => Some(offset.saturating_add_unsigned(len as u64)),613_ => None,614}615}616617618