Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_join.py
6939 views
1
from __future__ import annotations
2
3
import typing
4
import warnings
5
from datetime import date, datetime
6
from typing import TYPE_CHECKING, Any, Callable, Literal
7
8
import numpy as np
9
import pandas as pd
10
import pytest
11
12
import polars as pl
13
from polars.exceptions import (
14
ColumnNotFoundError,
15
ComputeError,
16
DuplicateError,
17
InvalidOperationError,
18
SchemaError,
19
)
20
from polars.testing import assert_frame_equal, assert_series_equal
21
from tests.unit.conftest import time_func
22
23
if TYPE_CHECKING:
24
from polars._typing import JoinStrategy, PolarsDataType
25
26
27
def test_semi_anti_join() -> None:
28
df_a = pl.DataFrame({"key": [1, 2, 3], "payload": ["f", "i", None]})
29
30
df_b = pl.DataFrame({"key": [3, 4, 5, None]})
31
32
assert df_a.join(df_b, on="key", how="anti").to_dict(as_series=False) == {
33
"key": [1, 2],
34
"payload": ["f", "i"],
35
}
36
assert df_a.join(df_b, on="key", how="semi").to_dict(as_series=False) == {
37
"key": [3],
38
"payload": [None],
39
}
40
41
# lazy
42
result = df_a.lazy().join(df_b.lazy(), on="key", how="anti").collect()
43
expected_values = {"key": [1, 2], "payload": ["f", "i"]}
44
assert result.to_dict(as_series=False) == expected_values
45
46
result = df_a.lazy().join(df_b.lazy(), on="key", how="semi").collect()
47
expected_values = {"key": [3], "payload": [None]}
48
assert result.to_dict(as_series=False) == expected_values
49
50
df_a = pl.DataFrame(
51
{"a": [1, 2, 3, 1], "b": ["a", "b", "c", "a"], "payload": [10, 20, 30, 40]}
52
)
53
54
df_b = pl.DataFrame({"a": [3, 3, 4, 5], "b": ["c", "c", "d", "e"]})
55
56
assert df_a.join(df_b, on=["a", "b"], how="anti").to_dict(as_series=False) == {
57
"a": [1, 2, 1],
58
"b": ["a", "b", "a"],
59
"payload": [10, 20, 40],
60
}
61
assert df_a.join(df_b, on=["a", "b"], how="semi").to_dict(as_series=False) == {
62
"a": [3],
63
"b": ["c"],
64
"payload": [30],
65
}
66
67
68
def test_join_same_cat_src() -> None:
69
df = pl.DataFrame(
70
data={"column": ["a", "a", "b"], "more": [1, 2, 3]},
71
schema=[("column", pl.Categorical), ("more", pl.Int32)],
72
)
73
df_agg = df.group_by("column").agg(pl.col("more").mean())
74
assert_frame_equal(
75
df.join(df_agg, on="column"),
76
pl.DataFrame(
77
{
78
"column": ["a", "a", "b"],
79
"more": [1, 2, 3],
80
"more_right": [1.5, 1.5, 3.0],
81
},
82
schema=[
83
("column", pl.Categorical),
84
("more", pl.Int32),
85
("more_right", pl.Float64),
86
],
87
),
88
check_row_order=False,
89
)
90
91
92
@pytest.mark.parametrize("reverse", [False, True])
93
def test_sorted_merge_joins(reverse: bool) -> None:
94
n = 30
95
df_a = pl.DataFrame({"a": np.sort(np.random.randint(0, n // 2, n))}).with_row_index(
96
"row_a"
97
)
98
df_b = pl.DataFrame(
99
{"a": np.sort(np.random.randint(0, n // 2, n // 2))}
100
).with_row_index("row_b")
101
102
if reverse:
103
df_a = df_a.select(pl.all().reverse())
104
df_b = df_b.select(pl.all().reverse())
105
106
join_strategies: list[JoinStrategy] = ["left", "inner"]
107
for cast_to in [int, str, float]:
108
for how in join_strategies:
109
df_a_ = df_a.with_columns(pl.col("a").cast(cast_to))
110
df_b_ = df_b.with_columns(pl.col("a").cast(cast_to))
111
112
# hash join
113
out_hash_join = df_a_.join(df_b_, on="a", how=how)
114
115
# sorted merge join
116
out_sorted_merge_join = df_a_.with_columns(
117
pl.col("a").set_sorted(descending=reverse)
118
).join(
119
df_b_.with_columns(pl.col("a").set_sorted(descending=reverse)),
120
on="a",
121
how=how,
122
)
123
124
assert_frame_equal(
125
out_hash_join, out_sorted_merge_join, check_row_order=False
126
)
127
128
129
def test_join_negative_integers() -> None:
130
expected = pl.DataFrame({"a": [-6, -1, 0], "b": [-6, -1, 0]})
131
df1 = pl.DataFrame(
132
{
133
"a": [-1, -6, -3, 0],
134
}
135
)
136
137
df2 = pl.DataFrame(
138
{
139
"a": [-6, -1, -4, -2, 0],
140
"b": [-6, -1, -4, -2, 0],
141
}
142
)
143
144
for dt in [pl.Int8, pl.Int16, pl.Int32, pl.Int64]:
145
assert_frame_equal(
146
df1.with_columns(pl.all().cast(dt)).join(
147
df2.with_columns(pl.all().cast(dt)), on="a", how="inner"
148
),
149
expected.select(pl.all().cast(dt)),
150
check_row_order=False,
151
)
152
153
154
def test_deprecated() -> None:
155
df = pl.DataFrame({"a": [1, 2], "b": [3, 4]})
156
other = pl.DataFrame({"a": [1, 2], "c": [3, 4]})
157
result = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [3, 4]})
158
159
np.testing.assert_equal(
160
df.join(other=other, on="a", maintain_order="left").to_numpy(),
161
result.to_numpy(),
162
)
163
np.testing.assert_equal(
164
df.lazy()
165
.join(other=other.lazy(), on="a", maintain_order="left")
166
.collect()
167
.to_numpy(),
168
result.to_numpy(),
169
)
170
171
172
def test_deprecated_parameter_join_nulls() -> None:
173
df = pl.DataFrame({"a": [1, None]})
174
with pytest.deprecated_call(
175
match=r"the argument `join_nulls` for `DataFrame.join` is deprecated. It was renamed to `nulls_equal`"
176
):
177
result = df.join(df, on="a", join_nulls=True) # type: ignore[call-arg]
178
assert_frame_equal(result, df, check_row_order=False)
179
180
181
def test_join_on_expressions() -> None:
182
df_a = pl.DataFrame({"a": [1, 2, 3]})
183
184
df_b = pl.DataFrame({"b": [1, 4, 9, 9, 0]})
185
186
assert_frame_equal(
187
df_a.join(df_b, left_on=(pl.col("a") ** 2).cast(int), right_on=pl.col("b")),
188
pl.DataFrame({"a": [1, 2, 3, 3], "b": [1, 4, 9, 9]}),
189
check_row_order=False,
190
)
191
192
193
def test_join_lazy_frame_on_expression() -> None:
194
# Tests a lazy frame projection pushdown bug
195
# https://github.com/pola-rs/polars/issues/19822
196
197
df = pl.DataFrame(data={"a": [0, 1], "b": [2, 3]})
198
199
lazy_join = (
200
df.lazy()
201
.join(df.lazy(), left_on=pl.coalesce("b", "a"), right_on="a")
202
.select("a")
203
.collect()
204
)
205
206
eager_join = df.join(df, left_on=pl.coalesce("b", "a"), right_on="a").select("a")
207
208
assert lazy_join.shape == eager_join.shape
209
210
211
def test_right_join_schema_maintained_22516() -> None:
212
df_left = pl.DataFrame({"number": [1]})
213
df_right = pl.DataFrame({"invoice_number": [1]})
214
eager_join = df_left.join(
215
df_right, left_on="number", right_on="invoice_number", how="right"
216
).select(pl.len())
217
218
lazy_join = (
219
df_left.lazy()
220
.join(df_right.lazy(), left_on="number", right_on="invoice_number", how="right")
221
.select(pl.len())
222
.collect()
223
)
224
225
assert lazy_join.item() == eager_join.item()
226
227
228
def test_join() -> None:
229
df_left = pl.DataFrame(
230
{
231
"a": ["a", "b", "a", "z"],
232
"b": [1, 2, 3, 4],
233
"c": [6, 5, 4, 3],
234
}
235
)
236
df_right = pl.DataFrame(
237
{
238
"a": ["b", "c", "b", "a"],
239
"k": [0, 3, 9, 6],
240
"c": [1, 0, 2, 1],
241
}
242
)
243
244
joined = df_left.join(
245
df_right, left_on="a", right_on="a", maintain_order="left_right"
246
).sort("a")
247
assert_series_equal(joined["b"], pl.Series("b", [1, 3, 2, 2]))
248
249
joined = df_left.join(
250
df_right, left_on="a", right_on="a", how="left", maintain_order="left_right"
251
).sort("a")
252
assert joined["c_right"].is_null().sum() == 1
253
assert_series_equal(joined["b"], pl.Series("b", [1, 3, 2, 2, 4]))
254
255
joined = df_left.join(df_right, left_on="a", right_on="a", how="full").sort("a")
256
assert joined["c_right"].null_count() == 1
257
assert joined["c"].null_count() == 1
258
assert joined["b"].null_count() == 1
259
assert joined["k"].null_count() == 1
260
assert joined["a"].null_count() == 1
261
262
# we need to pass in a column to join on, either by supplying `on`, or both
263
# `left_on` and `right_on`
264
with pytest.raises(ValueError):
265
df_left.join(df_right)
266
with pytest.raises(ValueError):
267
df_left.join(df_right, right_on="a")
268
with pytest.raises(ValueError):
269
df_left.join(df_right, left_on="a")
270
271
df_a = pl.DataFrame({"a": [1, 2, 1, 1], "b": ["a", "b", "c", "c"]})
272
df_b = pl.DataFrame(
273
{"foo": [1, 1, 1], "bar": ["a", "c", "c"], "ham": ["let", "var", "const"]}
274
)
275
276
# just check if join on multiple columns runs
277
df_a.join(df_b, left_on=["a", "b"], right_on=["foo", "bar"])
278
eager_join = df_a.join(df_b, left_on="a", right_on="foo")
279
lazy_join = df_a.lazy().join(df_b.lazy(), left_on="a", right_on="foo").collect()
280
281
cols = ["a", "b", "bar", "ham"]
282
assert lazy_join.shape == eager_join.shape
283
assert_frame_equal(lazy_join.sort(by=cols), eager_join.sort(by=cols))
284
285
286
def test_joins_dispatch() -> None:
287
# this just flexes the dispatch a bit
288
289
# don't change the data of this dataframe, this triggered:
290
# https://github.com/pola-rs/polars/issues/1688
291
dfa = pl.DataFrame(
292
{
293
"a": ["a", "b", "c", "a"],
294
"b": [1, 2, 3, 1],
295
"date": ["2021-01-01", "2021-01-02", "2021-01-03", "2021-01-01"],
296
"datetime": [13241324, 12341256, 12341234, 13241324],
297
}
298
).with_columns(
299
pl.col("date").str.strptime(pl.Date), pl.col("datetime").cast(pl.Datetime)
300
)
301
302
join_strategies: list[JoinStrategy] = ["left", "inner", "full"]
303
for how in join_strategies:
304
dfa.join(dfa, on=["a", "b", "date", "datetime"], how=how)
305
dfa.join(dfa, on=["date", "datetime"], how=how)
306
dfa.join(dfa, on=["date", "datetime", "a"], how=how)
307
dfa.join(dfa, on=["date", "a"], how=how)
308
dfa.join(dfa, on=["a", "datetime"], how=how)
309
dfa.join(dfa, on=["date"], how=how)
310
311
312
def test_join_on_cast() -> None:
313
df_a = (
314
pl.DataFrame({"a": [-5, -2, 3, 3, 9, 10]})
315
.with_row_index()
316
.with_columns(pl.col("a").cast(pl.Int32))
317
)
318
319
df_b = pl.DataFrame({"a": [-2, -3, 3, 10]})
320
321
assert_frame_equal(
322
df_a.join(df_b, on=pl.col("a").cast(pl.Int64)),
323
pl.DataFrame(
324
{
325
"index": [1, 2, 3, 5],
326
"a": [-2, 3, 3, 10],
327
"a_right": [-2, 3, 3, 10],
328
}
329
),
330
check_row_order=False,
331
check_dtypes=False,
332
)
333
assert df_a.lazy().join(
334
df_b.lazy(),
335
on=pl.col("a").cast(pl.Int64),
336
maintain_order="left",
337
).collect().to_dict(as_series=False) == {
338
"index": [1, 2, 3, 5],
339
"a": [-2, 3, 3, 10],
340
"a_right": [-2, 3, 3, 10],
341
}
342
343
344
def test_join_chunks_alignment_4720() -> None:
345
# https://github.com/pola-rs/polars/issues/4720
346
347
df1 = pl.DataFrame(
348
{
349
"index1": pl.arange(0, 2, eager=True),
350
"index2": pl.arange(10, 12, eager=True),
351
}
352
)
353
354
df2 = pl.DataFrame(
355
{
356
"index3": pl.arange(100, 102, eager=True),
357
}
358
)
359
360
df3 = pl.DataFrame(
361
{
362
"index1": pl.arange(0, 2, eager=True),
363
"index2": pl.arange(10, 12, eager=True),
364
"index3": pl.arange(100, 102, eager=True),
365
}
366
)
367
assert_frame_equal(
368
df1.join(df2, how="cross").join(
369
df3,
370
on=["index1", "index2", "index3"],
371
how="left",
372
),
373
pl.DataFrame(
374
{
375
"index1": [0, 0, 1, 1],
376
"index2": [10, 10, 11, 11],
377
"index3": [100, 101, 100, 101],
378
}
379
),
380
check_row_order=False,
381
)
382
383
assert_frame_equal(
384
df1.join(df2, how="cross").join(
385
df3,
386
on=["index3", "index1", "index2"],
387
how="left",
388
),
389
pl.DataFrame(
390
{
391
"index1": [0, 0, 1, 1],
392
"index2": [10, 10, 11, 11],
393
"index3": [100, 101, 100, 101],
394
}
395
),
396
check_row_order=False,
397
)
398
399
400
def test_jit_sort_joins() -> None:
401
n = 200
402
# Explicitly specify numpy dtype because of different defaults on Windows
403
dfa = pd.DataFrame(
404
{
405
"a": np.random.randint(0, 100, n, dtype=np.int64),
406
"b": np.arange(0, n, dtype=np.int64),
407
}
408
)
409
410
n = 40
411
dfb = pd.DataFrame(
412
{
413
"a": np.random.randint(0, 100, n, dtype=np.int64),
414
"b": np.arange(0, n, dtype=np.int64),
415
}
416
)
417
dfa_pl = pl.from_pandas(dfa).sort("a")
418
dfb_pl = pl.from_pandas(dfb)
419
420
join_strategies: list[Literal["left", "inner"]] = ["left", "inner"]
421
for how in join_strategies:
422
pd_result = dfa.merge(dfb, on="a", how=how)
423
pd_result.columns = pd.Index(["a", "b", "b_right"])
424
425
# left key sorted right is not
426
pl_result = dfa_pl.join(dfb_pl, on="a", how=how).sort(["a", "b", "b_right"])
427
428
a = (
429
pl.from_pandas(pd_result)
430
.with_columns(pl.all().cast(int))
431
.sort(["a", "b", "b_right"])
432
)
433
assert_frame_equal(a, pl_result)
434
assert pl_result["a"].flags["SORTED_ASC"]
435
436
# left key sorted right is not
437
pd_result = dfb.merge(dfa, on="a", how=how)
438
pd_result.columns = pd.Index(["a", "b", "b_right"])
439
pl_result = dfb_pl.join(dfa_pl, on="a", how=how).sort(["a", "b", "b_right"])
440
441
a = (
442
pl.from_pandas(pd_result)
443
.with_columns(pl.all().cast(int))
444
.sort(["a", "b", "b_right"])
445
)
446
assert_frame_equal(a, pl_result)
447
assert pl_result["a"].flags["SORTED_ASC"]
448
449
450
def test_join_panic_on_binary_expr_5915() -> None:
451
df_a = pl.DataFrame({"a": [1, 2, 3]}).lazy()
452
df_b = pl.DataFrame({"b": [1, 4, 9, 9, 0]}).lazy()
453
454
z = df_a.join(df_b, left_on=[(pl.col("a") + 1).cast(int)], right_on=[pl.col("b")])
455
assert z.collect().to_dict(as_series=False) == {"a": [3], "b": [4]}
456
457
458
def test_semi_join_projection_pushdown_6423() -> None:
459
df1 = pl.DataFrame({"x": [1]}).lazy()
460
df2 = pl.DataFrame({"y": [1], "x": [1]}).lazy()
461
462
assert (
463
df1.join(df2, left_on="x", right_on="y", how="semi")
464
.join(df2, left_on="x", right_on="y", how="semi")
465
.select(["x"])
466
).collect().to_dict(as_series=False) == {"x": [1]}
467
468
469
def test_semi_join_projection_pushdown_6455() -> None:
470
df = pl.DataFrame(
471
{
472
"id": [1, 1, 2],
473
"timestamp": [
474
datetime(2022, 12, 11),
475
datetime(2022, 12, 12),
476
datetime(2022, 1, 1),
477
],
478
"value": [1, 2, 4],
479
}
480
).lazy()
481
482
latest = df.group_by("id").agg(pl.col("timestamp").max())
483
df = df.join(latest, on=["id", "timestamp"], how="semi")
484
assert df.select(["id", "value"]).collect().to_dict(as_series=False) == {
485
"id": [1, 2],
486
"value": [2, 4],
487
}
488
489
490
def test_update() -> None:
491
df1 = pl.DataFrame(
492
{
493
"key1": [1, 2, 3, 4],
494
"key2": [1, 2, 3, 4],
495
"a": [1, 2, 3, 4],
496
"b": [1, 2, 3, 4],
497
"c": ["1", "2", "3", "4"],
498
"d": [
499
date(2023, 1, 1),
500
date(2023, 1, 2),
501
date(2023, 1, 3),
502
date(2023, 1, 4),
503
],
504
}
505
)
506
507
df2 = pl.DataFrame(
508
{
509
"key1": [1, 2, 3, 4],
510
"key2": [1, 2, 3, 5],
511
"a": [1, 1, 1, 1],
512
"b": [2, 2, 2, 2],
513
"c": ["3", "3", "3", "3"],
514
"d": [
515
date(2023, 5, 5),
516
date(2023, 5, 5),
517
date(2023, 5, 5),
518
date(2023, 5, 5),
519
],
520
}
521
)
522
523
# update only on key1
524
expected = pl.DataFrame(
525
{
526
"key1": [1, 2, 3, 4],
527
"key2": [1, 2, 3, 5],
528
"a": [1, 1, 1, 1],
529
"b": [2, 2, 2, 2],
530
"c": ["3", "3", "3", "3"],
531
"d": [
532
date(2023, 5, 5),
533
date(2023, 5, 5),
534
date(2023, 5, 5),
535
date(2023, 5, 5),
536
],
537
}
538
)
539
assert_frame_equal(df1.update(df2, on="key1"), expected)
540
541
# update on key1 using different left/right names
542
assert_frame_equal(
543
df1.update(
544
df2.rename({"key1": "key1b"}),
545
left_on="key1",
546
right_on="key1b",
547
),
548
expected,
549
)
550
551
# update on key1 and key2. This should fail to match the last item.
552
expected = pl.DataFrame(
553
{
554
"key1": [1, 2, 3, 4],
555
"key2": [1, 2, 3, 4],
556
"a": [1, 1, 1, 4],
557
"b": [2, 2, 2, 4],
558
"c": ["3", "3", "3", "4"],
559
"d": [
560
date(2023, 5, 5),
561
date(2023, 5, 5),
562
date(2023, 5, 5),
563
date(2023, 1, 4),
564
],
565
}
566
)
567
assert_frame_equal(df1.update(df2, on=["key1", "key2"]), expected)
568
569
# update on key1 and key2 using different left/right names
570
assert_frame_equal(
571
df1.update(
572
df2.rename({"key1": "key1b", "key2": "key2b"}),
573
left_on=["key1", "key2"],
574
right_on=["key1b", "key2b"],
575
),
576
expected,
577
)
578
579
df = pl.DataFrame({"A": [1, 2, 3, 4], "B": [400, 500, 600, 700]})
580
581
new_df = pl.DataFrame({"B": [4, None, 6], "C": [7, 8, 9]})
582
583
assert df.update(new_df).to_dict(as_series=False) == {
584
"A": [1, 2, 3, 4],
585
"B": [4, 500, 6, 700],
586
}
587
df1 = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
588
df2 = pl.DataFrame({"a": [2, 3], "b": [8, 9]})
589
590
assert df1.update(df2, on="a").to_dict(as_series=False) == {
591
"a": [1, 2, 3],
592
"b": [4, 8, 9],
593
}
594
595
a = pl.LazyFrame({"a": [1, 2, 3]})
596
b = pl.LazyFrame({"b": [4, 5], "c": [3, 1]})
597
c = a.update(b)
598
599
assert_frame_equal(a, c)
600
601
# check behaviour of 'how' param
602
result = a.update(b, left_on="a", right_on="c")
603
assert result.collect().to_series().to_list() == [1, 2, 3]
604
605
result = a.update(b, how="inner", left_on="a", right_on="c")
606
assert sorted(result.collect().to_series().to_list()) == [1, 3]
607
608
result = a.update(b.rename({"b": "a"}), how="full", on="a")
609
assert sorted(result.collect().to_series().sort().to_list()) == [1, 2, 3, 4, 5]
610
611
# check behavior of include_nulls=True
612
df = pl.DataFrame(
613
{
614
"A": [1, 2, 3, 4],
615
"B": [400, 500, 600, 700],
616
}
617
)
618
new_df = pl.DataFrame(
619
{
620
"B": [-66, None, -99],
621
"C": [5, 3, 1],
622
}
623
)
624
out = df.update(new_df, left_on="A", right_on="C", how="full", include_nulls=True)
625
expected = pl.DataFrame(
626
{
627
"A": [1, 2, 3, 4, 5],
628
"B": [-99, 500, None, 700, -66],
629
}
630
)
631
assert_frame_equal(out, expected, check_row_order=False)
632
633
# edge-case #11684
634
x = pl.DataFrame({"a": [0, 1]})
635
y = pl.DataFrame({"a": [2, 3]})
636
assert sorted(x.update(y, on="a", how="full")["a"].to_list()) == [0, 1, 2, 3]
637
638
# disallowed join strategies
639
for join_strategy in ("cross", "anti", "semi"):
640
with pytest.raises(
641
ValueError,
642
match=f"`how` must be one of {{'left', 'inner', 'full'}}; found '{join_strategy}'",
643
):
644
a.update(b, how=join_strategy) # type: ignore[arg-type]
645
646
647
def test_join_frame_consistency() -> None:
648
df = pl.DataFrame({"A": [1, 2, 3]})
649
ldf = pl.DataFrame({"A": [1, 2, 5]}).lazy()
650
651
with pytest.raises(TypeError, match="expected `other`.*LazyFrame"):
652
_ = ldf.join(df, on="A") # type: ignore[arg-type]
653
with pytest.raises(TypeError, match="expected `other`.*DataFrame"):
654
_ = df.join(ldf, on="A") # type: ignore[arg-type]
655
with pytest.raises(TypeError, match="expected `other`.*LazyFrame"):
656
_ = ldf.join_asof(df, on="A") # type: ignore[arg-type]
657
with pytest.raises(TypeError, match="expected `other`.*DataFrame"):
658
_ = df.join_asof(ldf, on="A") # type: ignore[arg-type]
659
660
661
def test_join_concat_projection_pd_case_7071() -> None:
662
ldf = pl.DataFrame({"id": [1, 2], "value": [100, 200]}).lazy()
663
ldf2 = pl.DataFrame({"id": [1, 3], "value": [100, 300]}).lazy()
664
665
ldf = ldf.join(ldf2, on=["id", "value"])
666
ldf = pl.concat([ldf, ldf2])
667
result = ldf.select("id")
668
669
expected = pl.DataFrame({"id": [1, 1, 3]}).lazy()
670
assert_frame_equal(result, expected)
671
672
673
@pytest.mark.may_fail_auto_streaming # legacy full join is not order-preserving whereas new-streaming is
674
def test_join_sorted_fast_paths_null() -> None:
675
df1 = pl.DataFrame({"x": [0, 1, 0]}).sort("x")
676
df2 = pl.DataFrame({"x": [0, None], "y": [0, 1]})
677
assert df1.join(df2, on="x", how="inner").to_dict(as_series=False) == {
678
"x": [0, 0],
679
"y": [0, 0],
680
}
681
assert df1.join(df2, on="x", how="left").to_dict(as_series=False) == {
682
"x": [0, 0, 1],
683
"y": [0, 0, None],
684
}
685
assert df1.join(df2, on="x", how="anti").to_dict(as_series=False) == {"x": [1]}
686
assert df1.join(df2, on="x", how="semi").to_dict(as_series=False) == {"x": [0, 0]}
687
assert df1.join(df2, on="x", how="full").to_dict(as_series=False) == {
688
"x": [0, 0, 1, None],
689
"x_right": [0, 0, None, None],
690
"y": [0, 0, None, 1],
691
}
692
693
694
def test_full_outer_join_list_() -> None:
695
schema = {"id": pl.Int64, "vals": pl.List(pl.Float64)}
696
join_schema = {**schema, **{k + "_right": t for (k, t) in schema.items()}}
697
df1 = pl.DataFrame({"id": [1], "vals": [[]]}, schema=schema) # type: ignore[arg-type]
698
df2 = pl.DataFrame({"id": [2, 3], "vals": [[], [4]]}, schema=schema) # type: ignore[arg-type]
699
expected = pl.DataFrame(
700
{
701
"id": [None, None, 1],
702
"vals": [None, None, []],
703
"id_right": [2, 3, None],
704
"vals_right": [[], [4.0], None],
705
},
706
schema=join_schema, # type: ignore[arg-type]
707
)
708
out = df1.join(df2, on="id", how="full", maintain_order="right_left")
709
assert_frame_equal(out, expected)
710
711
712
@pytest.mark.slow
713
def test_join_validation() -> None:
714
def test_each_join_validation(
715
unique: pl.DataFrame, duplicate: pl.DataFrame, on: str, how: JoinStrategy
716
) -> None:
717
# one_to_many
718
_one_to_many_success_inner = unique.join(
719
duplicate, on=on, how=how, validate="1:m"
720
)
721
722
with pytest.raises(ComputeError):
723
_one_to_many_fail_inner = duplicate.join(
724
unique, on=on, how=how, validate="1:m"
725
)
726
727
# one to one
728
with pytest.raises(ComputeError):
729
_one_to_one_fail_1_inner = unique.join(
730
duplicate, on=on, how=how, validate="1:1"
731
)
732
733
with pytest.raises(ComputeError):
734
_one_to_one_fail_2_inner = duplicate.join(
735
unique, on=on, how=how, validate="1:1"
736
)
737
738
# many to one
739
with pytest.raises(ComputeError):
740
_many_to_one_fail_inner = unique.join(
741
duplicate, on=on, how=how, validate="m:1"
742
)
743
744
_many_to_one_success_inner = duplicate.join(
745
unique, on=on, how=how, validate="m:1"
746
)
747
748
# many to many
749
_many_to_many_success_1_inner = duplicate.join(
750
unique, on=on, how=how, validate="m:m"
751
)
752
753
_many_to_many_success_2_inner = unique.join(
754
duplicate, on=on, how=how, validate="m:m"
755
)
756
757
# test data
758
short_unique = pl.DataFrame(
759
{
760
"id": [1, 2, 3, 4],
761
"id_str": ["1", "2", "3", "4"],
762
"name": ["hello", "world", "rust", "polars"],
763
}
764
)
765
short_duplicate = pl.DataFrame(
766
{"id": [1, 2, 3, 1], "id_str": ["1", "2", "3", "1"], "cnt": [2, 4, 6, 1]}
767
)
768
long_unique = pl.DataFrame(
769
{
770
"id": [1, 2, 3, 4, 5],
771
"id_str": ["1", "2", "3", "4", "5"],
772
"name": ["hello", "world", "rust", "polars", "meow"],
773
}
774
)
775
long_duplicate = pl.DataFrame(
776
{
777
"id": [1, 2, 3, 1, 5],
778
"id_str": ["1", "2", "3", "1", "5"],
779
"cnt": [2, 4, 6, 1, 8],
780
}
781
)
782
783
join_strategies: list[JoinStrategy] = ["inner", "full", "left"]
784
785
for join_col in ["id", "id_str"]:
786
for how in join_strategies:
787
# same size
788
test_each_join_validation(long_unique, long_duplicate, join_col, how)
789
790
# left longer
791
test_each_join_validation(long_unique, short_duplicate, join_col, how)
792
793
# right longer
794
test_each_join_validation(short_unique, long_duplicate, join_col, how)
795
796
797
@typing.no_type_check
798
def test_join_validation_many_keys() -> None:
799
# unique in both
800
df1 = pl.DataFrame(
801
{
802
"val1": [11, 12, 13, 14],
803
"val2": [1, 2, 3, 4],
804
}
805
)
806
df2 = pl.DataFrame(
807
{
808
"val1": [11, 12, 13, 14],
809
"val2": [1, 2, 3, 4],
810
}
811
)
812
for join_type in ["inner", "left", "full"]:
813
for val in ["m:m", "m:1", "1:1", "1:m"]:
814
df1.join(df2, on=["val1", "val2"], how=join_type, validate=val)
815
816
# many in lhs
817
df1 = pl.DataFrame(
818
{
819
"val1": [11, 11, 12, 13, 14],
820
"val2": [1, 1, 2, 3, 4],
821
}
822
)
823
824
for join_type in ["inner", "left", "full"]:
825
for val in ["1:1", "1:m"]:
826
with pytest.raises(ComputeError):
827
df1.join(df2, on=["val1", "val2"], how=join_type, validate=val)
828
829
# many in rhs
830
df1 = pl.DataFrame(
831
{
832
"val1": [11, 12, 13, 14],
833
"val2": [1, 2, 3, 4],
834
}
835
)
836
df2 = pl.DataFrame(
837
{
838
"val1": [11, 11, 12, 13, 14],
839
"val2": [1, 1, 2, 3, 4],
840
}
841
)
842
843
for join_type in ["inner", "left", "full"]:
844
for val in ["m:1", "1:1"]:
845
with pytest.raises(ComputeError):
846
df1.join(df2, on=["val1", "val2"], how=join_type, validate=val)
847
848
849
def test_full_outer_join_bool() -> None:
850
df1 = pl.DataFrame({"id": [True, False], "val": [1, 2]})
851
df2 = pl.DataFrame({"id": [True, False], "val": [0, -1]})
852
assert df1.join(df2, on="id", how="full", maintain_order="right").to_dict(
853
as_series=False
854
) == {
855
"id": [True, False],
856
"val": [1, 2],
857
"id_right": [True, False],
858
"val_right": [0, -1],
859
}
860
861
862
def test_full_outer_join_coalesce_different_names_13450() -> None:
863
df1 = pl.DataFrame({"L1": ["a", "b", "c"], "L3": ["b", "c", "d"], "L2": [1, 2, 3]})
864
df2 = pl.DataFrame({"L3": ["a", "c", "d"], "R2": [7, 8, 9]})
865
866
expected = pl.DataFrame(
867
{
868
"L1": ["a", "c", "d", "b"],
869
"L3": ["b", "d", None, "c"],
870
"L2": [1, 3, None, 2],
871
"R2": [7, 8, 9, None],
872
}
873
)
874
875
out = df1.join(df2, left_on="L1", right_on="L3", how="full", coalesce=True)
876
assert_frame_equal(out, expected, check_row_order=False)
877
878
879
# https://github.com/pola-rs/polars/issues/10663
880
def test_join_on_wildcard_error() -> None:
881
df = pl.DataFrame({"x": [1]})
882
df2 = pl.DataFrame({"x": [1], "y": [2]})
883
with pytest.raises(
884
InvalidOperationError,
885
):
886
df.join(df2, on=pl.all())
887
888
889
def test_join_on_nth_error() -> None:
890
df = pl.DataFrame({"x": [1]})
891
df2 = pl.DataFrame({"x": [1], "y": [2]})
892
with pytest.raises(
893
InvalidOperationError,
894
):
895
df.join(df2, on=pl.first())
896
897
898
def test_join_results_in_duplicate_names() -> None:
899
df = pl.DataFrame(
900
{
901
"a": [1, 2, 3],
902
"b": [4, 5, 6],
903
"c": [1, 2, 3],
904
"c_right": [1, 2, 3],
905
}
906
)
907
908
def f(x: Any) -> Any:
909
return x.join(x, on=["a", "b"], how="left")
910
911
# Ensure it also contains the hint
912
match_str = "(?s)column with name 'c_right' already exists.*You may want to try"
913
914
# Ensure it fails immediately when resolving schema.
915
with pytest.raises(DuplicateError, match=match_str):
916
f(df.lazy()).collect_schema()
917
918
with pytest.raises(DuplicateError, match=match_str):
919
f(df.lazy()).collect()
920
921
with pytest.raises(DuplicateError, match=match_str):
922
f(df).collect()
923
924
925
def test_join_duplicate_suffixed_columns_from_join_key_column_21048() -> None:
926
df = pl.DataFrame({"a": 1, "b": 1, "b_right": 1})
927
928
def f(x: Any) -> Any:
929
return x.join(x, on="a")
930
931
# Ensure it also contains the hint
932
match_str = "(?s)column with name 'b_right' already exists.*You may want to try"
933
934
# Ensure it fails immediately when resolving schema.
935
with pytest.raises(DuplicateError, match=match_str):
936
f(df.lazy()).collect_schema()
937
938
with pytest.raises(DuplicateError, match=match_str):
939
f(df.lazy()).collect()
940
941
with pytest.raises(DuplicateError, match=match_str):
942
f(df)
943
944
945
def test_join_projection_invalid_name_contains_suffix_15243() -> None:
946
df1 = pl.DataFrame({"a": [1, 2, 3]}).lazy()
947
df2 = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).lazy()
948
949
with pytest.raises(ColumnNotFoundError):
950
(
951
df1.join(df2, on="a")
952
.select(pl.col("b").filter(pl.col("b") == pl.col("foo_right")))
953
.collect()
954
)
955
956
957
def test_join_list_non_numeric() -> None:
958
assert (
959
pl.DataFrame(
960
{
961
"lists": [
962
["a", "b", "c"],
963
["a", "c", "b"],
964
["a", "c", "b"],
965
["a", "c", "d"],
966
]
967
}
968
)
969
).group_by("lists", maintain_order=True).agg(pl.len().alias("count")).to_dict(
970
as_series=False
971
) == {
972
"lists": [["a", "b", "c"], ["a", "c", "b"], ["a", "c", "d"]],
973
"count": [1, 2, 1],
974
}
975
976
977
@pytest.mark.slow
978
def test_join_4_columns_with_validity() -> None:
979
# join on 4 columns so we trigger combine validities
980
# use 138 as that is 2 u64 and a remainder
981
a = pl.DataFrame(
982
{"a": [None if a % 6 == 0 else a for a in range(138)]}
983
).with_columns(
984
b=pl.col("a"),
985
c=pl.col("a"),
986
d=pl.col("a"),
987
)
988
989
assert a.join(a, on=["a", "b", "c", "d"], how="inner", nulls_equal=True).shape == (
990
644,
991
4,
992
)
993
assert a.join(a, on=["a", "b", "c", "d"], how="inner", nulls_equal=False).shape == (
994
115,
995
4,
996
)
997
998
999
@pytest.mark.release
1000
def test_cross_join() -> None:
1001
# triggers > 100 rows implementation
1002
# https://github.com/pola-rs/polars/blob/5f5acb2a523ce01bc710768b396762b8e69a9e07/polars/polars-core/src/frame/cross_join.rs#L34
1003
df1 = pl.DataFrame({"col1": ["a"], "col2": ["d"]})
1004
df2 = pl.DataFrame({"frame2": pl.arange(0, 100, eager=True)})
1005
out = df2.join(df1, how="cross")
1006
df2 = pl.DataFrame({"frame2": pl.arange(0, 101, eager=True)})
1007
assert_frame_equal(df2.join(df1, how="cross").slice(0, 100), out)
1008
1009
1010
@pytest.mark.release
1011
def test_cross_join_slice_pushdown() -> None:
1012
# this will likely go out of memory if we did not pushdown the slice
1013
df = (
1014
pl.Series("x", pl.arange(0, 2**16 - 1, eager=True, dtype=pl.UInt16) % 2**15)
1015
).to_frame()
1016
1017
result = df.lazy().join(df.lazy(), how="cross", suffix="_").slice(-5, 10).collect()
1018
expected = pl.DataFrame(
1019
{
1020
"x": [32766, 32766, 32766, 32766, 32766],
1021
"x_": [32762, 32763, 32764, 32765, 32766],
1022
},
1023
schema={"x": pl.UInt16, "x_": pl.UInt16},
1024
)
1025
assert_frame_equal(result, expected)
1026
1027
result = df.lazy().join(df.lazy(), how="cross", suffix="_").slice(2, 10).collect()
1028
expected = pl.DataFrame(
1029
{
1030
"x": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
1031
"x_": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
1032
},
1033
schema={"x": pl.UInt16, "x_": pl.UInt16},
1034
)
1035
assert_frame_equal(result, expected)
1036
1037
1038
@pytest.mark.parametrize("how", ["left", "inner"])
1039
def test_join_coalesce(how: JoinStrategy) -> None:
1040
a = pl.LazyFrame({"a": [1, 2], "b": [1, 2]})
1041
b = pl.LazyFrame(
1042
{
1043
"a": [1, 2, 1, 2],
1044
"b": [5, 7, 8, 9],
1045
"c": [1, 2, 1, 2],
1046
}
1047
)
1048
1049
how = "inner"
1050
q = a.join(b, on="a", coalesce=False, how=how)
1051
out = q.collect()
1052
assert q.collect_schema() == out.schema
1053
assert out.columns == ["a", "b", "a_right", "b_right", "c"]
1054
1055
q = a.join(b, on=["a", "b"], coalesce=False, how=how)
1056
out = q.collect()
1057
assert q.collect_schema() == out.schema
1058
assert out.columns == ["a", "b", "a_right", "b_right", "c"]
1059
1060
q = a.join(b, on=["a", "b"], coalesce=True, how=how)
1061
out = q.collect()
1062
assert q.collect_schema() == out.schema
1063
assert out.columns == ["a", "b", "c"]
1064
1065
1066
@pytest.mark.parametrize("how", ["left", "inner", "full"])
1067
def test_join_empties(how: JoinStrategy) -> None:
1068
df1 = pl.DataFrame({"col1": [], "col2": [], "col3": []})
1069
df2 = pl.DataFrame({"col2": [], "col4": [], "col5": []})
1070
1071
df = df1.join(df2, on="col2", how=how)
1072
assert df.height == 0
1073
1074
1075
def test_join_raise_on_redundant_keys() -> None:
1076
left = pl.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5], "c": [5, 6, 7]})
1077
right = pl.DataFrame({"a": [2, 3, 4], "c": [4, 5, 6]})
1078
with pytest.raises(InvalidOperationError, match="already joined on"):
1079
left.join(right, on=["a", "a"], how="full", coalesce=True)
1080
1081
1082
@pytest.mark.parametrize("coalesce", [False, True])
1083
def test_join_raise_on_repeated_expression_key_names(coalesce: bool) -> None:
1084
left = pl.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5], "c": [5, 6, 7]})
1085
right = pl.DataFrame({"a": [2, 3, 4], "c": [4, 5, 6]})
1086
with ( # noqa: PT012
1087
pytest.raises(InvalidOperationError, match="already joined on"),
1088
warnings.catch_warnings(),
1089
):
1090
warnings.simplefilter(action="ignore", category=UserWarning)
1091
left.join(
1092
right, on=[pl.col("a"), pl.col("a") % 2], how="full", coalesce=coalesce
1093
)
1094
1095
1096
def test_join_lit_panic_11410() -> None:
1097
df = pl.LazyFrame({"date": [1, 2, 3], "symbol": [4, 5, 6]})
1098
dates = df.select("date").unique(maintain_order=True)
1099
symbols = df.select("symbol").unique(maintain_order=True)
1100
1101
assert symbols.join(
1102
dates, left_on=pl.lit(1), right_on=pl.lit(1), maintain_order="left_right"
1103
).collect().to_dict(as_series=False) == {
1104
"symbol": [4, 4, 4, 5, 5, 5, 6, 6, 6],
1105
"date": [1, 2, 3, 1, 2, 3, 1, 2, 3],
1106
}
1107
1108
1109
def test_join_empty_literal_17027() -> None:
1110
df1 = pl.DataFrame({"a": [1]})
1111
df2 = pl.DataFrame(schema={"a": pl.Int64})
1112
1113
assert df1.join(df2, on=pl.lit(0), how="left").height == 1
1114
assert df1.join(df2, on=pl.lit(0), how="inner").height == 0
1115
assert (
1116
df1.lazy()
1117
.join(df2.lazy(), on=pl.lit(0), how="inner")
1118
.collect(engine="streaming")
1119
.height
1120
== 0
1121
)
1122
assert (
1123
df1.lazy()
1124
.join(df2.lazy(), on=pl.lit(0), how="left")
1125
.collect(engine="streaming")
1126
.height
1127
== 1
1128
)
1129
1130
1131
@pytest.mark.parametrize(
1132
("left_on", "right_on"),
1133
zip(
1134
[pl.col("a"), pl.col("a").sort(), [pl.col("a"), pl.col("b")]],
1135
[pl.col("a").slice(0, 2) * 2, pl.col("b"), [pl.col("a"), pl.col("b").head()]],
1136
),
1137
)
1138
def test_join_non_elementwise_keys_raises(left_on: pl.Expr, right_on: pl.Expr) -> None:
1139
# https://github.com/pola-rs/polars/issues/17184
1140
left = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
1141
right = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
1142
1143
q = left.join(
1144
right,
1145
left_on=left_on,
1146
right_on=right_on,
1147
how="inner",
1148
)
1149
1150
with pytest.raises(pl.exceptions.InvalidOperationError):
1151
q.collect()
1152
1153
1154
def test_join_coalesce_not_supported_warning() -> None:
1155
# https://github.com/pola-rs/polars/issues/17184
1156
left = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
1157
right = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
1158
1159
q = left.join(
1160
right,
1161
left_on=[pl.col("a") * 2],
1162
right_on=[pl.col("a") * 2],
1163
how="inner",
1164
coalesce=True,
1165
)
1166
with pytest.warns(UserWarning, match="turning off key coalescing"):
1167
got = q.collect()
1168
expect = pl.DataFrame(
1169
{"a": [1, 2, 3], "b": [3, 4, 5], "a_right": [1, 2, 3], "b_right": [3, 4, 5]}
1170
)
1171
1172
assert_frame_equal(expect, got, check_row_order=False)
1173
1174
1175
@pytest.mark.parametrize(
1176
("on_args"),
1177
[
1178
{"on": "a", "left_on": "a"},
1179
{"on": "a", "right_on": "a"},
1180
{"on": "a", "left_on": "a", "right_on": "a"},
1181
],
1182
)
1183
def test_join_on_and_left_right_on(on_args: dict[str, str]) -> None:
1184
df1 = pl.DataFrame({"a": [1], "b": [2]})
1185
df2 = pl.DataFrame({"a": [1], "c": [3]})
1186
msg = "cannot use 'on' in conjunction with 'left_on' or 'right_on'"
1187
with pytest.raises(ValueError, match=msg):
1188
df1.join(df2, **on_args) # type: ignore[arg-type]
1189
1190
1191
@pytest.mark.parametrize(
1192
("on_args"),
1193
[
1194
{"left_on": "a"},
1195
{"right_on": "a"},
1196
],
1197
)
1198
def test_join_only_left_or_right_on(on_args: dict[str, str]) -> None:
1199
df1 = pl.DataFrame({"a": [1]})
1200
df2 = pl.DataFrame({"a": [1]})
1201
msg = "'left_on' requires corresponding 'right_on'"
1202
with pytest.raises(ValueError, match=msg):
1203
df1.join(df2, **on_args) # type: ignore[arg-type]
1204
1205
1206
@pytest.mark.parametrize(
1207
("on_args"),
1208
[
1209
{"on": "a"},
1210
{"left_on": "a", "right_on": "a"},
1211
],
1212
)
1213
def test_cross_join_no_on_keys(on_args: dict[str, str]) -> None:
1214
df1 = pl.DataFrame({"a": [1, 2]})
1215
df2 = pl.DataFrame({"b": [3, 4]})
1216
msg = "cross join should not pass join keys"
1217
with pytest.raises(ValueError, match=msg):
1218
df1.join(df2, how="cross", **on_args) # type: ignore[arg-type]
1219
1220
1221
@pytest.mark.parametrize("set_sorted", [True, False])
1222
def test_left_join_slice_pushdown_19405(set_sorted: bool) -> None:
1223
left = pl.LazyFrame({"k": [1, 2, 3, 4, 0]})
1224
right = pl.LazyFrame({"k": [1, 1, 1, 1, 0]})
1225
1226
if set_sorted:
1227
# The data isn't actually sorted on purpose to ensure we default to a
1228
# hash join unless we set the sorted flag here, in case there is new
1229
# code in the future that automatically identifies sortedness during
1230
# Series construction from Python.
1231
left = left.set_sorted("k")
1232
right = right.set_sorted("k")
1233
1234
q = left.join(right, on="k", how="left", maintain_order="left_right").head(5)
1235
assert_frame_equal(q.collect(), pl.DataFrame({"k": [1, 1, 1, 1, 2]}))
1236
1237
1238
def test_join_key_type_coercion_19597() -> None:
1239
left = pl.LazyFrame({"a": pl.Series([1, 2, 3], dtype=pl.Float64)})
1240
right = pl.LazyFrame({"a": pl.Series([1, 2, 3], dtype=pl.Int64)})
1241
1242
with pytest.raises(SchemaError, match="datatypes of join keys don't match"):
1243
left.join(right, left_on=pl.col("a"), right_on=pl.col("a")).collect_schema()
1244
1245
with pytest.raises(SchemaError, match="datatypes of join keys don't match"):
1246
left.join(
1247
right, left_on=pl.col("a") * 2, right_on=pl.col("a") * 2
1248
).collect_schema()
1249
1250
1251
def test_array_explode_join_19763() -> None:
1252
q = pl.LazyFrame().select(
1253
pl.lit(pl.Series([[1], [2]], dtype=pl.Array(pl.Int64, 1))).explode().alias("k")
1254
)
1255
1256
q = q.join(pl.LazyFrame({"k": [1, 2]}), on="k")
1257
1258
assert_frame_equal(q.collect().sort("k"), pl.DataFrame({"k": [1, 2]}))
1259
1260
1261
def test_join_full_19814() -> None:
1262
schema = {"a": pl.Int64, "c": pl.Categorical}
1263
a = pl.LazyFrame({"a": [1], "c": [None]}, schema=schema)
1264
b = pl.LazyFrame({"a": [1, 3, 4]})
1265
assert_frame_equal(
1266
a.join(b, on="a", how="full", coalesce=True).collect(),
1267
pl.DataFrame({"a": [1, 3, 4], "c": [None, None, None]}, schema=schema),
1268
check_row_order=False,
1269
)
1270
1271
1272
def test_join_preserve_order_inner() -> None:
1273
left = pl.LazyFrame({"a": [None, 2, 1, 1, 5]})
1274
right = pl.LazyFrame({"a": [1, 1, None, 2], "b": [6, 7, 8, 9]})
1275
1276
# Inner joins
1277
1278
inner_left = left.join(right, on="a", how="inner", maintain_order="left").collect()
1279
assert inner_left.get_column("a").cast(pl.UInt32).to_list() == [2, 1, 1, 1, 1]
1280
inner_left_right = left.join(
1281
right, on="a", how="inner", maintain_order="left"
1282
).collect()
1283
assert inner_left.get_column("a").equals(inner_left_right.get_column("a"))
1284
1285
inner_right = left.join(
1286
right, on="a", how="inner", maintain_order="right"
1287
).collect()
1288
assert inner_right.get_column("a").cast(pl.UInt32).to_list() == [1, 1, 1, 1, 2]
1289
inner_right_left = left.join(
1290
right, on="a", how="inner", maintain_order="right"
1291
).collect()
1292
assert inner_right.get_column("a").equals(inner_right_left.get_column("a"))
1293
1294
1295
# The new streaming engine does not provide the same maintain_order="none"
1296
# ordering guarantee that is currently kept for compatibility on the in-memory
1297
# engine.
1298
@pytest.mark.may_fail_auto_streaming
1299
def test_join_preserve_order_left() -> None:
1300
left = pl.LazyFrame({"a": [None, 2, 1, 1, 5]})
1301
right = pl.LazyFrame({"a": [1, None, 2, 6], "b": [6, 7, 8, 9]})
1302
1303
# Right now the left join algorithm is ordered without explicitly setting any order
1304
# This behaviour is deprecated but can only be removed in 2.0
1305
left_none = left.join(right, on="a", how="left", maintain_order="none").collect()
1306
assert left_none.get_column("a").cast(pl.UInt32).to_list() == [
1307
None,
1308
2,
1309
1,
1310
1,
1311
5,
1312
]
1313
1314
left_left = left.join(right, on="a", how="left", maintain_order="left").collect()
1315
assert left_left.get_column("a").cast(pl.UInt32).to_list() == [
1316
None,
1317
2,
1318
1,
1319
1,
1320
5,
1321
]
1322
1323
left_left_right = left.join(
1324
right, on="a", how="left", maintain_order="left_right"
1325
).collect()
1326
# If the left order is preserved then there are no unsorted right rows
1327
assert left_left.get_column("a").equals(left_left_right.get_column("a"))
1328
1329
left_right = left.join(right, on="a", how="left", maintain_order="right").collect()
1330
assert left_right.get_column("a").cast(pl.UInt32).to_list()[:5] == [
1331
1,
1332
1,
1333
2,
1334
None,
1335
5,
1336
]
1337
1338
left_right_left = left.join(
1339
right, on="a", how="left", maintain_order="right_left"
1340
).collect()
1341
assert left_right_left.get_column("a").cast(pl.UInt32).to_list() == [
1342
1,
1343
1,
1344
2,
1345
None,
1346
5,
1347
]
1348
1349
right_left = left.join(right, on="a", how="right", maintain_order="left").collect()
1350
assert right_left.get_column("a").cast(pl.UInt32).to_list() == [2, 1, 1, None, 6]
1351
1352
right_right = left.join(
1353
right, on="a", how="right", maintain_order="right"
1354
).collect()
1355
assert right_right.get_column("a").cast(pl.UInt32).to_list() == [
1356
1,
1357
1,
1358
None,
1359
2,
1360
6,
1361
]
1362
1363
1364
def test_join_preserve_order_full() -> None:
1365
left = pl.LazyFrame({"a": [None, 2, 1, 1, 5]})
1366
right = pl.LazyFrame({"a": [1, None, 2, 6], "b": [6, 7, 8, 9]})
1367
1368
full_left = left.join(right, on="a", how="full", maintain_order="left").collect()
1369
assert full_left.get_column("a").cast(pl.UInt32).to_list()[:5] == [
1370
None,
1371
2,
1372
1,
1373
1,
1374
5,
1375
]
1376
full_right = left.join(right, on="a", how="full", maintain_order="right").collect()
1377
assert full_right.get_column("a").cast(pl.UInt32).to_list()[:5] == [
1378
1,
1379
1,
1380
None,
1381
2,
1382
None,
1383
]
1384
1385
full_left_right = left.join(
1386
right, on="a", how="full", maintain_order="left_right"
1387
).collect()
1388
assert full_left_right.get_column("a_right").cast(pl.UInt32).to_list() == [
1389
None,
1390
2,
1391
1,
1392
1,
1393
None,
1394
None,
1395
6,
1396
]
1397
1398
full_right_left = left.join(
1399
right, on="a", how="full", maintain_order="right_left"
1400
).collect()
1401
assert full_right_left.get_column("a").cast(pl.UInt32).to_list() == [
1402
1,
1403
1,
1404
None,
1405
2,
1406
None,
1407
None,
1408
5,
1409
]
1410
1411
1412
@pytest.mark.parametrize(
1413
"dtypes",
1414
[
1415
["Int128", "Int128", "Int64"],
1416
["Int128", "Int128", "Int32"],
1417
["Int128", "Int128", "Int16"],
1418
["Int128", "Int128", "Int8"],
1419
["Int128", "UInt64", "Int128"],
1420
["Int128", "UInt64", "Int64"],
1421
["Int128", "UInt64", "Int32"],
1422
["Int128", "UInt64", "Int16"],
1423
["Int128", "UInt64", "Int8"],
1424
["Int128", "UInt32", "Int128"],
1425
["Int128", "UInt16", "Int128"],
1426
["Int128", "UInt8", "Int128"],
1427
1428
["Int64", "Int64", "Int32"],
1429
["Int64", "Int64", "Int16"],
1430
["Int64", "Int64", "Int8"],
1431
["Int64", "UInt32", "Int64"],
1432
["Int64", "UInt32", "Int32"],
1433
["Int64", "UInt32", "Int16"],
1434
["Int64", "UInt32", "Int8"],
1435
["Int64", "UInt16", "Int64"],
1436
["Int64", "UInt8", "Int64"],
1437
1438
["Int32", "Int32", "Int16"],
1439
["Int32", "Int32", "Int8"],
1440
["Int32", "UInt16", "Int32"],
1441
["Int32", "UInt16", "Int16"],
1442
["Int32", "UInt16", "Int8"],
1443
["Int32", "UInt8", "Int32"],
1444
1445
["Int16", "Int16", "Int8"],
1446
["Int16", "UInt8", "Int16"],
1447
["Int16", "UInt8", "Int8"],
1448
1449
["UInt64", "UInt64", "UInt32"],
1450
["UInt64", "UInt64", "UInt16"],
1451
["UInt64", "UInt64", "UInt8"],
1452
1453
["UInt32", "UInt32", "UInt16"],
1454
["UInt32", "UInt32", "UInt8"],
1455
1456
["UInt16", "UInt16", "UInt8"],
1457
1458
["Float64", "Float64", "Float32"],
1459
],
1460
) # fmt: skip
1461
@pytest.mark.parametrize("swap", [True, False])
1462
def test_join_numeric_key_upcast_15338(
1463
dtypes: tuple[str, str, str], swap: bool
1464
) -> None:
1465
supertype, ltype, rtype = (getattr(pl, x) for x in dtypes)
1466
ltype, rtype = (rtype, ltype) if swap else (ltype, rtype)
1467
1468
left = pl.select(pl.Series("a", [1, 1, 3]).cast(ltype)).lazy()
1469
right = pl.select(pl.Series("a", [1]).cast(rtype), b=pl.lit("A")).lazy()
1470
1471
assert_frame_equal(
1472
left.join(right, on="a", how="left").collect(),
1473
pl.select(a=pl.Series([1, 1, 3]).cast(ltype), b=pl.Series(["A", "A", None])),
1474
check_row_order=False,
1475
)
1476
1477
assert_frame_equal(
1478
left.join(right, on="a", how="left", coalesce=False).drop("a_right").collect(),
1479
pl.select(a=pl.Series([1, 1, 3]).cast(ltype), b=pl.Series(["A", "A", None])),
1480
check_row_order=False,
1481
)
1482
1483
assert_frame_equal(
1484
left.join(right, on="a", how="full").collect(),
1485
pl.select(
1486
a=pl.Series([1, 1, 3]).cast(ltype),
1487
a_right=pl.Series([1, 1, None]).cast(rtype),
1488
b=pl.Series(["A", "A", None]),
1489
),
1490
check_row_order=False,
1491
)
1492
1493
assert_frame_equal(
1494
left.join(right, on="a", how="full", coalesce=True).collect(),
1495
pl.select(
1496
a=pl.Series([1, 1, 3]).cast(supertype),
1497
b=pl.Series(["A", "A", None]),
1498
),
1499
check_row_order=False,
1500
)
1501
1502
assert_frame_equal(
1503
left.join(right, on="a", how="semi").collect(),
1504
pl.select(a=pl.Series([1, 1]).cast(ltype)),
1505
)
1506
1507
# join_where
1508
for optimizations in [pl.QueryOptFlags(), pl.QueryOptFlags.none()]:
1509
assert_frame_equal(
1510
left.join_where(right, pl.col("a") == pl.col("a_right")).collect(
1511
optimizations=optimizations,
1512
),
1513
pl.select(
1514
a=pl.Series([1, 1]).cast(ltype),
1515
a_right=pl.lit(1, dtype=rtype),
1516
b=pl.Series(["A", "A"]),
1517
),
1518
)
1519
1520
1521
def test_join_numeric_key_upcast_forbid_float_int() -> None:
1522
ltype = pl.Float64
1523
rtype = pl.Int128
1524
1525
left = pl.LazyFrame({"a": [1.0, 0.0]}, schema={"a": ltype})
1526
right = pl.LazyFrame({"a": [1, 2]}, schema={"a": rtype})
1527
1528
# Establish baseline: In a non-join context, comparisons between ltype and
1529
# rtype succeed even if the upcast is lossy.
1530
assert_frame_equal(
1531
left.with_columns(right.collect()["a"].alias("a_right"))
1532
.select(pl.col("a") == pl.col("a_right"))
1533
.collect(),
1534
pl.DataFrame({"a": [True, False]}),
1535
)
1536
1537
with pytest.raises(SchemaError, match="datatypes of join keys don't match"):
1538
left.join(right, on="a", how="left").collect()
1539
1540
for optimizations in [pl.QueryOptFlags(), pl.QueryOptFlags.none()]:
1541
with pytest.raises(
1542
SchemaError, match="'join_where' cannot compare Float64 with Int128"
1543
):
1544
left.join_where(right, pl.col("a") == pl.col("a_right")).collect(
1545
optimizations=optimizations,
1546
)
1547
1548
with pytest.raises(
1549
SchemaError, match="'join_where' cannot compare Float64 with Int128"
1550
):
1551
left.join_where(
1552
right, pl.col("a") == (pl.col("a") == pl.col("a_right"))
1553
).collect(optimizations=optimizations)
1554
1555
1556
def test_join_numeric_key_upcast_order() -> None:
1557
# E.g. when we are joining on this expression:
1558
# * col('a') + 127
1559
#
1560
# and we want to upcast, ensure that we upcast like this:
1561
# * ( col('a') + 127 ) .cast(<type>)
1562
#
1563
# and *not* like this:
1564
# * ( col('a').cast(<type>) + lit(127).cast(<type>) )
1565
#
1566
# as otherwise the results would be different.
1567
1568
left = pl.select(pl.Series("a", [1], dtype=pl.Int8)).lazy()
1569
right = pl.select(
1570
pl.Series("a", [1, 128, -128], dtype=pl.Int64), b=pl.lit("A")
1571
).lazy()
1572
1573
# col('a') in `left` is Int8, the result will overflow to become -128
1574
left_expr = pl.col("a") + 127
1575
1576
assert_frame_equal(
1577
left.join(right, left_on=left_expr, right_on="a", how="inner").collect(),
1578
pl.DataFrame(
1579
{
1580
"a": pl.Series([1], dtype=pl.Int8),
1581
"a_right": pl.Series([-128], dtype=pl.Int64),
1582
"b": "A",
1583
}
1584
),
1585
)
1586
1587
assert_frame_equal(
1588
left.join_where(right, left_expr == pl.col("a_right")).collect(),
1589
pl.DataFrame(
1590
{
1591
"a": pl.Series([1], dtype=pl.Int8),
1592
"a_right": pl.Series([-128], dtype=pl.Int64),
1593
"b": "A",
1594
}
1595
),
1596
)
1597
1598
assert_frame_equal(
1599
(
1600
left.join(right, left_on=left_expr, right_on="a", how="full")
1601
.collect()
1602
.sort(pl.all())
1603
),
1604
pl.DataFrame(
1605
{
1606
"a": pl.Series([1, None, None], dtype=pl.Int8),
1607
"a_right": pl.Series([-128, 1, 128], dtype=pl.Int64),
1608
"b": ["A", "A", "A"],
1609
}
1610
).sort(pl.all()),
1611
)
1612
1613
1614
def test_no_collapse_join_when_maintain_order_20725() -> None:
1615
df1 = pl.LazyFrame({"Fraction_1": [0, 25, 50, 75, 100]})
1616
df2 = pl.LazyFrame({"Fraction_2": [0, 1]})
1617
df3 = pl.LazyFrame({"Fraction_3": [0, 1]})
1618
1619
ldf = df1.join(df2, how="cross", maintain_order="left_right").join(
1620
df3, how="cross", maintain_order="left_right"
1621
)
1622
1623
df_pl_lazy = ldf.filter(pl.col("Fraction_1") == 100).collect()
1624
df_pl_eager = ldf.collect().filter(pl.col("Fraction_1") == 100)
1625
1626
assert_frame_equal(df_pl_lazy, df_pl_eager)
1627
1628
1629
def test_join_where_predicate_type_coercion_21009() -> None:
1630
left_frame = pl.LazyFrame(
1631
{
1632
"left_match": ["A", "B", "C", "D", "E", "F"],
1633
"left_date_start": range(6),
1634
}
1635
)
1636
1637
right_frame = pl.LazyFrame(
1638
{
1639
"right_match": ["D", "E", "F", "G", "H", "I"],
1640
"right_date": range(6),
1641
}
1642
)
1643
1644
# Note: Cannot eq the plans as the operand sides are non-deterministic
1645
1646
q1 = left_frame.join_where(
1647
right_frame,
1648
pl.col("left_match") == pl.col("right_match"),
1649
pl.col("right_date") >= pl.col("left_date_start"),
1650
)
1651
1652
plan = q1.explain().splitlines()
1653
assert plan[0].strip().startswith("FILTER")
1654
assert plan[1] == "FROM"
1655
assert plan[2].strip().startswith("INNER JOIN")
1656
1657
q2 = left_frame.join_where(
1658
right_frame,
1659
pl.all_horizontal(pl.col("left_match") == pl.col("right_match")),
1660
pl.col("right_date") >= pl.col("left_date_start"),
1661
)
1662
1663
plan = q2.explain().splitlines()
1664
assert plan[0].strip().startswith("FILTER")
1665
assert plan[1] == "FROM"
1666
assert plan[2].strip().startswith("INNER JOIN")
1667
1668
assert_frame_equal(q1.collect(), q2.collect())
1669
1670
1671
def test_join_right_predicate_pushdown_21142() -> None:
1672
left = pl.LazyFrame({"key": [1, 2, 4], "values": ["a", "b", "c"]})
1673
right = pl.LazyFrame({"key": [1, 2, 3], "values": ["d", "e", "f"]})
1674
1675
rjoin = left.join(right, on="key", how="right")
1676
1677
q = rjoin.filter(pl.col("values").is_null())
1678
1679
expect = pl.select(
1680
pl.Series("values", [None], pl.String),
1681
pl.Series("key", [3], pl.Int64),
1682
pl.Series("values_right", ["f"], pl.String),
1683
)
1684
1685
assert_frame_equal(q.collect(), expect)
1686
1687
# Ensure for right join, filter on RHS key-columns are pushed down.
1688
q = rjoin.filter(pl.col("values_right").is_null())
1689
1690
plan = q.explain()
1691
assert plan.index("FILTER") > plan.index("RIGHT PLAN ON")
1692
1693
assert_frame_equal(q.collect(), expect.clear())
1694
1695
1696
def test_join_where_nested_expr_21066() -> None:
1697
left = pl.LazyFrame({"a": [1, 2]})
1698
right = pl.LazyFrame({"a": [1]})
1699
1700
q = left.join_where(right, pl.col("a") == (pl.col("a_right") + 1))
1701
1702
assert_frame_equal(q.collect(), pl.DataFrame({"a": 2, "a_right": 1}))
1703
1704
1705
def test_select_after_join_where_20831() -> None:
1706
left = pl.LazyFrame(
1707
{
1708
"a": [1, 2, 3, 1, None],
1709
"b": [1, 2, 3, 4, 5],
1710
"c": [2, 3, 4, 5, 6],
1711
}
1712
)
1713
1714
right = pl.LazyFrame(
1715
{
1716
"a": [1, 4, 3, 7, None, None, 1],
1717
"c": [2, 3, 4, 5, 6, 7, 8],
1718
"d": [6, None, 7, 8, -1, 2, 4],
1719
}
1720
)
1721
1722
q = left.join_where(
1723
right, pl.col("b") * 2 <= pl.col("a_right"), pl.col("a") < pl.col("c_right")
1724
)
1725
1726
assert_frame_equal(
1727
q.select("d").collect().sort("d"),
1728
pl.Series("d", [None, None, 7, 8, 8, 8]).to_frame(),
1729
)
1730
1731
assert q.select(pl.len()).collect().item() == 6
1732
1733
q = (
1734
left.join(right, how="cross")
1735
.filter(pl.col("b") * 2 <= pl.col("a_right"))
1736
.filter(pl.col("a") < pl.col("c_right"))
1737
)
1738
1739
assert_frame_equal(
1740
q.select("d").collect().sort("d"),
1741
pl.Series("d", [None, None, 7, 8, 8, 8]).to_frame(),
1742
)
1743
1744
assert q.select(pl.len()).collect().item() == 6
1745
1746
1747
@pytest.mark.parametrize(
1748
("dtype", "data"),
1749
[
1750
(pl.Struct, [{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4}]),
1751
(pl.List, [[1], [2, 2], [3, 3, 3], [4, 4, 4, 4]]),
1752
(pl.Array(pl.Int64, 2), [[1, 1], [2, 2], [3, 3], [4, 4]]),
1753
],
1754
)
1755
def test_join_on_nested(dtype: PolarsDataType, data: list[Any]) -> None:
1756
lhs = pl.DataFrame(
1757
{
1758
"a": data[:3],
1759
"b": [1, 2, 3],
1760
}
1761
)
1762
rhs = pl.DataFrame(
1763
{
1764
"a": [data[3], data[1]],
1765
"c": [4, 2],
1766
}
1767
)
1768
1769
assert_frame_equal(
1770
lhs.join(rhs, on="a", how="left", maintain_order="left"),
1771
pl.select(
1772
a=pl.Series(data[:3]),
1773
b=pl.Series([1, 2, 3]),
1774
c=pl.Series([None, 2, None]),
1775
),
1776
)
1777
assert_frame_equal(
1778
lhs.join(rhs, on="a", how="right", maintain_order="right"),
1779
pl.select(
1780
b=pl.Series([None, 2]),
1781
a=pl.Series([data[3], data[1]]),
1782
c=pl.Series([4, 2]),
1783
),
1784
)
1785
assert_frame_equal(
1786
lhs.join(rhs, on="a", how="inner"),
1787
pl.select(
1788
a=pl.Series([data[1]]),
1789
b=pl.Series([2]),
1790
c=pl.Series([2]),
1791
),
1792
)
1793
assert_frame_equal(
1794
lhs.join(rhs, on="a", how="full", maintain_order="left_right"),
1795
pl.select(
1796
a=pl.Series(data[:3] + [None]),
1797
b=pl.Series([1, 2, 3, None]),
1798
a_right=pl.Series([None, data[1], None, data[3]]),
1799
c=pl.Series([None, 2, None, 4]),
1800
),
1801
)
1802
assert_frame_equal(
1803
lhs.join(rhs, on="a", how="semi"),
1804
pl.select(
1805
a=pl.Series([data[1]]),
1806
b=pl.Series([2]),
1807
),
1808
)
1809
assert_frame_equal(
1810
lhs.join(rhs, on="a", how="anti", maintain_order="left"),
1811
pl.select(
1812
a=pl.Series([data[0], data[2]]),
1813
b=pl.Series([1, 3]),
1814
),
1815
)
1816
assert_frame_equal(
1817
lhs.join(rhs, how="cross", maintain_order="left_right"),
1818
pl.select(
1819
a=pl.Series([data[0], data[0], data[1], data[1], data[2], data[2]]),
1820
b=pl.Series([1, 1, 2, 2, 3, 3]),
1821
a_right=pl.Series([data[3], data[1], data[3], data[1], data[3], data[1]]),
1822
c=pl.Series([4, 2, 4, 2, 4, 2]),
1823
),
1824
)
1825
1826
1827
def test_empty_join_result_with_array_15474() -> None:
1828
lhs = pl.DataFrame(
1829
{
1830
"x": [1, 2],
1831
"y": pl.Series([[1, 2, 3], [4, 5, 6]], dtype=pl.Array(pl.Int64, 3)),
1832
}
1833
)
1834
rhs = pl.DataFrame({"x": [0]})
1835
result = lhs.join(rhs, on="x")
1836
expected = pl.DataFrame(schema={"x": pl.Int64, "y": pl.Array(pl.Int64, 3)})
1837
assert_frame_equal(result, expected)
1838
1839
1840
@pytest.mark.slow
1841
def test_join_where_eager_perf_21145() -> None:
1842
left = pl.Series("left", range(3_000)).to_frame()
1843
right = pl.Series("right", range(1_000)).to_frame()
1844
1845
p = pl.col("left").is_between(pl.lit(0, dtype=pl.Int64), pl.col("right"))
1846
runtime_eager = time_func(lambda: left.join_where(right, p))
1847
runtime_lazy = time_func(lambda: left.lazy().join_where(right.lazy(), p).collect())
1848
runtime_ratio = runtime_eager / runtime_lazy
1849
1850
# Pick as high as reasonably possible for CI stability
1851
# * Was observed to be >=5 seconds on the bugged version, so 3 is a safe bet.
1852
threshold = 3
1853
1854
if runtime_ratio > threshold:
1855
msg = f"runtime_ratio ({runtime_ratio}) > {threshold}x ({runtime_eager = }, {runtime_lazy = })"
1856
raise ValueError(msg)
1857
1858
1859
def test_select_len_after_semi_anti_join_21343() -> None:
1860
lhs = pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
1861
rhs = pl.LazyFrame({"a": [1, 2, 3]})
1862
1863
q = lhs.join(rhs, on="a", how="anti").select(pl.len())
1864
1865
assert q.collect().item() == 0
1866
1867
1868
def test_multi_leftjoin_empty_right_21701() -> None:
1869
parent_data = {
1870
"id": [1, 30, 80],
1871
"parent_field1": [3, 20, 17],
1872
}
1873
parent_df = pl.LazyFrame(parent_data)
1874
child_df = pl.LazyFrame(
1875
[],
1876
schema={"id": pl.Int32(), "parent_id": pl.Int32(), "child_field1": pl.Int32()},
1877
)
1878
subchild_df = pl.LazyFrame(
1879
[], schema={"child_id": pl.Int32(), "subchild_field1": pl.Int32()}
1880
)
1881
1882
joined_df = parent_df.join(
1883
child_df.join(
1884
subchild_df, left_on=pl.col("id"), right_on=pl.col("child_id"), how="left"
1885
),
1886
left_on=pl.col("id"),
1887
right_on=pl.col("parent_id"),
1888
how="left",
1889
)
1890
joined_df = joined_df.select("id", "parent_field1")
1891
assert_frame_equal(joined_df.collect(), parent_df.collect(), check_row_order=False)
1892
1893
1894
@pytest.mark.parametrize("order", ["none", "left_right", "right_left"])
1895
def test_join_null_equal(order: Literal["none", "left_right", "right_left"]) -> None:
1896
lhs = pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3]})
1897
with_null = pl.DataFrame({"x": [1, None], "z": [1, 2]})
1898
without_null = pl.DataFrame({"x": [1, 3], "z": [1, 3]})
1899
check_row_order = order != "none"
1900
1901
# Inner join.
1902
assert_frame_equal(
1903
lhs.join(with_null, on="x", nulls_equal=True, maintain_order=order),
1904
pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3], "z": [1, 2, 2]}),
1905
check_row_order=check_row_order,
1906
)
1907
assert_frame_equal(
1908
lhs.join(without_null, on="x", nulls_equal=True),
1909
pl.DataFrame({"x": [1], "y": [1], "z": [1]}),
1910
)
1911
1912
# Left join.
1913
assert_frame_equal(
1914
lhs.join(with_null, on="x", how="left", nulls_equal=True, maintain_order=order),
1915
pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3], "z": [1, 2, 2]}),
1916
check_row_order=check_row_order,
1917
)
1918
assert_frame_equal(
1919
lhs.join(
1920
without_null, on="x", how="left", nulls_equal=True, maintain_order=order
1921
),
1922
pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3], "z": [1, None, None]}),
1923
check_row_order=check_row_order,
1924
)
1925
1926
# Full join.
1927
assert_frame_equal(
1928
lhs.join(
1929
with_null,
1930
on="x",
1931
how="full",
1932
nulls_equal=True,
1933
coalesce=True,
1934
maintain_order=order,
1935
),
1936
pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3], "z": [1, 2, 2]}),
1937
check_row_order=check_row_order,
1938
)
1939
if order == "left_right":
1940
expected = pl.DataFrame(
1941
{
1942
"x": [1, None, None, None],
1943
"x_right": [1, None, None, 3],
1944
"y": [1, 2, 3, None],
1945
"z": [1, None, None, 3],
1946
}
1947
)
1948
else:
1949
expected = pl.DataFrame(
1950
{
1951
"x": [1, None, None, None],
1952
"x_right": [1, 3, None, None],
1953
"y": [1, None, 2, 3],
1954
"z": [1, 3, None, None],
1955
}
1956
)
1957
assert_frame_equal(
1958
lhs.join(
1959
without_null, on="x", how="full", nulls_equal=True, maintain_order=order
1960
),
1961
expected,
1962
check_row_order=check_row_order,
1963
check_column_order=False,
1964
)
1965
1966
1967
def test_join_categorical_21815() -> None:
1968
left = pl.DataFrame({"x": ["a", "b", "c", "d"]}).with_columns(
1969
xc=pl.col.x.cast(pl.Categorical)
1970
)
1971
right = pl.DataFrame({"x": ["c", "d", "e", "f"]}).with_columns(
1972
xc=pl.col.x.cast(pl.Categorical)
1973
)
1974
1975
# As key.
1976
cat_key = left.join(right, on="xc", how="full")
1977
1978
# As payload.
1979
cat_payload = left.join(right, on="x", how="full")
1980
1981
expected = pl.DataFrame(
1982
{
1983
"x": ["a", "b", "c", "d", None, None],
1984
"x_right": [None, None, "c", "d", "e", "f"],
1985
}
1986
).with_columns(
1987
xc=pl.col.x.cast(pl.Categorical),
1988
xc_right=pl.col.x_right.cast(pl.Categorical),
1989
)
1990
1991
assert_frame_equal(
1992
cat_key, expected, check_row_order=False, check_column_order=False
1993
)
1994
assert_frame_equal(
1995
cat_payload, expected, check_row_order=False, check_column_order=False
1996
)
1997
1998
1999
def test_join_where_nested_boolean() -> None:
2000
df1 = pl.DataFrame({"a": [1, 9, 22], "b": [6, 4, 50]})
2001
df2 = pl.DataFrame({"c": [1]})
2002
2003
predicate = (pl.col("a") < pl.col("b")).cast(pl.Int32) < pl.col("c")
2004
result = df1.join_where(df2, predicate)
2005
expected = pl.DataFrame(
2006
{
2007
"a": [9],
2008
"b": [4],
2009
"c": [1],
2010
}
2011
)
2012
assert_frame_equal(result, expected)
2013
2014
2015
def test_join_where_dtype_upcast() -> None:
2016
df1 = pl.DataFrame(
2017
{
2018
"a": pl.Series([1, 9, 22], dtype=pl.Int8),
2019
"b": [6, 4, 50],
2020
}
2021
)
2022
df2 = pl.DataFrame({"c": [10]})
2023
2024
predicate = (pl.col("a") + (pl.col("b") > 0)) < pl.col("c")
2025
result = df1.join_where(df2, predicate)
2026
expected = pl.DataFrame(
2027
{
2028
"a": pl.Series([1], dtype=pl.Int8),
2029
"b": [6],
2030
"c": [10],
2031
}
2032
)
2033
assert_frame_equal(result, expected)
2034
2035
2036
def test_join_where_valid_dtype_upcast_same_side() -> None:
2037
# Unsafe comparisons are all contained entirely within one table (LHS)
2038
# Safe comparisons across both tables.
2039
df1 = pl.DataFrame(
2040
{
2041
"a": pl.Series([1, 9, 22], dtype=pl.Float32),
2042
"b": [6, 4, 50],
2043
}
2044
)
2045
df2 = pl.DataFrame({"c": [10, 1, 5]})
2046
2047
predicate = ((pl.col("a") < pl.col("b")).cast(pl.Int32) + 3) < pl.col("c")
2048
result = df1.join_where(df2, predicate).sort("a", "b", "c")
2049
expected = pl.DataFrame(
2050
{
2051
"a": pl.Series([1, 1, 9, 9, 22, 22], dtype=pl.Float32),
2052
"b": [6, 6, 4, 4, 50, 50],
2053
"c": [5, 10, 5, 10, 5, 10],
2054
}
2055
)
2056
assert_frame_equal(result, expected)
2057
2058
2059
def test_join_where_invalid_dtype_upcast_different_side() -> None:
2060
# Unsafe comparisons exist across tables.
2061
df1 = pl.DataFrame(
2062
{
2063
"a": pl.Series([1, 9, 22], dtype=pl.Float32),
2064
"b": pl.Series([6, 4, 50], dtype=pl.Float64),
2065
}
2066
)
2067
df2 = pl.DataFrame({"c": [10, 1, 5]})
2068
2069
predicate = ((pl.col("a") >= pl.col("c")) + 3) < 4
2070
with pytest.raises(
2071
SchemaError, match="'join_where' cannot compare Float32 with Int64"
2072
):
2073
df1.join_where(df2, predicate)
2074
2075
# add in a cast to predicate to fix
2076
predicate = ((pl.col("a").cast(pl.UInt8) >= pl.col("c")) + 3) < 4
2077
result = df1.join_where(df2, predicate).sort("a", "b", "c")
2078
expected = pl.DataFrame(
2079
{
2080
"a": pl.Series([1, 1, 9], dtype=pl.Float32),
2081
"b": pl.Series([6, 6, 4], dtype=pl.Float64),
2082
"c": [5, 10, 10],
2083
}
2084
)
2085
assert_frame_equal(result, expected)
2086
2087
2088
@pytest.mark.parametrize("dtype", [pl.Int32, pl.Float32])
2089
def test_join_where_literals(dtype: PolarsDataType) -> None:
2090
df1 = pl.DataFrame({"a": pl.Series([0, 1], dtype=dtype)})
2091
df2 = pl.DataFrame({"b": pl.Series([1, 2], dtype=dtype)})
2092
result = df1.join_where(df2, (pl.col("a") + pl.col("b")) < 2)
2093
expected = pl.DataFrame(
2094
{
2095
"a": pl.Series([0], dtype=dtype),
2096
"b": pl.Series([1], dtype=dtype),
2097
}
2098
)
2099
assert_frame_equal(result, expected)
2100
2101
2102
def test_join_where_categorical_string_compare() -> None:
2103
dt = pl.Enum(["a", "b", "c"])
2104
df1 = pl.DataFrame({"a": pl.Series(["a", "a", "b", "c"], dtype=dt)})
2105
df2 = pl.DataFrame({"b": [1, 6, 4]})
2106
predicate = pl.col("a").is_in(["a", "b"]) & (pl.col("b") < 5)
2107
result = df1.join_where(df2, predicate).sort("a", "b")
2108
expected = pl.DataFrame(
2109
{
2110
"a": pl.Series(["a", "a", "a", "a", "b", "b"], dtype=dt),
2111
"b": [1, 1, 4, 4, 1, 4],
2112
}
2113
)
2114
assert_frame_equal(result, expected)
2115
2116
2117
def test_join_where_nonboolean_predicate() -> None:
2118
df1 = pl.DataFrame({"a": [1, 2, 3]})
2119
df2 = pl.DataFrame({"b": [1, 2, 3]})
2120
with pytest.raises(
2121
ComputeError, match="'join_where' predicates must resolve to boolean"
2122
):
2123
df1.join_where(df2, pl.col("a") * 2)
2124
2125
2126
def test_empty_outer_join_22206() -> None:
2127
df = pl.LazyFrame({"a": [5, 6], "b": [1, 2]})
2128
empty = pl.LazyFrame(schema=df.collect_schema())
2129
assert_frame_equal(
2130
df.join(empty, on=["a", "b"], how="full", coalesce=True),
2131
df,
2132
check_row_order=False,
2133
)
2134
assert_frame_equal(
2135
empty.join(df, on=["a", "b"], how="full", coalesce=True),
2136
df,
2137
check_row_order=False,
2138
)
2139
2140
2141
def test_join_coalesce_22498() -> None:
2142
df_a = pl.DataFrame({"y": [2]})
2143
df_b = pl.DataFrame({"x": [1], "y": [2]})
2144
df_j = df_a.lazy().join(df_b.lazy(), how="full", on="y", coalesce=True)
2145
assert_frame_equal(df_j.collect(), pl.DataFrame({"y": [2], "x": [1]}))
2146
2147
2148
def _extract_plan_joins_and_filters(plan: str) -> list[str]:
2149
return [
2150
x
2151
for x in (x.strip() for x in plan.splitlines())
2152
if x.startswith("LEFT PLAN") # noqa: PIE810
2153
or x.startswith("RIGHT PLAN")
2154
or x.startswith("FILTER")
2155
]
2156
2157
2158
def test_join_filter_pushdown_inner_join() -> None:
2159
lhs = pl.LazyFrame(
2160
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
2161
)
2162
rhs = pl.LazyFrame(
2163
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}
2164
)
2165
2166
# Filter on key output column is pushed to both sides.
2167
q = lhs.join(rhs, on=["a", "b"], how="inner", maintain_order="left_right").filter(
2168
pl.col("b") <= 2
2169
)
2170
2171
expect = pl.DataFrame(
2172
{"a": [1, 2], "b": [1, 2], "c": ["a", "b"], "c_right": ["A", "B"]}
2173
)
2174
2175
plan = q.explain()
2176
2177
assert _extract_plan_joins_and_filters(plan) == [
2178
'LEFT PLAN ON: [col("a"), col("b")]',
2179
'FILTER [(col("b")) <= (2)]',
2180
'RIGHT PLAN ON: [col("a"), col("b")]',
2181
'FILTER [(col("b")) <= (2)]',
2182
]
2183
2184
assert_frame_equal(q.collect(), expect)
2185
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2186
2187
# Side-specific filters are all pushed for inner join.
2188
q = (
2189
lhs.join(rhs, on=["a", "b"], how="inner", maintain_order="left_right")
2190
.filter(pl.col("b") <= 2)
2191
.filter(pl.col("c") == "a", pl.col("c_right") == "A")
2192
)
2193
2194
expect = pl.DataFrame({"a": [1], "b": [1], "c": ["a"], "c_right": ["A"]})
2195
2196
plan = q.explain()
2197
2198
extract = _extract_plan_joins_and_filters(plan)
2199
2200
assert extract[0] == 'LEFT PLAN ON: [col("a"), col("b")]'
2201
assert 'col("c")) == ("a")' in extract[1]
2202
assert 'col("b")) <= (2)' in extract[1]
2203
2204
assert extract[2] == 'RIGHT PLAN ON: [col("a"), col("b")]'
2205
assert 'col("b")) <= (2)' in extract[3]
2206
assert 'col("c")) == ("A")' in extract[3]
2207
2208
assert_frame_equal(q.collect(), expect)
2209
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2210
2211
# Filter applied to the non-coalesced `_right` column of an inner-join is
2212
# also pushed to the left
2213
# input table.
2214
q = lhs.join(
2215
rhs, on=["a", "b"], how="inner", coalesce=False, maintain_order="left_right"
2216
).filter(pl.col("a_right") <= 2)
2217
2218
expect = pl.DataFrame(
2219
{
2220
"a": [1, 2],
2221
"b": [1, 2],
2222
"c": ["a", "b"],
2223
"a_right": [1, 2],
2224
"b_right": [1, 2],
2225
"c_right": ["A", "B"],
2226
}
2227
)
2228
2229
plan = q.explain()
2230
2231
extract = _extract_plan_joins_and_filters(plan)
2232
assert extract == [
2233
'LEFT PLAN ON: [col("a"), col("b")]',
2234
'FILTER [(col("a")) <= (2)]',
2235
'RIGHT PLAN ON: [col("a"), col("b")]',
2236
'FILTER [(col("a")) <= (2)]',
2237
]
2238
2239
assert_frame_equal(q.collect(), expect)
2240
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2241
2242
# Different names in left_on and right_on
2243
q = lhs.join(
2244
rhs, left_on="a", right_on="b", how="inner", maintain_order="left_right"
2245
).filter(pl.col("a") <= 2)
2246
2247
expect = pl.DataFrame(
2248
{
2249
"a": [1, 2],
2250
"b": [1, 2],
2251
"c": ["a", "b"],
2252
"a_right": [1, 2],
2253
"c_right": ["A", "B"],
2254
}
2255
)
2256
2257
plan = q.explain()
2258
2259
extract = _extract_plan_joins_and_filters(plan)
2260
assert extract == [
2261
'LEFT PLAN ON: [col("a")]',
2262
'FILTER [(col("a")) <= (2)]',
2263
'RIGHT PLAN ON: [col("b")]',
2264
'FILTER [(col("b")) <= (2)]',
2265
]
2266
2267
assert_frame_equal(q.collect(), expect)
2268
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2269
2270
# Different names in left_on and right_on, coalesce=False
2271
q = lhs.join(
2272
rhs,
2273
left_on="a",
2274
right_on="b",
2275
how="inner",
2276
coalesce=False,
2277
maintain_order="left_right",
2278
).filter(pl.col("a") <= 2)
2279
2280
expect = pl.DataFrame(
2281
{
2282
"a": [1, 2],
2283
"b": [1, 2],
2284
"c": ["a", "b"],
2285
"a_right": [1, 2],
2286
"b_right": [1, 2],
2287
"c_right": ["A", "B"],
2288
}
2289
)
2290
2291
plan = q.explain()
2292
2293
extract = _extract_plan_joins_and_filters(plan)
2294
assert extract == [
2295
'LEFT PLAN ON: [col("a")]',
2296
'FILTER [(col("a")) <= (2)]',
2297
'RIGHT PLAN ON: [col("b")]',
2298
'FILTER [(col("b")) <= (2)]',
2299
]
2300
2301
assert_frame_equal(q.collect(), expect)
2302
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2303
2304
# left_on=col(A), right_on=lit(1). Filters referencing col(A) can only push
2305
# to the left side.
2306
q = lhs.join(
2307
rhs,
2308
left_on=["a", pl.lit(1)],
2309
right_on=[pl.lit(1), "b"],
2310
how="inner",
2311
coalesce=False,
2312
maintain_order="left_right",
2313
).filter(
2314
pl.col("a") == 1,
2315
pl.col("b") >= 1,
2316
pl.col("a_right") <= 1,
2317
pl.col("b_right") >= 0,
2318
)
2319
2320
expect = pl.DataFrame(
2321
{
2322
"a": [1],
2323
"b": [1],
2324
"c": ["a"],
2325
"a_right": [1],
2326
"b_right": [1],
2327
"c_right": ["A"],
2328
}
2329
)
2330
2331
plan = q.explain()
2332
2333
extract = _extract_plan_joins_and_filters(plan)
2334
2335
assert (
2336
extract[0]
2337
== 'LEFT PLAN ON: [col("a").cast(Int64), col("_POLARS_0").cast(Int64)]'
2338
)
2339
assert '(col("a")) == (1)' in extract[1]
2340
assert '(col("b")) >= (1)' in extract[1]
2341
assert (
2342
extract[2]
2343
== 'RIGHT PLAN ON: [col("_POLARS_1").cast(Int64), col("b").cast(Int64)]'
2344
)
2345
assert '(col("b")) >= (0)' in extract[3]
2346
assert 'col("a")) <= (1)' in extract[3]
2347
2348
assert_frame_equal(q.collect(), expect)
2349
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2350
2351
# Filters don't pass if they refer to columns from both tables
2352
# TODO: In the optimizer we can add additional equalities into the join
2353
# condition itself for some cases.
2354
q = lhs.join(rhs, on=["a"], how="inner", maintain_order="left_right").filter(
2355
pl.col("b") == pl.col("b_right")
2356
)
2357
2358
expect = pl.DataFrame(
2359
{
2360
"a": [1, 2, 3],
2361
"b": [1, 2, 3],
2362
"c": ["a", "b", "c"],
2363
"b_right": [1, 2, 3],
2364
"c_right": ["A", "B", "C"],
2365
}
2366
)
2367
2368
plan = q.explain()
2369
2370
extract = _extract_plan_joins_and_filters(plan)
2371
assert extract == [
2372
'FILTER [(col("b")) == (col("b_right"))]',
2373
'LEFT PLAN ON: [col("a")]',
2374
'RIGHT PLAN ON: [col("a")]',
2375
]
2376
2377
assert_frame_equal(q.collect(), expect)
2378
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2379
2380
# Duplicate filter removal - https://github.com/pola-rs/polars/issues/23243
2381
q = (
2382
pl.LazyFrame({"x": [1, 2, 3]})
2383
.join(pl.LazyFrame({"x": [1, 2, 3]}), on="x", how="inner", coalesce=False)
2384
.filter(
2385
pl.col("x") == 2,
2386
pl.col("x_right") == 2,
2387
)
2388
)
2389
2390
expect = pl.DataFrame(
2391
[
2392
pl.Series("x", [2], dtype=pl.Int64),
2393
pl.Series("x_right", [2], dtype=pl.Int64),
2394
]
2395
)
2396
2397
plan = q.explain()
2398
2399
extract = _extract_plan_joins_and_filters(plan)
2400
2401
assert extract == [
2402
'LEFT PLAN ON: [col("x")]',
2403
'FILTER [(col("x")) == (2)]',
2404
'RIGHT PLAN ON: [col("x")]',
2405
'FILTER [(col("x")) == (2)]',
2406
]
2407
2408
assert_frame_equal(q.collect(), expect)
2409
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2410
2411
2412
def test_join_filter_pushdown_left_join() -> None:
2413
lhs = pl.LazyFrame(
2414
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
2415
)
2416
rhs = pl.LazyFrame(
2417
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}
2418
)
2419
2420
# Filter on key output column is pushed to both sides.
2421
q = lhs.join(rhs, on=["a", "b"], how="left", maintain_order="left_right").filter(
2422
pl.col("b") <= 2
2423
)
2424
2425
expect = pl.DataFrame(
2426
{"a": [1, 2], "b": [1, 2], "c": ["a", "b"], "c_right": ["A", "B"]}
2427
)
2428
2429
plan = q.explain()
2430
2431
extract = _extract_plan_joins_and_filters(plan)
2432
assert extract == [
2433
'LEFT PLAN ON: [col("a"), col("b")]',
2434
'FILTER [(col("b")) <= (2)]',
2435
'RIGHT PLAN ON: [col("a"), col("b")]',
2436
'FILTER [(col("b")) <= (2)]',
2437
]
2438
2439
assert_frame_equal(q.collect(), expect)
2440
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2441
2442
# Filter on key output column is pushed to both sides.
2443
# This tests joins on differing left/right names.
2444
q = lhs.join(
2445
rhs, left_on="a", right_on="b", how="left", maintain_order="left_right"
2446
).filter(pl.col("a") <= 2)
2447
2448
expect = pl.DataFrame(
2449
{
2450
"a": [1, 2],
2451
"b": [1, 2],
2452
"c": ["a", "b"],
2453
"a_right": [1, 2],
2454
"c_right": ["A", "B"],
2455
}
2456
)
2457
2458
plan = q.explain()
2459
2460
extract = _extract_plan_joins_and_filters(plan)
2461
assert extract == [
2462
'LEFT PLAN ON: [col("a")]',
2463
'FILTER [(col("a")) <= (2)]',
2464
'RIGHT PLAN ON: [col("b")]',
2465
'FILTER [(col("b")) <= (2)]',
2466
]
2467
2468
assert_frame_equal(q.collect(), expect)
2469
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2470
2471
# Filters referring to columns that exist only in the left table can be pushed.
2472
q = lhs.join(rhs, on=["a", "b"], how="left", maintain_order="left_right").filter(
2473
pl.col("c") == "b"
2474
)
2475
2476
expect = pl.DataFrame({"a": [2], "b": [2], "c": ["b"], "c_right": ["B"]})
2477
2478
plan = q.explain()
2479
2480
extract = _extract_plan_joins_and_filters(plan)
2481
assert extract == [
2482
'LEFT PLAN ON: [col("a"), col("b")]',
2483
'FILTER [(col("c")) == ("b")]',
2484
'RIGHT PLAN ON: [col("a"), col("b")]',
2485
]
2486
2487
assert_frame_equal(q.collect(), expect)
2488
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2489
2490
# Filters referring to columns that exist only in the right table cannot be
2491
# pushed for left-join
2492
q = lhs.join(rhs, on=["a", "b"], how="left", maintain_order="left_right").filter(
2493
# Note: `eq_missing` to block join downgrade.
2494
pl.col("c_right").eq_missing("B")
2495
)
2496
2497
expect = pl.DataFrame({"a": [2], "b": [2], "c": ["b"], "c_right": ["B"]})
2498
2499
plan = q.explain()
2500
2501
extract = _extract_plan_joins_and_filters(plan)
2502
assert extract == [
2503
'FILTER [(col("c_right")) ==v ("B")]',
2504
'LEFT PLAN ON: [col("a"), col("b")]',
2505
'RIGHT PLAN ON: [col("a"), col("b")]',
2506
]
2507
2508
assert_frame_equal(q.collect(), expect)
2509
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2510
2511
# Filters referring to a non-coalesced key column originating from the right
2512
# table cannot be pushed.
2513
#
2514
# Note, technically it's possible to push these filters if we can guarantee that
2515
# they do not remove NULLs (or otherwise if we also apply the filter on the
2516
# result table). But this is not something we do at the moment.
2517
q = lhs.join(
2518
rhs, on=["a", "b"], how="left", coalesce=False, maintain_order="left_right"
2519
).filter(pl.col("b_right").eq_missing(2))
2520
2521
expect = pl.DataFrame(
2522
{
2523
"a": [2],
2524
"b": [2],
2525
"c": ["b"],
2526
"a_right": [2],
2527
"b_right": [2],
2528
"c_right": ["B"],
2529
}
2530
)
2531
2532
plan = q.explain()
2533
2534
extract = _extract_plan_joins_and_filters(plan)
2535
assert extract == [
2536
'FILTER [(col("b_right")) ==v (2)]',
2537
'LEFT PLAN ON: [col("a"), col("b")]',
2538
'RIGHT PLAN ON: [col("a"), col("b")]',
2539
]
2540
2541
assert_frame_equal(q.collect(), expect)
2542
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2543
2544
2545
def test_join_filter_pushdown_right_join() -> None:
2546
lhs = pl.LazyFrame(
2547
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
2548
)
2549
rhs = pl.LazyFrame(
2550
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}
2551
)
2552
2553
# Filter on key output column is pushed to both sides.
2554
q = lhs.join(rhs, on=["a", "b"], how="right", maintain_order="left_right").filter(
2555
pl.col("b") <= 2
2556
)
2557
2558
expect = pl.DataFrame(
2559
{"c": ["a", "b"], "a": [1, 2], "b": [1, 2], "c_right": ["A", "B"]}
2560
)
2561
2562
plan = q.explain()
2563
2564
extract = _extract_plan_joins_and_filters(plan)
2565
assert extract == [
2566
'LEFT PLAN ON: [col("a"), col("b")]',
2567
'FILTER [(col("b")) <= (2)]',
2568
'RIGHT PLAN ON: [col("a"), col("b")]',
2569
'FILTER [(col("b")) <= (2)]',
2570
]
2571
2572
assert_frame_equal(q.collect(), expect)
2573
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2574
2575
# Filter on key output column is pushed to both sides.
2576
# This tests joins on differing left/right names.
2577
# col(A) is coalesced into col(B) (from right), but col(B) is named as
2578
# col(B_right) in the output because the LHS table also has a col(B).
2579
q = lhs.join(
2580
rhs, left_on="a", right_on="b", how="right", maintain_order="left_right"
2581
).filter(pl.col("b_right") <= 2)
2582
2583
expect = pl.DataFrame(
2584
{
2585
"b": [1, 2],
2586
"c": ["a", "b"],
2587
"a": [1, 2],
2588
"b_right": [1, 2],
2589
"c_right": ["A", "B"],
2590
}
2591
)
2592
2593
plan = q.explain()
2594
2595
extract = _extract_plan_joins_and_filters(plan)
2596
assert extract == [
2597
'LEFT PLAN ON: [col("a")]',
2598
'FILTER [(col("a")) <= (2)]',
2599
'RIGHT PLAN ON: [col("b")]',
2600
'FILTER [(col("b")) <= (2)]',
2601
]
2602
2603
assert_frame_equal(q.collect(), expect)
2604
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2605
2606
# Filters referring to columns that exist only in the right table can be pushed.
2607
q = lhs.join(rhs, on=["a", "b"], how="right", maintain_order="left_right").filter(
2608
pl.col("c_right") == "B"
2609
)
2610
2611
expect = pl.DataFrame({"c": ["b"], "a": [2], "b": [2], "c_right": ["B"]})
2612
2613
plan = q.explain()
2614
2615
extract = _extract_plan_joins_and_filters(plan)
2616
assert extract == [
2617
'LEFT PLAN ON: [col("a"), col("b")]',
2618
'RIGHT PLAN ON: [col("a"), col("b")]',
2619
'FILTER [(col("c")) == ("B")]',
2620
]
2621
2622
assert_frame_equal(q.collect(), expect)
2623
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2624
2625
# Filters referring to columns that exist only in the left table cannot be
2626
# pushed for right-join
2627
q = lhs.join(rhs, on=["a", "b"], how="right", maintain_order="left_right").filter(
2628
# Note: eq_missing to block join downgrade
2629
pl.col("c").eq_missing("b")
2630
)
2631
2632
expect = pl.DataFrame({"c": ["b"], "a": [2], "b": [2], "c_right": ["B"]})
2633
2634
plan = q.explain()
2635
2636
extract = _extract_plan_joins_and_filters(plan)
2637
assert extract == [
2638
'FILTER [(col("c")) ==v ("b")]',
2639
'LEFT PLAN ON: [col("a"), col("b")]',
2640
'RIGHT PLAN ON: [col("a"), col("b")]',
2641
]
2642
2643
assert_frame_equal(q.collect(), expect)
2644
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2645
2646
# Filters referring to a non-coalesced key column originating from the left
2647
# table cannot be pushed for right-join.
2648
q = lhs.join(
2649
rhs, on=["a", "b"], how="right", coalesce=False, maintain_order="left_right"
2650
).filter(pl.col("b").eq_missing(2))
2651
2652
expect = pl.DataFrame(
2653
{
2654
"a": [2],
2655
"b": [2],
2656
"c": ["b"],
2657
"a_right": [2],
2658
"b_right": [2],
2659
"c_right": ["B"],
2660
}
2661
)
2662
2663
plan = q.explain()
2664
2665
extract = _extract_plan_joins_and_filters(plan)
2666
assert extract == [
2667
'FILTER [(col("b")) ==v (2)]',
2668
'LEFT PLAN ON: [col("a"), col("b")]',
2669
'RIGHT PLAN ON: [col("a"), col("b")]',
2670
]
2671
2672
assert_frame_equal(q.collect(), expect)
2673
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2674
2675
2676
def test_join_filter_pushdown_full_join() -> None:
2677
lhs = pl.LazyFrame(
2678
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
2679
)
2680
rhs = pl.LazyFrame(
2681
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}
2682
)
2683
2684
# Full join can only push filters that refer to coalesced key columns.
2685
q = lhs.join(
2686
rhs,
2687
left_on="a",
2688
right_on="b",
2689
how="full",
2690
coalesce=True,
2691
maintain_order="left_right",
2692
).filter(pl.col("a") == 2)
2693
2694
expect = pl.DataFrame(
2695
{
2696
"a": [2],
2697
"b": [2],
2698
"c": ["b"],
2699
"a_right": [2],
2700
"c_right": ["B"],
2701
}
2702
)
2703
2704
plan = q.explain()
2705
extract = _extract_plan_joins_and_filters(plan)
2706
2707
assert extract == [
2708
'LEFT PLAN ON: [col("a")]',
2709
'FILTER [(col("a")) == (2)]',
2710
'RIGHT PLAN ON: [col("b")]',
2711
'FILTER [(col("b")) == (2)]',
2712
]
2713
2714
assert_frame_equal(q.collect(), expect)
2715
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2716
2717
# Non-coalescing full-join cannot push any filters
2718
# Note: We add fill_null to bypass non-NULL filter mask detection.
2719
q = lhs.join(
2720
rhs,
2721
left_on="a",
2722
right_on="b",
2723
how="full",
2724
coalesce=False,
2725
maintain_order="left_right",
2726
).filter(
2727
pl.col("a").fill_null(0) >= 2,
2728
pl.col("a").fill_null(0) <= 2,
2729
)
2730
2731
expect = pl.DataFrame(
2732
{
2733
"a": [2],
2734
"b": [2],
2735
"c": ["b"],
2736
"a_right": [2],
2737
"b_right": [2],
2738
"c_right": ["B"],
2739
}
2740
)
2741
2742
plan = q.explain()
2743
extract = _extract_plan_joins_and_filters(plan)
2744
2745
assert extract[0].startswith("FILTER ")
2746
assert extract[1:] == [
2747
'LEFT PLAN ON: [col("a")]',
2748
'RIGHT PLAN ON: [col("b")]',
2749
]
2750
2751
assert_frame_equal(q.collect(), expect)
2752
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2753
2754
2755
def test_join_filter_pushdown_semi_join() -> None:
2756
lhs = pl.LazyFrame(
2757
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
2758
)
2759
rhs = pl.LazyFrame(
2760
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}
2761
)
2762
2763
q = lhs.join(
2764
rhs,
2765
left_on=["a", "b"],
2766
right_on=["b", pl.lit(2)],
2767
how="semi",
2768
maintain_order="left_right",
2769
).filter(pl.col("a") == 2, pl.col("b") == 2, pl.col("c") == "b")
2770
2771
expect = pl.DataFrame(
2772
{
2773
"a": [2],
2774
"b": [2],
2775
"c": ["b"],
2776
}
2777
)
2778
2779
plan = q.explain()
2780
extract = _extract_plan_joins_and_filters(plan)
2781
2782
# * filter on col(a) is pushed to both sides (renamed to col(b) in the right side)
2783
# * filter on col(b) is pushed only to left, as the right join key is a literal
2784
# * filter on col(c) is pushed only to left, as the column does not exist in
2785
# the right.
2786
2787
assert extract[0] == 'LEFT PLAN ON: [col("a"), col("b").cast(Int64)]'
2788
assert 'col("a")) == (2)' in extract[1]
2789
assert 'col("b")) == (2)' in extract[1]
2790
assert 'col("c")) == ("b")' in extract[1]
2791
2792
assert extract[2:] == [
2793
'RIGHT PLAN ON: [col("b"), col("_POLARS_0").cast(Int64)]',
2794
'FILTER [(col("b")) == (2)]',
2795
]
2796
2797
assert_frame_equal(q.collect(), expect)
2798
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2799
2800
2801
def test_join_filter_pushdown_anti_join() -> None:
2802
lhs = pl.LazyFrame(
2803
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
2804
)
2805
rhs = pl.LazyFrame(
2806
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}
2807
)
2808
2809
q = lhs.join(
2810
rhs,
2811
left_on=["a", "b"],
2812
right_on=["b", pl.lit(1)],
2813
how="anti",
2814
maintain_order="left_right",
2815
).filter(pl.col("a") == 2, pl.col("b") == 2, pl.col("c") == "b")
2816
2817
expect = pl.DataFrame(
2818
{
2819
"a": [2],
2820
"b": [2],
2821
"c": ["b"],
2822
}
2823
)
2824
2825
plan = q.explain()
2826
extract = _extract_plan_joins_and_filters(plan)
2827
2828
assert extract[0] == 'LEFT PLAN ON: [col("a"), col("b").cast(Int64)]'
2829
assert 'col("a")) == (2)' in extract[1]
2830
assert 'col("b")) == (2)' in extract[1]
2831
assert 'col("c")) == ("b")' in extract[1]
2832
2833
assert extract[2:] == [
2834
'RIGHT PLAN ON: [col("b"), col("_POLARS_0").cast(Int64)]',
2835
'FILTER [(col("b")) == (2)]',
2836
]
2837
2838
assert_frame_equal(q.collect(), expect)
2839
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2840
2841
2842
def test_join_filter_pushdown_cross_join() -> None:
2843
lhs = pl.LazyFrame(
2844
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
2845
)
2846
rhs = pl.LazyFrame(
2847
{"a": [0, 0, 0, 0, 0], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}
2848
)
2849
2850
# Nested loop join for `!=`
2851
q = (
2852
lhs.with_row_index()
2853
.join(rhs, how="cross")
2854
.filter(
2855
pl.col("a") <= 4, pl.col("c_right") <= "B", pl.col("a") != pl.col("a_right")
2856
)
2857
.sort("index")
2858
)
2859
2860
expect = pl.DataFrame(
2861
[
2862
pl.Series("index", [0, 0, 1, 1, 2, 2, 3, 3], dtype=pl.get_index_type()),
2863
pl.Series("a", [1, 1, 2, 2, 3, 3, 4, 4], dtype=pl.Int64),
2864
pl.Series("b", [1, 1, 2, 2, 3, 3, 4, 4], dtype=pl.Int64),
2865
pl.Series("c", ["a", "a", "b", "b", "c", "c", "d", "d"], dtype=pl.String),
2866
pl.Series("a_right", [0, 0, 0, 0, 0, 0, 0, 0], dtype=pl.Int64),
2867
pl.Series("b_right", [1, 2, 1, 2, 1, 2, 1, 2], dtype=pl.Int64),
2868
pl.Series(
2869
"c_right", ["A", "B", "A", "B", "A", "B", "A", "B"], dtype=pl.String
2870
),
2871
]
2872
)
2873
2874
plan = q.explain()
2875
2876
assert 'NESTED LOOP JOIN ON [(col("a")) != (col("a_right"))]' in plan
2877
2878
extract = _extract_plan_joins_and_filters(plan)
2879
2880
assert extract == [
2881
"LEFT PLAN:",
2882
'FILTER [(col("a")) <= (4)]',
2883
"RIGHT PLAN:",
2884
'FILTER [(col("c")) <= ("B")]',
2885
]
2886
2887
assert_frame_equal(q.collect(), expect)
2888
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2889
2890
# Conversion to inner-join for `==`
2891
q = lhs.join(rhs, how="cross", maintain_order="left_right").filter(
2892
pl.col("a") <= 4,
2893
pl.col("c_right") <= "B",
2894
pl.col("a") == (pl.col("a_right") + 1),
2895
)
2896
2897
expect = pl.DataFrame(
2898
{
2899
"a": [1, 1],
2900
"b": [1, 1],
2901
"c": ["a", "a"],
2902
"a_right": [0, 0],
2903
"b_right": [1, 2],
2904
"c_right": ["A", "B"],
2905
}
2906
)
2907
2908
plan = q.explain()
2909
2910
extract = _extract_plan_joins_and_filters(plan)
2911
2912
assert extract == [
2913
'LEFT PLAN ON: [col("a")]',
2914
'FILTER [(col("a")) <= (4)]',
2915
'RIGHT PLAN ON: [[(col("a")) + (1)]]',
2916
'FILTER [(col("c")) <= ("B")]',
2917
]
2918
2919
assert_frame_equal(q.collect(), expect)
2920
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2921
2922
2923
def test_join_filter_pushdown_iejoin() -> None:
2924
lhs = pl.LazyFrame(
2925
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
2926
)
2927
rhs = pl.LazyFrame(
2928
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}
2929
)
2930
2931
q = (
2932
lhs.with_row_index()
2933
.join_where(
2934
rhs,
2935
pl.col("a") >= 1,
2936
pl.col("a") == pl.col("a_right"),
2937
pl.col("c_right") <= "B",
2938
)
2939
.sort("index")
2940
)
2941
2942
expect = pl.DataFrame(
2943
{
2944
"a": [1, 2],
2945
"b": [1, 2],
2946
"c": ["a", "b"],
2947
"a_right": [1, 2],
2948
"b_right": [1, 2],
2949
"c_right": ["A", "B"],
2950
}
2951
).with_row_index()
2952
2953
plan = q.explain()
2954
2955
assert "INNER JOIN" in plan
2956
2957
extract = _extract_plan_joins_and_filters(plan)
2958
2959
assert extract == [
2960
'LEFT PLAN ON: [col("a")]',
2961
'FILTER [(col("a")) >= (1)]',
2962
'RIGHT PLAN ON: [col("a")]',
2963
'FILTER [(col("c")) <= ("B")]',
2964
]
2965
2966
assert_frame_equal(q.collect(), expect)
2967
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2968
2969
q = (
2970
lhs.with_row_index()
2971
.join_where(
2972
rhs,
2973
pl.col("a") >= 1,
2974
pl.col("a") >= pl.col("a_right"),
2975
pl.col("c_right") <= "B",
2976
)
2977
.sort("index")
2978
)
2979
2980
expect = pl.DataFrame(
2981
[
2982
pl.Series("index", [0, 1, 1, 2, 2, 3, 3, 4, 4], dtype=pl.get_index_type()),
2983
pl.Series("a", [1, 2, 2, 3, 3, 4, 4, 5, 5], dtype=pl.Int64),
2984
pl.Series("b", [1, 2, 2, 3, 3, 4, 4, None, None], dtype=pl.Int64),
2985
pl.Series(
2986
"c", ["a", "b", "b", "c", "c", "d", "d", "e", "e"], dtype=pl.String
2987
),
2988
pl.Series("a_right", [1, 2, 1, 2, 1, 2, 1, 2, 1], dtype=pl.Int64),
2989
pl.Series("b_right", [1, 2, 1, 2, 1, 2, 1, 2, 1], dtype=pl.Int64),
2990
pl.Series(
2991
"c_right",
2992
["A", "B", "A", "B", "A", "B", "A", "B", "A"],
2993
dtype=pl.String,
2994
),
2995
]
2996
)
2997
2998
plan = q.explain()
2999
3000
assert "IEJOIN" in plan
3001
3002
extract = _extract_plan_joins_and_filters(plan)
3003
3004
assert extract == [
3005
'LEFT PLAN ON: [col("a")]',
3006
'FILTER [(col("a")) >= (1)]',
3007
'RIGHT PLAN ON: [col("a")]',
3008
'FILTER [(col("c")) <= ("B")]',
3009
]
3010
3011
assert_frame_equal(q.collect(), expect)
3012
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3013
3014
3015
def test_join_filter_pushdown_asof_join() -> None:
3016
lhs = pl.LazyFrame(
3017
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
3018
)
3019
rhs = pl.LazyFrame(
3020
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}
3021
)
3022
3023
q = lhs.join_asof(
3024
rhs,
3025
left_on=pl.col("a").set_sorted(),
3026
right_on=pl.col("b").set_sorted(),
3027
tolerance=0,
3028
).filter(
3029
pl.col("a") >= 2,
3030
pl.col("b") >= 3,
3031
pl.col("c") >= "A",
3032
pl.col("c_right") >= "B",
3033
)
3034
3035
expect = pl.DataFrame(
3036
{
3037
"a": [3],
3038
"b": [3],
3039
"c": ["c"],
3040
"a_right": [3],
3041
"b_right": [3],
3042
"c_right": ["C"],
3043
}
3044
)
3045
3046
plan = q.explain()
3047
extract = _extract_plan_joins_and_filters(plan)
3048
3049
assert extract[:2] == [
3050
'FILTER [(col("c_right")) >= ("B")]',
3051
'LEFT PLAN ON: [col("a").set_sorted()]',
3052
]
3053
3054
assert 'col("b")) >= (3)' in extract[2]
3055
assert 'col("c")) >= ("A")' in extract[2]
3056
assert 'col("a")) >= (2)' in extract[2]
3057
3058
assert extract[3:] == ['RIGHT PLAN ON: [col("b").set_sorted()]']
3059
3060
assert_frame_equal(q.collect(), expect)
3061
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3062
3063
# With "by" columns
3064
q = lhs.join_asof(
3065
rhs,
3066
left_on="a",
3067
right_on="b",
3068
tolerance=99,
3069
by_left="b",
3070
by_right="a",
3071
).filter(
3072
pl.col("a") >= 2,
3073
pl.col("b") >= 3,
3074
pl.col("c") >= "A",
3075
pl.col("c_right") >= "B",
3076
)
3077
3078
expect = pl.DataFrame(
3079
{
3080
"a": [3],
3081
"b": [3],
3082
"c": ["c"],
3083
"b_right": [3],
3084
"c_right": ["C"],
3085
}
3086
)
3087
3088
plan = q.explain()
3089
extract = _extract_plan_joins_and_filters(plan)
3090
3091
assert extract[:2] == [
3092
'FILTER [(col("c_right")) >= ("B")]',
3093
'LEFT PLAN ON: [col("a")]',
3094
]
3095
assert 'col("a")) >= (2)' in extract[2]
3096
assert 'col("b")) >= (3)' in extract[2]
3097
3098
assert extract[3:] == [
3099
'RIGHT PLAN ON: [col("b")]',
3100
'FILTER [(col("a")) >= (3)]',
3101
]
3102
3103
assert_frame_equal(q.collect(), expect)
3104
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3105
3106
3107
def test_join_filter_pushdown_full_join_rewrite() -> None:
3108
lhs = pl.LazyFrame(
3109
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
3110
)
3111
rhs = pl.LazyFrame(
3112
{
3113
"a": [1, 2, 3, 4, None],
3114
"b": [1, 2, 3, None, 5],
3115
"c": ["A", "B", "C", "D", "E"],
3116
}
3117
)
3118
3119
# Downgrades to left-join
3120
q = lhs.join(rhs, on=["a", "b"], how="full", maintain_order="left_right").filter(
3121
pl.col("b") >= 3
3122
)
3123
3124
expect = pl.DataFrame(
3125
[
3126
pl.Series("a", [3, 4], dtype=pl.Int64),
3127
pl.Series("b", [3, 4], dtype=pl.Int64),
3128
pl.Series("c", ["c", "d"], dtype=pl.String),
3129
pl.Series("a_right", [3, None], dtype=pl.Int64),
3130
pl.Series("b_right", [3, None], dtype=pl.Int64),
3131
pl.Series("c_right", ["C", None], dtype=pl.String),
3132
]
3133
)
3134
3135
plan = q.explain()
3136
3137
assert "FULL JOIN" not in plan
3138
assert plan.startswith("LEFT JOIN")
3139
3140
extract = _extract_plan_joins_and_filters(plan)
3141
3142
assert extract == [
3143
'LEFT PLAN ON: [col("a"), col("b")]',
3144
'FILTER [(col("b")) >= (3)]',
3145
'RIGHT PLAN ON: [col("a"), col("b")]',
3146
'FILTER [(col("b")) >= (3)]',
3147
]
3148
3149
assert_frame_equal(q.collect(), expect)
3150
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3151
3152
# Downgrades to right-join
3153
q = lhs.join(
3154
rhs, left_on="a", right_on="b", how="full", maintain_order="left_right"
3155
).filter(pl.col("b_right") >= 3)
3156
3157
expect = pl.DataFrame(
3158
[
3159
pl.Series("a", [3, 5], dtype=pl.Int64),
3160
pl.Series("b", [3, None], dtype=pl.Int64),
3161
pl.Series("c", ["c", "e"], dtype=pl.String),
3162
pl.Series("a_right", [3, None], dtype=pl.Int64),
3163
pl.Series("b_right", [3, 5], dtype=pl.Int64),
3164
pl.Series("c_right", ["C", "E"], dtype=pl.String),
3165
]
3166
)
3167
3168
plan = q.explain()
3169
3170
assert "FULL JOIN" not in plan
3171
assert "RIGHT JOIN" in plan
3172
3173
extract = _extract_plan_joins_and_filters(plan)
3174
3175
assert extract == [
3176
'LEFT PLAN ON: [col("a")]',
3177
'FILTER [(col("a")) >= (3)]',
3178
'RIGHT PLAN ON: [col("b")]',
3179
'FILTER [(col("b")) >= (3)]',
3180
]
3181
3182
assert_frame_equal(q.collect(), expect)
3183
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3184
3185
# Downgrades to right-join
3186
q = lhs.join(
3187
rhs,
3188
left_on="a",
3189
right_on="b",
3190
how="full",
3191
coalesce=True,
3192
maintain_order="left_right",
3193
).filter(
3194
(pl.col("a") >= 1) | pl.col("a").is_null(), # col(a) from LHS
3195
pl.col("a_right") >= 3, # col(a) from RHS
3196
(pl.col("b") >= 2) | pl.col("b").is_null(), # col(b) from LHS
3197
pl.col("c_right") >= "C", # col(c) from RHS
3198
)
3199
3200
expect = pl.DataFrame(
3201
[
3202
pl.Series("a", [3, None], dtype=pl.Int64),
3203
pl.Series("b", [3, None], dtype=pl.Int64),
3204
pl.Series("c", ["c", None], dtype=pl.String),
3205
pl.Series("a_right", [3, 4], dtype=pl.Int64),
3206
pl.Series("c_right", ["C", "D"], dtype=pl.String),
3207
]
3208
)
3209
3210
plan = q.explain()
3211
3212
assert "FULL JOIN" not in plan
3213
assert "RIGHT JOIN" in plan
3214
3215
extract = _extract_plan_joins_and_filters(plan)
3216
3217
assert [
3218
'FILTER [([(col("b")) >= (2)]) | (col("b").is_null())]',
3219
'LEFT PLAN ON: [col("a")]',
3220
'FILTER [([(col("a")) >= (1)]) | (col("a").is_null())]',
3221
'RIGHT PLAN ON: [col("b")]',
3222
]
3223
3224
assert 'col("a")) >= (3)' in extract[4]
3225
assert '(col("b")) >= (1)]) | (col("b").alias("a").is_null())' in extract[4]
3226
assert 'col("c")) >= ("C")' in extract[4]
3227
3228
assert len(extract) == 5
3229
3230
assert_frame_equal(q.collect(), expect)
3231
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3232
3233
# Downgrades to inner-join
3234
q = lhs.join(rhs, on=["a", "b"], how="full", maintain_order="left_right").filter(
3235
pl.col("b").is_not_null(), pl.col("b_right").is_not_null()
3236
)
3237
3238
expect = pl.DataFrame(
3239
[
3240
pl.Series("a", [1, 2, 3], dtype=pl.Int64),
3241
pl.Series("b", [1, 2, 3], dtype=pl.Int64),
3242
pl.Series("c", ["a", "b", "c"], dtype=pl.String),
3243
pl.Series("a_right", [1, 2, 3], dtype=pl.Int64),
3244
pl.Series("b_right", [1, 2, 3], dtype=pl.Int64),
3245
pl.Series("c_right", ["A", "B", "C"], dtype=pl.String),
3246
]
3247
)
3248
3249
plan = q.explain()
3250
3251
assert "FULL JOIN" not in plan
3252
assert plan.startswith("INNER JOIN")
3253
3254
extract = _extract_plan_joins_and_filters(plan)
3255
3256
assert extract[0] == 'LEFT PLAN ON: [col("a"), col("b")]'
3257
assert 'col("b").is_not_null()' in extract[1]
3258
assert 'col("b").alias("b_right").is_not_null()' in extract[1]
3259
3260
assert extract[2] == 'RIGHT PLAN ON: [col("a"), col("b")]'
3261
assert 'col("b").is_not_null()' in extract[3]
3262
assert 'col("b").alias("b_right").is_not_null()' in extract[3]
3263
3264
assert len(extract) == 4
3265
3266
assert_frame_equal(q.collect(), expect)
3267
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3268
3269
# Does not downgrade because col(b) is a coalesced key-column, but the filter
3270
# is still pushed to both sides.
3271
q = lhs.join(
3272
rhs, on=["a", "b"], how="full", coalesce=True, maintain_order="left_right"
3273
).filter(pl.col("b") >= 3)
3274
3275
expect = pl.DataFrame(
3276
[
3277
pl.Series("a", [3, 4, None], dtype=pl.Int64),
3278
pl.Series("b", [3, 4, 5], dtype=pl.Int64),
3279
pl.Series("c", ["c", "d", None], dtype=pl.String),
3280
pl.Series("c_right", ["C", None, "E"], dtype=pl.String),
3281
]
3282
)
3283
3284
plan = q.explain()
3285
assert plan.startswith("FULL JOIN")
3286
3287
extract = _extract_plan_joins_and_filters(plan)
3288
3289
assert extract == [
3290
'LEFT PLAN ON: [col("a"), col("b")]',
3291
'FILTER [(col("b")) >= (3)]',
3292
'RIGHT PLAN ON: [col("a"), col("b")]',
3293
'FILTER [(col("b")) >= (3)]',
3294
]
3295
3296
assert_frame_equal(q.collect(), expect)
3297
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3298
3299
3300
def test_join_filter_pushdown_right_join_rewrite() -> None:
3301
lhs = pl.LazyFrame(
3302
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
3303
)
3304
rhs = pl.LazyFrame(
3305
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}
3306
)
3307
3308
# Downgrades to inner-join
3309
q = lhs.join(
3310
rhs,
3311
left_on="a",
3312
right_on="b",
3313
how="right",
3314
coalesce=True,
3315
maintain_order="left_right",
3316
).filter(
3317
pl.col("a") <= 7, # col(a) from RHS (LHS col(a) is coalesced into col(b_right))
3318
pl.col("b_right") <= 10, # Key-column filter
3319
pl.col("c") <= "b", # col(c) from LHS
3320
)
3321
3322
expect = pl.DataFrame(
3323
[
3324
pl.Series("b", [1, 2], dtype=pl.Int64),
3325
pl.Series("c", ["a", "b"], dtype=pl.String),
3326
pl.Series("a", [1, 2], dtype=pl.Int64),
3327
pl.Series("b_right", [1, 2], dtype=pl.Int64),
3328
pl.Series("c_right", ["A", "B"], dtype=pl.String),
3329
]
3330
)
3331
3332
plan = q.explain()
3333
3334
assert "RIGHT JOIN" not in plan
3335
assert "INNER JOIN" in plan
3336
3337
extract = _extract_plan_joins_and_filters(plan)
3338
3339
assert extract[0] == 'LEFT PLAN ON: [col("a")]'
3340
assert 'col("a")) <= (10)' in extract[1]
3341
assert 'col("c")) <= ("b")' in extract[1]
3342
3343
assert extract[2] == 'RIGHT PLAN ON: [col("b")]'
3344
assert 'col("a")) <= (7)' in extract[3]
3345
assert 'col("b")) <= (10)' in extract[3]
3346
3347
assert len(extract) == 4
3348
3349
assert_frame_equal(q.collect(), expect)
3350
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3351
3352
3353
def test_join_filter_pushdown_join_rewrite_equality_above_and() -> None:
3354
lhs = pl.LazyFrame(
3355
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
3356
)
3357
rhs = pl.LazyFrame(
3358
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}
3359
)
3360
3361
q = lhs.join(
3362
rhs,
3363
left_on="a",
3364
right_on="b",
3365
how="full",
3366
coalesce=False,
3367
maintain_order="left_right",
3368
).filter(((pl.col("b") >= 3) & False) >= False)
3369
3370
expect = pl.DataFrame(
3371
[
3372
pl.Series("a", [1, 2, 3, 4, 5, None], dtype=pl.Int64),
3373
pl.Series("b", [1, 2, 3, 4, None, None], dtype=pl.Int64),
3374
pl.Series("c", ["a", "b", "c", "d", "e", None], dtype=pl.String),
3375
pl.Series("a_right", [1, 2, 3, None, 5, 4], dtype=pl.Int64),
3376
pl.Series("b_right", [1, 2, 3, None, 5, None], dtype=pl.Int64),
3377
pl.Series("c_right", ["A", "B", "C", None, "E", "D"], dtype=pl.String),
3378
]
3379
)
3380
3381
assert_frame_equal(q.collect(), expect)
3382
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3383
3384
3385
def test_join_filter_pushdown_left_join_rewrite() -> None:
3386
lhs = pl.LazyFrame(
3387
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
3388
)
3389
rhs = pl.LazyFrame(
3390
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", None, "D", "E"]}
3391
)
3392
3393
# Downgrades to inner-join
3394
q = lhs.join(
3395
rhs,
3396
left_on="a",
3397
right_on="b",
3398
how="left",
3399
coalesce=True,
3400
maintain_order="left_right",
3401
).filter(pl.col("c_right") <= "B")
3402
3403
expect = pl.DataFrame(
3404
[
3405
pl.Series("a", [1, 2], dtype=pl.Int64),
3406
pl.Series("b", [1, 2], dtype=pl.Int64),
3407
pl.Series("c", ["a", "b"], dtype=pl.String),
3408
pl.Series("a_right", [1, 2], dtype=pl.Int64),
3409
pl.Series("c_right", ["A", "B"], dtype=pl.String),
3410
]
3411
)
3412
3413
plan = q.explain()
3414
3415
assert "LEFT JOIN" not in plan
3416
assert plan.startswith("INNER JOIN")
3417
3418
extract = _extract_plan_joins_and_filters(plan)
3419
3420
assert extract == [
3421
'LEFT PLAN ON: [col("a")]',
3422
'RIGHT PLAN ON: [col("b")]',
3423
'FILTER [(col("c")) <= ("B")]',
3424
]
3425
3426
assert_frame_equal(q.collect(), expect)
3427
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3428
3429
3430
def test_join_filter_pushdown_left_join_rewrite_23133() -> None:
3431
lhs = pl.LazyFrame(
3432
{
3433
"foo": [1, 2, 3],
3434
"bar": [6.0, 7.0, 8.0],
3435
"ham": ["a", "b", "c"],
3436
}
3437
)
3438
3439
rhs = pl.LazyFrame(
3440
{
3441
"apple": ["x", "y", "z"],
3442
"ham": ["a", "b", "d"],
3443
"bar": ["a", "b", "c"],
3444
"foo2": [1, 2, 3],
3445
}
3446
)
3447
3448
q = lhs.join(rhs, how="left", on="ham", maintain_order="left_right").filter(
3449
pl.col("ham") == "a", pl.col("apple") == "x", pl.col("foo") <= 2
3450
)
3451
3452
expect = pl.DataFrame(
3453
[
3454
pl.Series("foo", [1], dtype=pl.Int64),
3455
pl.Series("bar", [6.0], dtype=pl.Float64),
3456
pl.Series("ham", ["a"], dtype=pl.String),
3457
pl.Series("apple", ["x"], dtype=pl.String),
3458
pl.Series("bar_right", ["a"], dtype=pl.String),
3459
pl.Series("foo2", [1], dtype=pl.Int64),
3460
]
3461
)
3462
3463
plan = q.explain()
3464
assert "FULL JOIN" not in plan
3465
assert plan.startswith("INNER JOIN")
3466
3467
extract = _extract_plan_joins_and_filters(plan)
3468
3469
assert extract[0] == 'LEFT PLAN ON: [col("ham")]'
3470
assert '(col("foo")) <= (2)' in extract[1]
3471
assert 'col("ham")) == ("a")' in extract[1]
3472
3473
assert extract[2] == 'RIGHT PLAN ON: [col("ham")]'
3474
assert 'col("ham")) == ("a")' in extract[3]
3475
assert 'col("apple")) == ("x")' in extract[3]
3476
3477
assert len(extract) == 4
3478
3479
assert_frame_equal(q.collect(), expect)
3480
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3481
3482
3483
def test_join_rewrite_panic_23307() -> None:
3484
lhs = pl.select(a=pl.lit(1, dtype=pl.Int8)).lazy()
3485
rhs = pl.select(a=pl.lit(1, dtype=pl.Int16), x=pl.lit(1, dtype=pl.Int32)).lazy()
3486
3487
q = lhs.join(rhs, on="a", how="left", coalesce=True).filter(pl.col("x") >= 1)
3488
3489
assert_frame_equal(
3490
q.collect(),
3491
pl.select(
3492
a=pl.lit(1, dtype=pl.Int8),
3493
x=pl.lit(1, dtype=pl.Int32),
3494
),
3495
)
3496
3497
lhs = pl.select(a=pl.lit(999, dtype=pl.Int16)).lazy()
3498
3499
# Note: -25 matches to (999).overflowing_cast(Int8).
3500
# This is specially chosen to test that we don't accidentally push the filter
3501
# to the RHS.
3502
rhs = pl.LazyFrame(
3503
{"a": [1, -25], "x": [1, 2]}, schema={"a": pl.Int8, "x": pl.Int32}
3504
)
3505
3506
q = lhs.join(
3507
rhs,
3508
on=pl.col("a").cast(pl.Int8, strict=False, wrap_numerical=True),
3509
how="left",
3510
coalesce=False,
3511
).filter(pl.col("a") >= 0)
3512
3513
expect = pl.DataFrame(
3514
{"a": 999, "a_right": -25, "x": 2},
3515
schema={"a": pl.Int16, "a_right": pl.Int8, "x": pl.Int32},
3516
)
3517
3518
plan = q.explain()
3519
3520
assert not plan.startswith("FILTER")
3521
3522
assert_frame_equal(q.collect(), expect)
3523
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3524
3525
3526
@pytest.mark.parametrize(
3527
("expr_first_input", "expr_func"),
3528
[
3529
(pl.lit(None, dtype=pl.Int64), lambda col: col >= 1),
3530
(pl.lit(None, dtype=pl.Int64), lambda col: (col >= 1).is_not_null()),
3531
(pl.lit(None, dtype=pl.Int64), lambda col: (~(col >= 1)).is_not_null()),
3532
(pl.lit(None, dtype=pl.Int64), lambda col: ~(col >= 1).is_null()),
3533
#
3534
(pl.lit(None, dtype=pl.Int64), lambda col: col.is_in([1])),
3535
(pl.lit(None, dtype=pl.Int64), lambda col: ~col.is_in([1])),
3536
#
3537
(pl.lit(None, dtype=pl.Int64), lambda col: col.is_between(1, 1)),
3538
(1, lambda col: col.is_between(None, 1)),
3539
(1, lambda col: col.is_between(1, None)),
3540
#
3541
(pl.lit(None, dtype=pl.Int64), lambda col: col.is_close(1)),
3542
(1, lambda col: col.is_close(pl.lit(None, dtype=pl.Int64))),
3543
#
3544
(pl.lit(None, dtype=pl.Int64), lambda col: col.is_nan()),
3545
(pl.lit(None, dtype=pl.Int64), lambda col: col.is_not_nan()),
3546
(pl.lit(None, dtype=pl.Int64), lambda col: col.is_finite()),
3547
(pl.lit(None, dtype=pl.Int64), lambda col: col.is_infinite()),
3548
#
3549
(pl.lit(None, dtype=pl.Float64), lambda col: col.is_nan()),
3550
(pl.lit(None, dtype=pl.Float64), lambda col: col.is_not_nan()),
3551
(pl.lit(None, dtype=pl.Float64), lambda col: col.is_finite()),
3552
(pl.lit(None, dtype=pl.Float64), lambda col: col.is_infinite()),
3553
],
3554
)
3555
def test_join_rewrite_null_preserving_exprs(
3556
expr_first_input: Any, expr_func: Callable[[pl.Expr], pl.Expr]
3557
) -> None:
3558
lhs = pl.LazyFrame({"a": 1})
3559
rhs = pl.select(a=1, x=expr_first_input).lazy()
3560
3561
assert (
3562
pl.select(expr_first_input)
3563
.select(expr_func(pl.first()))
3564
.select(pl.first().is_null() | ~pl.first())
3565
.to_series()
3566
.item()
3567
)
3568
3569
q = lhs.join(rhs, on="a", how="left", maintain_order="left_right").filter(
3570
expr_func(pl.col("x"))
3571
)
3572
3573
plan = q.explain()
3574
assert plan.startswith("INNER JOIN")
3575
3576
out = q.collect()
3577
3578
assert out.height == 0
3579
assert_frame_equal(out, q.collect(optimizations=pl.QueryOptFlags.none()))
3580
3581
3582
@pytest.mark.parametrize(
3583
("expr_first_input", "expr_func"),
3584
[
3585
(
3586
pl.lit(None, dtype=pl.Int64),
3587
lambda x: ~(x.is_in([1, None], nulls_equal=True)),
3588
),
3589
(
3590
pl.lit(None, dtype=pl.Int64),
3591
lambda x: x.is_in([1, None], nulls_equal=True) > True,
3592
),
3593
(
3594
pl.lit(None, dtype=pl.Int64),
3595
lambda x: x.is_in([1], nulls_equal=True),
3596
),
3597
],
3598
)
3599
def test_join_rewrite_forbid_exprs(
3600
expr_first_input: Any, expr_func: Callable[[pl.Expr], pl.Expr]
3601
) -> None:
3602
lhs = pl.LazyFrame({"a": 1})
3603
rhs = pl.select(a=1, x=expr_first_input).lazy()
3604
3605
q = lhs.join(rhs, on="a", how="left", maintain_order="left_right").filter(
3606
expr_func(pl.col("x"))
3607
)
3608
3609
plan = q.explain()
3610
assert plan.startswith("FILTER")
3611
3612
assert_frame_equal(q.collect(), q.collect(optimizations=pl.QueryOptFlags.none()))
3613
3614
3615
def test_join_filter_pushdown_iejoin_cse_23469() -> None:
3616
lf_x = pl.LazyFrame({"x": [1, 2, 3]})
3617
lf_y = pl.LazyFrame({"y": [1, 2, 3]})
3618
3619
lf_xy = lf_x.join(lf_y, how="cross").filter(pl.col("x") > pl.col("y"))
3620
3621
q = pl.concat([lf_xy, lf_xy])
3622
3623
assert_frame_equal(
3624
q.collect().sort(pl.all()),
3625
pl.DataFrame(
3626
{
3627
"x": [2, 2, 3, 3, 3, 3],
3628
"y": [1, 1, 1, 1, 2, 2],
3629
},
3630
),
3631
)
3632
3633
q = pl.concat([lf_xy, lf_xy]).filter(pl.col("x") > pl.col("y"))
3634
3635
assert_frame_equal(
3636
q.collect().sort(pl.all()),
3637
pl.DataFrame(
3638
{
3639
"x": [2, 2, 3, 3, 3, 3],
3640
"y": [1, 1, 1, 1, 2, 2],
3641
},
3642
),
3643
)
3644
3645
q = (
3646
lf_x.join_where(lf_y, pl.col("x") == pl.col("y"))
3647
.cache()
3648
.filter(pl.col("x") >= 0)
3649
)
3650
3651
assert_frame_equal(
3652
q.collect().sort(pl.all()), pl.DataFrame({"x": [1, 2, 3], "y": [1, 2, 3]})
3653
)
3654
3655
3656
def test_join_cast_type_coercion_23236() -> None:
3657
lhs = pl.LazyFrame([{"name": "a"}]).rename({"name": "newname"})
3658
rhs = pl.LazyFrame([{"name": "a"}])
3659
3660
q = lhs.join(rhs, left_on=pl.col("newname").cast(pl.String), right_on="name")
3661
3662
assert_frame_equal(q.collect(), pl.DataFrame({"newname": "a", "name": "a"}))
3663
3664
3665
@pytest.mark.parametrize(
3666
("how", "expected"),
3667
[
3668
(
3669
"inner",
3670
pl.DataFrame(schema={"a": pl.Int128, "a_right": pl.Int128}),
3671
),
3672
(
3673
"left",
3674
pl.DataFrame(
3675
{"a": [1, 1, 2], "a_right": None},
3676
schema={"a": pl.Int128, "a_right": pl.Int128},
3677
),
3678
),
3679
(
3680
"right",
3681
pl.DataFrame(
3682
{
3683
"a": None,
3684
"a_right": [
3685
-9223372036854775808,
3686
-9223372036854775807,
3687
-9223372036854775806,
3688
],
3689
},
3690
schema={"a": pl.Int128, "a_right": pl.Int128},
3691
),
3692
),
3693
(
3694
"full",
3695
pl.DataFrame(
3696
[
3697
pl.Series("a", [None, None, None, 1, 1, 2], dtype=pl.Int128),
3698
pl.Series(
3699
"a_right",
3700
[
3701
-9223372036854775808,
3702
-9223372036854775807,
3703
-9223372036854775806,
3704
None,
3705
None,
3706
None,
3707
],
3708
dtype=pl.Int128,
3709
),
3710
]
3711
),
3712
),
3713
(
3714
"semi",
3715
pl.DataFrame([pl.Series("a", [], dtype=pl.Int128)]),
3716
),
3717
(
3718
"anti",
3719
pl.DataFrame([pl.Series("a", [1, 1, 2], dtype=pl.Int128)]),
3720
),
3721
],
3722
)
3723
@pytest.mark.parametrize(
3724
("sort_left", "sort_right"),
3725
[(True, True), (True, False), (False, True), (False, False)],
3726
)
3727
def test_join_i128_23688(
3728
how: str, expected: pl.DataFrame, sort_left: bool, sort_right: bool
3729
) -> None:
3730
lhs = pl.LazyFrame({"a": pl.Series([1, 1, 2], dtype=pl.Int128)})
3731
3732
rhs = pl.LazyFrame(
3733
{
3734
"a": pl.Series(
3735
[
3736
-9223372036854775808,
3737
-9223372036854775807,
3738
-9223372036854775806,
3739
],
3740
dtype=pl.Int128,
3741
)
3742
}
3743
)
3744
3745
lhs = lhs.collect().sort("a").lazy() if sort_left else lhs
3746
rhs = rhs.collect().sort("a").lazy() if sort_right else rhs
3747
3748
q = lhs.join(rhs, on="a", how=how, coalesce=False) # type: ignore[arg-type]
3749
3750
assert_frame_equal(
3751
q.collect().sort(pl.all()),
3752
expected,
3753
)
3754
3755
q = (
3756
lhs.with_columns(b=pl.col("a"))
3757
.join(
3758
rhs.with_columns(b=pl.col("a")),
3759
on=["a", "b"],
3760
how=how, # type: ignore[arg-type]
3761
coalesce=False,
3762
)
3763
.select(expected.columns)
3764
)
3765
3766
assert_frame_equal(
3767
q.collect().sort(pl.all()),
3768
expected,
3769
)
3770
3771
3772
def test_join_asof_by_i128() -> None:
3773
lhs = pl.LazyFrame({"a": pl.Series([1, 1, 2], dtype=pl.Int128), "i": 1})
3774
3775
rhs = pl.LazyFrame(
3776
{
3777
"a": pl.Series(
3778
[
3779
-9223372036854775808,
3780
-9223372036854775807,
3781
-9223372036854775806,
3782
],
3783
dtype=pl.Int128,
3784
),
3785
"i": 1,
3786
}
3787
).with_columns(b=pl.col("a"))
3788
3789
q = lhs.join_asof(rhs, on="i", by="a")
3790
3791
assert_frame_equal(
3792
q.collect().sort(pl.all()),
3793
pl.DataFrame(
3794
{"a": [1, 1, 2], "i": 1, "b": None},
3795
schema={"a": pl.Int128, "i": pl.Int32, "b": pl.Int128},
3796
),
3797
)
3798
3799