Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/frame/join/cross_join.rs
8430 views
1
use polars_core::utils::{
2
_set_partition_size, CustomIterTools, NoNull, accumulate_dataframes_vertical_unchecked,
3
concat_df_unchecked, split,
4
};
5
use polars_utils::pl_str::PlSmallStr;
6
7
use super::*;
8
9
fn slice_take(
10
total_rows: IdxSize,
11
n_rows_right: IdxSize,
12
slice: Option<(i64, usize)>,
13
inner: fn(IdxSize, IdxSize, IdxSize) -> IdxCa,
14
) -> IdxCa {
15
match slice {
16
None => inner(0, total_rows, n_rows_right),
17
Some((offset, len)) => {
18
let (offset, len) = slice_offsets(offset, len, total_rows as usize);
19
inner(offset as IdxSize, (len + offset) as IdxSize, n_rows_right)
20
},
21
}
22
}
23
24
fn take_left(total_rows: IdxSize, n_rows_right: IdxSize, slice: Option<(i64, usize)>) -> IdxCa {
25
fn inner(offset: IdxSize, total_rows: IdxSize, n_rows_right: IdxSize) -> IdxCa {
26
let mut take: NoNull<IdxCa> = (offset..total_rows)
27
.map(|i| i / n_rows_right)
28
.collect_trusted();
29
take.set_sorted_flag(IsSorted::Ascending);
30
take.into_inner()
31
}
32
slice_take(total_rows, n_rows_right, slice, inner)
33
}
34
35
fn take_right(total_rows: IdxSize, n_rows_right: IdxSize, slice: Option<(i64, usize)>) -> IdxCa {
36
fn inner(offset: IdxSize, total_rows: IdxSize, n_rows_right: IdxSize) -> IdxCa {
37
let take: NoNull<IdxCa> = (offset..total_rows)
38
.map(|i| i % n_rows_right)
39
.collect_trusted();
40
take.into_inner()
41
}
42
slice_take(total_rows, n_rows_right, slice, inner)
43
}
44
45
pub trait CrossJoin: IntoDf {
46
/// Creates the Cartesian product from both frames, preserves the order of the left keys.
47
fn cross_join(
48
&self,
49
other: &DataFrame,
50
suffix: Option<PlSmallStr>,
51
slice: Option<(i64, usize)>,
52
maintain_order: MaintainOrderJoin,
53
) -> PolarsResult<DataFrame> {
54
let (l_df, r_df) = cross_join_dfs(self.to_df(), other, slice, true, maintain_order)?;
55
56
_finish_join(l_df, r_df, suffix)
57
}
58
}
59
60
impl CrossJoin for DataFrame {}
61
62
fn cross_join_dfs<'a>(
63
mut df_self: &'a DataFrame,
64
mut other: &'a DataFrame,
65
slice: Option<(i64, usize)>,
66
parallel: bool,
67
maintain_order: MaintainOrderJoin,
68
) -> PolarsResult<(DataFrame, DataFrame)> {
69
if df_self.height() == 0 || other.height() == 0 {
70
return Ok((df_self.clear(), other.clear()));
71
}
72
73
let left_is_primary = match maintain_order {
74
MaintainOrderJoin::None => true,
75
MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight => true,
76
MaintainOrderJoin::Right | MaintainOrderJoin::RightLeft => false,
77
};
78
79
if !left_is_primary {
80
core::mem::swap(&mut df_self, &mut other);
81
}
82
83
let n_rows_left = df_self.height() as IdxSize;
84
let n_rows_right = other.height() as IdxSize;
85
let Some(total_rows) = n_rows_left.checked_mul(n_rows_right) else {
86
polars_bail!(
87
ComputeError: "cross joins would produce more rows than fits into 2^32; \
88
consider compiling with polars-big-idx feature, or set 'streaming'"
89
);
90
};
91
92
// the left side has the Nth row combined with every row from right.
93
// So let's say we have the following no. of rows
94
// left: 3
95
// right: 4
96
//
97
// left take idx: 000011112222
98
// right take idx: 012301230123
99
100
let create_left_df = || {
101
// SAFETY:
102
// take left is in bounds
103
unsafe {
104
df_self.take_unchecked_impl(&take_left(total_rows, n_rows_right, slice), parallel)
105
}
106
};
107
108
let create_right_df = || {
109
// concatenation of dataframes is very expensive if we need to make the series mutable
110
// many times, these are atomic operations
111
// so we choose a different strategy at > 100 rows (arbitrarily small number)
112
if n_rows_left > 100 || slice.is_some() {
113
// SAFETY:
114
// take right is in bounds
115
unsafe {
116
other.take_unchecked_impl(&take_right(total_rows, n_rows_right, slice), parallel)
117
}
118
} else {
119
let iter = (0..n_rows_left).map(|_| other);
120
concat_df_unchecked(iter)
121
}
122
};
123
let (l_df, r_df) = if parallel {
124
try_raise_keyboard_interrupt();
125
POOL.install(|| rayon::join(create_left_df, create_right_df))
126
} else {
127
(create_left_df(), create_right_df())
128
};
129
if left_is_primary {
130
Ok((l_df, r_df))
131
} else {
132
Ok((r_df, l_df))
133
}
134
}
135
136
pub(super) fn fused_cross_filter(
137
left: &DataFrame,
138
right: &DataFrame,
139
suffix: Option<PlSmallStr>,
140
cross_join_options: &CrossJoinOptions,
141
maintain_order: MaintainOrderJoin,
142
) -> PolarsResult<DataFrame> {
143
let unfiltered_size = (left.height() as u64).saturating_mul(right.height() as u64);
144
let chunk_size = (unfiltered_size / _set_partition_size() as u64).clamp(1, 100_000);
145
let num_chunks = (unfiltered_size / chunk_size).max(1) as usize;
146
147
let left_is_primary = match maintain_order {
148
MaintainOrderJoin::None => true,
149
MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight => true,
150
MaintainOrderJoin::Right | MaintainOrderJoin::RightLeft => false,
151
};
152
153
let split_chunks;
154
let cartesian_prod = if left_is_primary {
155
split_chunks = split(left, num_chunks);
156
split_chunks.iter().map(|l| (l, right)).collect::<Vec<_>>()
157
} else {
158
split_chunks = split(right, num_chunks);
159
split_chunks.iter().map(|r| (left, r)).collect::<Vec<_>>()
160
};
161
162
let names = _finish_join(left.clear(), right.clear(), suffix)?;
163
let rename_names = names.get_column_names();
164
let rename_names = &rename_names[left.width()..];
165
166
let dfs = POOL
167
.install(|| {
168
cartesian_prod.par_iter().map(|(left, right)| {
169
let (mut left, right) = cross_join_dfs(left, right, None, false, maintain_order)?;
170
let mut right_columns = right.into_columns();
171
172
for (c, name) in right_columns.iter_mut().zip(rename_names) {
173
c.rename((*name).clone());
174
}
175
176
unsafe { left.hstack_mut_unchecked(&right_columns) };
177
178
cross_join_options.predicate.apply(left)
179
})
180
})
181
.collect::<PolarsResult<Vec<_>>>()?;
182
183
Ok(accumulate_dataframes_vertical_unchecked(dfs))
184
}
185
186