Path: blob/main/crates/polars-ops/src/frame/join/iejoin/l1_l2.rs
6940 views
#![allow(unsafe_op_in_unsafe_fn)]1use polars_core::chunked_array::ChunkedArray;2use polars_core::datatypes::{IdxCa, PolarsNumericType};3use polars_core::prelude::Series;4use polars_core::with_match_physical_numeric_polars_type;5use polars_error::PolarsResult;6use polars_utils::IdxSize;7use polars_utils::total_ord::TotalOrd;89use super::*;1011/// Create a vector of L1 items from the array of LHS x values concatenated with RHS x values12/// and their ordering.13pub(super) fn build_l1_array<T>(14ca: &ChunkedArray<T>,15order: &IdxCa,16right_df_offset: IdxSize,17) -> PolarsResult<Vec<L1Item<T::Native>>>18where19T: PolarsNumericType,20{21assert_eq!(order.null_count(), 0);22assert_eq!(ca.chunks().len(), 1);23let arr = ca.downcast_get(0).unwrap();24// Even if there are nulls, they will not be selected by order.25let values = arr.values().as_slice();2627let mut array: Vec<L1Item<T::Native>> = Vec::with_capacity(ca.len());2829for order_arr in order.downcast_iter() {30for index in order_arr.values().as_slice().iter().copied() {31debug_assert!(arr.get(index as usize).is_some());32let value = unsafe { *values.get_unchecked(index as usize) };33let row_index = if index < right_df_offset {34// Row from LHS35index as i64 + 136} else {37// Row from RHS38-((index - right_df_offset) as i64) - 139};40array.push(L1Item { row_index, value });41}42}4344Ok(array)45}4647pub(super) fn build_l2_array(s: &Series, order: &[IdxSize]) -> PolarsResult<Vec<L2Item>> {48with_match_physical_numeric_polars_type!(s.dtype(), |$T| {49build_l2_array_impl::<$T>(s.as_ref().as_ref(), order)50})51}5253/// Create a vector of L2 items from the array of y values ordered according to the L1 order,54/// and their ordering. We don't need to store actual y values but only track whether we're at55/// the end of a run of equal values.56fn build_l2_array_impl<T>(ca: &ChunkedArray<T>, order: &[IdxSize]) -> PolarsResult<Vec<L2Item>>57where58T: PolarsNumericType,59T::Native: TotalOrd,60{61assert_eq!(ca.chunks().len(), 1);6263let mut array = Vec::with_capacity(ca.len());64let mut prev_index = 0;65let mut prev_value = T::Native::default();6667let arr = ca.downcast_get(0).unwrap();68// Even if there are nulls, they will not be selected by order.69let values = arr.values().as_slice();7071for (i, l1_index) in order.iter().copied().enumerate() {72debug_assert!(arr.get(l1_index as usize).is_some());73let value = unsafe { *values.get_unchecked(l1_index as usize) };74if i > 0 {75array.push(L2Item {76l1_index: prev_index,77run_end: value.tot_ne(&prev_value),78});79}80prev_index = l1_index;81prev_value = value;82}83if !order.is_empty() {84array.push(L2Item {85l1_index: prev_index,86run_end: true,87});88}89Ok(array)90}9192/// Item in L1 array used in the IEJoin algorithm93#[derive(Clone, Copy, Debug)]94pub(super) struct L1Item<T> {95/// 1 based index for entries from the LHS df, or -1 based index for entries from the RHS96pub(super) row_index: i64,97/// X value98pub(super) value: T,99}100101/// Item in L2 array used in the IEJoin algorithm102#[derive(Clone, Copy, Debug)]103pub(super) struct L2Item {104/// Corresponding index into the L1 array of105pub(super) l1_index: IdxSize,106/// Whether this is the end of a run of equal y values107pub(super) run_end: bool,108}109110pub(super) trait L1Array {111unsafe fn process_entry(112&self,113l1_index: usize,114bit_array: &mut FilteredBitArray,115op1: InequalityOperator,116left_row_ids: &mut Vec<IdxSize>,117right_row_ids: &mut Vec<IdxSize>,118) -> i64;119120unsafe fn process_lhs_entry(121&self,122l1_index: usize,123bit_array: &FilteredBitArray,124op1: InequalityOperator,125left_row_ids: &mut Vec<IdxSize>,126right_row_ids: &mut Vec<IdxSize>,127) -> i64;128129unsafe fn mark_visited(&self, index: usize, bit_array: &mut FilteredBitArray);130}131132/// Find the position in the L1 array where we should begin checking for matches,133/// given the index in L1 corresponding to the current position in L2.134unsafe fn find_search_start_index<T>(135l1_array: &[L1Item<T>],136index: usize,137operator: InequalityOperator,138) -> usize139where140T: NumericNative,141T: TotalOrd,142{143let sub_l1 = l1_array.get_unchecked(index..);144let value = l1_array.get_unchecked(index).value;145146match operator {147InequalityOperator::Gt => {148sub_l1.partition_point_exponential(|a| a.value.tot_ge(&value)) + index149},150InequalityOperator::Lt => {151sub_l1.partition_point_exponential(|a| a.value.tot_le(&value)) + index152},153InequalityOperator::GtEq => {154sub_l1.partition_point_exponential(|a| value.tot_lt(&a.value)) + index155},156InequalityOperator::LtEq => {157sub_l1.partition_point_exponential(|a| value.tot_gt(&a.value)) + index158},159}160}161162fn find_matches_in_l1<T>(163l1_array: &[L1Item<T>],164l1_index: usize,165row_index: i64,166bit_array: &FilteredBitArray,167op1: InequalityOperator,168left_row_ids: &mut Vec<IdxSize>,169right_row_ids: &mut Vec<IdxSize>,170) -> i64171where172T: NumericNative,173T: TotalOrd,174{175debug_assert!(row_index > 0);176let mut match_count = 0;177178// This entry comes from the left hand side DataFrame.179// Find all following entries in L1 (meaning they satisfy the first operator)180// that have already been visited (so satisfy the second operator).181// Because we use a stable sort for l2, we know that we won't find any182// matches for duplicate y values when traversing forwards in l1.183let start_index = unsafe { find_search_start_index(l1_array, l1_index, op1) };184unsafe {185bit_array.on_set_bits_from(start_index, |set_bit: usize| {186// SAFETY187// set bit is within bounds.188let right_row_index = l1_array.get_unchecked(set_bit).row_index;189debug_assert!(right_row_index < 0);190left_row_ids.push((row_index - 1) as IdxSize);191right_row_ids.push((-right_row_index) as IdxSize - 1);192match_count += 1;193})194};195196match_count197}198199impl<T> L1Array for Vec<L1Item<T>>200where201T: NumericNative,202{203unsafe fn process_entry(204&self,205l1_index: usize,206bit_array: &mut FilteredBitArray,207op1: InequalityOperator,208left_row_ids: &mut Vec<IdxSize>,209right_row_ids: &mut Vec<IdxSize>,210) -> i64 {211let row_index = self.get_unchecked(l1_index).row_index;212let from_lhs = row_index > 0;213if from_lhs {214find_matches_in_l1(215self,216l1_index,217row_index,218bit_array,219op1,220left_row_ids,221right_row_ids,222)223} else {224bit_array.set_bit_unchecked(l1_index);2250226}227}228229unsafe fn process_lhs_entry(230&self,231l1_index: usize,232bit_array: &FilteredBitArray,233op1: InequalityOperator,234left_row_ids: &mut Vec<IdxSize>,235right_row_ids: &mut Vec<IdxSize>,236) -> i64 {237let row_index = self.get_unchecked(l1_index).row_index;238let from_lhs = row_index > 0;239if from_lhs {240find_matches_in_l1(241self,242l1_index,243row_index,244bit_array,245op1,246left_row_ids,247right_row_ids,248)249} else {2500251}252}253254unsafe fn mark_visited(&self, index: usize, bit_array: &mut FilteredBitArray) {255let from_lhs = self.get_unchecked(index).row_index > 0;256// We only mark RHS entries as visited,257// so that we don't try to match LHS entries with other LHS entries.258if !from_lhs {259bit_array.set_bit_unchecked(index);260}261}262}263264265