Path: blob/main/crates/polars-ops/src/frame/join/cross_join.rs
8430 views
use polars_core::utils::{1_set_partition_size, CustomIterTools, NoNull, accumulate_dataframes_vertical_unchecked,2concat_df_unchecked, split,3};4use polars_utils::pl_str::PlSmallStr;56use super::*;78fn slice_take(9total_rows: IdxSize,10n_rows_right: IdxSize,11slice: Option<(i64, usize)>,12inner: fn(IdxSize, IdxSize, IdxSize) -> IdxCa,13) -> IdxCa {14match slice {15None => inner(0, total_rows, n_rows_right),16Some((offset, len)) => {17let (offset, len) = slice_offsets(offset, len, total_rows as usize);18inner(offset as IdxSize, (len + offset) as IdxSize, n_rows_right)19},20}21}2223fn take_left(total_rows: IdxSize, n_rows_right: IdxSize, slice: Option<(i64, usize)>) -> IdxCa {24fn inner(offset: IdxSize, total_rows: IdxSize, n_rows_right: IdxSize) -> IdxCa {25let mut take: NoNull<IdxCa> = (offset..total_rows)26.map(|i| i / n_rows_right)27.collect_trusted();28take.set_sorted_flag(IsSorted::Ascending);29take.into_inner()30}31slice_take(total_rows, n_rows_right, slice, inner)32}3334fn take_right(total_rows: IdxSize, n_rows_right: IdxSize, slice: Option<(i64, usize)>) -> IdxCa {35fn inner(offset: IdxSize, total_rows: IdxSize, n_rows_right: IdxSize) -> IdxCa {36let take: NoNull<IdxCa> = (offset..total_rows)37.map(|i| i % n_rows_right)38.collect_trusted();39take.into_inner()40}41slice_take(total_rows, n_rows_right, slice, inner)42}4344pub trait CrossJoin: IntoDf {45/// Creates the Cartesian product from both frames, preserves the order of the left keys.46fn cross_join(47&self,48other: &DataFrame,49suffix: Option<PlSmallStr>,50slice: Option<(i64, usize)>,51maintain_order: MaintainOrderJoin,52) -> PolarsResult<DataFrame> {53let (l_df, r_df) = cross_join_dfs(self.to_df(), other, slice, true, maintain_order)?;5455_finish_join(l_df, r_df, suffix)56}57}5859impl CrossJoin for DataFrame {}6061fn cross_join_dfs<'a>(62mut df_self: &'a DataFrame,63mut other: &'a DataFrame,64slice: Option<(i64, usize)>,65parallel: bool,66maintain_order: MaintainOrderJoin,67) -> PolarsResult<(DataFrame, DataFrame)> {68if df_self.height() == 0 || other.height() == 0 {69return Ok((df_self.clear(), other.clear()));70}7172let left_is_primary = match maintain_order {73MaintainOrderJoin::None => true,74MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight => true,75MaintainOrderJoin::Right | MaintainOrderJoin::RightLeft => false,76};7778if !left_is_primary {79core::mem::swap(&mut df_self, &mut other);80}8182let n_rows_left = df_self.height() as IdxSize;83let n_rows_right = other.height() as IdxSize;84let Some(total_rows) = n_rows_left.checked_mul(n_rows_right) else {85polars_bail!(86ComputeError: "cross joins would produce more rows than fits into 2^32; \87consider compiling with polars-big-idx feature, or set 'streaming'"88);89};9091// the left side has the Nth row combined with every row from right.92// So let's say we have the following no. of rows93// left: 394// right: 495//96// left take idx: 00001111222297// right take idx: 0123012301239899let create_left_df = || {100// SAFETY:101// take left is in bounds102unsafe {103df_self.take_unchecked_impl(&take_left(total_rows, n_rows_right, slice), parallel)104}105};106107let create_right_df = || {108// concatenation of dataframes is very expensive if we need to make the series mutable109// many times, these are atomic operations110// so we choose a different strategy at > 100 rows (arbitrarily small number)111if n_rows_left > 100 || slice.is_some() {112// SAFETY:113// take right is in bounds114unsafe {115other.take_unchecked_impl(&take_right(total_rows, n_rows_right, slice), parallel)116}117} else {118let iter = (0..n_rows_left).map(|_| other);119concat_df_unchecked(iter)120}121};122let (l_df, r_df) = if parallel {123try_raise_keyboard_interrupt();124POOL.install(|| rayon::join(create_left_df, create_right_df))125} else {126(create_left_df(), create_right_df())127};128if left_is_primary {129Ok((l_df, r_df))130} else {131Ok((r_df, l_df))132}133}134135pub(super) fn fused_cross_filter(136left: &DataFrame,137right: &DataFrame,138suffix: Option<PlSmallStr>,139cross_join_options: &CrossJoinOptions,140maintain_order: MaintainOrderJoin,141) -> PolarsResult<DataFrame> {142let unfiltered_size = (left.height() as u64).saturating_mul(right.height() as u64);143let chunk_size = (unfiltered_size / _set_partition_size() as u64).clamp(1, 100_000);144let num_chunks = (unfiltered_size / chunk_size).max(1) as usize;145146let left_is_primary = match maintain_order {147MaintainOrderJoin::None => true,148MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight => true,149MaintainOrderJoin::Right | MaintainOrderJoin::RightLeft => false,150};151152let split_chunks;153let cartesian_prod = if left_is_primary {154split_chunks = split(left, num_chunks);155split_chunks.iter().map(|l| (l, right)).collect::<Vec<_>>()156} else {157split_chunks = split(right, num_chunks);158split_chunks.iter().map(|r| (left, r)).collect::<Vec<_>>()159};160161let names = _finish_join(left.clear(), right.clear(), suffix)?;162let rename_names = names.get_column_names();163let rename_names = &rename_names[left.width()..];164165let dfs = POOL166.install(|| {167cartesian_prod.par_iter().map(|(left, right)| {168let (mut left, right) = cross_join_dfs(left, right, None, false, maintain_order)?;169let mut right_columns = right.into_columns();170171for (c, name) in right_columns.iter_mut().zip(rename_names) {172c.rename((*name).clone());173}174175unsafe { left.hstack_mut_unchecked(&right_columns) };176177cross_join_options.predicate.apply(left)178})179})180.collect::<PolarsResult<Vec<_>>>()?;181182Ok(accumulate_dataframes_vertical_unchecked(dfs))183}184185186