Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/frame/join/args.rs
6940 views
1
use super::*;
2
3
pub(super) type JoinIds = Vec<IdxSize>;
4
pub type LeftJoinIds = (ChunkJoinIds, ChunkJoinOptIds);
5
pub type InnerJoinIds = (JoinIds, JoinIds);
6
7
#[cfg(feature = "chunked_ids")]
8
pub(super) type ChunkJoinIds = Either<Vec<IdxSize>, Vec<ChunkId>>;
9
#[cfg(feature = "chunked_ids")]
10
pub type ChunkJoinOptIds = Either<Vec<NullableIdxSize>, Vec<ChunkId>>;
11
12
#[cfg(not(feature = "chunked_ids"))]
13
pub type ChunkJoinOptIds = Vec<NullableIdxSize>;
14
15
#[cfg(not(feature = "chunked_ids"))]
16
pub type ChunkJoinIds = Vec<IdxSize>;
17
18
#[cfg(feature = "serde")]
19
use serde::{Deserialize, Serialize};
20
use strum_macros::IntoStaticStr;
21
22
#[derive(Clone, PartialEq, Debug, Hash, Default)]
23
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
24
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
25
pub struct JoinArgs {
26
pub how: JoinType,
27
pub validation: JoinValidation,
28
pub suffix: Option<PlSmallStr>,
29
pub slice: Option<(i64, usize)>,
30
pub nulls_equal: bool,
31
pub coalesce: JoinCoalesce,
32
pub maintain_order: MaintainOrderJoin,
33
}
34
35
impl JoinArgs {
36
pub fn should_coalesce(&self) -> bool {
37
self.coalesce.coalesce(&self.how)
38
}
39
}
40
41
#[derive(Clone, PartialEq, Hash, Default, IntoStaticStr)]
42
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
43
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
44
pub enum JoinType {
45
#[default]
46
Inner,
47
Left,
48
Right,
49
Full,
50
// Box is okay because this is inside a `Arc<JoinOptionsIR>`
51
#[cfg(feature = "asof_join")]
52
AsOf(Box<AsOfOptions>),
53
#[cfg(feature = "semi_anti_join")]
54
Semi,
55
#[cfg(feature = "semi_anti_join")]
56
Anti,
57
#[cfg(feature = "iejoin")]
58
// Options are set by optimizer/planner in Options
59
IEJoin,
60
// Options are set by optimizer/planner in Options
61
Cross,
62
}
63
64
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default)]
65
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
66
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
67
pub enum JoinCoalesce {
68
#[default]
69
JoinSpecific,
70
CoalesceColumns,
71
KeepColumns,
72
}
73
74
impl JoinCoalesce {
75
pub fn coalesce(&self, join_type: &JoinType) -> bool {
76
use JoinCoalesce::*;
77
use JoinType::*;
78
match join_type {
79
Left | Inner | Right => {
80
matches!(self, JoinSpecific | CoalesceColumns)
81
},
82
Full => {
83
matches!(self, CoalesceColumns)
84
},
85
#[cfg(feature = "asof_join")]
86
AsOf(_) => matches!(self, JoinSpecific | CoalesceColumns),
87
#[cfg(feature = "iejoin")]
88
IEJoin => false,
89
Cross => false,
90
#[cfg(feature = "semi_anti_join")]
91
Semi | Anti => false,
92
}
93
}
94
}
95
96
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default, IntoStaticStr)]
97
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
98
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
99
#[strum(serialize_all = "snake_case")]
100
pub enum MaintainOrderJoin {
101
#[default]
102
None,
103
Left,
104
Right,
105
LeftRight,
106
RightLeft,
107
}
108
109
impl MaintainOrderJoin {
110
pub(super) fn flip(&self) -> Self {
111
match self {
112
MaintainOrderJoin::None => MaintainOrderJoin::None,
113
MaintainOrderJoin::Left => MaintainOrderJoin::Right,
114
MaintainOrderJoin::Right => MaintainOrderJoin::Left,
115
MaintainOrderJoin::LeftRight => MaintainOrderJoin::RightLeft,
116
MaintainOrderJoin::RightLeft => MaintainOrderJoin::LeftRight,
117
}
118
}
119
}
120
121
impl JoinArgs {
122
pub fn new(how: JoinType) -> Self {
123
Self {
124
how,
125
validation: Default::default(),
126
suffix: None,
127
slice: None,
128
nulls_equal: false,
129
coalesce: Default::default(),
130
maintain_order: Default::default(),
131
}
132
}
133
134
pub fn with_coalesce(mut self, coalesce: JoinCoalesce) -> Self {
135
self.coalesce = coalesce;
136
self
137
}
138
139
pub fn with_suffix(mut self, suffix: Option<PlSmallStr>) -> Self {
140
self.suffix = suffix;
141
self
142
}
143
144
pub fn suffix(&self) -> &PlSmallStr {
145
const DEFAULT: &PlSmallStr = &PlSmallStr::from_static("_right");
146
self.suffix.as_ref().unwrap_or(DEFAULT)
147
}
148
}
149
150
impl From<JoinType> for JoinArgs {
151
fn from(value: JoinType) -> Self {
152
JoinArgs::new(value)
153
}
154
}
155
156
pub trait CrossJoinFilter: Send + Sync {
157
fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame>;
158
}
159
160
impl<T> CrossJoinFilter for T
161
where
162
T: Fn(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
163
{
164
fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame> {
165
self(df)
166
}
167
}
168
169
#[derive(Clone)]
170
pub struct CrossJoinOptions {
171
pub predicate: Arc<dyn CrossJoinFilter>,
172
}
173
174
impl CrossJoinOptions {
175
fn as_ptr_ref(&self) -> *const dyn CrossJoinFilter {
176
Arc::as_ptr(&self.predicate)
177
}
178
}
179
180
impl Eq for CrossJoinOptions {}
181
182
impl PartialEq for CrossJoinOptions {
183
fn eq(&self, other: &Self) -> bool {
184
std::ptr::addr_eq(self.as_ptr_ref(), other.as_ptr_ref())
185
}
186
}
187
188
impl Hash for CrossJoinOptions {
189
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
190
self.as_ptr_ref().hash(state);
191
}
192
}
193
194
impl Debug for CrossJoinOptions {
195
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
196
write!(f, "CrossJoinOptions",)
197
}
198
}
199
200
#[derive(Clone, PartialEq, Eq, Hash, IntoStaticStr, Debug)]
201
#[strum(serialize_all = "snake_case")]
202
pub enum JoinTypeOptions {
203
#[cfg(feature = "iejoin")]
204
IEJoin(IEJoinOptions),
205
Cross(CrossJoinOptions),
206
}
207
208
impl Display for JoinType {
209
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
210
use JoinType::*;
211
let val = match self {
212
Left => "LEFT",
213
Right => "RIGHT",
214
Inner => "INNER",
215
Full => "FULL",
216
#[cfg(feature = "asof_join")]
217
AsOf(_) => "ASOF",
218
#[cfg(feature = "iejoin")]
219
IEJoin => "IEJOIN",
220
Cross => "CROSS",
221
#[cfg(feature = "semi_anti_join")]
222
Semi => "SEMI",
223
#[cfg(feature = "semi_anti_join")]
224
Anti => "ANTI",
225
};
226
write!(f, "{val}")
227
}
228
}
229
230
impl Debug for JoinType {
231
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
232
write!(f, "{self}")
233
}
234
}
235
236
impl JoinType {
237
pub fn is_equi(&self) -> bool {
238
matches!(
239
self,
240
JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full
241
)
242
}
243
244
pub fn is_semi_anti(&self) -> bool {
245
#[cfg(feature = "semi_anti_join")]
246
{
247
matches!(self, JoinType::Semi | JoinType::Anti)
248
}
249
#[cfg(not(feature = "semi_anti_join"))]
250
{
251
false
252
}
253
}
254
255
pub fn is_semi(&self) -> bool {
256
#[cfg(feature = "semi_anti_join")]
257
{
258
matches!(self, JoinType::Semi)
259
}
260
#[cfg(not(feature = "semi_anti_join"))]
261
{
262
false
263
}
264
}
265
266
pub fn is_anti(&self) -> bool {
267
#[cfg(feature = "semi_anti_join")]
268
{
269
matches!(self, JoinType::Anti)
270
}
271
#[cfg(not(feature = "semi_anti_join"))]
272
{
273
false
274
}
275
}
276
277
pub fn is_asof(&self) -> bool {
278
#[cfg(feature = "asof_join")]
279
{
280
matches!(self, JoinType::AsOf(_))
281
}
282
#[cfg(not(feature = "asof_join"))]
283
{
284
false
285
}
286
}
287
288
pub fn is_cross(&self) -> bool {
289
matches!(self, JoinType::Cross)
290
}
291
292
pub fn is_ie(&self) -> bool {
293
#[cfg(feature = "iejoin")]
294
{
295
matches!(self, JoinType::IEJoin)
296
}
297
#[cfg(not(feature = "iejoin"))]
298
{
299
false
300
}
301
}
302
}
303
304
#[derive(Copy, Clone, PartialEq, Eq, Default, Hash)]
305
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
306
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
307
pub enum JoinValidation {
308
/// No unique checks
309
#[default]
310
ManyToMany,
311
/// Check if join keys are unique in right dataset.
312
ManyToOne,
313
/// Check if join keys are unique in left dataset.
314
OneToMany,
315
/// Check if join keys are unique in both left and right datasets
316
OneToOne,
317
}
318
319
impl JoinValidation {
320
pub fn needs_checks(&self) -> bool {
321
!matches!(self, JoinValidation::ManyToMany)
322
}
323
324
fn swap(self, swap: bool) -> Self {
325
use JoinValidation::*;
326
if swap {
327
match self {
328
ManyToMany => ManyToMany,
329
ManyToOne => OneToMany,
330
OneToMany => ManyToOne,
331
OneToOne => OneToOne,
332
}
333
} else {
334
self
335
}
336
}
337
338
pub fn is_valid_join(&self, join_type: &JoinType) -> PolarsResult<()> {
339
if !self.needs_checks() {
340
return Ok(());
341
}
342
polars_ensure!(matches!(join_type, JoinType::Inner | JoinType::Full | JoinType::Left),
343
ComputeError: "{self} validation on a {join_type} join is not supported");
344
Ok(())
345
}
346
347
pub(super) fn validate_probe(
348
&self,
349
s_left: &Series,
350
s_right: &Series,
351
build_shortest_table: bool,
352
nulls_equal: bool,
353
) -> PolarsResult<()> {
354
// In default, probe is the left series.
355
//
356
// In inner join and outer join, the shortest relation will be used to create a hash table.
357
// In left join, always use the right side to create.
358
//
359
// If `build_shortest_table` and left is shorter, swap. Then rhs will be the probe.
360
// If left == right, swap too. (apply the same logic as `det_hash_prone_order`)
361
let should_swap = build_shortest_table && s_left.len() <= s_right.len();
362
let probe = if should_swap { s_right } else { s_left };
363
364
use JoinValidation::*;
365
let valid = match self.swap(should_swap) {
366
// Only check the `build` side.
367
// The other side use `validate_build` to check
368
ManyToMany | ManyToOne => true,
369
OneToMany | OneToOne => {
370
if !nulls_equal && probe.null_count() > 0 {
371
probe.n_unique()? - 1 == probe.len() - probe.null_count()
372
} else {
373
probe.n_unique()? == probe.len()
374
}
375
},
376
};
377
polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
378
Ok(())
379
}
380
381
pub(super) fn validate_build(
382
&self,
383
build_size: usize,
384
expected_size: usize,
385
swapped: bool,
386
) -> PolarsResult<()> {
387
use JoinValidation::*;
388
389
// In default, build is in rhs.
390
let valid = match self.swap(swapped) {
391
// Only check the `build` side.
392
// The other side use `validate_prone` to check
393
ManyToMany | OneToMany => true,
394
ManyToOne | OneToOne => build_size == expected_size,
395
};
396
polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
397
Ok(())
398
}
399
}
400
401
impl Display for JoinValidation {
402
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
403
let s = match self {
404
JoinValidation::ManyToMany => "m:m",
405
JoinValidation::ManyToOne => "m:1",
406
JoinValidation::OneToMany => "1:m",
407
JoinValidation::OneToOne => "1:1",
408
};
409
write!(f, "{s}")
410
}
411
}
412
413
impl Debug for JoinValidation {
414
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
415
write!(f, "JoinValidation: {self}")
416
}
417
}
418
419