Path: blob/main/crates/polars-ops/src/frame/join/cross_join.rs
6940 views
use polars_core::utils::{1_set_partition_size, CustomIterTools, NoNull, accumulate_dataframes_vertical_unchecked,2concat_df_unchecked, split,3};4use polars_utils::pl_str::PlSmallStr;56use super::*;78fn slice_take(9total_rows: IdxSize,10n_rows_right: IdxSize,11slice: Option<(i64, usize)>,12inner: fn(IdxSize, IdxSize, IdxSize) -> IdxCa,13) -> IdxCa {14match slice {15None => inner(0, total_rows, n_rows_right),16Some((offset, len)) => {17let (offset, len) = slice_offsets(offset, len, total_rows as usize);18inner(offset as IdxSize, (len + offset) as IdxSize, n_rows_right)19},20}21}2223fn take_left(total_rows: IdxSize, n_rows_right: IdxSize, slice: Option<(i64, usize)>) -> IdxCa {24fn inner(offset: IdxSize, total_rows: IdxSize, n_rows_right: IdxSize) -> IdxCa {25let mut take: NoNull<IdxCa> = (offset..total_rows)26.map(|i| i / n_rows_right)27.collect_trusted();28take.set_sorted_flag(IsSorted::Ascending);29take.into_inner()30}31slice_take(total_rows, n_rows_right, slice, inner)32}3334fn take_right(total_rows: IdxSize, n_rows_right: IdxSize, slice: Option<(i64, usize)>) -> IdxCa {35fn inner(offset: IdxSize, total_rows: IdxSize, n_rows_right: IdxSize) -> IdxCa {36let take: NoNull<IdxCa> = (offset..total_rows)37.map(|i| i % n_rows_right)38.collect_trusted();39take.into_inner()40}41slice_take(total_rows, n_rows_right, slice, inner)42}4344pub trait CrossJoin: IntoDf {45#[doc(hidden)]46/// used by streaming47fn _cross_join_with_names(48&self,49other: &DataFrame,50names: &[PlSmallStr],51) -> PolarsResult<DataFrame> {52let (mut l_df, r_df) = cross_join_dfs(self.to_df(), other, None, false)?;53l_df.clear_schema();5455unsafe {56l_df.get_columns_mut().extend_from_slice(r_df.get_columns());5758l_df.get_columns_mut()59.iter_mut()60.zip(names)61.for_each(|(s, name)| {62if s.name() != name {63s.rename(name.clone());64}65});66}67Ok(l_df)68}6970/// Creates the Cartesian product from both frames, preserves the order of the left keys.71fn cross_join(72&self,73other: &DataFrame,74suffix: Option<PlSmallStr>,75slice: Option<(i64, usize)>,76) -> PolarsResult<DataFrame> {77let (l_df, r_df) = cross_join_dfs(self.to_df(), other, slice, true)?;7879_finish_join(l_df, r_df, suffix)80}81}8283impl CrossJoin for DataFrame {}8485fn cross_join_dfs(86df_self: &DataFrame,87other: &DataFrame,88slice: Option<(i64, usize)>,89parallel: bool,90) -> PolarsResult<(DataFrame, DataFrame)> {91let n_rows_left = df_self.height() as IdxSize;92let n_rows_right = other.height() as IdxSize;93let Some(total_rows) = n_rows_left.checked_mul(n_rows_right) else {94polars_bail!(95ComputeError: "cross joins would produce more rows than fits into 2^32; \96consider compiling with polars-big-idx feature, or set 'streaming'"97);98};99if n_rows_left == 0 || n_rows_right == 0 {100return Ok((df_self.clear(), other.clear()));101}102103// the left side has the Nth row combined with every row from right.104// So let's say we have the following no. of rows105// left: 3106// right: 4107//108// left take idx: 000011112222109// right take idx: 012301230123110111let create_left_df = || {112// SAFETY:113// take left is in bounds114unsafe {115df_self.take_unchecked_impl(&take_left(total_rows, n_rows_right, slice), parallel)116}117};118119let create_right_df = || {120// concatenation of dataframes is very expensive if we need to make the series mutable121// many times, these are atomic operations122// so we choose a different strategy at > 100 rows (arbitrarily small number)123if n_rows_left > 100 || slice.is_some() {124// SAFETY:125// take right is in bounds126unsafe {127other.take_unchecked_impl(&take_right(total_rows, n_rows_right, slice), parallel)128}129} else {130let iter = (0..n_rows_left).map(|_| other);131concat_df_unchecked(iter)132}133};134let (l_df, r_df) = if parallel {135try_raise_keyboard_interrupt();136POOL.install(|| rayon::join(create_left_df, create_right_df))137} else {138(create_left_df(), create_right_df())139};140Ok((l_df, r_df))141}142143pub(super) fn fused_cross_filter(144left: &DataFrame,145right: &DataFrame,146suffix: Option<PlSmallStr>,147cross_join_options: &CrossJoinOptions,148) -> PolarsResult<DataFrame> {149// Because we do a cartesian product, the number of partitions is squared.150// We take the sqrt, but we don't expect every partition to produce results and work can be151// imbalanced, so we multiply the number of partitions by 2;152let n_partitions = (_set_partition_size() as f32).sqrt() as usize * 2;153let splitted_a = split(left, n_partitions);154let splitted_b = split(right, n_partitions);155156let cartesian_prod = splitted_a157.iter()158.flat_map(|l| splitted_b.iter().map(move |r| (l, r)))159.collect::<Vec<_>>();160161let names = _finish_join(left.clear(), right.clear(), suffix)?;162let rename_names = names.get_column_names();163let rename_names = &rename_names[left.width()..];164165let dfs = POOL166.install(|| {167cartesian_prod.par_iter().map(|(left, right)| {168let (mut left, right) = cross_join_dfs(left, right, None, false)?;169let mut right_columns = right.take_columns();170171for (c, name) in right_columns.iter_mut().zip(rename_names) {172c.rename((*name).clone());173}174175unsafe { left.hstack_mut_unchecked(&right_columns) };176177cross_join_options.predicate.apply(left)178})179})180.collect::<PolarsResult<Vec<_>>>()?;181182Ok(accumulate_dataframes_vertical_unchecked(dfs))183}184185186