Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/frame/join/cross_join.rs
6940 views
1
use polars_core::utils::{
2
_set_partition_size, CustomIterTools, NoNull, accumulate_dataframes_vertical_unchecked,
3
concat_df_unchecked, split,
4
};
5
use polars_utils::pl_str::PlSmallStr;
6
7
use super::*;
8
9
fn slice_take(
10
total_rows: IdxSize,
11
n_rows_right: IdxSize,
12
slice: Option<(i64, usize)>,
13
inner: fn(IdxSize, IdxSize, IdxSize) -> IdxCa,
14
) -> IdxCa {
15
match slice {
16
None => inner(0, total_rows, n_rows_right),
17
Some((offset, len)) => {
18
let (offset, len) = slice_offsets(offset, len, total_rows as usize);
19
inner(offset as IdxSize, (len + offset) as IdxSize, n_rows_right)
20
},
21
}
22
}
23
24
fn take_left(total_rows: IdxSize, n_rows_right: IdxSize, slice: Option<(i64, usize)>) -> IdxCa {
25
fn inner(offset: IdxSize, total_rows: IdxSize, n_rows_right: IdxSize) -> IdxCa {
26
let mut take: NoNull<IdxCa> = (offset..total_rows)
27
.map(|i| i / n_rows_right)
28
.collect_trusted();
29
take.set_sorted_flag(IsSorted::Ascending);
30
take.into_inner()
31
}
32
slice_take(total_rows, n_rows_right, slice, inner)
33
}
34
35
fn take_right(total_rows: IdxSize, n_rows_right: IdxSize, slice: Option<(i64, usize)>) -> IdxCa {
36
fn inner(offset: IdxSize, total_rows: IdxSize, n_rows_right: IdxSize) -> IdxCa {
37
let take: NoNull<IdxCa> = (offset..total_rows)
38
.map(|i| i % n_rows_right)
39
.collect_trusted();
40
take.into_inner()
41
}
42
slice_take(total_rows, n_rows_right, slice, inner)
43
}
44
45
pub trait CrossJoin: IntoDf {
46
#[doc(hidden)]
47
/// used by streaming
48
fn _cross_join_with_names(
49
&self,
50
other: &DataFrame,
51
names: &[PlSmallStr],
52
) -> PolarsResult<DataFrame> {
53
let (mut l_df, r_df) = cross_join_dfs(self.to_df(), other, None, false)?;
54
l_df.clear_schema();
55
56
unsafe {
57
l_df.get_columns_mut().extend_from_slice(r_df.get_columns());
58
59
l_df.get_columns_mut()
60
.iter_mut()
61
.zip(names)
62
.for_each(|(s, name)| {
63
if s.name() != name {
64
s.rename(name.clone());
65
}
66
});
67
}
68
Ok(l_df)
69
}
70
71
/// Creates the Cartesian product from both frames, preserves the order of the left keys.
72
fn cross_join(
73
&self,
74
other: &DataFrame,
75
suffix: Option<PlSmallStr>,
76
slice: Option<(i64, usize)>,
77
) -> PolarsResult<DataFrame> {
78
let (l_df, r_df) = cross_join_dfs(self.to_df(), other, slice, true)?;
79
80
_finish_join(l_df, r_df, suffix)
81
}
82
}
83
84
impl CrossJoin for DataFrame {}
85
86
fn cross_join_dfs(
87
df_self: &DataFrame,
88
other: &DataFrame,
89
slice: Option<(i64, usize)>,
90
parallel: bool,
91
) -> PolarsResult<(DataFrame, DataFrame)> {
92
let n_rows_left = df_self.height() as IdxSize;
93
let n_rows_right = other.height() as IdxSize;
94
let Some(total_rows) = n_rows_left.checked_mul(n_rows_right) else {
95
polars_bail!(
96
ComputeError: "cross joins would produce more rows than fits into 2^32; \
97
consider compiling with polars-big-idx feature, or set 'streaming'"
98
);
99
};
100
if n_rows_left == 0 || n_rows_right == 0 {
101
return Ok((df_self.clear(), other.clear()));
102
}
103
104
// the left side has the Nth row combined with every row from right.
105
// So let's say we have the following no. of rows
106
// left: 3
107
// right: 4
108
//
109
// left take idx: 000011112222
110
// right take idx: 012301230123
111
112
let create_left_df = || {
113
// SAFETY:
114
// take left is in bounds
115
unsafe {
116
df_self.take_unchecked_impl(&take_left(total_rows, n_rows_right, slice), parallel)
117
}
118
};
119
120
let create_right_df = || {
121
// concatenation of dataframes is very expensive if we need to make the series mutable
122
// many times, these are atomic operations
123
// so we choose a different strategy at > 100 rows (arbitrarily small number)
124
if n_rows_left > 100 || slice.is_some() {
125
// SAFETY:
126
// take right is in bounds
127
unsafe {
128
other.take_unchecked_impl(&take_right(total_rows, n_rows_right, slice), parallel)
129
}
130
} else {
131
let iter = (0..n_rows_left).map(|_| other);
132
concat_df_unchecked(iter)
133
}
134
};
135
let (l_df, r_df) = if parallel {
136
try_raise_keyboard_interrupt();
137
POOL.install(|| rayon::join(create_left_df, create_right_df))
138
} else {
139
(create_left_df(), create_right_df())
140
};
141
Ok((l_df, r_df))
142
}
143
144
pub(super) fn fused_cross_filter(
145
left: &DataFrame,
146
right: &DataFrame,
147
suffix: Option<PlSmallStr>,
148
cross_join_options: &CrossJoinOptions,
149
) -> PolarsResult<DataFrame> {
150
// Because we do a cartesian product, the number of partitions is squared.
151
// We take the sqrt, but we don't expect every partition to produce results and work can be
152
// imbalanced, so we multiply the number of partitions by 2;
153
let n_partitions = (_set_partition_size() as f32).sqrt() as usize * 2;
154
let splitted_a = split(left, n_partitions);
155
let splitted_b = split(right, n_partitions);
156
157
let cartesian_prod = splitted_a
158
.iter()
159
.flat_map(|l| splitted_b.iter().map(move |r| (l, r)))
160
.collect::<Vec<_>>();
161
162
let names = _finish_join(left.clear(), right.clear(), suffix)?;
163
let rename_names = names.get_column_names();
164
let rename_names = &rename_names[left.width()..];
165
166
let dfs = POOL
167
.install(|| {
168
cartesian_prod.par_iter().map(|(left, right)| {
169
let (mut left, right) = cross_join_dfs(left, right, None, false)?;
170
let mut right_columns = right.take_columns();
171
172
for (c, name) in right_columns.iter_mut().zip(rename_names) {
173
c.rename((*name).clone());
174
}
175
176
unsafe { left.hstack_mut_unchecked(&right_columns) };
177
178
cross_join_options.predicate.apply(left)
179
})
180
})
181
.collect::<PolarsResult<Vec<_>>>()?;
182
183
Ok(accumulate_dataframes_vertical_unchecked(dfs))
184
}
185
186