Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_join.py
8424 views
1
from __future__ import annotations
2
3
import typing
4
import warnings
5
from datetime import date, datetime
6
from typing import TYPE_CHECKING, Any, Literal
7
8
import numpy as np
9
import pandas as pd
10
import pytest
11
12
import polars as pl
13
from polars.exceptions import (
14
ColumnNotFoundError,
15
ComputeError,
16
DuplicateError,
17
InvalidOperationError,
18
SchemaError,
19
)
20
from polars.testing import assert_frame_equal, assert_series_equal
21
from tests.unit.conftest import time_func
22
23
if TYPE_CHECKING:
24
from collections.abc import Callable
25
26
from polars._typing import JoinStrategy, PolarsDataType
27
28
29
def test_semi_anti_join() -> None:
30
df_a = pl.DataFrame({"key": [1, 2, 3], "payload": ["f", "i", None]})
31
32
df_b = pl.DataFrame({"key": [3, 4, 5, None]})
33
34
assert df_a.join(df_b, on="key", how="anti").to_dict(as_series=False) == {
35
"key": [1, 2],
36
"payload": ["f", "i"],
37
}
38
assert df_a.join(df_b, on="key", how="semi").to_dict(as_series=False) == {
39
"key": [3],
40
"payload": [None],
41
}
42
43
# lazy
44
result = df_a.lazy().join(df_b.lazy(), on="key", how="anti").collect()
45
expected_values = {"key": [1, 2], "payload": ["f", "i"]}
46
assert result.to_dict(as_series=False) == expected_values
47
48
result = df_a.lazy().join(df_b.lazy(), on="key", how="semi").collect()
49
expected_values = {"key": [3], "payload": [None]}
50
assert result.to_dict(as_series=False) == expected_values
51
52
df_a = pl.DataFrame(
53
{"a": [1, 2, 3, 1], "b": ["a", "b", "c", "a"], "payload": [10, 20, 30, 40]}
54
)
55
56
df_b = pl.DataFrame({"a": [3, 3, 4, 5], "b": ["c", "c", "d", "e"]})
57
58
assert df_a.join(df_b, on=["a", "b"], how="anti").to_dict(as_series=False) == {
59
"a": [1, 2, 1],
60
"b": ["a", "b", "a"],
61
"payload": [10, 20, 40],
62
}
63
assert df_a.join(df_b, on=["a", "b"], how="semi").to_dict(as_series=False) == {
64
"a": [3],
65
"b": ["c"],
66
"payload": [30],
67
}
68
69
70
def test_join_same_cat_src() -> None:
71
df = pl.DataFrame(
72
data={"column": ["a", "a", "b"], "more": [1, 2, 3]},
73
schema=[("column", pl.Categorical), ("more", pl.Int32)],
74
)
75
df_agg = df.group_by("column").agg(pl.col("more").mean())
76
assert_frame_equal(
77
df.join(df_agg, on="column"),
78
pl.DataFrame(
79
{
80
"column": ["a", "a", "b"],
81
"more": [1, 2, 3],
82
"more_right": [1.5, 1.5, 3.0],
83
},
84
schema=[
85
("column", pl.Categorical),
86
("more", pl.Int32),
87
("more_right", pl.Float64),
88
],
89
),
90
check_row_order=False,
91
)
92
93
94
@pytest.mark.parametrize("reverse", [False, True])
95
def test_sorted_merge_joins(reverse: bool) -> None:
96
n = 30
97
df_a = pl.DataFrame({"a": np.sort(np.random.randint(0, n // 2, n))}).with_row_index(
98
"row_a"
99
)
100
df_b = pl.DataFrame(
101
{"a": np.sort(np.random.randint(0, n // 2, n // 2))}
102
).with_row_index("row_b")
103
104
if reverse:
105
df_a = df_a.select(pl.all().reverse())
106
df_b = df_b.select(pl.all().reverse())
107
108
join_strategies: list[JoinStrategy] = ["left", "inner"]
109
for cast_to in [int, str, float]:
110
for how in join_strategies:
111
df_a_ = df_a.with_columns(pl.col("a").cast(cast_to))
112
df_b_ = df_b.with_columns(pl.col("a").cast(cast_to))
113
114
# hash join
115
out_hash_join = df_a_.join(df_b_, on="a", how=how)
116
117
# sorted merge join
118
out_sorted_merge_join = df_a_.with_columns(
119
pl.col("a").set_sorted(descending=reverse)
120
).join(
121
df_b_.with_columns(pl.col("a").set_sorted(descending=reverse)),
122
on="a",
123
how=how,
124
)
125
126
assert_frame_equal(
127
out_hash_join, out_sorted_merge_join, check_row_order=False
128
)
129
130
131
def test_join_negative_integers() -> None:
132
expected = pl.DataFrame({"a": [-6, -1, 0], "b": [-6, -1, 0]})
133
df1 = pl.DataFrame(
134
{
135
"a": [-1, -6, -3, 0],
136
}
137
)
138
139
df2 = pl.DataFrame(
140
{
141
"a": [-6, -1, -4, -2, 0],
142
"b": [-6, -1, -4, -2, 0],
143
}
144
)
145
146
for dt in [pl.Int8, pl.Int16, pl.Int32, pl.Int64]:
147
assert_frame_equal(
148
df1.with_columns(pl.all().cast(dt)).join(
149
df2.with_columns(pl.all().cast(dt)), on="a", how="inner"
150
),
151
expected.select(pl.all().cast(dt)),
152
check_row_order=False,
153
)
154
155
156
def test_deprecated() -> None:
157
df = pl.DataFrame({"a": [1, 2], "b": [3, 4]})
158
other = pl.DataFrame({"a": [1, 2], "c": [3, 4]})
159
result = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [3, 4]})
160
161
np.testing.assert_equal(
162
df.join(other=other, on="a", maintain_order="left").to_numpy(),
163
result.to_numpy(),
164
)
165
np.testing.assert_equal(
166
df.lazy()
167
.join(other=other.lazy(), on="a", maintain_order="left")
168
.collect()
169
.to_numpy(),
170
result.to_numpy(),
171
)
172
173
174
def test_deprecated_parameter_join_nulls() -> None:
175
df = pl.DataFrame({"a": [1, None]})
176
with pytest.deprecated_call(
177
match=r"the argument `join_nulls` for `DataFrame.join` is deprecated. It was renamed to `nulls_equal`"
178
):
179
result = df.join(df, on="a", join_nulls=True) # type: ignore[call-arg]
180
assert_frame_equal(result, df, check_row_order=False)
181
182
183
def test_join_on_expressions() -> None:
184
df_a = pl.DataFrame({"a": [1, 2, 3]})
185
186
df_b = pl.DataFrame({"b": [1, 4, 9, 9, 0]})
187
188
assert_frame_equal(
189
df_a.join(df_b, left_on=(pl.col("a") ** 2).cast(int), right_on=pl.col("b")),
190
pl.DataFrame({"a": [1, 2, 3, 3], "b": [1, 4, 9, 9]}),
191
check_row_order=False,
192
)
193
194
195
def test_join_lazy_frame_on_expression() -> None:
196
# Tests a lazy frame projection pushdown bug
197
# https://github.com/pola-rs/polars/issues/19822
198
199
df = pl.DataFrame(data={"a": [0, 1], "b": [2, 3]})
200
201
lazy_join = (
202
df.lazy()
203
.join(df.lazy(), left_on=pl.coalesce("b", "a"), right_on="a")
204
.select("a")
205
.collect()
206
)
207
208
eager_join = df.join(df, left_on=pl.coalesce("b", "a"), right_on="a").select("a")
209
210
assert lazy_join.shape == eager_join.shape
211
212
213
def test_right_join_schema_maintained_22516() -> None:
214
df_left = pl.DataFrame({"number": [1]})
215
df_right = pl.DataFrame({"invoice_number": [1]})
216
eager_join = df_left.join(
217
df_right, left_on="number", right_on="invoice_number", how="right"
218
).select(pl.len())
219
220
lazy_join = (
221
df_left.lazy()
222
.join(df_right.lazy(), left_on="number", right_on="invoice_number", how="right")
223
.select(pl.len())
224
.collect()
225
)
226
227
assert lazy_join.item() == eager_join.item()
228
229
230
def test_join() -> None:
231
df_left = pl.DataFrame(
232
{
233
"a": ["a", "b", "a", "z"],
234
"b": [1, 2, 3, 4],
235
"c": [6, 5, 4, 3],
236
}
237
)
238
df_right = pl.DataFrame(
239
{
240
"a": ["b", "c", "b", "a"],
241
"k": [0, 3, 9, 6],
242
"c": [1, 0, 2, 1],
243
}
244
)
245
246
joined = df_left.join(
247
df_right, left_on="a", right_on="a", maintain_order="left_right"
248
).sort("a")
249
assert_series_equal(joined["b"], pl.Series("b", [1, 3, 2, 2]))
250
251
joined = df_left.join(
252
df_right, left_on="a", right_on="a", how="left", maintain_order="left_right"
253
).sort("a")
254
assert joined["c_right"].is_null().sum() == 1
255
assert_series_equal(joined["b"], pl.Series("b", [1, 3, 2, 2, 4]))
256
257
joined = df_left.join(df_right, left_on="a", right_on="a", how="full").sort("a")
258
assert joined["c_right"].null_count() == 1
259
assert joined["c"].null_count() == 1
260
assert joined["b"].null_count() == 1
261
assert joined["k"].null_count() == 1
262
assert joined["a"].null_count() == 1
263
264
# we need to pass in a column to join on, either by supplying `on`, or both
265
# `left_on` and `right_on`
266
with pytest.raises(ValueError):
267
df_left.join(df_right)
268
with pytest.raises(ValueError):
269
df_left.join(df_right, right_on="a")
270
with pytest.raises(ValueError):
271
df_left.join(df_right, left_on="a")
272
273
df_a = pl.DataFrame({"a": [1, 2, 1, 1], "b": ["a", "b", "c", "c"]})
274
df_b = pl.DataFrame(
275
{"foo": [1, 1, 1], "bar": ["a", "c", "c"], "ham": ["let", "var", "const"]}
276
)
277
278
# just check if join on multiple columns runs
279
df_a.join(df_b, left_on=["a", "b"], right_on=["foo", "bar"])
280
eager_join = df_a.join(df_b, left_on="a", right_on="foo")
281
lazy_join = df_a.lazy().join(df_b.lazy(), left_on="a", right_on="foo").collect()
282
283
cols = ["a", "b", "bar", "ham"]
284
assert lazy_join.shape == eager_join.shape
285
assert_frame_equal(lazy_join.sort(by=cols), eager_join.sort(by=cols))
286
287
288
def test_joins_dispatch() -> None:
289
# this just flexes the dispatch a bit
290
291
# don't change the data of this dataframe, this triggered:
292
# https://github.com/pola-rs/polars/issues/1688
293
dfa = pl.DataFrame(
294
{
295
"a": ["a", "b", "c", "a"],
296
"b": [1, 2, 3, 1],
297
"date": ["2021-01-01", "2021-01-02", "2021-01-03", "2021-01-01"],
298
"datetime": [13241324, 12341256, 12341234, 13241324],
299
}
300
).with_columns(
301
pl.col("date").str.strptime(pl.Date), pl.col("datetime").cast(pl.Datetime)
302
)
303
304
join_strategies: list[JoinStrategy] = ["left", "inner", "full"]
305
for how in join_strategies:
306
dfa.join(dfa, on=["a", "b", "date", "datetime"], how=how)
307
dfa.join(dfa, on=["date", "datetime"], how=how)
308
dfa.join(dfa, on=["date", "datetime", "a"], how=how)
309
dfa.join(dfa, on=["date", "a"], how=how)
310
dfa.join(dfa, on=["a", "datetime"], how=how)
311
dfa.join(dfa, on=["date"], how=how)
312
313
314
def test_join_on_cast() -> None:
315
df_a = (
316
pl.DataFrame({"a": [-5, -2, 3, 3, 9, 10]})
317
.with_row_index()
318
.with_columns(pl.col("a").cast(pl.Int32))
319
)
320
321
df_b = pl.DataFrame({"a": [-2, -3, 3, 10]})
322
323
assert_frame_equal(
324
df_a.join(df_b, on=pl.col("a").cast(pl.Int64)),
325
pl.DataFrame(
326
{
327
"index": [1, 2, 3, 5],
328
"a": [-2, 3, 3, 10],
329
"a_right": [-2, 3, 3, 10],
330
}
331
),
332
check_row_order=False,
333
check_dtypes=False,
334
)
335
assert df_a.lazy().join(
336
df_b.lazy(),
337
on=pl.col("a").cast(pl.Int64),
338
maintain_order="left",
339
).collect().to_dict(as_series=False) == {
340
"index": [1, 2, 3, 5],
341
"a": [-2, 3, 3, 10],
342
"a_right": [-2, 3, 3, 10],
343
}
344
345
346
def test_join_chunks_alignment_4720() -> None:
347
# https://github.com/pola-rs/polars/issues/4720
348
349
df1 = pl.DataFrame(
350
{
351
"index1": pl.arange(0, 2, eager=True),
352
"index2": pl.arange(10, 12, eager=True),
353
}
354
)
355
356
df2 = pl.DataFrame(
357
{
358
"index3": pl.arange(100, 102, eager=True),
359
}
360
)
361
362
df3 = pl.DataFrame(
363
{
364
"index1": pl.arange(0, 2, eager=True),
365
"index2": pl.arange(10, 12, eager=True),
366
"index3": pl.arange(100, 102, eager=True),
367
}
368
)
369
assert_frame_equal(
370
df1.join(df2, how="cross").join(
371
df3,
372
on=["index1", "index2", "index3"],
373
how="left",
374
),
375
pl.DataFrame(
376
{
377
"index1": [0, 0, 1, 1],
378
"index2": [10, 10, 11, 11],
379
"index3": [100, 101, 100, 101],
380
}
381
),
382
check_row_order=False,
383
)
384
385
assert_frame_equal(
386
df1.join(df2, how="cross").join(
387
df3,
388
on=["index3", "index1", "index2"],
389
how="left",
390
),
391
pl.DataFrame(
392
{
393
"index1": [0, 0, 1, 1],
394
"index2": [10, 10, 11, 11],
395
"index3": [100, 101, 100, 101],
396
}
397
),
398
check_row_order=False,
399
)
400
401
402
def test_jit_sort_joins() -> None:
403
n = 200
404
# Explicitly specify numpy dtype because of different defaults on Windows
405
dfa = pd.DataFrame(
406
{
407
"a": np.random.randint(0, 100, n, dtype=np.int64),
408
"b": np.arange(0, n, dtype=np.int64),
409
}
410
)
411
412
n = 40
413
dfb = pd.DataFrame(
414
{
415
"a": np.random.randint(0, 100, n, dtype=np.int64),
416
"b": np.arange(0, n, dtype=np.int64),
417
}
418
)
419
dfa_pl = pl.from_pandas(dfa).sort("a")
420
dfb_pl = pl.from_pandas(dfb)
421
422
join_strategies: list[Literal["left", "inner"]] = ["left", "inner"]
423
for how in join_strategies:
424
pd_result = dfa.merge(dfb, on="a", how=how)
425
pd_result.columns = pd.Index(["a", "b", "b_right"])
426
427
# left key sorted right is not
428
pl_result = dfa_pl.join(dfb_pl, on="a", how=how).sort(["a", "b", "b_right"])
429
430
a = (
431
pl.from_pandas(pd_result)
432
.with_columns(pl.all().cast(int))
433
.sort(["a", "b", "b_right"])
434
)
435
assert_frame_equal(a, pl_result)
436
assert pl_result["a"].flags["SORTED_ASC"]
437
438
# left key sorted right is not
439
pd_result = dfb.merge(dfa, on="a", how=how)
440
pd_result.columns = pd.Index(["a", "b", "b_right"])
441
pl_result = dfb_pl.join(dfa_pl, on="a", how=how).sort(["a", "b", "b_right"])
442
443
a = (
444
pl.from_pandas(pd_result)
445
.with_columns(pl.all().cast(int))
446
.sort(["a", "b", "b_right"])
447
)
448
assert_frame_equal(a, pl_result)
449
assert pl_result["a"].flags["SORTED_ASC"]
450
451
452
def test_join_panic_on_binary_expr_5915() -> None:
453
df_a = pl.DataFrame({"a": [1, 2, 3]}).lazy()
454
df_b = pl.DataFrame({"b": [1, 4, 9, 9, 0]}).lazy()
455
456
z = df_a.join(df_b, left_on=[(pl.col("a") + 1).cast(int)], right_on=[pl.col("b")])
457
assert z.collect().to_dict(as_series=False) == {"a": [3], "b": [4]}
458
459
460
def test_semi_join_projection_pushdown_6423() -> None:
461
df1 = pl.DataFrame({"x": [1]}).lazy()
462
df2 = pl.DataFrame({"y": [1], "x": [1]}).lazy()
463
464
assert (
465
df1.join(df2, left_on="x", right_on="y", how="semi")
466
.join(df2, left_on="x", right_on="y", how="semi")
467
.select(["x"])
468
).collect().to_dict(as_series=False) == {"x": [1]}
469
470
471
def test_semi_join_projection_pushdown_6455() -> None:
472
df = pl.DataFrame(
473
{
474
"id": [1, 1, 2],
475
"timestamp": [
476
datetime(2022, 12, 11),
477
datetime(2022, 12, 12),
478
datetime(2022, 1, 1),
479
],
480
"value": [1, 2, 4],
481
}
482
).lazy()
483
484
latest = df.group_by("id").agg(pl.col("timestamp").max())
485
df = df.join(latest, on=["id", "timestamp"], how="semi")
486
assert df.select(["id", "value"]).collect().to_dict(as_series=False) == {
487
"id": [1, 2],
488
"value": [2, 4],
489
}
490
491
492
def test_update() -> None:
493
df1 = pl.DataFrame(
494
{
495
"key1": [1, 2, 3, 4],
496
"key2": [1, 2, 3, 4],
497
"a": [1, 2, 3, 4],
498
"b": [1, 2, 3, 4],
499
"c": ["1", "2", "3", "4"],
500
"d": [
501
date(2023, 1, 1),
502
date(2023, 1, 2),
503
date(2023, 1, 3),
504
date(2023, 1, 4),
505
],
506
}
507
)
508
509
df2 = pl.DataFrame(
510
{
511
"key1": [1, 2, 3, 4],
512
"key2": [1, 2, 3, 5],
513
"a": [1, 1, 1, 1],
514
"b": [2, 2, 2, 2],
515
"c": ["3", "3", "3", "3"],
516
"d": [
517
date(2023, 5, 5),
518
date(2023, 5, 5),
519
date(2023, 5, 5),
520
date(2023, 5, 5),
521
],
522
}
523
)
524
525
# update only on key1
526
expected = pl.DataFrame(
527
{
528
"key1": [1, 2, 3, 4],
529
"key2": [1, 2, 3, 5],
530
"a": [1, 1, 1, 1],
531
"b": [2, 2, 2, 2],
532
"c": ["3", "3", "3", "3"],
533
"d": [
534
date(2023, 5, 5),
535
date(2023, 5, 5),
536
date(2023, 5, 5),
537
date(2023, 5, 5),
538
],
539
}
540
)
541
assert_frame_equal(df1.update(df2, on="key1"), expected)
542
543
# update on key1 using different left/right names
544
assert_frame_equal(
545
df1.update(
546
df2.rename({"key1": "key1b"}),
547
left_on="key1",
548
right_on="key1b",
549
),
550
expected,
551
)
552
553
# update on key1 and key2. This should fail to match the last item.
554
expected = pl.DataFrame(
555
{
556
"key1": [1, 2, 3, 4],
557
"key2": [1, 2, 3, 4],
558
"a": [1, 1, 1, 4],
559
"b": [2, 2, 2, 4],
560
"c": ["3", "3", "3", "4"],
561
"d": [
562
date(2023, 5, 5),
563
date(2023, 5, 5),
564
date(2023, 5, 5),
565
date(2023, 1, 4),
566
],
567
}
568
)
569
assert_frame_equal(df1.update(df2, on=["key1", "key2"]), expected)
570
571
# update on key1 and key2 using different left/right names
572
assert_frame_equal(
573
df1.update(
574
df2.rename({"key1": "key1b", "key2": "key2b"}),
575
left_on=["key1", "key2"],
576
right_on=["key1b", "key2b"],
577
),
578
expected,
579
)
580
581
df = pl.DataFrame({"A": [1, 2, 3, 4], "B": [400, 500, 600, 700]})
582
583
new_df = pl.DataFrame({"B": [4, None, 6], "C": [7, 8, 9]})
584
585
assert df.update(new_df).to_dict(as_series=False) == {
586
"A": [1, 2, 3, 4],
587
"B": [4, 500, 6, 700],
588
}
589
df1 = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
590
df2 = pl.DataFrame({"a": [2, 3], "b": [8, 9]})
591
592
assert df1.update(df2, on="a").to_dict(as_series=False) == {
593
"a": [1, 2, 3],
594
"b": [4, 8, 9],
595
}
596
597
a = pl.LazyFrame({"a": [1, 2, 3]})
598
b = pl.LazyFrame({"b": [4, 5], "c": [3, 1]})
599
c = a.update(b)
600
601
assert_frame_equal(a, c)
602
603
# check behaviour of 'how' param
604
result = a.update(b, left_on="a", right_on="c")
605
assert result.collect().to_series().to_list() == [1, 2, 3]
606
607
result = a.update(b, how="inner", left_on="a", right_on="c")
608
assert sorted(result.collect().to_series().to_list()) == [1, 3]
609
610
result = a.update(b.rename({"b": "a"}), how="full", on="a")
611
assert sorted(result.collect().to_series().sort().to_list()) == [1, 2, 3, 4, 5]
612
613
# check behavior of include_nulls=True
614
df = pl.DataFrame(
615
{
616
"A": [1, 2, 3, 4],
617
"B": [400, 500, 600, 700],
618
}
619
)
620
new_df = pl.DataFrame(
621
{
622
"B": [-66, None, -99],
623
"C": [5, 3, 1],
624
}
625
)
626
out = df.update(new_df, left_on="A", right_on="C", how="full", include_nulls=True)
627
expected = pl.DataFrame(
628
{
629
"A": [1, 2, 3, 4, 5],
630
"B": [-99, 500, None, 700, -66],
631
}
632
)
633
assert_frame_equal(out, expected, check_row_order=False)
634
635
# edge-case #11684
636
x = pl.DataFrame({"a": [0, 1]})
637
y = pl.DataFrame({"a": [2, 3]})
638
assert sorted(x.update(y, on="a", how="full")["a"].to_list()) == [0, 1, 2, 3]
639
640
# disallowed join strategies
641
for join_strategy in ("cross", "anti", "semi"):
642
with pytest.raises(
643
ValueError,
644
match=f"`how` must be one of {{'left', 'inner', 'full'}}; found '{join_strategy}'",
645
):
646
a.update(b, how=join_strategy) # type: ignore[arg-type]
647
648
649
def test_join_frame_consistency() -> None:
650
df = pl.DataFrame({"A": [1, 2, 3]})
651
ldf = pl.DataFrame({"A": [1, 2, 5]}).lazy()
652
653
with pytest.raises(TypeError, match=r"expected `other`.*LazyFrame"):
654
_ = ldf.join(df, on="A") # type: ignore[arg-type]
655
with pytest.raises(TypeError, match=r"expected `other`.*DataFrame"):
656
_ = df.join(ldf, on="A") # type: ignore[arg-type]
657
with pytest.raises(TypeError, match=r"expected `other`.*LazyFrame"):
658
_ = ldf.join_asof(df, on="A") # type: ignore[arg-type]
659
with pytest.raises(TypeError, match=r"expected `other`.*DataFrame"):
660
_ = df.join_asof(ldf, on="A") # type: ignore[arg-type]
661
662
663
def test_join_concat_projection_pd_case_7071() -> None:
664
ldf = pl.DataFrame({"id": [1, 2], "value": [100, 200]}).lazy()
665
ldf2 = pl.DataFrame({"id": [1, 3], "value": [100, 300]}).lazy()
666
667
ldf = ldf.join(ldf2, on=["id", "value"])
668
ldf = pl.concat([ldf, ldf2])
669
result = ldf.select("id")
670
671
expected = pl.DataFrame({"id": [1, 1, 3]}).lazy()
672
assert_frame_equal(result, expected)
673
674
675
@pytest.mark.may_fail_auto_streaming # legacy full join is not order-preserving whereas new-streaming is
676
def test_join_sorted_fast_paths_null() -> None:
677
df1 = pl.DataFrame({"x": [0, 1, 0]}).sort("x")
678
df2 = pl.DataFrame({"x": [0, None], "y": [0, 1]})
679
assert df1.join(df2, on="x", how="inner").to_dict(as_series=False) == {
680
"x": [0, 0],
681
"y": [0, 0],
682
}
683
assert df1.join(df2, on="x", how="left").to_dict(as_series=False) == {
684
"x": [0, 0, 1],
685
"y": [0, 0, None],
686
}
687
assert df1.join(df2, on="x", how="anti").to_dict(as_series=False) == {"x": [1]}
688
assert df1.join(df2, on="x", how="semi").to_dict(as_series=False) == {"x": [0, 0]}
689
assert df1.join(df2, on="x", how="full").to_dict(as_series=False) == {
690
"x": [0, 0, 1, None],
691
"x_right": [0, 0, None, None],
692
"y": [0, 0, None, 1],
693
}
694
695
696
def test_full_outer_join_list_() -> None:
697
schema = {"id": pl.Int64, "vals": pl.List(pl.Float64)}
698
join_schema = {**schema, **{k + "_right": t for (k, t) in schema.items()}}
699
df1 = pl.DataFrame({"id": [1], "vals": [[]]}, schema=schema) # type: ignore[arg-type]
700
df2 = pl.DataFrame({"id": [2, 3], "vals": [[], [4]]}, schema=schema) # type: ignore[arg-type]
701
expected = pl.DataFrame(
702
{
703
"id": [None, None, 1],
704
"vals": [None, None, []],
705
"id_right": [2, 3, None],
706
"vals_right": [[], [4.0], None],
707
},
708
schema=join_schema, # type: ignore[arg-type]
709
)
710
out = df1.join(df2, on="id", how="full", maintain_order="right_left")
711
assert_frame_equal(out, expected)
712
713
714
@pytest.mark.slow
715
def test_join_validation() -> None:
716
def test_each_join_validation(
717
unique: pl.DataFrame, duplicate: pl.DataFrame, on: str, how: JoinStrategy
718
) -> None:
719
# one_to_many
720
_one_to_many_success_inner = unique.join(
721
duplicate, on=on, how=how, validate="1:m"
722
)
723
724
with pytest.raises(ComputeError):
725
_one_to_many_fail_inner = duplicate.join(
726
unique, on=on, how=how, validate="1:m"
727
)
728
729
# one to one
730
with pytest.raises(ComputeError):
731
_one_to_one_fail_1_inner = unique.join(
732
duplicate, on=on, how=how, validate="1:1"
733
)
734
735
with pytest.raises(ComputeError):
736
_one_to_one_fail_2_inner = duplicate.join(
737
unique, on=on, how=how, validate="1:1"
738
)
739
740
# many to one
741
with pytest.raises(ComputeError):
742
_many_to_one_fail_inner = unique.join(
743
duplicate, on=on, how=how, validate="m:1"
744
)
745
746
_many_to_one_success_inner = duplicate.join(
747
unique, on=on, how=how, validate="m:1"
748
)
749
750
# many to many
751
_many_to_many_success_1_inner = duplicate.join(
752
unique, on=on, how=how, validate="m:m"
753
)
754
755
_many_to_many_success_2_inner = unique.join(
756
duplicate, on=on, how=how, validate="m:m"
757
)
758
759
# test data
760
short_unique = pl.DataFrame(
761
{
762
"id": [1, 2, 3, 4],
763
"id_str": ["1", "2", "3", "4"],
764
"name": ["hello", "world", "rust", "polars"],
765
}
766
)
767
short_duplicate = pl.DataFrame(
768
{"id": [1, 2, 3, 1], "id_str": ["1", "2", "3", "1"], "cnt": [2, 4, 6, 1]}
769
)
770
long_unique = pl.DataFrame(
771
{
772
"id": [1, 2, 3, 4, 5],
773
"id_str": ["1", "2", "3", "4", "5"],
774
"name": ["hello", "world", "rust", "polars", "meow"],
775
}
776
)
777
long_duplicate = pl.DataFrame(
778
{
779
"id": [1, 2, 3, 1, 5],
780
"id_str": ["1", "2", "3", "1", "5"],
781
"cnt": [2, 4, 6, 1, 8],
782
}
783
)
784
785
join_strategies: list[JoinStrategy] = ["inner", "full", "left"]
786
787
for join_col in ["id", "id_str"]:
788
for how in join_strategies:
789
# same size
790
test_each_join_validation(long_unique, long_duplicate, join_col, how)
791
792
# left longer
793
test_each_join_validation(long_unique, short_duplicate, join_col, how)
794
795
# right longer
796
test_each_join_validation(short_unique, long_duplicate, join_col, how)
797
798
799
@typing.no_type_check
800
def test_join_validation_many_keys() -> None:
801
# unique in both
802
df1 = pl.DataFrame(
803
{
804
"val1": [11, 12, 13, 14],
805
"val2": [1, 2, 3, 4],
806
}
807
)
808
df2 = pl.DataFrame(
809
{
810
"val1": [11, 12, 13, 14],
811
"val2": [1, 2, 3, 4],
812
}
813
)
814
for join_type in ["inner", "left", "full"]:
815
for val in ["m:m", "m:1", "1:1", "1:m"]:
816
df1.join(df2, on=["val1", "val2"], how=join_type, validate=val)
817
818
# many in lhs
819
df1 = pl.DataFrame(
820
{
821
"val1": [11, 11, 12, 13, 14],
822
"val2": [1, 1, 2, 3, 4],
823
}
824
)
825
826
for join_type in ["inner", "left", "full"]:
827
for val in ["1:1", "1:m"]:
828
with pytest.raises(ComputeError):
829
df1.join(df2, on=["val1", "val2"], how=join_type, validate=val)
830
831
# many in rhs
832
df1 = pl.DataFrame(
833
{
834
"val1": [11, 12, 13, 14],
835
"val2": [1, 2, 3, 4],
836
}
837
)
838
df2 = pl.DataFrame(
839
{
840
"val1": [11, 11, 12, 13, 14],
841
"val2": [1, 1, 2, 3, 4],
842
}
843
)
844
845
for join_type in ["inner", "left", "full"]:
846
for val in ["m:1", "1:1"]:
847
with pytest.raises(ComputeError):
848
df1.join(df2, on=["val1", "val2"], how=join_type, validate=val)
849
850
851
def test_full_outer_join_bool() -> None:
852
df1 = pl.DataFrame({"id": [True, False], "val": [1, 2]})
853
df2 = pl.DataFrame({"id": [True, False], "val": [0, -1]})
854
assert df1.join(df2, on="id", how="full", maintain_order="right").to_dict(
855
as_series=False
856
) == {
857
"id": [True, False],
858
"val": [1, 2],
859
"id_right": [True, False],
860
"val_right": [0, -1],
861
}
862
863
864
def test_full_outer_join_coalesce_different_names_13450() -> None:
865
df1 = pl.DataFrame({"L1": ["a", "b", "c"], "L3": ["b", "c", "d"], "L2": [1, 2, 3]})
866
df2 = pl.DataFrame({"L3": ["a", "c", "d"], "R2": [7, 8, 9]})
867
868
expected = pl.DataFrame(
869
{
870
"L1": ["a", "c", "d", "b"],
871
"L3": ["b", "d", None, "c"],
872
"L2": [1, 3, None, 2],
873
"R2": [7, 8, 9, None],
874
}
875
)
876
877
out = df1.join(df2, left_on="L1", right_on="L3", how="full", coalesce=True)
878
assert_frame_equal(out, expected, check_row_order=False)
879
880
881
# https://github.com/pola-rs/polars/issues/10663
882
def test_join_on_wildcard_error() -> None:
883
df = pl.DataFrame({"x": [1]})
884
df2 = pl.DataFrame({"x": [1], "y": [2]})
885
with pytest.raises(
886
InvalidOperationError,
887
):
888
df.join(df2, on=pl.all())
889
890
891
def test_join_on_nth_error() -> None:
892
df = pl.DataFrame({"x": [1]})
893
df2 = pl.DataFrame({"x": [1], "y": [2]})
894
with pytest.raises(
895
InvalidOperationError,
896
):
897
df.join(df2, on=pl.first())
898
899
900
def test_join_results_in_duplicate_names() -> None:
901
df = pl.DataFrame(
902
{
903
"a": [1, 2, 3],
904
"b": [4, 5, 6],
905
"c": [1, 2, 3],
906
"c_right": [1, 2, 3],
907
}
908
)
909
910
def f(x: Any) -> Any:
911
return x.join(x, on=["a", "b"], how="left")
912
913
# Ensure it also contains the hint
914
match_str = "(?s)column with name 'c_right' already exists.*You may want to try"
915
916
# Ensure it fails immediately when resolving schema.
917
with pytest.raises(DuplicateError, match=match_str):
918
f(df.lazy()).collect_schema()
919
920
with pytest.raises(DuplicateError, match=match_str):
921
f(df.lazy()).collect()
922
923
with pytest.raises(DuplicateError, match=match_str):
924
f(df).collect()
925
926
927
def test_join_duplicate_suffixed_columns_from_join_key_column_21048() -> None:
928
df = pl.DataFrame({"a": 1, "b": 1, "b_right": 1})
929
930
def f(x: Any) -> Any:
931
return x.join(x, on="a")
932
933
# Ensure it also contains the hint
934
match_str = "(?s)column with name 'b_right' already exists.*You may want to try"
935
936
# Ensure it fails immediately when resolving schema.
937
with pytest.raises(DuplicateError, match=match_str):
938
f(df.lazy()).collect_schema()
939
940
with pytest.raises(DuplicateError, match=match_str):
941
f(df.lazy()).collect()
942
943
with pytest.raises(DuplicateError, match=match_str):
944
f(df)
945
946
947
def test_join_projection_invalid_name_contains_suffix_15243() -> None:
948
df1 = pl.DataFrame({"a": [1, 2, 3]}).lazy()
949
df2 = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).lazy()
950
951
with pytest.raises(ColumnNotFoundError):
952
(
953
df1.join(df2, on="a")
954
.select(pl.col("b").filter(pl.col("b") == pl.col("foo_right")))
955
.collect()
956
)
957
958
959
def test_join_list_non_numeric() -> None:
960
assert (
961
pl.DataFrame(
962
{
963
"lists": [
964
["a", "b", "c"],
965
["a", "c", "b"],
966
["a", "c", "b"],
967
["a", "c", "d"],
968
]
969
}
970
)
971
).group_by("lists", maintain_order=True).agg(pl.len().alias("count")).to_dict(
972
as_series=False
973
) == {
974
"lists": [["a", "b", "c"], ["a", "c", "b"], ["a", "c", "d"]],
975
"count": [1, 2, 1],
976
}
977
978
979
@pytest.mark.slow
980
def test_join_4_columns_with_validity() -> None:
981
# join on 4 columns so we trigger combine validities
982
# use 138 as that is 2 u64 and a remainder
983
a = pl.DataFrame(
984
{"a": [None if a % 6 == 0 else a for a in range(138)]}
985
).with_columns(
986
b=pl.col("a"),
987
c=pl.col("a"),
988
d=pl.col("a"),
989
)
990
991
assert a.join(a, on=["a", "b", "c", "d"], how="inner", nulls_equal=True).shape == (
992
644,
993
4,
994
)
995
assert a.join(a, on=["a", "b", "c", "d"], how="inner", nulls_equal=False).shape == (
996
115,
997
4,
998
)
999
1000
1001
@pytest.mark.release
1002
def test_cross_join() -> None:
1003
# triggers > 100 rows implementation
1004
# https://github.com/pola-rs/polars/blob/5f5acb2a523ce01bc710768b396762b8e69a9e07/polars/polars-core/src/frame/cross_join.rs#L34
1005
df1 = pl.DataFrame({"col1": ["a"], "col2": ["d"]})
1006
df2 = pl.DataFrame({"frame2": pl.arange(0, 100, eager=True)})
1007
out = df2.join(df1, how="cross")
1008
df2 = pl.DataFrame({"frame2": pl.arange(0, 101, eager=True)})
1009
assert_frame_equal(
1010
df2.join(df1, how="cross", maintain_order="left_right").slice(0, 100), out
1011
)
1012
1013
1014
@pytest.mark.release
1015
def test_cross_join_slice_pushdown() -> None:
1016
# this will likely go out of memory if we did not pushdown the slice
1017
df = (
1018
pl.Series("x", pl.arange(0, 2**16 - 1, eager=True, dtype=pl.UInt16) % 2**15)
1019
).to_frame()
1020
1021
result = (
1022
df.lazy()
1023
.join(df.lazy(), how="cross", maintain_order="left_right", suffix="_")
1024
.slice(-5, 10)
1025
.collect()
1026
)
1027
expected = pl.DataFrame(
1028
{
1029
"x": [32766, 32766, 32766, 32766, 32766],
1030
"x_": [32762, 32763, 32764, 32765, 32766],
1031
},
1032
schema={"x": pl.UInt16, "x_": pl.UInt16},
1033
)
1034
assert_frame_equal(result, expected)
1035
1036
result = (
1037
df.lazy()
1038
.join(df.lazy(), how="cross", maintain_order="left_right", suffix="_")
1039
.slice(2, 10)
1040
.collect()
1041
)
1042
expected = pl.DataFrame(
1043
{
1044
"x": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
1045
"x_": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
1046
},
1047
schema={"x": pl.UInt16, "x_": pl.UInt16},
1048
)
1049
assert_frame_equal(result, expected)
1050
1051
1052
@pytest.mark.parametrize("how", ["left", "inner"])
1053
def test_join_coalesce(how: JoinStrategy) -> None:
1054
a = pl.LazyFrame({"a": [1, 2], "b": [1, 2]})
1055
b = pl.LazyFrame(
1056
{
1057
"a": [1, 2, 1, 2],
1058
"b": [5, 7, 8, 9],
1059
"c": [1, 2, 1, 2],
1060
}
1061
)
1062
1063
how = "inner"
1064
q = a.join(b, on="a", coalesce=False, how=how)
1065
out = q.collect()
1066
assert q.collect_schema() == out.schema
1067
assert out.columns == ["a", "b", "a_right", "b_right", "c"]
1068
1069
q = a.join(b, on=["a", "b"], coalesce=False, how=how)
1070
out = q.collect()
1071
assert q.collect_schema() == out.schema
1072
assert out.columns == ["a", "b", "a_right", "b_right", "c"]
1073
1074
q = a.join(b, on=["a", "b"], coalesce=True, how=how)
1075
out = q.collect()
1076
assert q.collect_schema() == out.schema
1077
assert out.columns == ["a", "b", "c"]
1078
1079
1080
@pytest.mark.parametrize("how", ["left", "inner", "full"])
1081
def test_join_empties(how: JoinStrategy) -> None:
1082
df1 = pl.DataFrame({"col1": [], "col2": [], "col3": []})
1083
df2 = pl.DataFrame({"col2": [], "col4": [], "col5": []})
1084
1085
df = df1.join(df2, on="col2", how=how)
1086
assert df.height == 0
1087
1088
1089
def test_join_raise_on_redundant_keys() -> None:
1090
left = pl.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5], "c": [5, 6, 7]})
1091
right = pl.DataFrame({"a": [2, 3, 4], "c": [4, 5, 6]})
1092
with pytest.raises(InvalidOperationError, match="already joined on"):
1093
left.join(right, on=["a", "a"], how="full", coalesce=True)
1094
1095
1096
@pytest.mark.parametrize("coalesce", [False, True])
1097
def test_join_raise_on_repeated_expression_key_names(coalesce: bool) -> None:
1098
left = pl.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5], "c": [5, 6, 7]})
1099
right = pl.DataFrame({"a": [2, 3, 4], "c": [4, 5, 6]})
1100
with ( # noqa: PT012
1101
pytest.raises(InvalidOperationError, match="already joined on"),
1102
warnings.catch_warnings(),
1103
):
1104
warnings.simplefilter(action="ignore", category=UserWarning)
1105
left.join(
1106
right, on=[pl.col("a"), pl.col("a") % 2], how="full", coalesce=coalesce
1107
)
1108
1109
1110
def test_join_lit_panic_11410() -> None:
1111
df = pl.LazyFrame({"date": [1, 2, 3], "symbol": [4, 5, 6]})
1112
dates = df.select("date").unique(maintain_order=True)
1113
symbols = df.select("symbol").unique(maintain_order=True)
1114
1115
assert symbols.join(
1116
dates, left_on=pl.lit(1), right_on=pl.lit(1), maintain_order="left_right"
1117
).collect().to_dict(as_series=False) == {
1118
"symbol": [4, 4, 4, 5, 5, 5, 6, 6, 6],
1119
"date": [1, 2, 3, 1, 2, 3, 1, 2, 3],
1120
}
1121
1122
1123
def test_join_empty_literal_17027() -> None:
1124
df1 = pl.DataFrame({"a": [1]})
1125
df2 = pl.DataFrame(schema={"a": pl.Int64})
1126
1127
assert df1.join(df2, on=pl.lit(0), how="left").height == 1
1128
assert df1.join(df2, on=pl.lit(0), how="inner").height == 0
1129
assert (
1130
df1.lazy()
1131
.join(df2.lazy(), on=pl.lit(0), how="inner")
1132
.collect(engine="streaming")
1133
.height
1134
== 0
1135
)
1136
assert (
1137
df1.lazy()
1138
.join(df2.lazy(), on=pl.lit(0), how="left")
1139
.collect(engine="streaming")
1140
.height
1141
== 1
1142
)
1143
1144
1145
@pytest.mark.parametrize(
1146
("left_on", "right_on"),
1147
zip(
1148
[pl.col("a"), pl.col("a").sort(), [pl.col("a"), pl.col("b")]],
1149
[pl.col("a").slice(0, 2) * 2, pl.col("b"), [pl.col("a"), pl.col("b").head()]],
1150
strict=False,
1151
),
1152
)
1153
def test_join_non_elementwise_keys_raises(left_on: pl.Expr, right_on: pl.Expr) -> None:
1154
# https://github.com/pola-rs/polars/issues/17184
1155
left = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
1156
right = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
1157
1158
q = left.join(
1159
right,
1160
left_on=left_on,
1161
right_on=right_on,
1162
how="inner",
1163
)
1164
1165
with pytest.raises(pl.exceptions.InvalidOperationError):
1166
q.collect()
1167
1168
1169
def test_join_coalesce_not_supported_warning() -> None:
1170
# https://github.com/pola-rs/polars/issues/17184
1171
left = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
1172
right = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
1173
1174
q = left.join(
1175
right,
1176
left_on=[pl.col("a") * 2],
1177
right_on=[pl.col("a") * 2],
1178
how="inner",
1179
coalesce=True,
1180
)
1181
with pytest.warns(UserWarning, match="turning off key coalescing"):
1182
got = q.collect()
1183
expect = pl.DataFrame(
1184
{"a": [1, 2, 3], "b": [3, 4, 5], "a_right": [1, 2, 3], "b_right": [3, 4, 5]}
1185
)
1186
1187
assert_frame_equal(expect, got, check_row_order=False)
1188
1189
1190
@pytest.mark.parametrize(
1191
("on_args"),
1192
[
1193
{"on": "a", "left_on": "a"},
1194
{"on": "a", "right_on": "a"},
1195
{"on": "a", "left_on": "a", "right_on": "a"},
1196
],
1197
)
1198
def test_join_on_and_left_right_on(on_args: dict[str, str]) -> None:
1199
df1 = pl.DataFrame({"a": [1], "b": [2]})
1200
df2 = pl.DataFrame({"a": [1], "c": [3]})
1201
msg = "cannot use 'on' in conjunction with 'left_on' or 'right_on'"
1202
with pytest.raises(ValueError, match=msg):
1203
df1.join(df2, **on_args) # type: ignore[arg-type]
1204
1205
1206
@pytest.mark.parametrize(
1207
("on_args"),
1208
[
1209
{"left_on": "a"},
1210
{"right_on": "a"},
1211
],
1212
)
1213
def test_join_only_left_or_right_on(on_args: dict[str, str]) -> None:
1214
df1 = pl.DataFrame({"a": [1]})
1215
df2 = pl.DataFrame({"a": [1]})
1216
msg = "'left_on' requires corresponding 'right_on'"
1217
with pytest.raises(ValueError, match=msg):
1218
df1.join(df2, **on_args) # type: ignore[arg-type]
1219
1220
1221
@pytest.mark.parametrize(
1222
("on_args"),
1223
[
1224
{"on": "a"},
1225
{"left_on": "a", "right_on": "a"},
1226
],
1227
)
1228
def test_cross_join_no_on_keys(on_args: dict[str, str]) -> None:
1229
df1 = pl.DataFrame({"a": [1, 2]})
1230
df2 = pl.DataFrame({"b": [3, 4]})
1231
msg = "cross join should not pass join keys"
1232
with pytest.raises(ValueError, match=msg):
1233
df1.join(df2, how="cross", **on_args) # type: ignore[arg-type]
1234
1235
1236
@pytest.mark.parametrize("set_sorted", [True, False])
1237
def test_left_join_slice_pushdown_19405(set_sorted: bool) -> None:
1238
left = pl.LazyFrame({"k": [1, 2, 3, 4, 0]})
1239
right = pl.LazyFrame({"k": [1, 1, 1, 1, 0]})
1240
1241
if set_sorted:
1242
# The data isn't actually sorted on purpose to ensure we default to a
1243
# hash join unless we set the sorted flag here, in case there is new
1244
# code in the future that automatically identifies sortedness during
1245
# Series construction from Python.
1246
left = left.set_sorted("k")
1247
right = right.set_sorted("k")
1248
1249
q = left.join(right, on="k", how="left", maintain_order="left_right").head(5)
1250
assert_frame_equal(q.collect(), pl.DataFrame({"k": [1, 1, 1, 1, 2]}))
1251
1252
1253
def test_join_key_type_coercion_19597() -> None:
1254
left = pl.LazyFrame({"a": pl.Series([1, 2, 3], dtype=pl.Float64)})
1255
right = pl.LazyFrame({"a": pl.Series([1, 2, 3], dtype=pl.Int64)})
1256
1257
with pytest.raises(SchemaError, match="datatypes of join keys don't match"):
1258
left.join(right, left_on=pl.col("a"), right_on=pl.col("a")).collect_schema()
1259
1260
with pytest.raises(SchemaError, match="datatypes of join keys don't match"):
1261
left.join(
1262
right, left_on=pl.col("a") * 2, right_on=pl.col("a") * 2
1263
).collect_schema()
1264
1265
1266
def test_array_explode_join_19763() -> None:
1267
q = pl.LazyFrame().select(
1268
pl.lit(pl.Series([[1], [2]], dtype=pl.Array(pl.Int64, 1))).explode().alias("k")
1269
)
1270
1271
q = q.join(pl.LazyFrame({"k": [1, 2]}), on="k")
1272
1273
assert_frame_equal(q.collect().sort("k"), pl.DataFrame({"k": [1, 2]}))
1274
1275
1276
def test_join_full_19814() -> None:
1277
schema = {"a": pl.Int64, "c": pl.Categorical}
1278
a = pl.LazyFrame({"a": [1], "c": [None]}, schema=schema)
1279
b = pl.LazyFrame({"a": [1, 3, 4]})
1280
assert_frame_equal(
1281
a.join(b, on="a", how="full", coalesce=True).collect(),
1282
pl.DataFrame({"a": [1, 3, 4], "c": [None, None, None]}, schema=schema),
1283
check_row_order=False,
1284
)
1285
1286
1287
def test_join_preserve_order_inner() -> None:
1288
left = pl.LazyFrame({"a": [None, 2, 1, 1, 5]})
1289
right = pl.LazyFrame({"a": [1, 1, None, 2], "b": [6, 7, 8, 9]})
1290
1291
# Inner joins
1292
1293
inner_left = left.join(right, on="a", how="inner", maintain_order="left").collect()
1294
assert inner_left.get_column("a").cast(pl.UInt32).to_list() == [2, 1, 1, 1, 1]
1295
inner_left_right = left.join(
1296
right, on="a", how="inner", maintain_order="left"
1297
).collect()
1298
assert inner_left.get_column("a").equals(inner_left_right.get_column("a"))
1299
1300
inner_right = left.join(
1301
right, on="a", how="inner", maintain_order="right"
1302
).collect()
1303
assert inner_right.get_column("a").cast(pl.UInt32).to_list() == [1, 1, 1, 1, 2]
1304
inner_right_left = left.join(
1305
right, on="a", how="inner", maintain_order="right"
1306
).collect()
1307
assert inner_right.get_column("a").equals(inner_right_left.get_column("a"))
1308
1309
1310
# The new streaming engine does not provide the same maintain_order="none"
1311
# ordering guarantee that is currently kept for compatibility on the in-memory
1312
# engine.
1313
@pytest.mark.may_fail_auto_streaming
1314
def test_join_preserve_order_left() -> None:
1315
left = pl.LazyFrame({"a": [None, 2, 1, 1, 5]})
1316
right = pl.LazyFrame({"a": [1, None, 2, 6], "b": [6, 7, 8, 9]})
1317
1318
# Right now the left join algorithm is ordered without explicitly setting any order
1319
# This behaviour is deprecated but can only be removed in 2.0
1320
left_none = left.join(right, on="a", how="left", maintain_order="none").collect()
1321
assert left_none.get_column("a").cast(pl.UInt32).to_list() == [
1322
None,
1323
2,
1324
1,
1325
1,
1326
5,
1327
]
1328
1329
left_left = left.join(right, on="a", how="left", maintain_order="left").collect()
1330
assert left_left.get_column("a").cast(pl.UInt32).to_list() == [
1331
None,
1332
2,
1333
1,
1334
1,
1335
5,
1336
]
1337
1338
left_left_right = left.join(
1339
right, on="a", how="left", maintain_order="left_right"
1340
).collect()
1341
# If the left order is preserved then there are no unsorted right rows
1342
assert left_left.get_column("a").equals(left_left_right.get_column("a"))
1343
1344
left_right = left.join(right, on="a", how="left", maintain_order="right").collect()
1345
assert left_right.get_column("a").cast(pl.UInt32).to_list()[:5] == [
1346
1,
1347
1,
1348
2,
1349
None,
1350
5,
1351
]
1352
1353
left_right_left = left.join(
1354
right, on="a", how="left", maintain_order="right_left"
1355
).collect()
1356
assert left_right_left.get_column("a").cast(pl.UInt32).to_list() == [
1357
1,
1358
1,
1359
2,
1360
None,
1361
5,
1362
]
1363
1364
right_left = left.join(right, on="a", how="right", maintain_order="left").collect()
1365
assert right_left.get_column("a").cast(pl.UInt32).to_list() == [2, 1, 1, None, 6]
1366
1367
right_right = left.join(
1368
right, on="a", how="right", maintain_order="right"
1369
).collect()
1370
assert right_right.get_column("a").cast(pl.UInt32).to_list() == [
1371
1,
1372
1,
1373
None,
1374
2,
1375
6,
1376
]
1377
1378
1379
def test_join_preserve_order_full() -> None:
1380
left = pl.LazyFrame({"a": [None, 2, 1, 1, 5]})
1381
right = pl.LazyFrame({"a": [1, None, 2, 6], "b": [6, 7, 8, 9]})
1382
1383
full_left = left.join(right, on="a", how="full", maintain_order="left").collect()
1384
assert full_left.get_column("a").cast(pl.UInt32).to_list()[:5] == [
1385
None,
1386
2,
1387
1,
1388
1,
1389
5,
1390
]
1391
full_right = left.join(right, on="a", how="full", maintain_order="right").collect()
1392
assert full_right.get_column("a").cast(pl.UInt32).to_list()[:5] == [
1393
1,
1394
1,
1395
None,
1396
2,
1397
None,
1398
]
1399
1400
full_left_right = left.join(
1401
right, on="a", how="full", maintain_order="left_right"
1402
).collect()
1403
assert full_left_right.get_column("a_right").cast(pl.UInt32).to_list() == [
1404
None,
1405
2,
1406
1,
1407
1,
1408
None,
1409
None,
1410
6,
1411
]
1412
1413
full_right_left = left.join(
1414
right, on="a", how="full", maintain_order="right_left"
1415
).collect()
1416
assert full_right_left.get_column("a").cast(pl.UInt32).to_list() == [
1417
1,
1418
1,
1419
None,
1420
2,
1421
None,
1422
None,
1423
5,
1424
]
1425
1426
1427
@pytest.mark.parametrize(
1428
"dtypes",
1429
[
1430
["Int128", "Int128", "Int64"],
1431
["Int128", "Int128", "Int32"],
1432
["Int128", "Int128", "Int16"],
1433
["Int128", "Int128", "Int8"],
1434
["Int128", "UInt64", "Int128"],
1435
["Int128", "UInt64", "Int64"],
1436
["Int128", "UInt64", "Int32"],
1437
["Int128", "UInt64", "Int16"],
1438
["Int128", "UInt64", "Int8"],
1439
["Int128", "UInt32", "Int128"],
1440
["Int128", "UInt16", "Int128"],
1441
["Int128", "UInt8", "Int128"],
1442
1443
["Int64", "Int64", "Int32"],
1444
["Int64", "Int64", "Int16"],
1445
["Int64", "Int64", "Int8"],
1446
["Int64", "UInt32", "Int64"],
1447
["Int64", "UInt32", "Int32"],
1448
["Int64", "UInt32", "Int16"],
1449
["Int64", "UInt32", "Int8"],
1450
["Int64", "UInt16", "Int64"],
1451
["Int64", "UInt8", "Int64"],
1452
1453
["Int32", "Int32", "Int16"],
1454
["Int32", "Int32", "Int8"],
1455
["Int32", "UInt16", "Int32"],
1456
["Int32", "UInt16", "Int16"],
1457
["Int32", "UInt16", "Int8"],
1458
["Int32", "UInt8", "Int32"],
1459
1460
["Int16", "Int16", "Int8"],
1461
["Int16", "UInt8", "Int16"],
1462
["Int16", "UInt8", "Int8"],
1463
1464
["UInt128", "UInt128", "UInt64"],
1465
["UInt128", "UInt128", "UInt32"],
1466
["UInt128", "UInt128", "UInt16"],
1467
["UInt128", "UInt128", "UInt8"],
1468
["UInt128", "UInt64", "UInt128"],
1469
["UInt128", "UInt32", "UInt128"],
1470
["UInt128", "UInt16", "UInt128"],
1471
["UInt128", "UInt8", "UInt128"],
1472
1473
["UInt64", "UInt64", "UInt32"],
1474
["UInt64", "UInt64", "UInt16"],
1475
["UInt64", "UInt64", "UInt8"],
1476
1477
["UInt32", "UInt32", "UInt16"],
1478
["UInt32", "UInt32", "UInt8"],
1479
1480
["UInt16", "UInt16", "UInt8"],
1481
1482
["Float64", "Float64", "Float32"],
1483
["Float32", "Float32", "Float16"],
1484
],
1485
) # fmt: skip
1486
@pytest.mark.parametrize("swap", [True, False])
1487
def test_join_numeric_key_upcast_15338(
1488
dtypes: tuple[str, str, str], swap: bool
1489
) -> None:
1490
supertype, ltype, rtype = (getattr(pl, x) for x in dtypes)
1491
ltype, rtype = (rtype, ltype) if swap else (ltype, rtype)
1492
1493
left = pl.select(pl.Series("a", [1, 1, 3]).cast(ltype)).lazy()
1494
right = pl.select(pl.Series("a", [1]).cast(rtype), b=pl.lit("A")).lazy()
1495
1496
assert_frame_equal(
1497
left.join(right, on="a", how="left").collect(),
1498
pl.select(a=pl.Series([1, 1, 3]).cast(ltype), b=pl.Series(["A", "A", None])),
1499
check_row_order=False,
1500
)
1501
1502
assert_frame_equal(
1503
left.join(right, on="a", how="left", coalesce=False).drop("a_right").collect(),
1504
pl.select(a=pl.Series([1, 1, 3]).cast(ltype), b=pl.Series(["A", "A", None])),
1505
check_row_order=False,
1506
)
1507
1508
assert_frame_equal(
1509
left.join(right, on="a", how="full").collect(),
1510
pl.select(
1511
a=pl.Series([1, 1, 3]).cast(ltype),
1512
a_right=pl.Series([1, 1, None]).cast(rtype),
1513
b=pl.Series(["A", "A", None]),
1514
),
1515
check_row_order=False,
1516
)
1517
1518
assert_frame_equal(
1519
left.join(right, on="a", how="full", coalesce=True).collect(),
1520
pl.select(
1521
a=pl.Series([1, 1, 3]).cast(supertype),
1522
b=pl.Series(["A", "A", None]),
1523
),
1524
check_row_order=False,
1525
)
1526
1527
assert_frame_equal(
1528
left.join(right, on="a", how="semi").collect(),
1529
pl.select(a=pl.Series([1, 1]).cast(ltype)),
1530
)
1531
1532
# join_where
1533
for optimizations in [pl.QueryOptFlags(), pl.QueryOptFlags.none()]:
1534
assert_frame_equal(
1535
left.join_where(right, pl.col("a") == pl.col("a_right")).collect(
1536
optimizations=optimizations,
1537
),
1538
pl.select(
1539
a=pl.Series([1, 1]).cast(ltype),
1540
a_right=pl.lit(1, dtype=rtype),
1541
b=pl.Series(["A", "A"]),
1542
),
1543
)
1544
1545
1546
def test_join_numeric_key_upcast_forbid_float_int() -> None:
1547
ltype = pl.Float64
1548
rtype = pl.Int128
1549
1550
left = pl.LazyFrame({"a": [1.0, 0.0]}, schema={"a": ltype})
1551
right = pl.LazyFrame({"a": [1, 2]}, schema={"a": rtype})
1552
1553
# Establish baseline: In a non-join context, comparisons between ltype and
1554
# rtype succeed even if the upcast is lossy.
1555
assert_frame_equal(
1556
left.with_columns(right.collect()["a"].alias("a_right"))
1557
.select(pl.col("a") == pl.col("a_right"))
1558
.collect(),
1559
pl.DataFrame({"a": [True, False]}),
1560
)
1561
1562
with pytest.raises(SchemaError, match="datatypes of join keys don't match"):
1563
left.join(right, on="a", how="left").collect()
1564
1565
for optimizations in [pl.QueryOptFlags(), pl.QueryOptFlags.none()]:
1566
with pytest.raises(
1567
SchemaError, match="'join_where' cannot compare Float64 with Int128"
1568
):
1569
left.join_where(right, pl.col("a") == pl.col("a_right")).collect(
1570
optimizations=optimizations,
1571
)
1572
1573
with pytest.raises(
1574
SchemaError, match="'join_where' cannot compare Float64 with Int128"
1575
):
1576
left.join_where(
1577
right, pl.col("a") == (pl.col("a") == pl.col("a_right"))
1578
).collect(optimizations=optimizations)
1579
1580
1581
def test_join_numeric_key_upcast_order() -> None:
1582
# E.g. when we are joining on this expression:
1583
# * col('a') + 127
1584
#
1585
# and we want to upcast, ensure that we upcast like this:
1586
# * ( col('a') + 127 ) .cast(<type>)
1587
#
1588
# and *not* like this:
1589
# * ( col('a').cast(<type>) + lit(127).cast(<type>) )
1590
#
1591
# as otherwise the results would be different.
1592
1593
left = pl.select(pl.Series("a", [1], dtype=pl.Int8)).lazy()
1594
right = pl.select(
1595
pl.Series("a", [1, 128, -128], dtype=pl.Int64), b=pl.lit("A")
1596
).lazy()
1597
1598
# col('a') in `left` is Int8, the result will overflow to become -128
1599
left_expr = pl.col("a") + 127
1600
1601
assert_frame_equal(
1602
left.join(right, left_on=left_expr, right_on="a", how="inner").collect(),
1603
pl.DataFrame(
1604
{
1605
"a": pl.Series([1], dtype=pl.Int8),
1606
"a_right": pl.Series([-128], dtype=pl.Int64),
1607
"b": "A",
1608
}
1609
),
1610
)
1611
1612
assert_frame_equal(
1613
left.join_where(right, left_expr == pl.col("a_right")).collect(),
1614
pl.DataFrame(
1615
{
1616
"a": pl.Series([1], dtype=pl.Int8),
1617
"a_right": pl.Series([-128], dtype=pl.Int64),
1618
"b": "A",
1619
}
1620
),
1621
)
1622
1623
assert_frame_equal(
1624
(
1625
left.join(right, left_on=left_expr, right_on="a", how="full")
1626
.collect()
1627
.sort(pl.all())
1628
),
1629
pl.DataFrame(
1630
{
1631
"a": pl.Series([1, None, None], dtype=pl.Int8),
1632
"a_right": pl.Series([-128, 1, 128], dtype=pl.Int64),
1633
"b": ["A", "A", "A"],
1634
}
1635
).sort(pl.all()),
1636
)
1637
1638
1639
def test_no_collapse_join_when_maintain_order_20725() -> None:
1640
df1 = pl.LazyFrame({"Fraction_1": [0, 25, 50, 75, 100]})
1641
df2 = pl.LazyFrame({"Fraction_2": [0, 1]})
1642
df3 = pl.LazyFrame({"Fraction_3": [0, 1]})
1643
1644
ldf = df1.join(df2, how="cross", maintain_order="left_right").join(
1645
df3, how="cross", maintain_order="left_right"
1646
)
1647
1648
df_pl_lazy = ldf.filter(pl.col("Fraction_1") == 100).collect()
1649
df_pl_eager = ldf.collect().filter(pl.col("Fraction_1") == 100)
1650
1651
assert_frame_equal(df_pl_lazy, df_pl_eager)
1652
1653
1654
def test_join_where_predicate_type_coercion_21009() -> None:
1655
left_frame = pl.LazyFrame(
1656
{
1657
"left_match": ["A", "B", "C", "D", "E", "F"],
1658
"left_date_start": range(6),
1659
}
1660
)
1661
1662
right_frame = pl.LazyFrame(
1663
{
1664
"right_match": ["D", "E", "F", "G", "H", "I"],
1665
"right_date": range(6),
1666
}
1667
)
1668
1669
# Note: Cannot eq the plans as the operand sides are non-deterministic
1670
1671
q1 = left_frame.join_where(
1672
right_frame,
1673
pl.col("left_match") == pl.col("right_match"),
1674
pl.col("right_date") >= pl.col("left_date_start"),
1675
)
1676
1677
plan = q1.explain().splitlines()
1678
assert plan[0].strip().startswith("FILTER")
1679
assert plan[1] == "FROM"
1680
assert plan[2].strip().startswith("INNER JOIN")
1681
1682
q2 = left_frame.join_where(
1683
right_frame,
1684
pl.all_horizontal(pl.col("left_match") == pl.col("right_match")),
1685
pl.col("right_date") >= pl.col("left_date_start"),
1686
)
1687
1688
plan = q2.explain().splitlines()
1689
assert plan[0].strip().startswith("FILTER")
1690
assert plan[1] == "FROM"
1691
assert plan[2].strip().startswith("INNER JOIN")
1692
1693
assert_frame_equal(q1.collect(), q2.collect())
1694
1695
1696
def test_join_right_predicate_pushdown_21142() -> None:
1697
left = pl.LazyFrame({"key": [1, 2, 4], "values": ["a", "b", "c"]})
1698
right = pl.LazyFrame({"key": [1, 2, 3], "values": ["d", "e", "f"]})
1699
1700
rjoin = left.join(right, on="key", how="right")
1701
1702
q = rjoin.filter(pl.col("values").is_null())
1703
1704
expect = pl.select(
1705
pl.Series("values", [None], pl.String),
1706
pl.Series("key", [3], pl.Int64),
1707
pl.Series("values_right", ["f"], pl.String),
1708
)
1709
1710
assert_frame_equal(q.collect(), expect)
1711
1712
# Ensure for right join, filter on RHS key-columns are pushed down.
1713
q = rjoin.filter(pl.col("values_right").is_null())
1714
1715
plan = q.explain()
1716
assert plan.index("FILTER") > plan.index("RIGHT PLAN ON")
1717
1718
assert_frame_equal(q.collect(), expect.clear())
1719
1720
1721
def test_join_where_nested_expr_21066() -> None:
1722
left = pl.LazyFrame({"a": [1, 2]})
1723
right = pl.LazyFrame({"a": [1]})
1724
1725
q = left.join_where(right, pl.col("a") == (pl.col("a_right") + 1))
1726
1727
assert_frame_equal(q.collect(), pl.DataFrame({"a": 2, "a_right": 1}))
1728
1729
1730
def test_select_after_join_where_20831() -> None:
1731
left = pl.LazyFrame(
1732
{
1733
"a": [1, 2, 3, 1, None],
1734
"b": [1, 2, 3, 4, 5],
1735
"c": [2, 3, 4, 5, 6],
1736
}
1737
)
1738
1739
right = pl.LazyFrame(
1740
{
1741
"a": [1, 4, 3, 7, None, None, 1],
1742
"c": [2, 3, 4, 5, 6, 7, 8],
1743
"d": [6, None, 7, 8, -1, 2, 4],
1744
}
1745
)
1746
1747
q = left.join_where(
1748
right, pl.col("b") * 2 <= pl.col("a_right"), pl.col("a") < pl.col("c_right")
1749
)
1750
1751
assert_frame_equal(
1752
q.select("d").collect().sort("d"),
1753
pl.Series("d", [None, None, 7, 8, 8, 8]).to_frame(),
1754
)
1755
1756
assert q.select(pl.len()).collect().item() == 6
1757
1758
q = (
1759
left.join(right, how="cross")
1760
.filter(pl.col("b") * 2 <= pl.col("a_right"))
1761
.filter(pl.col("a") < pl.col("c_right"))
1762
)
1763
1764
assert_frame_equal(
1765
q.select("d").collect().sort("d"),
1766
pl.Series("d", [None, None, 7, 8, 8, 8]).to_frame(),
1767
)
1768
1769
assert q.select(pl.len()).collect().item() == 6
1770
1771
1772
@pytest.mark.parametrize(
1773
("dtype", "data"),
1774
[
1775
(pl.Struct, [{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4}]),
1776
(pl.List, [[1], [2, 2], [3, 3, 3], [4, 4, 4, 4]]),
1777
(pl.Array(pl.Int64, 2), [[1, 1], [2, 2], [3, 3], [4, 4]]),
1778
],
1779
)
1780
def test_join_on_nested(dtype: PolarsDataType, data: list[Any]) -> None:
1781
lhs = pl.DataFrame(
1782
{
1783
"a": data[:3],
1784
"b": [1, 2, 3],
1785
}
1786
)
1787
rhs = pl.DataFrame(
1788
{
1789
"a": [data[3], data[1]],
1790
"c": [4, 2],
1791
}
1792
)
1793
1794
assert_frame_equal(
1795
lhs.join(rhs, on="a", how="left", maintain_order="left"),
1796
pl.select(
1797
a=pl.Series(data[:3]),
1798
b=pl.Series([1, 2, 3]),
1799
c=pl.Series([None, 2, None]),
1800
),
1801
)
1802
assert_frame_equal(
1803
lhs.join(rhs, on="a", how="right", maintain_order="right"),
1804
pl.select(
1805
b=pl.Series([None, 2]),
1806
a=pl.Series([data[3], data[1]]),
1807
c=pl.Series([4, 2]),
1808
),
1809
)
1810
assert_frame_equal(
1811
lhs.join(rhs, on="a", how="inner"),
1812
pl.select(
1813
a=pl.Series([data[1]]),
1814
b=pl.Series([2]),
1815
c=pl.Series([2]),
1816
),
1817
)
1818
assert_frame_equal(
1819
lhs.join(rhs, on="a", how="full", maintain_order="left_right"),
1820
pl.select(
1821
a=pl.Series(data[:3] + [None]),
1822
b=pl.Series([1, 2, 3, None]),
1823
a_right=pl.Series([None, data[1], None, data[3]]),
1824
c=pl.Series([None, 2, None, 4]),
1825
),
1826
)
1827
assert_frame_equal(
1828
lhs.join(rhs, on="a", how="semi"),
1829
pl.select(
1830
a=pl.Series([data[1]]),
1831
b=pl.Series([2]),
1832
),
1833
)
1834
assert_frame_equal(
1835
lhs.join(rhs, on="a", how="anti", maintain_order="left"),
1836
pl.select(
1837
a=pl.Series([data[0], data[2]]),
1838
b=pl.Series([1, 3]),
1839
),
1840
)
1841
assert_frame_equal(
1842
lhs.join(rhs, how="cross", maintain_order="left_right"),
1843
pl.select(
1844
a=pl.Series([data[0], data[0], data[1], data[1], data[2], data[2]]),
1845
b=pl.Series([1, 1, 2, 2, 3, 3]),
1846
a_right=pl.Series([data[3], data[1], data[3], data[1], data[3], data[1]]),
1847
c=pl.Series([4, 2, 4, 2, 4, 2]),
1848
),
1849
)
1850
1851
1852
def test_empty_join_result_with_array_15474() -> None:
1853
lhs = pl.DataFrame(
1854
{
1855
"x": [1, 2],
1856
"y": pl.Series([[1, 2, 3], [4, 5, 6]], dtype=pl.Array(pl.Int64, 3)),
1857
}
1858
)
1859
rhs = pl.DataFrame({"x": [0]})
1860
result = lhs.join(rhs, on="x")
1861
expected = pl.DataFrame(schema={"x": pl.Int64, "y": pl.Array(pl.Int64, 3)})
1862
assert_frame_equal(result, expected)
1863
1864
1865
@pytest.mark.slow
1866
def test_join_where_eager_perf_21145() -> None:
1867
left = pl.Series("left", range(3_000)).to_frame()
1868
right = pl.Series("right", range(1_000)).to_frame()
1869
1870
p = pl.col("left").is_between(pl.lit(0, dtype=pl.Int64), pl.col("right"))
1871
runtime_eager = time_func(lambda: left.join_where(right, p))
1872
runtime_lazy = time_func(lambda: left.lazy().join_where(right.lazy(), p).collect())
1873
runtime_ratio = runtime_eager / runtime_lazy
1874
1875
# Pick as high as reasonably possible for CI stability
1876
# * Was observed to be >=5 seconds on the bugged version, so 3 is a safe bet.
1877
threshold = 3
1878
1879
if runtime_ratio > threshold:
1880
msg = f"runtime_ratio ({runtime_ratio}) > {threshold}x ({runtime_eager = }, {runtime_lazy = })"
1881
raise ValueError(msg)
1882
1883
1884
def test_select_len_after_semi_anti_join_21343() -> None:
1885
lhs = pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
1886
rhs = pl.LazyFrame({"a": [1, 2, 3]})
1887
1888
q = lhs.join(rhs, on="a", how="anti").select(pl.len())
1889
1890
assert q.collect().item() == 0
1891
1892
1893
def test_multi_leftjoin_empty_right_21701() -> None:
1894
parent_data = {
1895
"id": [1, 30, 80],
1896
"parent_field1": [3, 20, 17],
1897
}
1898
parent_df = pl.LazyFrame(parent_data)
1899
child_df = pl.LazyFrame(
1900
[],
1901
schema={"id": pl.Int32(), "parent_id": pl.Int32(), "child_field1": pl.Int32()},
1902
)
1903
subchild_df = pl.LazyFrame(
1904
[], schema={"child_id": pl.Int32(), "subchild_field1": pl.Int32()}
1905
)
1906
1907
joined_df = parent_df.join(
1908
child_df.join(
1909
subchild_df, left_on=pl.col("id"), right_on=pl.col("child_id"), how="left"
1910
),
1911
left_on=pl.col("id"),
1912
right_on=pl.col("parent_id"),
1913
how="left",
1914
)
1915
joined_df = joined_df.select("id", "parent_field1")
1916
assert_frame_equal(joined_df.collect(), parent_df.collect(), check_row_order=False)
1917
1918
1919
@pytest.mark.parametrize("order", ["none", "left_right", "right_left"])
1920
def test_join_null_equal(order: Literal["none", "left_right", "right_left"]) -> None:
1921
lhs = pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3]})
1922
with_null = pl.DataFrame({"x": [1, None], "z": [1, 2]})
1923
without_null = pl.DataFrame({"x": [1, 3], "z": [1, 3]})
1924
check_row_order = order != "none"
1925
1926
# Inner join.
1927
assert_frame_equal(
1928
lhs.join(with_null, on="x", nulls_equal=True, maintain_order=order),
1929
pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3], "z": [1, 2, 2]}),
1930
check_row_order=check_row_order,
1931
)
1932
assert_frame_equal(
1933
lhs.join(without_null, on="x", nulls_equal=True),
1934
pl.DataFrame({"x": [1], "y": [1], "z": [1]}),
1935
)
1936
1937
# Left join.
1938
assert_frame_equal(
1939
lhs.join(with_null, on="x", how="left", nulls_equal=True, maintain_order=order),
1940
pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3], "z": [1, 2, 2]}),
1941
check_row_order=check_row_order,
1942
)
1943
assert_frame_equal(
1944
lhs.join(
1945
without_null, on="x", how="left", nulls_equal=True, maintain_order=order
1946
),
1947
pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3], "z": [1, None, None]}),
1948
check_row_order=check_row_order,
1949
)
1950
1951
# Full join.
1952
assert_frame_equal(
1953
lhs.join(
1954
with_null,
1955
on="x",
1956
how="full",
1957
nulls_equal=True,
1958
coalesce=True,
1959
maintain_order=order,
1960
),
1961
pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3], "z": [1, 2, 2]}),
1962
check_row_order=check_row_order,
1963
)
1964
if order == "left_right":
1965
expected = pl.DataFrame(
1966
{
1967
"x": [1, None, None, None],
1968
"x_right": [1, None, None, 3],
1969
"y": [1, 2, 3, None],
1970
"z": [1, None, None, 3],
1971
}
1972
)
1973
else:
1974
expected = pl.DataFrame(
1975
{
1976
"x": [1, None, None, None],
1977
"x_right": [1, 3, None, None],
1978
"y": [1, None, 2, 3],
1979
"z": [1, 3, None, None],
1980
}
1981
)
1982
assert_frame_equal(
1983
lhs.join(
1984
without_null, on="x", how="full", nulls_equal=True, maintain_order=order
1985
),
1986
expected,
1987
check_row_order=check_row_order,
1988
check_column_order=False,
1989
)
1990
1991
1992
def test_join_categorical_21815() -> None:
1993
left = pl.DataFrame({"x": ["a", "b", "c", "d"]}).with_columns(
1994
xc=pl.col.x.cast(pl.Categorical)
1995
)
1996
right = pl.DataFrame({"x": ["c", "d", "e", "f"]}).with_columns(
1997
xc=pl.col.x.cast(pl.Categorical)
1998
)
1999
2000
# As key.
2001
cat_key = left.join(right, on="xc", how="full")
2002
2003
# As payload.
2004
cat_payload = left.join(right, on="x", how="full")
2005
2006
expected = pl.DataFrame(
2007
{
2008
"x": ["a", "b", "c", "d", None, None],
2009
"x_right": [None, None, "c", "d", "e", "f"],
2010
}
2011
).with_columns(
2012
xc=pl.col.x.cast(pl.Categorical),
2013
xc_right=pl.col.x_right.cast(pl.Categorical),
2014
)
2015
2016
assert_frame_equal(
2017
cat_key, expected, check_row_order=False, check_column_order=False
2018
)
2019
assert_frame_equal(
2020
cat_payload, expected, check_row_order=False, check_column_order=False
2021
)
2022
2023
2024
def test_join_where_nested_boolean() -> None:
2025
df1 = pl.DataFrame({"a": [1, 9, 22], "b": [6, 4, 50]})
2026
df2 = pl.DataFrame({"c": [1]})
2027
2028
predicate = (pl.col("a") < pl.col("b")).cast(pl.Int32) < pl.col("c")
2029
result = df1.join_where(df2, predicate)
2030
expected = pl.DataFrame(
2031
{
2032
"a": [9],
2033
"b": [4],
2034
"c": [1],
2035
}
2036
)
2037
assert_frame_equal(result, expected)
2038
2039
2040
def test_join_where_dtype_upcast() -> None:
2041
df1 = pl.DataFrame(
2042
{
2043
"a": pl.Series([1, 9, 22], dtype=pl.Int8),
2044
"b": [6, 4, 50],
2045
}
2046
)
2047
df2 = pl.DataFrame({"c": [10]})
2048
2049
predicate = (pl.col("a") + (pl.col("b") > 0)) < pl.col("c")
2050
result = df1.join_where(df2, predicate)
2051
expected = pl.DataFrame(
2052
{
2053
"a": pl.Series([1], dtype=pl.Int8),
2054
"b": [6],
2055
"c": [10],
2056
}
2057
)
2058
assert_frame_equal(result, expected)
2059
2060
2061
def test_join_where_valid_dtype_upcast_same_side() -> None:
2062
# Unsafe comparisons are all contained entirely within one table (LHS)
2063
# Safe comparisons across both tables.
2064
df1 = pl.DataFrame(
2065
{
2066
"a": pl.Series([1, 9, 22], dtype=pl.Float32),
2067
"b": [6, 4, 50],
2068
}
2069
)
2070
df2 = pl.DataFrame({"c": [10, 1, 5]})
2071
2072
predicate = ((pl.col("a") < pl.col("b")).cast(pl.Int32) + 3) < pl.col("c")
2073
result = df1.join_where(df2, predicate).sort("a", "b", "c")
2074
expected = pl.DataFrame(
2075
{
2076
"a": pl.Series([1, 1, 9, 9, 22, 22], dtype=pl.Float32),
2077
"b": [6, 6, 4, 4, 50, 50],
2078
"c": [5, 10, 5, 10, 5, 10],
2079
}
2080
)
2081
assert_frame_equal(result, expected)
2082
2083
2084
def test_join_where_invalid_dtype_upcast_different_side() -> None:
2085
# Unsafe comparisons exist across tables.
2086
df1 = pl.DataFrame(
2087
{
2088
"a": pl.Series([1, 9, 22], dtype=pl.Float32),
2089
"b": pl.Series([6, 4, 50], dtype=pl.Float64),
2090
}
2091
)
2092
df2 = pl.DataFrame({"c": [10, 1, 5]})
2093
2094
predicate = ((pl.col("a") >= pl.col("c")) + 3) < 4
2095
with pytest.raises(
2096
SchemaError, match="'join_where' cannot compare Float32 with Int64"
2097
):
2098
df1.join_where(df2, predicate)
2099
2100
# add in a cast to predicate to fix
2101
predicate = ((pl.col("a").cast(pl.UInt8) >= pl.col("c")) + 3) < 4
2102
result = df1.join_where(df2, predicate).sort("a", "b", "c")
2103
expected = pl.DataFrame(
2104
{
2105
"a": pl.Series([1, 1, 9], dtype=pl.Float32),
2106
"b": pl.Series([6, 6, 4], dtype=pl.Float64),
2107
"c": [5, 10, 10],
2108
}
2109
)
2110
assert_frame_equal(result, expected)
2111
2112
2113
@pytest.mark.parametrize("dtype", [pl.Int32, pl.Float32])
2114
def test_join_where_literals(dtype: PolarsDataType) -> None:
2115
df1 = pl.DataFrame({"a": pl.Series([0, 1], dtype=dtype)})
2116
df2 = pl.DataFrame({"b": pl.Series([1, 2], dtype=dtype)})
2117
result = df1.join_where(df2, (pl.col("a") + pl.col("b")) < 2)
2118
expected = pl.DataFrame(
2119
{
2120
"a": pl.Series([0], dtype=dtype),
2121
"b": pl.Series([1], dtype=dtype),
2122
}
2123
)
2124
assert_frame_equal(result, expected)
2125
2126
2127
def test_join_where_categorical_string_compare() -> None:
2128
dt = pl.Enum(["a", "b", "c"])
2129
df1 = pl.DataFrame({"a": pl.Series(["a", "a", "b", "c"], dtype=dt)})
2130
df2 = pl.DataFrame({"b": [1, 6, 4]})
2131
predicate = pl.col("a").is_in(["a", "b"]) & (pl.col("b") < 5)
2132
result = df1.join_where(df2, predicate).sort("a", "b")
2133
expected = pl.DataFrame(
2134
{
2135
"a": pl.Series(["a", "a", "a", "a", "b", "b"], dtype=dt),
2136
"b": [1, 1, 4, 4, 1, 4],
2137
}
2138
)
2139
assert_frame_equal(result, expected)
2140
2141
2142
def test_join_where_nonboolean_predicate() -> None:
2143
df1 = pl.DataFrame({"a": [1, 2, 3]})
2144
df2 = pl.DataFrame({"b": [1, 2, 3]})
2145
with pytest.raises(
2146
ComputeError, match="'join_where' predicates must resolve to boolean"
2147
):
2148
df1.join_where(df2, pl.col("a") * 2)
2149
2150
2151
def test_empty_outer_join_22206() -> None:
2152
df = pl.LazyFrame({"a": [5, 6], "b": [1, 2]})
2153
empty = pl.LazyFrame(schema=df.collect_schema())
2154
assert_frame_equal(
2155
df.join(empty, on=["a", "b"], how="full", coalesce=True),
2156
df,
2157
check_row_order=False,
2158
)
2159
assert_frame_equal(
2160
empty.join(df, on=["a", "b"], how="full", coalesce=True),
2161
df,
2162
check_row_order=False,
2163
)
2164
2165
2166
def test_join_coalesce_22498() -> None:
2167
df_a = pl.DataFrame({"y": [2]})
2168
df_b = pl.DataFrame({"x": [1], "y": [2]})
2169
df_j = df_a.lazy().join(df_b.lazy(), how="full", on="y", coalesce=True)
2170
assert_frame_equal(df_j.collect(), pl.DataFrame({"y": [2], "x": [1]}))
2171
2172
2173
def _extract_plan_joins_and_filters(plan: str) -> list[str]:
2174
return [
2175
x
2176
for x in (x.strip() for x in plan.splitlines())
2177
if x.startswith("LEFT PLAN") # noqa: PIE810
2178
or x.startswith("RIGHT PLAN")
2179
or x.startswith("FILTER")
2180
]
2181
2182
2183
def test_join_filter_pushdown_inner_join() -> None:
2184
lhs = pl.LazyFrame(
2185
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
2186
)
2187
rhs = pl.LazyFrame(
2188
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}
2189
)
2190
2191
# Filter on key output column is pushed to both sides.
2192
q = lhs.join(rhs, on=["a", "b"], how="inner", maintain_order="left_right").filter(
2193
pl.col("b") <= 2
2194
)
2195
2196
expect = pl.DataFrame(
2197
{"a": [1, 2], "b": [1, 2], "c": ["a", "b"], "c_right": ["A", "B"]}
2198
)
2199
2200
plan = q.explain()
2201
2202
assert _extract_plan_joins_and_filters(plan) == [
2203
'LEFT PLAN ON: [col("a"), col("b")]',
2204
'FILTER [(col("b")) <= (2)]',
2205
'RIGHT PLAN ON: [col("a"), col("b")]',
2206
'FILTER [(col("b")) <= (2)]',
2207
]
2208
2209
assert_frame_equal(q.collect(), expect)
2210
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2211
2212
# Side-specific filters are all pushed for inner join.
2213
q = (
2214
lhs.join(rhs, on=["a", "b"], how="inner", maintain_order="left_right")
2215
.filter(pl.col("b") <= 2)
2216
.filter(pl.col("c") == "a", pl.col("c_right") == "A")
2217
)
2218
2219
expect = pl.DataFrame({"a": [1], "b": [1], "c": ["a"], "c_right": ["A"]})
2220
2221
plan = q.explain()
2222
2223
extract = _extract_plan_joins_and_filters(plan)
2224
2225
assert extract[0] == 'LEFT PLAN ON: [col("a"), col("b")]'
2226
assert 'col("c")) == ("a")' in extract[1]
2227
assert 'col("b")) <= (2)' in extract[1]
2228
2229
assert extract[2] == 'RIGHT PLAN ON: [col("a"), col("b")]'
2230
assert 'col("b")) <= (2)' in extract[3]
2231
assert 'col("c")) == ("A")' in extract[3]
2232
2233
assert_frame_equal(q.collect(), expect)
2234
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2235
2236
# Filter applied to the non-coalesced `_right` column of an inner-join is
2237
# also pushed to the left
2238
# input table.
2239
q = lhs.join(
2240
rhs, on=["a", "b"], how="inner", coalesce=False, maintain_order="left_right"
2241
).filter(pl.col("a_right") <= 2)
2242
2243
expect = pl.DataFrame(
2244
{
2245
"a": [1, 2],
2246
"b": [1, 2],
2247
"c": ["a", "b"],
2248
"a_right": [1, 2],
2249
"b_right": [1, 2],
2250
"c_right": ["A", "B"],
2251
}
2252
)
2253
2254
plan = q.explain()
2255
2256
extract = _extract_plan_joins_and_filters(plan)
2257
assert extract == [
2258
'LEFT PLAN ON: [col("a"), col("b")]',
2259
'FILTER [(col("a")) <= (2)]',
2260
'RIGHT PLAN ON: [col("a"), col("b")]',
2261
'FILTER [(col("a")) <= (2)]',
2262
]
2263
2264
assert_frame_equal(q.collect(), expect)
2265
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2266
2267
# Different names in left_on and right_on
2268
q = lhs.join(
2269
rhs, left_on="a", right_on="b", how="inner", maintain_order="left_right"
2270
).filter(pl.col("a") <= 2)
2271
2272
expect = pl.DataFrame(
2273
{
2274
"a": [1, 2],
2275
"b": [1, 2],
2276
"c": ["a", "b"],
2277
"a_right": [1, 2],
2278
"c_right": ["A", "B"],
2279
}
2280
)
2281
2282
plan = q.explain()
2283
2284
extract = _extract_plan_joins_and_filters(plan)
2285
assert extract == [
2286
'LEFT PLAN ON: [col("a")]',
2287
'FILTER [(col("a")) <= (2)]',
2288
'RIGHT PLAN ON: [col("b")]',
2289
'FILTER [(col("b")) <= (2)]',
2290
]
2291
2292
assert_frame_equal(q.collect(), expect)
2293
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2294
2295
# Different names in left_on and right_on, coalesce=False
2296
q = lhs.join(
2297
rhs,
2298
left_on="a",
2299
right_on="b",
2300
how="inner",
2301
coalesce=False,
2302
maintain_order="left_right",
2303
).filter(pl.col("a") <= 2)
2304
2305
expect = pl.DataFrame(
2306
{
2307
"a": [1, 2],
2308
"b": [1, 2],
2309
"c": ["a", "b"],
2310
"a_right": [1, 2],
2311
"b_right": [1, 2],
2312
"c_right": ["A", "B"],
2313
}
2314
)
2315
2316
plan = q.explain()
2317
2318
extract = _extract_plan_joins_and_filters(plan)
2319
assert extract == [
2320
'LEFT PLAN ON: [col("a")]',
2321
'FILTER [(col("a")) <= (2)]',
2322
'RIGHT PLAN ON: [col("b")]',
2323
'FILTER [(col("b")) <= (2)]',
2324
]
2325
2326
assert_frame_equal(q.collect(), expect)
2327
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2328
2329
# left_on=col(A), right_on=lit(1). Filters referencing col(A) can only push
2330
# to the left side.
2331
q = lhs.join(
2332
rhs,
2333
left_on=["a", pl.lit(1)],
2334
right_on=[pl.lit(1), "b"],
2335
how="inner",
2336
coalesce=False,
2337
maintain_order="left_right",
2338
).filter(
2339
pl.col("a") == 1,
2340
pl.col("b") >= 1,
2341
pl.col("a_right") <= 1,
2342
pl.col("b_right") >= 0,
2343
)
2344
2345
expect = pl.DataFrame(
2346
{
2347
"a": [1],
2348
"b": [1],
2349
"c": ["a"],
2350
"a_right": [1],
2351
"b_right": [1],
2352
"c_right": ["A"],
2353
}
2354
)
2355
2356
plan = q.explain()
2357
2358
extract = _extract_plan_joins_and_filters(plan)
2359
2360
assert (
2361
extract[0]
2362
== 'LEFT PLAN ON: [col("a").cast(Int64), col("_POLARS_0").cast(Int64)]'
2363
)
2364
assert '(col("a")) == (1)' in extract[1]
2365
assert '(col("b")) >= (1)' in extract[1]
2366
assert (
2367
extract[2]
2368
== 'RIGHT PLAN ON: [col("_POLARS_1").cast(Int64), col("b").cast(Int64)]'
2369
)
2370
assert '(col("b")) >= (0)' in extract[3]
2371
assert 'col("a")) <= (1)' in extract[3]
2372
2373
assert_frame_equal(q.collect(), expect)
2374
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2375
2376
# Filters don't pass if they refer to columns from both tables
2377
# TODO: In the optimizer we can add additional equalities into the join
2378
# condition itself for some cases.
2379
q = lhs.join(rhs, on=["a"], how="inner", maintain_order="left_right").filter(
2380
pl.col("b") == pl.col("b_right")
2381
)
2382
2383
expect = pl.DataFrame(
2384
{
2385
"a": [1, 2, 3],
2386
"b": [1, 2, 3],
2387
"c": ["a", "b", "c"],
2388
"b_right": [1, 2, 3],
2389
"c_right": ["A", "B", "C"],
2390
}
2391
)
2392
2393
plan = q.explain()
2394
2395
extract = _extract_plan_joins_and_filters(plan)
2396
assert extract == [
2397
'FILTER [(col("b")) == (col("b_right"))]',
2398
'LEFT PLAN ON: [col("a")]',
2399
'RIGHT PLAN ON: [col("a")]',
2400
]
2401
2402
assert_frame_equal(q.collect(), expect)
2403
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2404
2405
# Duplicate filter removal - https://github.com/pola-rs/polars/issues/23243
2406
q = (
2407
pl.LazyFrame({"x": [1, 2, 3]})
2408
.join(pl.LazyFrame({"x": [1, 2, 3]}), on="x", how="inner", coalesce=False)
2409
.filter(
2410
pl.col("x") == 2,
2411
pl.col("x_right") == 2,
2412
)
2413
)
2414
2415
expect = pl.DataFrame(
2416
[
2417
pl.Series("x", [2], dtype=pl.Int64),
2418
pl.Series("x_right", [2], dtype=pl.Int64),
2419
]
2420
)
2421
2422
plan = q.explain()
2423
2424
extract = _extract_plan_joins_and_filters(plan)
2425
2426
assert extract == [
2427
'LEFT PLAN ON: [col("x")]',
2428
'FILTER [(col("x")) == (2)]',
2429
'RIGHT PLAN ON: [col("x")]',
2430
'FILTER [(col("x")) == (2)]',
2431
]
2432
2433
assert_frame_equal(q.collect(), expect)
2434
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2435
2436
2437
def test_join_filter_pushdown_left_join() -> None:
2438
lhs = pl.LazyFrame(
2439
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
2440
)
2441
rhs = pl.LazyFrame(
2442
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}
2443
)
2444
2445
# Filter on key output column is pushed to both sides.
2446
q = lhs.join(rhs, on=["a", "b"], how="left", maintain_order="left_right").filter(
2447
pl.col("b") <= 2
2448
)
2449
2450
expect = pl.DataFrame(
2451
{"a": [1, 2], "b": [1, 2], "c": ["a", "b"], "c_right": ["A", "B"]}
2452
)
2453
2454
plan = q.explain()
2455
2456
extract = _extract_plan_joins_and_filters(plan)
2457
assert extract == [
2458
'LEFT PLAN ON: [col("a"), col("b")]',
2459
'FILTER [(col("b")) <= (2)]',
2460
'RIGHT PLAN ON: [col("a"), col("b")]',
2461
'FILTER [(col("b")) <= (2)]',
2462
]
2463
2464
assert_frame_equal(q.collect(), expect)
2465
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2466
2467
# Filter on key output column is pushed to both sides.
2468
# This tests joins on differing left/right names.
2469
q = lhs.join(
2470
rhs, left_on="a", right_on="b", how="left", maintain_order="left_right"
2471
).filter(pl.col("a") <= 2)
2472
2473
expect = pl.DataFrame(
2474
{
2475
"a": [1, 2],
2476
"b": [1, 2],
2477
"c": ["a", "b"],
2478
"a_right": [1, 2],
2479
"c_right": ["A", "B"],
2480
}
2481
)
2482
2483
plan = q.explain()
2484
2485
extract = _extract_plan_joins_and_filters(plan)
2486
assert extract == [
2487
'LEFT PLAN ON: [col("a")]',
2488
'FILTER [(col("a")) <= (2)]',
2489
'RIGHT PLAN ON: [col("b")]',
2490
'FILTER [(col("b")) <= (2)]',
2491
]
2492
2493
assert_frame_equal(q.collect(), expect)
2494
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2495
2496
# Filters referring to columns that exist only in the left table can be pushed.
2497
q = lhs.join(rhs, on=["a", "b"], how="left", maintain_order="left_right").filter(
2498
pl.col("c") == "b"
2499
)
2500
2501
expect = pl.DataFrame({"a": [2], "b": [2], "c": ["b"], "c_right": ["B"]})
2502
2503
plan = q.explain()
2504
2505
extract = _extract_plan_joins_and_filters(plan)
2506
assert extract == [
2507
'LEFT PLAN ON: [col("a"), col("b")]',
2508
'FILTER [(col("c")) == ("b")]',
2509
'RIGHT PLAN ON: [col("a"), col("b")]',
2510
]
2511
2512
assert_frame_equal(q.collect(), expect)
2513
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2514
2515
# Filters referring to columns that exist only in the right table cannot be
2516
# pushed for left-join
2517
q = lhs.join(rhs, on=["a", "b"], how="left", maintain_order="left_right").filter(
2518
# Note: `eq_missing` to block join downgrade.
2519
pl.col("c_right").eq_missing("B")
2520
)
2521
2522
expect = pl.DataFrame({"a": [2], "b": [2], "c": ["b"], "c_right": ["B"]})
2523
2524
plan = q.explain()
2525
2526
extract = _extract_plan_joins_and_filters(plan)
2527
assert extract == [
2528
'FILTER [(col("c_right")) ==v ("B")]',
2529
'LEFT PLAN ON: [col("a"), col("b")]',
2530
'RIGHT PLAN ON: [col("a"), col("b")]',
2531
]
2532
2533
assert_frame_equal(q.collect(), expect)
2534
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2535
2536
# Filters referring to a non-coalesced key column originating from the right
2537
# table cannot be pushed.
2538
#
2539
# Note, technically it's possible to push these filters if we can guarantee that
2540
# they do not remove NULLs (or otherwise if we also apply the filter on the
2541
# result table). But this is not something we do at the moment.
2542
q = lhs.join(
2543
rhs, on=["a", "b"], how="left", coalesce=False, maintain_order="left_right"
2544
).filter(pl.col("b_right").eq_missing(2))
2545
2546
expect = pl.DataFrame(
2547
{
2548
"a": [2],
2549
"b": [2],
2550
"c": ["b"],
2551
"a_right": [2],
2552
"b_right": [2],
2553
"c_right": ["B"],
2554
}
2555
)
2556
2557
plan = q.explain()
2558
2559
extract = _extract_plan_joins_and_filters(plan)
2560
assert extract == [
2561
'FILTER [(col("b_right")) ==v (2)]',
2562
'LEFT PLAN ON: [col("a"), col("b")]',
2563
'RIGHT PLAN ON: [col("a"), col("b")]',
2564
]
2565
2566
assert_frame_equal(q.collect(), expect)
2567
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2568
2569
2570
def test_join_filter_pushdown_right_join() -> None:
2571
lhs = pl.LazyFrame(
2572
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
2573
)
2574
rhs = pl.LazyFrame(
2575
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}
2576
)
2577
2578
# Filter on key output column is pushed to both sides.
2579
q = lhs.join(rhs, on=["a", "b"], how="right", maintain_order="left_right").filter(
2580
pl.col("b") <= 2
2581
)
2582
2583
expect = pl.DataFrame(
2584
{"c": ["a", "b"], "a": [1, 2], "b": [1, 2], "c_right": ["A", "B"]}
2585
)
2586
2587
plan = q.explain()
2588
2589
extract = _extract_plan_joins_and_filters(plan)
2590
assert extract == [
2591
'LEFT PLAN ON: [col("a"), col("b")]',
2592
'FILTER [(col("b")) <= (2)]',
2593
'RIGHT PLAN ON: [col("a"), col("b")]',
2594
'FILTER [(col("b")) <= (2)]',
2595
]
2596
2597
assert_frame_equal(q.collect(), expect)
2598
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2599
2600
# Filter on key output column is pushed to both sides.
2601
# This tests joins on differing left/right names.
2602
# col(A) is coalesced into col(B) (from right), but col(B) is named as
2603
# col(B_right) in the output because the LHS table also has a col(B).
2604
q = lhs.join(
2605
rhs, left_on="a", right_on="b", how="right", maintain_order="left_right"
2606
).filter(pl.col("b_right") <= 2)
2607
2608
expect = pl.DataFrame(
2609
{
2610
"b": [1, 2],
2611
"c": ["a", "b"],
2612
"a": [1, 2],
2613
"b_right": [1, 2],
2614
"c_right": ["A", "B"],
2615
}
2616
)
2617
2618
plan = q.explain()
2619
2620
extract = _extract_plan_joins_and_filters(plan)
2621
assert extract == [
2622
'LEFT PLAN ON: [col("a")]',
2623
'FILTER [(col("a")) <= (2)]',
2624
'RIGHT PLAN ON: [col("b")]',
2625
'FILTER [(col("b")) <= (2)]',
2626
]
2627
2628
assert_frame_equal(q.collect(), expect)
2629
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2630
2631
# Filters referring to columns that exist only in the right table can be pushed.
2632
q = lhs.join(rhs, on=["a", "b"], how="right", maintain_order="left_right").filter(
2633
pl.col("c_right") == "B"
2634
)
2635
2636
expect = pl.DataFrame({"c": ["b"], "a": [2], "b": [2], "c_right": ["B"]})
2637
2638
plan = q.explain()
2639
2640
extract = _extract_plan_joins_and_filters(plan)
2641
assert extract == [
2642
'LEFT PLAN ON: [col("a"), col("b")]',
2643
'RIGHT PLAN ON: [col("a"), col("b")]',
2644
'FILTER [(col("c")) == ("B")]',
2645
]
2646
2647
assert_frame_equal(q.collect(), expect)
2648
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2649
2650
# Filters referring to columns that exist only in the left table cannot be
2651
# pushed for right-join
2652
q = lhs.join(rhs, on=["a", "b"], how="right", maintain_order="left_right").filter(
2653
# Note: eq_missing to block join downgrade
2654
pl.col("c").eq_missing("b")
2655
)
2656
2657
expect = pl.DataFrame({"c": ["b"], "a": [2], "b": [2], "c_right": ["B"]})
2658
2659
plan = q.explain()
2660
2661
extract = _extract_plan_joins_and_filters(plan)
2662
assert extract == [
2663
'FILTER [(col("c")) ==v ("b")]',
2664
'LEFT PLAN ON: [col("a"), col("b")]',
2665
'RIGHT PLAN ON: [col("a"), col("b")]',
2666
]
2667
2668
assert_frame_equal(q.collect(), expect)
2669
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2670
2671
# Filters referring to a non-coalesced key column originating from the left
2672
# table cannot be pushed for right-join.
2673
q = lhs.join(
2674
rhs, on=["a", "b"], how="right", coalesce=False, maintain_order="left_right"
2675
).filter(pl.col("b").eq_missing(2))
2676
2677
expect = pl.DataFrame(
2678
{
2679
"a": [2],
2680
"b": [2],
2681
"c": ["b"],
2682
"a_right": [2],
2683
"b_right": [2],
2684
"c_right": ["B"],
2685
}
2686
)
2687
2688
plan = q.explain()
2689
2690
extract = _extract_plan_joins_and_filters(plan)
2691
assert extract == [
2692
'FILTER [(col("b")) ==v (2)]',
2693
'LEFT PLAN ON: [col("a"), col("b")]',
2694
'RIGHT PLAN ON: [col("a"), col("b")]',
2695
]
2696
2697
assert_frame_equal(q.collect(), expect)
2698
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2699
2700
2701
def test_join_filter_pushdown_full_join() -> None:
2702
lhs = pl.LazyFrame(
2703
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
2704
)
2705
rhs = pl.LazyFrame(
2706
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}
2707
)
2708
2709
# Full join can only push filters that refer to coalesced key columns.
2710
q = lhs.join(
2711
rhs,
2712
left_on="a",
2713
right_on="b",
2714
how="full",
2715
coalesce=True,
2716
maintain_order="left_right",
2717
).filter(pl.col("a") == 2)
2718
2719
expect = pl.DataFrame(
2720
{
2721
"a": [2],
2722
"b": [2],
2723
"c": ["b"],
2724
"a_right": [2],
2725
"c_right": ["B"],
2726
}
2727
)
2728
2729
plan = q.explain()
2730
extract = _extract_plan_joins_and_filters(plan)
2731
2732
assert extract == [
2733
'LEFT PLAN ON: [col("a")]',
2734
'FILTER [(col("a")) == (2)]',
2735
'RIGHT PLAN ON: [col("b")]',
2736
'FILTER [(col("b")) == (2)]',
2737
]
2738
2739
assert_frame_equal(q.collect(), expect)
2740
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2741
2742
# Non-coalescing full-join cannot push any filters
2743
# Note: We add fill_null to bypass non-NULL filter mask detection.
2744
q = lhs.join(
2745
rhs,
2746
left_on="a",
2747
right_on="b",
2748
how="full",
2749
coalesce=False,
2750
maintain_order="left_right",
2751
).filter(
2752
pl.col("a").fill_null(0) >= 2,
2753
pl.col("a").fill_null(0) <= 2,
2754
)
2755
2756
expect = pl.DataFrame(
2757
{
2758
"a": [2],
2759
"b": [2],
2760
"c": ["b"],
2761
"a_right": [2],
2762
"b_right": [2],
2763
"c_right": ["B"],
2764
}
2765
)
2766
2767
plan = q.explain()
2768
extract = _extract_plan_joins_and_filters(plan)
2769
2770
assert extract[0].startswith("FILTER ")
2771
assert extract[1:] == [
2772
'LEFT PLAN ON: [col("a")]',
2773
'RIGHT PLAN ON: [col("b")]',
2774
]
2775
2776
assert_frame_equal(q.collect(), expect)
2777
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2778
2779
2780
def test_join_filter_pushdown_semi_join() -> None:
2781
lhs = pl.LazyFrame(
2782
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
2783
)
2784
rhs = pl.LazyFrame(
2785
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}
2786
)
2787
2788
q = lhs.join(
2789
rhs,
2790
left_on=["a", "b"],
2791
right_on=["b", pl.lit(2)],
2792
how="semi",
2793
maintain_order="left_right",
2794
).filter(pl.col("a") == 2, pl.col("b") == 2, pl.col("c") == "b")
2795
2796
expect = pl.DataFrame(
2797
{
2798
"a": [2],
2799
"b": [2],
2800
"c": ["b"],
2801
}
2802
)
2803
2804
plan = q.explain()
2805
extract = _extract_plan_joins_and_filters(plan)
2806
2807
# * filter on col(a) is pushed to both sides (renamed to col(b) in the right side)
2808
# * filter on col(b) is pushed only to left, as the right join key is a literal
2809
# * filter on col(c) is pushed only to left, as the column does not exist in
2810
# the right.
2811
2812
assert extract[0] == 'LEFT PLAN ON: [col("a"), col("b").cast(Int64)]'
2813
assert 'col("a")) == (2)' in extract[1]
2814
assert 'col("b")) == (2)' in extract[1]
2815
assert 'col("c")) == ("b")' in extract[1]
2816
2817
assert extract[2:] == [
2818
'RIGHT PLAN ON: [col("b"), col("_POLARS_0").cast(Int64)]',
2819
'FILTER [(col("b")) == (2)]',
2820
]
2821
2822
assert_frame_equal(q.collect(), expect)
2823
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2824
2825
2826
def test_join_filter_pushdown_anti_join() -> None:
2827
lhs = pl.LazyFrame(
2828
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
2829
)
2830
rhs = pl.LazyFrame(
2831
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}
2832
)
2833
2834
q = lhs.join(
2835
rhs,
2836
left_on=["a", "b"],
2837
right_on=["b", pl.lit(1)],
2838
how="anti",
2839
maintain_order="left_right",
2840
).filter(pl.col("a") == 2, pl.col("b") == 2, pl.col("c") == "b")
2841
2842
expect = pl.DataFrame(
2843
{
2844
"a": [2],
2845
"b": [2],
2846
"c": ["b"],
2847
}
2848
)
2849
2850
plan = q.explain()
2851
extract = _extract_plan_joins_and_filters(plan)
2852
2853
assert extract[0] == 'LEFT PLAN ON: [col("a"), col("b").cast(Int64)]'
2854
assert 'col("a")) == (2)' in extract[1]
2855
assert 'col("b")) == (2)' in extract[1]
2856
assert 'col("c")) == ("b")' in extract[1]
2857
2858
assert extract[2:] == [
2859
'RIGHT PLAN ON: [col("b"), col("_POLARS_0").cast(Int64)]',
2860
'FILTER [(col("b")) == (2)]',
2861
]
2862
2863
assert_frame_equal(q.collect(), expect)
2864
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2865
2866
2867
def test_join_filter_pushdown_cross_join() -> None:
2868
lhs = pl.LazyFrame(
2869
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
2870
)
2871
rhs = pl.LazyFrame(
2872
{"a": [0, 0, 0, 0, 0], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}
2873
)
2874
2875
# Nested loop join for `!=`
2876
q = (
2877
lhs.with_row_index()
2878
.join(rhs, how="cross")
2879
.filter(
2880
pl.col("a") <= 4, pl.col("c_right") <= "B", pl.col("a") != pl.col("a_right")
2881
)
2882
.sort("index")
2883
)
2884
2885
expect = pl.DataFrame(
2886
[
2887
pl.Series("index", [0, 0, 1, 1, 2, 2, 3, 3], dtype=pl.get_index_type()),
2888
pl.Series("a", [1, 1, 2, 2, 3, 3, 4, 4], dtype=pl.Int64),
2889
pl.Series("b", [1, 1, 2, 2, 3, 3, 4, 4], dtype=pl.Int64),
2890
pl.Series("c", ["a", "a", "b", "b", "c", "c", "d", "d"], dtype=pl.String),
2891
pl.Series("a_right", [0, 0, 0, 0, 0, 0, 0, 0], dtype=pl.Int64),
2892
pl.Series("b_right", [1, 2, 1, 2, 1, 2, 1, 2], dtype=pl.Int64),
2893
pl.Series(
2894
"c_right", ["A", "B", "A", "B", "A", "B", "A", "B"], dtype=pl.String
2895
),
2896
]
2897
)
2898
2899
plan = q.explain()
2900
2901
assert 'NESTED LOOP JOIN ON [(col("a")) != (col("a_right"))]' in plan
2902
2903
extract = _extract_plan_joins_and_filters(plan)
2904
2905
assert extract == [
2906
"LEFT PLAN:",
2907
'FILTER [(col("a")) <= (4)]',
2908
"RIGHT PLAN:",
2909
'FILTER [(col("c")) <= ("B")]',
2910
]
2911
2912
assert_frame_equal(q.collect(), expect)
2913
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2914
2915
# Conversion to inner-join for `==`
2916
q = lhs.join(rhs, how="cross", maintain_order="left_right").filter(
2917
pl.col("a") <= 4,
2918
pl.col("c_right") <= "B",
2919
pl.col("a") == (pl.col("a_right") + 1),
2920
)
2921
2922
expect = pl.DataFrame(
2923
{
2924
"a": [1, 1],
2925
"b": [1, 1],
2926
"c": ["a", "a"],
2927
"a_right": [0, 0],
2928
"b_right": [1, 2],
2929
"c_right": ["A", "B"],
2930
}
2931
)
2932
2933
plan = q.explain()
2934
2935
extract = _extract_plan_joins_and_filters(plan)
2936
2937
assert extract == [
2938
'LEFT PLAN ON: [col("a")]',
2939
'FILTER [(col("a")) <= (4)]',
2940
'RIGHT PLAN ON: [[(col("a")) + (1)]]',
2941
'FILTER [(col("c")) <= ("B")]',
2942
]
2943
2944
assert_frame_equal(q.collect(), expect)
2945
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2946
2947
# Avoid conversion for order maintaining cross-join
2948
q = (
2949
pl.LazyFrame(
2950
[
2951
pl.Series("a", [2, 4, 8, 9, 11], dtype=pl.Int64),
2952
pl.Series("b", [1, 2, 3, 4, 5], dtype=pl.Int64),
2953
]
2954
)
2955
.join(
2956
pl.LazyFrame(
2957
{
2958
"c": [0, 1, 2, 3, 4],
2959
}
2960
),
2961
how="cross",
2962
maintain_order="left_right",
2963
)
2964
.filter(pl.col("c") <= pl.col("b"))
2965
)
2966
2967
expect = pl.DataFrame(
2968
[
2969
pl.Series(
2970
"a",
2971
[2, 2, 4, 4, 4, 8, 8, 8, 8, 9, 9, 9, 9, 9, 11, 11, 11, 11, 11],
2972
dtype=pl.Int64,
2973
),
2974
pl.Series(
2975
"b",
2976
[1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5],
2977
dtype=pl.Int64,
2978
),
2979
pl.Series(
2980
"c",
2981
[0, 1, 0, 1, 2, 0, 1, 2, 3, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
2982
dtype=pl.Int64,
2983
),
2984
]
2985
)
2986
2987
plan = q.explain()
2988
2989
assert plan.startswith('NESTED LOOP JOIN ON [(col("c")) <= (col("b"))]:')
2990
2991
assert_frame_equal(q.collect(), expect)
2992
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
2993
2994
2995
def test_join_filter_pushdown_iejoin() -> None:
2996
lhs = pl.LazyFrame(
2997
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
2998
)
2999
rhs = pl.LazyFrame(
3000
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}
3001
)
3002
3003
q = (
3004
lhs.with_row_index()
3005
.join_where(
3006
rhs,
3007
pl.col("a") >= 1,
3008
pl.col("a") == pl.col("a_right"),
3009
pl.col("c_right") <= "B",
3010
)
3011
.sort("index")
3012
)
3013
3014
expect = pl.DataFrame(
3015
{
3016
"a": [1, 2],
3017
"b": [1, 2],
3018
"c": ["a", "b"],
3019
"a_right": [1, 2],
3020
"b_right": [1, 2],
3021
"c_right": ["A", "B"],
3022
}
3023
).with_row_index()
3024
3025
plan = q.explain()
3026
3027
assert "INNER JOIN" in plan
3028
3029
extract = _extract_plan_joins_and_filters(plan)
3030
3031
assert extract[:3] == [
3032
'LEFT PLAN ON: [col("a")]',
3033
'FILTER [(col("a")) >= (1)]',
3034
'RIGHT PLAN ON: [col("a")]',
3035
]
3036
3037
assert extract[3].startswith("FILTER")
3038
assert 'col("c")) <= ("B")' in extract[3]
3039
assert '(col("a")) >= (1)' in extract[3]
3040
assert len(extract) == 4
3041
3042
assert_frame_equal(q.collect(), expect)
3043
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3044
3045
q = (
3046
lhs.with_row_index()
3047
.join_where(
3048
rhs,
3049
pl.col("a") >= 1,
3050
pl.col("a") >= pl.col("a_right"),
3051
pl.col("c_right") <= "B",
3052
)
3053
.sort("index")
3054
)
3055
3056
expect = pl.DataFrame(
3057
[
3058
pl.Series("index", [0, 1, 1, 2, 2, 3, 3, 4, 4], dtype=pl.get_index_type()),
3059
pl.Series("a", [1, 2, 2, 3, 3, 4, 4, 5, 5], dtype=pl.Int64),
3060
pl.Series("b", [1, 2, 2, 3, 3, 4, 4, None, None], dtype=pl.Int64),
3061
pl.Series(
3062
"c", ["a", "b", "b", "c", "c", "d", "d", "e", "e"], dtype=pl.String
3063
),
3064
pl.Series("a_right", [1, 1, 2, 1, 2, 1, 2, 1, 2], dtype=pl.Int64),
3065
pl.Series("b_right", [1, 1, 2, 1, 2, 1, 2, 1, 2], dtype=pl.Int64),
3066
pl.Series(
3067
"c_right",
3068
["A", "A", "B", "A", "B", "A", "B", "A", "B"],
3069
dtype=pl.String,
3070
),
3071
]
3072
)
3073
3074
plan = q.explain()
3075
3076
assert "IEJOIN" in plan
3077
3078
extract = _extract_plan_joins_and_filters(plan)
3079
3080
assert extract == [
3081
'LEFT PLAN ON: [col("a")]',
3082
'FILTER [(col("a")) >= (1)]',
3083
'RIGHT PLAN ON: [col("a")]',
3084
'FILTER [(col("c")) <= ("B")]',
3085
]
3086
3087
assert_frame_equal(q.collect().sort(pl.all()), expect)
3088
assert_frame_equal(
3089
q.collect(optimizations=pl.QueryOptFlags.none()).sort(pl.all()),
3090
expect,
3091
)
3092
3093
q = pl.LazyFrame({"x": [1, 2, 3]}).join_where(
3094
pl.LazyFrame({"x": [1, 2, 3]}),
3095
pl.col("x") > pl.col("x_right"),
3096
pl.col("x") > 1,
3097
)
3098
3099
expect = pl.DataFrame(
3100
[
3101
pl.Series("x", [2, 3, 3], dtype=pl.Int64),
3102
pl.Series("x_right", [1, 1, 2], dtype=pl.Int64),
3103
]
3104
)
3105
3106
plan = q.explain()
3107
3108
assert "IEJOIN" in plan
3109
3110
extract = _extract_plan_joins_and_filters(plan)
3111
3112
assert extract == [
3113
'LEFT PLAN ON: [col("x")]',
3114
'FILTER [(col("x")) > (1)]',
3115
'RIGHT PLAN ON: [col("x")]',
3116
]
3117
3118
assert_frame_equal(q.collect().sort(pl.all()), expect)
3119
assert_frame_equal(
3120
q.collect(optimizations=pl.QueryOptFlags.none()).sort(pl.all()),
3121
expect,
3122
)
3123
3124
# Join filter pushdown inside CSE - https://github.com/pola-rs/polars/issues/23489
3125
3126
lf_x = pl.LazyFrame({"a": [1, 2, 3], "b": [1, 2, 3]})
3127
lf_y = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 3, 3]})
3128
3129
lf_xy = lf_x.join_where(lf_y, pl.col("a") > pl.col("a_right"))
3130
3131
q = pl.concat([lf_xy, lf_xy]).filter(
3132
pl.col("b") < pl.col("b_right"), pl.col("a") > 0
3133
)
3134
3135
expect = pl.DataFrame(
3136
[
3137
pl.Series("a", [2, 2], dtype=pl.Int64),
3138
pl.Series("b", [2, 2], dtype=pl.Int64),
3139
pl.Series("a_right", [1, 1], dtype=pl.Int64),
3140
pl.Series("b_right", [3, 3], dtype=pl.Int64),
3141
]
3142
)
3143
3144
plan = q.explain()
3145
3146
assert "IEJOIN" in plan
3147
3148
extract = _extract_plan_joins_and_filters(plan)
3149
3150
assert extract[0] in {
3151
'LEFT PLAN ON: [col("a"), col("b")]',
3152
'LEFT PLAN ON: [col("b"), col("a")]',
3153
}
3154
assert extract[1] == 'FILTER [(col("a")) > (0)]'
3155
assert extract[2] in {
3156
'RIGHT PLAN ON: [col("a"), col("b")]',
3157
'RIGHT PLAN ON: [col("b"), col("a")]',
3158
}
3159
assert extract[3] in {
3160
'LEFT PLAN ON: [col("a"), col("b")]',
3161
'LEFT PLAN ON: [col("b"), col("a")]',
3162
}
3163
assert extract[4] == 'FILTER [(col("a")) > (0)]'
3164
assert extract[5] in {
3165
'RIGHT PLAN ON: [col("a"), col("b")]',
3166
'RIGHT PLAN ON: [col("b"), col("a")]',
3167
}
3168
assert len(extract) == 6
3169
3170
assert_frame_equal(q.collect().sort(pl.all()), expect)
3171
assert_frame_equal(
3172
q.collect(optimizations=pl.QueryOptFlags.none()).sort(pl.all()),
3173
expect,
3174
)
3175
3176
3177
def test_join_filter_pushdown_asof_join() -> None:
3178
lhs = pl.LazyFrame(
3179
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
3180
)
3181
rhs = pl.LazyFrame(
3182
{
3183
"a": [1, 2, 3, 4, 5],
3184
"b": [1, 2, 3, None, None],
3185
"c": ["A", "B", "C", "D", "E"],
3186
}
3187
)
3188
3189
q = lhs.join_asof(
3190
rhs,
3191
left_on=pl.col("a").set_sorted(),
3192
right_on=pl.col("b").set_sorted(),
3193
tolerance=0,
3194
).filter(
3195
pl.col("a") >= 2,
3196
pl.col("b") >= 3,
3197
pl.col("c") >= "A",
3198
pl.col("c_right") >= "B",
3199
)
3200
3201
expect = pl.DataFrame(
3202
{
3203
"a": [3],
3204
"b": [3],
3205
"c": ["c"],
3206
"a_right": [3],
3207
"b_right": [3],
3208
"c_right": ["C"],
3209
}
3210
)
3211
3212
plan = q.explain()
3213
extract = _extract_plan_joins_and_filters(plan)
3214
3215
assert extract[:2] == [
3216
'FILTER [(col("c_right")) >= ("B")]',
3217
'LEFT PLAN ON: [col("a").set_sorted()]',
3218
]
3219
3220
assert 'col("b")) >= (3)' in extract[2]
3221
assert 'col("c")) >= ("A")' in extract[2]
3222
assert 'col("a")) >= (2)' in extract[2]
3223
3224
assert extract[3:] == ['RIGHT PLAN ON: [col("b").set_sorted()]']
3225
3226
assert_frame_equal(q.collect(), expect)
3227
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3228
3229
# With "by" columns
3230
q = lhs.join_asof(
3231
rhs,
3232
left_on="a",
3233
right_on="b",
3234
tolerance=99,
3235
by_left="b",
3236
by_right="a",
3237
).filter(
3238
pl.col("a") >= 2,
3239
pl.col("b") >= 3,
3240
pl.col("c") >= "A",
3241
pl.col("c_right") >= "B",
3242
)
3243
3244
expect = pl.DataFrame(
3245
{
3246
"a": [3],
3247
"b": [3],
3248
"c": ["c"],
3249
"b_right": [3],
3250
"c_right": ["C"],
3251
}
3252
)
3253
3254
plan = q.explain()
3255
extract = _extract_plan_joins_and_filters(plan)
3256
3257
assert extract[:2] == [
3258
'FILTER [(col("c_right")) >= ("B")]',
3259
'LEFT PLAN ON: [col("a")]',
3260
]
3261
assert 'col("a")) >= (2)' in extract[2]
3262
assert 'col("b")) >= (3)' in extract[2]
3263
3264
assert extract[3:] == [
3265
'RIGHT PLAN ON: [col("b")]',
3266
'FILTER [(col("a")) >= (3)]',
3267
]
3268
3269
assert_frame_equal(q.collect(), expect)
3270
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3271
3272
3273
def test_join_filter_pushdown_full_join_rewrite() -> None:
3274
lhs = pl.LazyFrame(
3275
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
3276
)
3277
rhs = pl.LazyFrame(
3278
{
3279
"a": [1, 2, 3, 4, None],
3280
"b": [1, 2, 3, None, 5],
3281
"c": ["A", "B", "C", "D", "E"],
3282
}
3283
)
3284
3285
# Downgrades to left-join
3286
q = lhs.join(rhs, on=["a", "b"], how="full", maintain_order="left_right").filter(
3287
pl.col("b") >= 3
3288
)
3289
3290
expect = pl.DataFrame(
3291
[
3292
pl.Series("a", [3, 4], dtype=pl.Int64),
3293
pl.Series("b", [3, 4], dtype=pl.Int64),
3294
pl.Series("c", ["c", "d"], dtype=pl.String),
3295
pl.Series("a_right", [3, None], dtype=pl.Int64),
3296
pl.Series("b_right", [3, None], dtype=pl.Int64),
3297
pl.Series("c_right", ["C", None], dtype=pl.String),
3298
]
3299
)
3300
3301
plan = q.explain()
3302
3303
assert "FULL JOIN" not in plan
3304
assert plan.startswith("LEFT JOIN")
3305
3306
extract = _extract_plan_joins_and_filters(plan)
3307
3308
assert extract == [
3309
'LEFT PLAN ON: [col("a"), col("b")]',
3310
'FILTER [(col("b")) >= (3)]',
3311
'RIGHT PLAN ON: [col("a"), col("b")]',
3312
'FILTER [(col("b")) >= (3)]',
3313
]
3314
3315
assert_frame_equal(q.collect(), expect)
3316
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3317
3318
# Downgrades to right-join
3319
q = lhs.join(
3320
rhs, left_on="a", right_on="b", how="full", maintain_order="left_right"
3321
).filter(pl.col("b_right") >= 3)
3322
3323
expect = pl.DataFrame(
3324
[
3325
pl.Series("a", [3, 5], dtype=pl.Int64),
3326
pl.Series("b", [3, None], dtype=pl.Int64),
3327
pl.Series("c", ["c", "e"], dtype=pl.String),
3328
pl.Series("a_right", [3, None], dtype=pl.Int64),
3329
pl.Series("b_right", [3, 5], dtype=pl.Int64),
3330
pl.Series("c_right", ["C", "E"], dtype=pl.String),
3331
]
3332
)
3333
3334
plan = q.explain()
3335
3336
assert "FULL JOIN" not in plan
3337
assert "RIGHT JOIN" in plan
3338
3339
extract = _extract_plan_joins_and_filters(plan)
3340
3341
assert extract == [
3342
'LEFT PLAN ON: [col("a")]',
3343
'FILTER [(col("a")) >= (3)]',
3344
'RIGHT PLAN ON: [col("b")]',
3345
'FILTER [(col("b")) >= (3)]',
3346
]
3347
3348
assert_frame_equal(q.collect(), expect)
3349
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3350
3351
# Downgrades to right-join
3352
q = lhs.join(
3353
rhs,
3354
left_on="a",
3355
right_on="b",
3356
how="full",
3357
coalesce=True,
3358
maintain_order="left_right",
3359
).filter(
3360
(pl.col("a") >= 1) | pl.col("a").is_null(), # col(a) from LHS
3361
pl.col("a_right") >= 3, # col(a) from RHS
3362
(pl.col("b") >= 2) | pl.col("b").is_null(), # col(b) from LHS
3363
pl.col("c_right") >= "C", # col(c) from RHS
3364
)
3365
3366
expect = pl.DataFrame(
3367
[
3368
pl.Series("a", [3, None], dtype=pl.Int64),
3369
pl.Series("b", [3, None], dtype=pl.Int64),
3370
pl.Series("c", ["c", None], dtype=pl.String),
3371
pl.Series("a_right", [3, 4], dtype=pl.Int64),
3372
pl.Series("c_right", ["C", "D"], dtype=pl.String),
3373
]
3374
)
3375
3376
plan = q.explain()
3377
3378
assert "FULL JOIN" not in plan
3379
assert "RIGHT JOIN" in plan
3380
3381
extract = _extract_plan_joins_and_filters(plan)
3382
3383
assert [
3384
'FILTER [([(col("b")) >= (2)]) | (col("b").is_null())]',
3385
'LEFT PLAN ON: [col("a")]',
3386
'FILTER [([(col("a")) >= (1)]) | (col("a").is_null())]',
3387
'RIGHT PLAN ON: [col("b")]',
3388
]
3389
3390
assert 'col("a")) >= (3)' in extract[4]
3391
assert '(col("b")) >= (1)]) | (col("b").alias("a").is_null())' in extract[4]
3392
assert 'col("c")) >= ("C")' in extract[4]
3393
3394
assert len(extract) == 5
3395
3396
assert_frame_equal(q.collect(), expect)
3397
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3398
3399
# Downgrades to inner-join
3400
q = lhs.join(rhs, on=["a", "b"], how="full", maintain_order="left_right").filter(
3401
pl.col("b").is_not_null(), pl.col("b_right").is_not_null()
3402
)
3403
3404
expect = pl.DataFrame(
3405
[
3406
pl.Series("a", [1, 2, 3], dtype=pl.Int64),
3407
pl.Series("b", [1, 2, 3], dtype=pl.Int64),
3408
pl.Series("c", ["a", "b", "c"], dtype=pl.String),
3409
pl.Series("a_right", [1, 2, 3], dtype=pl.Int64),
3410
pl.Series("b_right", [1, 2, 3], dtype=pl.Int64),
3411
pl.Series("c_right", ["A", "B", "C"], dtype=pl.String),
3412
]
3413
)
3414
3415
plan = q.explain()
3416
3417
assert "FULL JOIN" not in plan
3418
assert plan.startswith("INNER JOIN")
3419
3420
extract = _extract_plan_joins_and_filters(plan)
3421
3422
assert extract[0] == 'LEFT PLAN ON: [col("a"), col("b")]'
3423
assert 'col("b").is_not_null()' in extract[1]
3424
assert 'col("b").alias("b_right").is_not_null()' in extract[1]
3425
3426
assert extract[2] == 'RIGHT PLAN ON: [col("a"), col("b")]'
3427
assert 'col("b").is_not_null()' in extract[3]
3428
assert 'col("b").alias("b_right").is_not_null()' in extract[3]
3429
3430
assert len(extract) == 4
3431
3432
assert_frame_equal(q.collect(), expect)
3433
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3434
3435
# Does not downgrade because col(b) is a coalesced key-column, but the filter
3436
# is still pushed to both sides.
3437
q = lhs.join(
3438
rhs, on=["a", "b"], how="full", coalesce=True, maintain_order="left_right"
3439
).filter(pl.col("b") >= 3)
3440
3441
expect = pl.DataFrame(
3442
[
3443
pl.Series("a", [3, 4, None], dtype=pl.Int64),
3444
pl.Series("b", [3, 4, 5], dtype=pl.Int64),
3445
pl.Series("c", ["c", "d", None], dtype=pl.String),
3446
pl.Series("c_right", ["C", None, "E"], dtype=pl.String),
3447
]
3448
)
3449
3450
plan = q.explain()
3451
assert plan.startswith("FULL JOIN")
3452
3453
extract = _extract_plan_joins_and_filters(plan)
3454
3455
assert extract == [
3456
'LEFT PLAN ON: [col("a"), col("b")]',
3457
'FILTER [(col("b")) >= (3)]',
3458
'RIGHT PLAN ON: [col("a"), col("b")]',
3459
'FILTER [(col("b")) >= (3)]',
3460
]
3461
3462
assert_frame_equal(q.collect(), expect)
3463
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3464
3465
3466
def test_join_filter_pushdown_right_join_rewrite() -> None:
3467
lhs = pl.LazyFrame(
3468
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
3469
)
3470
rhs = pl.LazyFrame(
3471
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}
3472
)
3473
3474
# Downgrades to inner-join
3475
q = lhs.join(
3476
rhs,
3477
left_on="a",
3478
right_on="b",
3479
how="right",
3480
coalesce=True,
3481
maintain_order="left_right",
3482
).filter(
3483
pl.col("a") <= 7, # col(a) from RHS (LHS col(a) is coalesced into col(b_right))
3484
pl.col("b_right") <= 10, # Key-column filter
3485
pl.col("c") <= "b", # col(c) from LHS
3486
)
3487
3488
expect = pl.DataFrame(
3489
[
3490
pl.Series("b", [1, 2], dtype=pl.Int64),
3491
pl.Series("c", ["a", "b"], dtype=pl.String),
3492
pl.Series("a", [1, 2], dtype=pl.Int64),
3493
pl.Series("b_right", [1, 2], dtype=pl.Int64),
3494
pl.Series("c_right", ["A", "B"], dtype=pl.String),
3495
]
3496
)
3497
3498
plan = q.explain()
3499
3500
assert "RIGHT JOIN" not in plan
3501
assert "INNER JOIN" in plan
3502
3503
extract = _extract_plan_joins_and_filters(plan)
3504
3505
assert extract[0] == 'LEFT PLAN ON: [col("a")]'
3506
assert 'col("a")) <= (10)' in extract[1]
3507
assert 'col("c")) <= ("b")' in extract[1]
3508
3509
assert extract[2] == 'RIGHT PLAN ON: [col("b")]'
3510
assert 'col("a")) <= (7)' in extract[3]
3511
assert 'col("b")) <= (10)' in extract[3]
3512
3513
assert len(extract) == 4
3514
3515
assert_frame_equal(q.collect(), expect)
3516
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3517
3518
3519
def test_join_filter_pushdown_join_rewrite_equality_above_and() -> None:
3520
lhs = pl.LazyFrame(
3521
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
3522
)
3523
rhs = pl.LazyFrame(
3524
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", "C", "D", "E"]}
3525
)
3526
3527
q = lhs.join(
3528
rhs,
3529
left_on="a",
3530
right_on="b",
3531
how="full",
3532
coalesce=False,
3533
maintain_order="left_right",
3534
).filter(((pl.col("b") >= 3) & False) >= False)
3535
3536
expect = pl.DataFrame(
3537
[
3538
pl.Series("a", [1, 2, 3, 4, 5, None], dtype=pl.Int64),
3539
pl.Series("b", [1, 2, 3, 4, None, None], dtype=pl.Int64),
3540
pl.Series("c", ["a", "b", "c", "d", "e", None], dtype=pl.String),
3541
pl.Series("a_right", [1, 2, 3, None, 5, 4], dtype=pl.Int64),
3542
pl.Series("b_right", [1, 2, 3, None, 5, None], dtype=pl.Int64),
3543
pl.Series("c_right", ["A", "B", "C", None, "E", "D"], dtype=pl.String),
3544
]
3545
)
3546
3547
assert_frame_equal(q.collect(), expect)
3548
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3549
3550
3551
def test_join_filter_pushdown_left_join_rewrite() -> None:
3552
lhs = pl.LazyFrame(
3553
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, None], "c": ["a", "b", "c", "d", "e"]}
3554
)
3555
rhs = pl.LazyFrame(
3556
{"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, None, 5], "c": ["A", "B", None, "D", "E"]}
3557
)
3558
3559
# Downgrades to inner-join
3560
q = lhs.join(
3561
rhs,
3562
left_on="a",
3563
right_on="b",
3564
how="left",
3565
coalesce=True,
3566
maintain_order="left_right",
3567
).filter(pl.col("c_right") <= "B")
3568
3569
expect = pl.DataFrame(
3570
[
3571
pl.Series("a", [1, 2], dtype=pl.Int64),
3572
pl.Series("b", [1, 2], dtype=pl.Int64),
3573
pl.Series("c", ["a", "b"], dtype=pl.String),
3574
pl.Series("a_right", [1, 2], dtype=pl.Int64),
3575
pl.Series("c_right", ["A", "B"], dtype=pl.String),
3576
]
3577
)
3578
3579
plan = q.explain()
3580
3581
assert "LEFT JOIN" not in plan
3582
assert plan.startswith("INNER JOIN")
3583
3584
extract = _extract_plan_joins_and_filters(plan)
3585
3586
assert extract == [
3587
'LEFT PLAN ON: [col("a")]',
3588
'RIGHT PLAN ON: [col("b")]',
3589
'FILTER [(col("c")) <= ("B")]',
3590
]
3591
3592
assert_frame_equal(q.collect(), expect)
3593
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3594
3595
3596
def test_join_filter_pushdown_left_join_rewrite_23133() -> None:
3597
lhs = pl.LazyFrame(
3598
{
3599
"foo": [1, 2, 3],
3600
"bar": [6.0, 7.0, 8.0],
3601
"ham": ["a", "b", "c"],
3602
}
3603
)
3604
3605
rhs = pl.LazyFrame(
3606
{
3607
"apple": ["x", "y", "z"],
3608
"ham": ["a", "b", "d"],
3609
"bar": ["a", "b", "c"],
3610
"foo2": [1, 2, 3],
3611
}
3612
)
3613
3614
q = lhs.join(rhs, how="left", on="ham", maintain_order="left_right").filter(
3615
pl.col("ham") == "a", pl.col("apple") == "x", pl.col("foo") <= 2
3616
)
3617
3618
expect = pl.DataFrame(
3619
[
3620
pl.Series("foo", [1], dtype=pl.Int64),
3621
pl.Series("bar", [6.0], dtype=pl.Float64),
3622
pl.Series("ham", ["a"], dtype=pl.String),
3623
pl.Series("apple", ["x"], dtype=pl.String),
3624
pl.Series("bar_right", ["a"], dtype=pl.String),
3625
pl.Series("foo2", [1], dtype=pl.Int64),
3626
]
3627
)
3628
3629
plan = q.explain()
3630
assert "FULL JOIN" not in plan
3631
assert plan.startswith("INNER JOIN")
3632
3633
extract = _extract_plan_joins_and_filters(plan)
3634
3635
assert extract[0] == 'LEFT PLAN ON: [col("ham")]'
3636
assert '(col("foo")) <= (2)' in extract[1]
3637
assert 'col("ham")) == ("a")' in extract[1]
3638
3639
assert extract[2] == 'RIGHT PLAN ON: [col("ham")]'
3640
assert 'col("ham")) == ("a")' in extract[3]
3641
assert 'col("apple")) == ("x")' in extract[3]
3642
3643
assert len(extract) == 4
3644
3645
assert_frame_equal(q.collect(), expect)
3646
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3647
3648
3649
def test_join_rewrite_panic_23307() -> None:
3650
lhs = pl.select(a=pl.lit(1, dtype=pl.Int8)).lazy()
3651
rhs = pl.select(a=pl.lit(1, dtype=pl.Int16), x=pl.lit(1, dtype=pl.Int32)).lazy()
3652
3653
q = lhs.join(rhs, on="a", how="left", coalesce=True).filter(pl.col("x") >= 1)
3654
3655
assert_frame_equal(
3656
q.collect(),
3657
pl.select(
3658
a=pl.lit(1, dtype=pl.Int8),
3659
x=pl.lit(1, dtype=pl.Int32),
3660
),
3661
)
3662
3663
lhs = pl.select(a=pl.lit(999, dtype=pl.Int16)).lazy()
3664
3665
# Note: -25 matches to (999).overflowing_cast(Int8).
3666
# This is specially chosen to test that we don't accidentally push the filter
3667
# to the RHS.
3668
rhs = pl.LazyFrame(
3669
{"a": [1, -25], "x": [1, 2]}, schema={"a": pl.Int8, "x": pl.Int32}
3670
)
3671
3672
q = lhs.join(
3673
rhs,
3674
on=pl.col("a").cast(pl.Int8, strict=False, wrap_numerical=True),
3675
how="left",
3676
coalesce=False,
3677
).filter(pl.col("a") >= 0)
3678
3679
expect = pl.DataFrame(
3680
{"a": 999, "a_right": -25, "x": 2},
3681
schema={"a": pl.Int16, "a_right": pl.Int8, "x": pl.Int32},
3682
)
3683
3684
plan = q.explain()
3685
3686
assert not plan.startswith("FILTER")
3687
3688
assert_frame_equal(q.collect(), expect)
3689
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
3690
3691
3692
@pytest.mark.parametrize(
3693
("expr_first_input", "expr_func"),
3694
[
3695
(pl.lit(None, dtype=pl.Int64), lambda col: col >= 1),
3696
(pl.lit(None, dtype=pl.Int64), lambda col: (col >= 1).is_not_null()),
3697
(pl.lit(None, dtype=pl.Int64), lambda col: (~(col >= 1)).is_not_null()),
3698
(pl.lit(None, dtype=pl.Int64), lambda col: ~(col >= 1).is_null()),
3699
#
3700
(pl.lit(None, dtype=pl.Int64), lambda col: col.is_in([1])),
3701
(pl.lit(None, dtype=pl.Int64), lambda col: ~col.is_in([1])),
3702
#
3703
(pl.lit(None, dtype=pl.Int64), lambda col: col.is_between(1, 1)),
3704
(1, lambda col: col.is_between(None, 1)),
3705
(1, lambda col: col.is_between(1, None)),
3706
#
3707
(pl.lit(None, dtype=pl.Int64), lambda col: col.is_close(1)),
3708
(1, lambda col: col.is_close(pl.lit(None, dtype=pl.Int64))),
3709
#
3710
(pl.lit(None, dtype=pl.Int64), lambda col: col.is_nan()),
3711
(pl.lit(None, dtype=pl.Int64), lambda col: col.is_not_nan()),
3712
(pl.lit(None, dtype=pl.Int64), lambda col: col.is_finite()),
3713
(pl.lit(None, dtype=pl.Int64), lambda col: col.is_infinite()),
3714
#
3715
(pl.lit(None, dtype=pl.Float64), lambda col: col.is_nan()),
3716
(pl.lit(None, dtype=pl.Float64), lambda col: col.is_not_nan()),
3717
(pl.lit(None, dtype=pl.Float64), lambda col: col.is_finite()),
3718
(pl.lit(None, dtype=pl.Float64), lambda col: col.is_infinite()),
3719
],
3720
)
3721
def test_join_rewrite_null_preserving_exprs(
3722
expr_first_input: Any, expr_func: Callable[[pl.Expr], pl.Expr]
3723
) -> None:
3724
lhs = pl.LazyFrame({"a": 1})
3725
rhs = pl.select(a=1, x=expr_first_input).lazy()
3726
3727
assert (
3728
pl.select(expr_first_input)
3729
.select(expr_func(pl.first()))
3730
.select(pl.first().is_null() | ~pl.first())
3731
.to_series()
3732
.item()
3733
)
3734
3735
q = lhs.join(rhs, on="a", how="left", maintain_order="left_right").filter(
3736
expr_func(pl.col("x"))
3737
)
3738
3739
plan = q.explain()
3740
assert plan.startswith("INNER JOIN")
3741
3742
out = q.collect()
3743
3744
assert out.height == 0
3745
assert_frame_equal(out, q.collect(optimizations=pl.QueryOptFlags.none()))
3746
3747
3748
@pytest.mark.parametrize(
3749
("expr_first_input", "expr_func"),
3750
[
3751
(
3752
pl.lit(None, dtype=pl.Int64),
3753
lambda x: ~(x.is_in([1, None], nulls_equal=True)),
3754
),
3755
(
3756
pl.lit(None, dtype=pl.Int64),
3757
lambda x: x.is_in([1, None], nulls_equal=True) > True,
3758
),
3759
(
3760
pl.lit(None, dtype=pl.Int64),
3761
lambda x: x.is_in([1], nulls_equal=True),
3762
),
3763
],
3764
)
3765
def test_join_rewrite_forbid_exprs(
3766
expr_first_input: Any, expr_func: Callable[[pl.Expr], pl.Expr]
3767
) -> None:
3768
lhs = pl.LazyFrame({"a": 1})
3769
rhs = pl.select(a=1, x=expr_first_input).lazy()
3770
3771
q = lhs.join(rhs, on="a", how="left", maintain_order="left_right").filter(
3772
expr_func(pl.col("x"))
3773
)
3774
3775
plan = q.explain()
3776
assert plan.startswith("FILTER")
3777
3778
assert_frame_equal(q.collect(), q.collect(optimizations=pl.QueryOptFlags.none()))
3779
3780
3781
def test_join_coalesce_column_order_23177() -> None:
3782
df1 = pl.DataFrame({"time": ["09:00:21"], "symbol": [5253]})
3783
df2 = pl.DataFrame({"symbol": [5253], "time": ["09:00:21"]})
3784
3785
q = df1.lazy().join(df2.lazy(), on=["time", "symbol"], how="full", coalesce=True)
3786
3787
expect = pl.DataFrame({"time": ["09:00:21"], "symbol": [5253]})
3788
3789
assert_frame_equal(q.collect(), expect)
3790
3791
3792
def test_join_filter_pushdown_iejoin_cse_23469() -> None:
3793
lf_x = pl.LazyFrame({"x": [1, 2, 3]})
3794
lf_y = pl.LazyFrame({"y": [1, 2, 3]})
3795
3796
lf_xy = lf_x.join(lf_y, how="cross").filter(pl.col("x") > pl.col("y"))
3797
3798
q = pl.concat([lf_xy, lf_xy])
3799
3800
assert_frame_equal(
3801
q.collect().sort(pl.all()),
3802
pl.DataFrame(
3803
{
3804
"x": [2, 2, 3, 3, 3, 3],
3805
"y": [1, 1, 1, 1, 2, 2],
3806
},
3807
),
3808
)
3809
3810
q = pl.concat([lf_xy, lf_xy]).filter(pl.col("x") > pl.col("y"))
3811
3812
assert_frame_equal(
3813
q.collect().sort(pl.all()),
3814
pl.DataFrame(
3815
{
3816
"x": [2, 2, 3, 3, 3, 3],
3817
"y": [1, 1, 1, 1, 2, 2],
3818
},
3819
),
3820
)
3821
3822
q = (
3823
lf_x.join_where(lf_y, pl.col("x") == pl.col("y"))
3824
.cache()
3825
.filter(pl.col("x") >= 0)
3826
)
3827
3828
assert_frame_equal(
3829
q.collect().sort(pl.all()), pl.DataFrame({"x": [1, 2, 3], "y": [1, 2, 3]})
3830
)
3831
3832
3833
def test_join_cast_type_coercion_23236() -> None:
3834
lhs = pl.LazyFrame([{"name": "a"}]).rename({"name": "newname"})
3835
rhs = pl.LazyFrame([{"name": "a"}])
3836
3837
q = lhs.join(rhs, left_on=pl.col("newname").cast(pl.String), right_on="name")
3838
3839
assert_frame_equal(q.collect(), pl.DataFrame({"newname": "a", "name": "a"}))
3840
3841
3842
@pytest.mark.parametrize(
3843
("how", "expected"),
3844
[
3845
(
3846
"inner",
3847
pl.DataFrame(schema={"a": pl.Int128, "a_right": pl.Int128}),
3848
),
3849
(
3850
"left",
3851
pl.DataFrame(
3852
{"a": [1, 1, 2], "a_right": None},
3853
schema={"a": pl.Int128, "a_right": pl.Int128},
3854
),
3855
),
3856
(
3857
"right",
3858
pl.DataFrame(
3859
{
3860
"a": None,
3861
"a_right": [
3862
-9223372036854775808,
3863
-9223372036854775807,
3864
-9223372036854775806,
3865
],
3866
},
3867
schema={"a": pl.Int128, "a_right": pl.Int128},
3868
),
3869
),
3870
(
3871
"full",
3872
pl.DataFrame(
3873
[
3874
pl.Series("a", [None, None, None, 1, 1, 2], dtype=pl.Int128),
3875
pl.Series(
3876
"a_right",
3877
[
3878
-9223372036854775808,
3879
-9223372036854775807,
3880
-9223372036854775806,
3881
None,
3882
None,
3883
None,
3884
],
3885
dtype=pl.Int128,
3886
),
3887
]
3888
),
3889
),
3890
(
3891
"semi",
3892
pl.DataFrame([pl.Series("a", [], dtype=pl.Int128)]),
3893
),
3894
(
3895
"anti",
3896
pl.DataFrame([pl.Series("a", [1, 1, 2], dtype=pl.Int128)]),
3897
),
3898
],
3899
)
3900
@pytest.mark.parametrize(
3901
("sort_left", "sort_right"),
3902
[(True, True), (True, False), (False, True), (False, False)],
3903
)
3904
def test_join_i128_23688(
3905
how: str, expected: pl.DataFrame, sort_left: bool, sort_right: bool
3906
) -> None:
3907
lhs = pl.LazyFrame({"a": pl.Series([1, 1, 2], dtype=pl.Int128)})
3908
3909
rhs = pl.LazyFrame(
3910
{
3911
"a": pl.Series(
3912
[
3913
-9223372036854775808,
3914
-9223372036854775807,
3915
-9223372036854775806,
3916
],
3917
dtype=pl.Int128,
3918
)
3919
}
3920
)
3921
3922
lhs = lhs.collect().sort("a").lazy() if sort_left else lhs
3923
rhs = rhs.collect().sort("a").lazy() if sort_right else rhs
3924
3925
q = lhs.join(rhs, on="a", how=how, coalesce=False) # type: ignore[arg-type]
3926
3927
assert_frame_equal(
3928
q.collect().sort(pl.all()),
3929
expected,
3930
)
3931
3932
q = (
3933
lhs.with_columns(b=pl.col("a"))
3934
.join(
3935
rhs.with_columns(b=pl.col("a")),
3936
on=["a", "b"],
3937
how=how, # type: ignore[arg-type]
3938
coalesce=False,
3939
)
3940
.select(expected.columns)
3941
)
3942
3943
assert_frame_equal(
3944
q.collect().sort(pl.all()),
3945
expected,
3946
)
3947
3948
3949
def test_join_asof_by_i128() -> None:
3950
lhs = pl.LazyFrame({"a": pl.Series([1, 1, 2], dtype=pl.Int128), "i": 1})
3951
3952
rhs = pl.LazyFrame(
3953
{
3954
"a": pl.Series(
3955
[
3956
-9223372036854775808,
3957
-9223372036854775807,
3958
-9223372036854775806,
3959
],
3960
dtype=pl.Int128,
3961
),
3962
"i": 1,
3963
}
3964
).with_columns(b=pl.col("a"))
3965
3966
q = lhs.join_asof(rhs, on="i", by="a")
3967
3968
assert_frame_equal(
3969
q.collect().sort(pl.all()),
3970
pl.DataFrame(
3971
{"a": [1, 1, 2], "i": 1, "b": None},
3972
schema={"a": pl.Int128, "i": pl.Int32, "b": pl.Int128},
3973
),
3974
)
3975
3976
3977
def test_join_lazyframe_with_itself_after_sort_25395() -> None:
3978
lf = pl.LazyFrame({"a": [1]})
3979
result = lf.sort("a").join(lf, on="a").collect()
3980
3981
assert_frame_equal(result, pl.DataFrame({"a": [1]}))
3982
3983
3984
def test_join_right_with_cast_predicate_pushdown() -> None:
3985
lhs = pl.LazyFrame({"x": [0, 1], "z": [4, 5]})
3986
rhs = pl.LazyFrame({"y": [2, 3]}).cast(pl.Int32)
3987
3988
out = (
3989
lhs.join(rhs, left_on="x", right_on="y", how="right")
3990
.filter(pl.col("z") >= 6)
3991
.collect()
3992
)
3993
3994
ret = pl.DataFrame(
3995
{
3996
"z": [],
3997
"y": [],
3998
},
3999
schema={"z": pl.Int64, "y": pl.Int64},
4000
)
4001
assert_frame_equal(out, ret, check_column_order=True, check_row_order=False)
4002
4003
4004
def test_full_join_rewrite_to_right_with_cast() -> None:
4005
lhs = pl.LazyFrame({"x": [0, 1], "a": [10, 20]})
4006
rhs = pl.LazyFrame({"y": [2, 3], "b": [30, 40]}).cast(pl.Int32)
4007
4008
out = (
4009
lhs.join(rhs, left_on="x", right_on="y", how="full")
4010
.filter(pl.col("b") >= 0)
4011
.collect()
4012
)
4013
4014
ret = pl.DataFrame(
4015
{
4016
"x": [None, None],
4017
"a": [None, None],
4018
"y": [2, 3],
4019
"b": [30, 40],
4020
},
4021
schema={
4022
"x": pl.Int64,
4023
"a": pl.Int64,
4024
"y": pl.Int32,
4025
"b": pl.Int32,
4026
},
4027
)
4028
assert_frame_equal(out, ret, check_column_order=True, check_row_order=False)
4029
4030