Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/frame/join/args.rs
8446 views
1
use super::*;
2
3
pub(super) type JoinIds = Vec<IdxSize>;
4
pub type LeftJoinIds = (ChunkJoinIds, ChunkJoinOptIds);
5
pub type InnerJoinIds = (JoinIds, JoinIds);
6
7
#[cfg(feature = "chunked_ids")]
8
pub(super) type ChunkJoinIds = Either<Vec<IdxSize>, Vec<ChunkId>>;
9
#[cfg(feature = "chunked_ids")]
10
pub type ChunkJoinOptIds = Either<Vec<NullableIdxSize>, Vec<ChunkId>>;
11
12
#[cfg(not(feature = "chunked_ids"))]
13
pub type ChunkJoinOptIds = Vec<NullableIdxSize>;
14
15
#[cfg(not(feature = "chunked_ids"))]
16
pub type ChunkJoinIds = Vec<IdxSize>;
17
18
#[cfg(feature = "serde")]
19
use serde::{Deserialize, Serialize};
20
use strum_macros::IntoStaticStr;
21
22
/// Parameters for which side to use as the build side in a join. Currently only
23
/// respected by the streaming engine.
24
#[derive(Clone, PartialEq, Debug, Hash)]
25
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
26
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
27
pub enum JoinBuildSide {
28
/// Unless there's a very good reason to believe that the right side is
29
/// smaller, use the left side.
30
PreferLeft,
31
/// Regardless of other heuristics, use the left side as build side.
32
ForceLeft,
33
34
// Similar to above.
35
PreferRight,
36
ForceRight,
37
}
38
39
#[derive(Clone, PartialEq, Debug, Hash, Default)]
40
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
41
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
42
pub struct JoinArgs {
43
pub how: JoinType,
44
pub validation: JoinValidation,
45
pub suffix: Option<PlSmallStr>,
46
pub slice: Option<(i64, usize)>,
47
pub nulls_equal: bool,
48
pub coalesce: JoinCoalesce,
49
pub maintain_order: MaintainOrderJoin,
50
pub build_side: Option<JoinBuildSide>,
51
}
52
53
impl JoinArgs {
54
pub fn should_coalesce(&self) -> bool {
55
self.coalesce.coalesce(&self.how)
56
}
57
}
58
59
#[derive(Clone, PartialEq, Hash, Default, IntoStaticStr)]
60
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
61
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
62
pub enum JoinType {
63
#[default]
64
Inner,
65
Left,
66
Right,
67
Full,
68
// Box is okay because this is inside a `Arc<JoinOptionsIR>`
69
#[cfg(feature = "asof_join")]
70
AsOf(Box<AsOfOptions>),
71
#[cfg(feature = "semi_anti_join")]
72
Semi,
73
#[cfg(feature = "semi_anti_join")]
74
Anti,
75
#[cfg(feature = "iejoin")]
76
// Options are set by optimizer/planner in Options
77
IEJoin,
78
// Options are set by optimizer/planner in Options
79
Cross,
80
}
81
82
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default)]
83
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
84
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
85
pub enum JoinCoalesce {
86
#[default]
87
JoinSpecific,
88
CoalesceColumns,
89
KeepColumns,
90
}
91
92
impl JoinCoalesce {
93
pub fn coalesce(&self, join_type: &JoinType) -> bool {
94
use JoinCoalesce::*;
95
use JoinType::*;
96
match join_type {
97
Left | Inner | Right => {
98
matches!(self, JoinSpecific | CoalesceColumns)
99
},
100
Full => {
101
matches!(self, CoalesceColumns)
102
},
103
#[cfg(feature = "asof_join")]
104
AsOf(_) => matches!(self, JoinSpecific | CoalesceColumns),
105
#[cfg(feature = "iejoin")]
106
IEJoin => false,
107
Cross => false,
108
#[cfg(feature = "semi_anti_join")]
109
Semi | Anti => false,
110
}
111
}
112
}
113
114
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default, IntoStaticStr)]
115
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
116
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
117
#[strum(serialize_all = "snake_case")]
118
pub enum MaintainOrderJoin {
119
#[default]
120
None,
121
Left,
122
Right,
123
LeftRight,
124
RightLeft,
125
}
126
127
impl MaintainOrderJoin {
128
pub(super) fn flip(&self) -> Self {
129
match self {
130
MaintainOrderJoin::None => MaintainOrderJoin::None,
131
MaintainOrderJoin::Left => MaintainOrderJoin::Right,
132
MaintainOrderJoin::Right => MaintainOrderJoin::Left,
133
MaintainOrderJoin::LeftRight => MaintainOrderJoin::RightLeft,
134
MaintainOrderJoin::RightLeft => MaintainOrderJoin::LeftRight,
135
}
136
}
137
}
138
139
impl JoinArgs {
140
pub fn new(how: JoinType) -> Self {
141
Self {
142
how,
143
validation: Default::default(),
144
suffix: None,
145
slice: None,
146
nulls_equal: false,
147
coalesce: Default::default(),
148
maintain_order: Default::default(),
149
build_side: None,
150
}
151
}
152
153
pub fn with_coalesce(mut self, coalesce: JoinCoalesce) -> Self {
154
self.coalesce = coalesce;
155
self
156
}
157
158
pub fn with_suffix(mut self, suffix: Option<PlSmallStr>) -> Self {
159
self.suffix = suffix;
160
self
161
}
162
163
pub fn with_build_side(mut self, build_side: Option<JoinBuildSide>) -> Self {
164
self.build_side = build_side;
165
self
166
}
167
168
pub fn suffix(&self) -> &PlSmallStr {
169
const DEFAULT: &PlSmallStr = &PlSmallStr::from_static("_right");
170
self.suffix.as_ref().unwrap_or(DEFAULT)
171
}
172
}
173
174
impl From<JoinType> for JoinArgs {
175
fn from(value: JoinType) -> Self {
176
JoinArgs::new(value)
177
}
178
}
179
180
pub trait CrossJoinFilter: Send + Sync {
181
fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame>;
182
}
183
184
impl<T> CrossJoinFilter for T
185
where
186
T: Fn(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
187
{
188
fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame> {
189
self(df)
190
}
191
}
192
193
#[derive(Clone)]
194
pub struct CrossJoinOptions {
195
pub predicate: Arc<dyn CrossJoinFilter>,
196
}
197
198
impl CrossJoinOptions {
199
fn as_ptr_ref(&self) -> *const dyn CrossJoinFilter {
200
Arc::as_ptr(&self.predicate)
201
}
202
}
203
204
impl Eq for CrossJoinOptions {}
205
206
impl PartialEq for CrossJoinOptions {
207
fn eq(&self, other: &Self) -> bool {
208
std::ptr::addr_eq(self.as_ptr_ref(), other.as_ptr_ref())
209
}
210
}
211
212
impl Hash for CrossJoinOptions {
213
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
214
self.as_ptr_ref().hash(state);
215
}
216
}
217
218
impl Debug for CrossJoinOptions {
219
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
220
write!(f, "CrossJoinOptions",)
221
}
222
}
223
224
#[derive(Clone, PartialEq, Eq, Hash, IntoStaticStr, Debug)]
225
#[strum(serialize_all = "snake_case")]
226
pub enum JoinTypeOptions {
227
#[cfg(feature = "iejoin")]
228
IEJoin(IEJoinOptions),
229
Cross(CrossJoinOptions),
230
}
231
232
impl Display for JoinType {
233
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
234
use JoinType::*;
235
let val = match self {
236
Left => "LEFT",
237
Right => "RIGHT",
238
Inner => "INNER",
239
Full => "FULL",
240
#[cfg(feature = "asof_join")]
241
AsOf(_) => "ASOF",
242
#[cfg(feature = "iejoin")]
243
IEJoin => "IEJOIN",
244
Cross => "CROSS",
245
#[cfg(feature = "semi_anti_join")]
246
Semi => "SEMI",
247
#[cfg(feature = "semi_anti_join")]
248
Anti => "ANTI",
249
};
250
write!(f, "{val}")
251
}
252
}
253
254
impl Debug for JoinType {
255
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
256
write!(f, "{self}")
257
}
258
}
259
260
impl JoinType {
261
pub fn is_equi(&self) -> bool {
262
matches!(
263
self,
264
JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full
265
)
266
}
267
268
pub fn is_semi_anti(&self) -> bool {
269
#[cfg(feature = "semi_anti_join")]
270
{
271
matches!(self, JoinType::Semi | JoinType::Anti)
272
}
273
#[cfg(not(feature = "semi_anti_join"))]
274
{
275
false
276
}
277
}
278
279
pub fn is_semi(&self) -> bool {
280
#[cfg(feature = "semi_anti_join")]
281
{
282
matches!(self, JoinType::Semi)
283
}
284
#[cfg(not(feature = "semi_anti_join"))]
285
{
286
false
287
}
288
}
289
290
pub fn is_anti(&self) -> bool {
291
#[cfg(feature = "semi_anti_join")]
292
{
293
matches!(self, JoinType::Anti)
294
}
295
#[cfg(not(feature = "semi_anti_join"))]
296
{
297
false
298
}
299
}
300
301
pub fn is_asof(&self) -> bool {
302
#[cfg(feature = "asof_join")]
303
{
304
matches!(self, JoinType::AsOf(_))
305
}
306
#[cfg(not(feature = "asof_join"))]
307
{
308
false
309
}
310
}
311
312
pub fn is_cross(&self) -> bool {
313
matches!(self, JoinType::Cross)
314
}
315
316
pub fn is_ie(&self) -> bool {
317
#[cfg(feature = "iejoin")]
318
{
319
matches!(self, JoinType::IEJoin)
320
}
321
#[cfg(not(feature = "iejoin"))]
322
{
323
false
324
}
325
}
326
}
327
328
#[derive(Copy, Clone, PartialEq, Eq, Default, Hash)]
329
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
330
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
331
pub enum JoinValidation {
332
/// No unique checks
333
#[default]
334
ManyToMany,
335
/// Check if join keys are unique in right dataset.
336
ManyToOne,
337
/// Check if join keys are unique in left dataset.
338
OneToMany,
339
/// Check if join keys are unique in both left and right datasets
340
OneToOne,
341
}
342
343
impl JoinValidation {
344
pub fn needs_checks(&self) -> bool {
345
!matches!(self, JoinValidation::ManyToMany)
346
}
347
348
fn swap(self, swap: bool) -> Self {
349
use JoinValidation::*;
350
if swap {
351
match self {
352
ManyToMany => ManyToMany,
353
ManyToOne => OneToMany,
354
OneToMany => ManyToOne,
355
OneToOne => OneToOne,
356
}
357
} else {
358
self
359
}
360
}
361
362
pub fn is_valid_join(&self, join_type: &JoinType) -> PolarsResult<()> {
363
if !self.needs_checks() {
364
return Ok(());
365
}
366
polars_ensure!(matches!(join_type, JoinType::Inner | JoinType::Full | JoinType::Left),
367
ComputeError: "{self} validation on a {join_type} join is not supported");
368
Ok(())
369
}
370
371
pub(super) fn validate_probe(
372
&self,
373
s_left: &Series,
374
s_right: &Series,
375
build_shortest_table: bool,
376
nulls_equal: bool,
377
) -> PolarsResult<()> {
378
// In default, probe is the left series.
379
//
380
// In inner join and outer join, the shortest relation will be used to create a hash table.
381
// In left join, always use the right side to create.
382
//
383
// If `build_shortest_table` and left is shorter, swap. Then rhs will be the probe.
384
// If left == right, swap too. (apply the same logic as `det_hash_prone_order`)
385
let should_swap = build_shortest_table && s_left.len() <= s_right.len();
386
let probe = if should_swap { s_right } else { s_left };
387
388
use JoinValidation::*;
389
let valid = match self.swap(should_swap) {
390
// Only check the `build` side.
391
// The other side use `validate_build` to check
392
ManyToMany | ManyToOne => true,
393
OneToMany | OneToOne => {
394
if !nulls_equal && probe.null_count() > 0 {
395
probe.n_unique()? - 1 == probe.len() - probe.null_count()
396
} else {
397
probe.n_unique()? == probe.len()
398
}
399
},
400
};
401
polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
402
Ok(())
403
}
404
405
pub(super) fn validate_build(
406
&self,
407
build_size: usize,
408
expected_size: usize,
409
swapped: bool,
410
) -> PolarsResult<()> {
411
use JoinValidation::*;
412
413
// In default, build is in rhs.
414
let valid = match self.swap(swapped) {
415
// Only check the `build` side.
416
// The other side use `validate_prone` to check
417
ManyToMany | OneToMany => true,
418
ManyToOne | OneToOne => build_size == expected_size,
419
};
420
polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
421
Ok(())
422
}
423
}
424
425
impl Display for JoinValidation {
426
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
427
let s = match self {
428
JoinValidation::ManyToMany => "m:m",
429
JoinValidation::ManyToOne => "m:1",
430
JoinValidation::OneToMany => "1:m",
431
JoinValidation::OneToOne => "1:1",
432
};
433
write!(f, "{s}")
434
}
435
}
436
437
impl Debug for JoinValidation {
438
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
439
write!(f, "JoinValidation: {self}")
440
}
441
}
442
443